diff --git a/.gitignore b/.gitignore index 5a611909..390c8a03 100644 --- a/.gitignore +++ b/.gitignore @@ -116,3 +116,4 @@ openspec/ docs/ code-reviews/ AGENTS.md +backend/cmd/server/server diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 664e7aca..ebbaa172 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -70,7 +70,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) redeemCache := repository.NewRedeemCache(redisClient) - redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService) + redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client) redeemHandler := handler.NewRedeemHandler(redeemService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) dashboardService := service.NewDashboardService(usageLogRepository) diff --git a/backend/ent/account.go b/backend/ent/account.go index 59f55edb..82867111 100644 --- a/backend/ent/account.go +++ b/backend/ent/account.go @@ -11,6 +11,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/Wei-Shaw/sub2api/ent/account" + "github.com/Wei-Shaw/sub2api/ent/proxy" ) // Account is the model entity for the Account schema. @@ -70,11 +71,15 @@ type Account struct { type AccountEdges struct { // Groups holds the value of the groups edge. Groups []*Group `json:"groups,omitempty"` + // Proxy holds the value of the proxy edge. + Proxy *Proxy `json:"proxy,omitempty"` + // UsageLogs holds the value of the usage_logs edge. + UsageLogs []*UsageLog `json:"usage_logs,omitempty"` // AccountGroups holds the value of the account_groups edge. AccountGroups []*AccountGroup `json:"account_groups,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [2]bool + loadedTypes [4]bool } // GroupsOrErr returns the Groups value or an error if the edge @@ -86,10 +91,30 @@ func (e AccountEdges) GroupsOrErr() ([]*Group, error) { return nil, &NotLoadedError{edge: "groups"} } +// ProxyOrErr returns the Proxy value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e AccountEdges) ProxyOrErr() (*Proxy, error) { + if e.Proxy != nil { + return e.Proxy, nil + } else if e.loadedTypes[1] { + return nil, &NotFoundError{label: proxy.Label} + } + return nil, &NotLoadedError{edge: "proxy"} +} + +// UsageLogsOrErr returns the UsageLogs value or an error if the edge +// was not loaded in eager-loading. +func (e AccountEdges) UsageLogsOrErr() ([]*UsageLog, error) { + if e.loadedTypes[2] { + return e.UsageLogs, nil + } + return nil, &NotLoadedError{edge: "usage_logs"} +} + // AccountGroupsOrErr returns the AccountGroups value or an error if the edge // was not loaded in eager-loading. func (e AccountEdges) AccountGroupsOrErr() ([]*AccountGroup, error) { - if e.loadedTypes[1] { + if e.loadedTypes[3] { return e.AccountGroups, nil } return nil, &NotLoadedError{edge: "account_groups"} @@ -289,6 +314,16 @@ func (_m *Account) QueryGroups() *GroupQuery { return NewAccountClient(_m.config).QueryGroups(_m) } +// QueryProxy queries the "proxy" edge of the Account entity. +func (_m *Account) QueryProxy() *ProxyQuery { + return NewAccountClient(_m.config).QueryProxy(_m) +} + +// QueryUsageLogs queries the "usage_logs" edge of the Account entity. +func (_m *Account) QueryUsageLogs() *UsageLogQuery { + return NewAccountClient(_m.config).QueryUsageLogs(_m) +} + // QueryAccountGroups queries the "account_groups" edge of the Account entity. func (_m *Account) QueryAccountGroups() *AccountGroupQuery { return NewAccountClient(_m.config).QueryAccountGroups(_m) diff --git a/backend/ent/account/account.go b/backend/ent/account/account.go index 65a130fd..c48db1e3 100644 --- a/backend/ent/account/account.go +++ b/backend/ent/account/account.go @@ -59,6 +59,10 @@ const ( FieldSessionWindowStatus = "session_window_status" // EdgeGroups holds the string denoting the groups edge name in mutations. EdgeGroups = "groups" + // EdgeProxy holds the string denoting the proxy edge name in mutations. + EdgeProxy = "proxy" + // EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations. + EdgeUsageLogs = "usage_logs" // EdgeAccountGroups holds the string denoting the account_groups edge name in mutations. EdgeAccountGroups = "account_groups" // Table holds the table name of the account in the database. @@ -68,6 +72,20 @@ const ( // GroupsInverseTable is the table name for the Group entity. // It exists in this package in order to avoid circular dependency with the "group" package. GroupsInverseTable = "groups" + // ProxyTable is the table that holds the proxy relation/edge. + ProxyTable = "accounts" + // ProxyInverseTable is the table name for the Proxy entity. + // It exists in this package in order to avoid circular dependency with the "proxy" package. + ProxyInverseTable = "proxies" + // ProxyColumn is the table column denoting the proxy relation/edge. + ProxyColumn = "proxy_id" + // UsageLogsTable is the table that holds the usage_logs relation/edge. + UsageLogsTable = "usage_logs" + // UsageLogsInverseTable is the table name for the UsageLog entity. + // It exists in this package in order to avoid circular dependency with the "usagelog" package. + UsageLogsInverseTable = "usage_logs" + // UsageLogsColumn is the table column denoting the usage_logs relation/edge. + UsageLogsColumn = "account_id" // AccountGroupsTable is the table that holds the account_groups relation/edge. AccountGroupsTable = "account_groups" // AccountGroupsInverseTable is the table name for the AccountGroup entity. @@ -274,6 +292,27 @@ func ByGroups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { } } +// ByProxyField orders the results by proxy field. +func ByProxyField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newProxyStep(), sql.OrderByField(field, opts...)) + } +} + +// ByUsageLogsCount orders the results by usage_logs count. +func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...) + } +} + +// ByUsageLogs orders the results by usage_logs terms. +func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + // ByAccountGroupsCount orders the results by account_groups count. func ByAccountGroupsCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -294,6 +333,20 @@ func newGroupsStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.M2M, false, GroupsTable, GroupsPrimaryKey...), ) } +func newProxyStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ProxyInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, ProxyTable, ProxyColumn), + ) +} +func newUsageLogsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UsageLogsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) +} func newAccountGroupsStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), diff --git a/backend/ent/account/where.go b/backend/ent/account/where.go index f54f538f..b79b5f8b 100644 --- a/backend/ent/account/where.go +++ b/backend/ent/account/where.go @@ -495,26 +495,6 @@ func ProxyIDNotIn(vs ...int64) predicate.Account { return predicate.Account(sql.FieldNotIn(FieldProxyID, vs...)) } -// ProxyIDGT applies the GT predicate on the "proxy_id" field. -func ProxyIDGT(v int64) predicate.Account { - return predicate.Account(sql.FieldGT(FieldProxyID, v)) -} - -// ProxyIDGTE applies the GTE predicate on the "proxy_id" field. -func ProxyIDGTE(v int64) predicate.Account { - return predicate.Account(sql.FieldGTE(FieldProxyID, v)) -} - -// ProxyIDLT applies the LT predicate on the "proxy_id" field. -func ProxyIDLT(v int64) predicate.Account { - return predicate.Account(sql.FieldLT(FieldProxyID, v)) -} - -// ProxyIDLTE applies the LTE predicate on the "proxy_id" field. -func ProxyIDLTE(v int64) predicate.Account { - return predicate.Account(sql.FieldLTE(FieldProxyID, v)) -} - // ProxyIDIsNil applies the IsNil predicate on the "proxy_id" field. func ProxyIDIsNil() predicate.Account { return predicate.Account(sql.FieldIsNull(FieldProxyID)) @@ -1153,6 +1133,52 @@ func HasGroupsWith(preds ...predicate.Group) predicate.Account { }) } +// HasProxy applies the HasEdge predicate on the "proxy" edge. +func HasProxy() predicate.Account { + return predicate.Account(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, ProxyTable, ProxyColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasProxyWith applies the HasEdge predicate on the "proxy" edge with a given conditions (other predicates). +func HasProxyWith(preds ...predicate.Proxy) predicate.Account { + return predicate.Account(func(s *sql.Selector) { + step := newProxyStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge. +func HasUsageLogs() predicate.Account { + return predicate.Account(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates). +func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.Account { + return predicate.Account(func(s *sql.Selector) { + step := newUsageLogsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // HasAccountGroups applies the HasEdge predicate on the "account_groups" edge. func HasAccountGroups() predicate.Account { return predicate.Account(func(s *sql.Selector) { diff --git a/backend/ent/account_create.go b/backend/ent/account_create.go index 6d813817..2fb52a81 100644 --- a/backend/ent/account_create.go +++ b/backend/ent/account_create.go @@ -13,6 +13,8 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/account" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/proxy" + "github.com/Wei-Shaw/sub2api/ent/usagelog" ) // AccountCreate is the builder for creating a Account entity. @@ -292,6 +294,26 @@ func (_c *AccountCreate) AddGroups(v ...*Group) *AccountCreate { return _c.AddGroupIDs(ids...) } +// SetProxy sets the "proxy" edge to the Proxy entity. +func (_c *AccountCreate) SetProxy(v *Proxy) *AccountCreate { + return _c.SetProxyID(v.ID) +} + +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_c *AccountCreate) AddUsageLogIDs(ids ...int64) *AccountCreate { + _c.mutation.AddUsageLogIDs(ids...) + return _c +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_c *AccountCreate) AddUsageLogs(v ...*UsageLog) *AccountCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddUsageLogIDs(ids...) +} + // Mutation returns the AccountMutation object of the builder. func (_c *AccountCreate) Mutation() *AccountMutation { return _c.mutation @@ -495,10 +517,6 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) { _spec.SetField(account.FieldExtra, field.TypeJSON, value) _node.Extra = value } - if value, ok := _c.mutation.ProxyID(); ok { - _spec.SetField(account.FieldProxyID, field.TypeInt64, value) - _node.ProxyID = &value - } if value, ok := _c.mutation.Concurrency(); ok { _spec.SetField(account.FieldConcurrency, field.TypeInt, value) _node.Concurrency = value @@ -567,6 +585,39 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) { edge.Target.Fields = specE.Fields _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.ProxyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: account.ProxyTable, + Columns: []string{account.ProxyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(proxy.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.ProxyID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: account.UsageLogsTable, + Columns: []string{account.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } @@ -721,12 +772,6 @@ func (u *AccountUpsert) UpdateProxyID() *AccountUpsert { return u } -// AddProxyID adds v to the "proxy_id" field. -func (u *AccountUpsert) AddProxyID(v int64) *AccountUpsert { - u.Add(account.FieldProxyID, v) - return u -} - // ClearProxyID clears the value of the "proxy_id" field. func (u *AccountUpsert) ClearProxyID() *AccountUpsert { u.SetNull(account.FieldProxyID) @@ -1094,13 +1139,6 @@ func (u *AccountUpsertOne) SetProxyID(v int64) *AccountUpsertOne { }) } -// AddProxyID adds v to the "proxy_id" field. -func (u *AccountUpsertOne) AddProxyID(v int64) *AccountUpsertOne { - return u.Update(func(s *AccountUpsert) { - s.AddProxyID(v) - }) -} - // UpdateProxyID sets the "proxy_id" field to the value that was provided on create. func (u *AccountUpsertOne) UpdateProxyID() *AccountUpsertOne { return u.Update(func(s *AccountUpsert) { @@ -1676,13 +1714,6 @@ func (u *AccountUpsertBulk) SetProxyID(v int64) *AccountUpsertBulk { }) } -// AddProxyID adds v to the "proxy_id" field. -func (u *AccountUpsertBulk) AddProxyID(v int64) *AccountUpsertBulk { - return u.Update(func(s *AccountUpsert) { - s.AddProxyID(v) - }) -} - // UpdateProxyID sets the "proxy_id" field to the value that was provided on create. func (u *AccountUpsertBulk) UpdateProxyID() *AccountUpsertBulk { return u.Update(func(s *AccountUpsert) { diff --git a/backend/ent/account_query.go b/backend/ent/account_query.go index e5712884..3e363ecd 100644 --- a/backend/ent/account_query.go +++ b/backend/ent/account_query.go @@ -16,6 +16,8 @@ import ( "github.com/Wei-Shaw/sub2api/ent/accountgroup" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/proxy" + "github.com/Wei-Shaw/sub2api/ent/usagelog" ) // AccountQuery is the builder for querying Account entities. @@ -26,6 +28,8 @@ type AccountQuery struct { inters []Interceptor predicates []predicate.Account withGroups *GroupQuery + withProxy *ProxyQuery + withUsageLogs *UsageLogQuery withAccountGroups *AccountGroupQuery // intermediate query (i.e. traversal path). sql *sql.Selector @@ -85,6 +89,50 @@ func (_q *AccountQuery) QueryGroups() *GroupQuery { return query } +// QueryProxy chains the current query on the "proxy" edge. +func (_q *AccountQuery) QueryProxy() *ProxyQuery { + query := (&ProxyClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(account.Table, account.FieldID, selector), + sqlgraph.To(proxy.Table, proxy.FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, account.ProxyTable, account.ProxyColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryUsageLogs chains the current query on the "usage_logs" edge. +func (_q *AccountQuery) QueryUsageLogs() *UsageLogQuery { + query := (&UsageLogClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(account.Table, account.FieldID, selector), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, account.UsageLogsTable, account.UsageLogsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // QueryAccountGroups chains the current query on the "account_groups" edge. func (_q *AccountQuery) QueryAccountGroups() *AccountGroupQuery { query := (&AccountGroupClient{config: _q.config}).Query() @@ -300,6 +348,8 @@ func (_q *AccountQuery) Clone() *AccountQuery { inters: append([]Interceptor{}, _q.inters...), predicates: append([]predicate.Account{}, _q.predicates...), withGroups: _q.withGroups.Clone(), + withProxy: _q.withProxy.Clone(), + withUsageLogs: _q.withUsageLogs.Clone(), withAccountGroups: _q.withAccountGroups.Clone(), // clone intermediate query. sql: _q.sql.Clone(), @@ -318,6 +368,28 @@ func (_q *AccountQuery) WithGroups(opts ...func(*GroupQuery)) *AccountQuery { return _q } +// WithProxy tells the query-builder to eager-load the nodes that are connected to +// the "proxy" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AccountQuery) WithProxy(opts ...func(*ProxyQuery)) *AccountQuery { + query := (&ProxyClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withProxy = query + return _q +} + +// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to +// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AccountQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *AccountQuery { + query := (&UsageLogClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUsageLogs = query + return _q +} + // WithAccountGroups tells the query-builder to eager-load the nodes that are connected to // the "account_groups" edge. The optional arguments are used to configure the query builder of the edge. func (_q *AccountQuery) WithAccountGroups(opts ...func(*AccountGroupQuery)) *AccountQuery { @@ -407,8 +479,10 @@ func (_q *AccountQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Acco var ( nodes = []*Account{} _spec = _q.querySpec() - loadedTypes = [2]bool{ + loadedTypes = [4]bool{ _q.withGroups != nil, + _q.withProxy != nil, + _q.withUsageLogs != nil, _q.withAccountGroups != nil, } ) @@ -437,6 +511,19 @@ func (_q *AccountQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Acco return nil, err } } + if query := _q.withProxy; query != nil { + if err := _q.loadProxy(ctx, query, nodes, nil, + func(n *Account, e *Proxy) { n.Edges.Proxy = e }); err != nil { + return nil, err + } + } + if query := _q.withUsageLogs; query != nil { + if err := _q.loadUsageLogs(ctx, query, nodes, + func(n *Account) { n.Edges.UsageLogs = []*UsageLog{} }, + func(n *Account, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil { + return nil, err + } + } if query := _q.withAccountGroups; query != nil { if err := _q.loadAccountGroups(ctx, query, nodes, func(n *Account) { n.Edges.AccountGroups = []*AccountGroup{} }, @@ -508,6 +595,68 @@ func (_q *AccountQuery) loadGroups(ctx context.Context, query *GroupQuery, nodes } return nil } +func (_q *AccountQuery) loadProxy(ctx context.Context, query *ProxyQuery, nodes []*Account, init func(*Account), assign func(*Account, *Proxy)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*Account) + for i := range nodes { + if nodes[i].ProxyID == nil { + continue + } + fk := *nodes[i].ProxyID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(proxy.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "proxy_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *AccountQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*Account, init func(*Account), assign func(*Account, *UsageLog)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*Account) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(usagelog.FieldAccountID) + } + query.Where(predicate.UsageLog(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(account.UsageLogsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.AccountID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "account_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *AccountQuery) loadAccountGroups(ctx context.Context, query *AccountGroupQuery, nodes []*Account, init func(*Account), assign func(*Account, *AccountGroup)) error { fks := make([]driver.Value, 0, len(nodes)) nodeids := make(map[int64]*Account) @@ -564,6 +713,9 @@ func (_q *AccountQuery) querySpec() *sqlgraph.QuerySpec { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) } } + if _q.withProxy != nil { + _spec.Node.AddColumnOnce(account.FieldProxyID) + } } if ps := _q.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { diff --git a/backend/ent/account_update.go b/backend/ent/account_update.go index 49eaaea8..cf8708c5 100644 --- a/backend/ent/account_update.go +++ b/backend/ent/account_update.go @@ -14,6 +14,8 @@ import ( "github.com/Wei-Shaw/sub2api/ent/account" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/proxy" + "github.com/Wei-Shaw/sub2api/ent/usagelog" ) // AccountUpdate is the builder for updating Account entities. @@ -111,7 +113,6 @@ func (_u *AccountUpdate) SetExtra(v map[string]interface{}) *AccountUpdate { // SetProxyID sets the "proxy_id" field. func (_u *AccountUpdate) SetProxyID(v int64) *AccountUpdate { - _u.mutation.ResetProxyID() _u.mutation.SetProxyID(v) return _u } @@ -124,12 +125,6 @@ func (_u *AccountUpdate) SetNillableProxyID(v *int64) *AccountUpdate { return _u } -// AddProxyID adds value to the "proxy_id" field. -func (_u *AccountUpdate) AddProxyID(v int64) *AccountUpdate { - _u.mutation.AddProxyID(v) - return _u -} - // ClearProxyID clears the value of the "proxy_id" field. func (_u *AccountUpdate) ClearProxyID() *AccountUpdate { _u.mutation.ClearProxyID() @@ -381,6 +376,26 @@ func (_u *AccountUpdate) AddGroups(v ...*Group) *AccountUpdate { return _u.AddGroupIDs(ids...) } +// SetProxy sets the "proxy" edge to the Proxy entity. +func (_u *AccountUpdate) SetProxy(v *Proxy) *AccountUpdate { + return _u.SetProxyID(v.ID) +} + +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *AccountUpdate) AddUsageLogIDs(ids ...int64) *AccountUpdate { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *AccountUpdate) AddUsageLogs(v ...*UsageLog) *AccountUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // Mutation returns the AccountMutation object of the builder. func (_u *AccountUpdate) Mutation() *AccountMutation { return _u.mutation @@ -407,6 +422,33 @@ func (_u *AccountUpdate) RemoveGroups(v ...*Group) *AccountUpdate { return _u.RemoveGroupIDs(ids...) } +// ClearProxy clears the "proxy" edge to the Proxy entity. +func (_u *AccountUpdate) ClearProxy() *AccountUpdate { + _u.mutation.ClearProxy() + return _u +} + +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *AccountUpdate) ClearUsageLogs() *AccountUpdate { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *AccountUpdate) RemoveUsageLogIDs(ids ...int64) *AccountUpdate { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *AccountUpdate) RemoveUsageLogs(v ...*UsageLog) *AccountUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *AccountUpdate) Save(ctx context.Context) (int, error) { if err := _u.defaults(); err != nil { @@ -515,15 +557,6 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.Extra(); ok { _spec.SetField(account.FieldExtra, field.TypeJSON, value) } - if value, ok := _u.mutation.ProxyID(); ok { - _spec.SetField(account.FieldProxyID, field.TypeInt64, value) - } - if value, ok := _u.mutation.AddedProxyID(); ok { - _spec.AddField(account.FieldProxyID, field.TypeInt64, value) - } - if _u.mutation.ProxyIDCleared() { - _spec.ClearField(account.FieldProxyID, field.TypeInt64) - } if value, ok := _u.mutation.Concurrency(); ok { _spec.SetField(account.FieldConcurrency, field.TypeInt, value) } @@ -647,6 +680,80 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) { edge.Target.Fields = specE.Fields _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.ProxyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: account.ProxyTable, + Columns: []string{account.ProxyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(proxy.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.ProxyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: account.ProxyTable, + Columns: []string{account.ProxyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(proxy.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: account.UsageLogsTable, + Columns: []string{account.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: account.UsageLogsTable, + Columns: []string{account.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: account.UsageLogsTable, + Columns: []string{account.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{account.Label} @@ -749,7 +856,6 @@ func (_u *AccountUpdateOne) SetExtra(v map[string]interface{}) *AccountUpdateOne // SetProxyID sets the "proxy_id" field. func (_u *AccountUpdateOne) SetProxyID(v int64) *AccountUpdateOne { - _u.mutation.ResetProxyID() _u.mutation.SetProxyID(v) return _u } @@ -762,12 +868,6 @@ func (_u *AccountUpdateOne) SetNillableProxyID(v *int64) *AccountUpdateOne { return _u } -// AddProxyID adds value to the "proxy_id" field. -func (_u *AccountUpdateOne) AddProxyID(v int64) *AccountUpdateOne { - _u.mutation.AddProxyID(v) - return _u -} - // ClearProxyID clears the value of the "proxy_id" field. func (_u *AccountUpdateOne) ClearProxyID() *AccountUpdateOne { _u.mutation.ClearProxyID() @@ -1019,6 +1119,26 @@ func (_u *AccountUpdateOne) AddGroups(v ...*Group) *AccountUpdateOne { return _u.AddGroupIDs(ids...) } +// SetProxy sets the "proxy" edge to the Proxy entity. +func (_u *AccountUpdateOne) SetProxy(v *Proxy) *AccountUpdateOne { + return _u.SetProxyID(v.ID) +} + +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *AccountUpdateOne) AddUsageLogIDs(ids ...int64) *AccountUpdateOne { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *AccountUpdateOne) AddUsageLogs(v ...*UsageLog) *AccountUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // Mutation returns the AccountMutation object of the builder. func (_u *AccountUpdateOne) Mutation() *AccountMutation { return _u.mutation @@ -1045,6 +1165,33 @@ func (_u *AccountUpdateOne) RemoveGroups(v ...*Group) *AccountUpdateOne { return _u.RemoveGroupIDs(ids...) } +// ClearProxy clears the "proxy" edge to the Proxy entity. +func (_u *AccountUpdateOne) ClearProxy() *AccountUpdateOne { + _u.mutation.ClearProxy() + return _u +} + +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *AccountUpdateOne) ClearUsageLogs() *AccountUpdateOne { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *AccountUpdateOne) RemoveUsageLogIDs(ids ...int64) *AccountUpdateOne { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *AccountUpdateOne) RemoveUsageLogs(v ...*UsageLog) *AccountUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // Where appends a list predicates to the AccountUpdate builder. func (_u *AccountUpdateOne) Where(ps ...predicate.Account) *AccountUpdateOne { _u.mutation.Where(ps...) @@ -1183,15 +1330,6 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er if value, ok := _u.mutation.Extra(); ok { _spec.SetField(account.FieldExtra, field.TypeJSON, value) } - if value, ok := _u.mutation.ProxyID(); ok { - _spec.SetField(account.FieldProxyID, field.TypeInt64, value) - } - if value, ok := _u.mutation.AddedProxyID(); ok { - _spec.AddField(account.FieldProxyID, field.TypeInt64, value) - } - if _u.mutation.ProxyIDCleared() { - _spec.ClearField(account.FieldProxyID, field.TypeInt64) - } if value, ok := _u.mutation.Concurrency(); ok { _spec.SetField(account.FieldConcurrency, field.TypeInt, value) } @@ -1315,6 +1453,80 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er edge.Target.Fields = specE.Fields _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.ProxyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: account.ProxyTable, + Columns: []string{account.ProxyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(proxy.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.ProxyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: account.ProxyTable, + Columns: []string{account.ProxyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(proxy.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: account.UsageLogsTable, + Columns: []string{account.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: account.UsageLogsTable, + Columns: []string{account.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: account.UsageLogsTable, + Columns: []string{account.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &Account{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/ent/apikey.go b/backend/ent/apikey.go index 30cf9b4d..61ac15fa 100644 --- a/backend/ent/apikey.go +++ b/backend/ent/apikey.go @@ -47,9 +47,11 @@ type ApiKeyEdges struct { User *User `json:"user,omitempty"` // Group holds the value of the group edge. Group *Group `json:"group,omitempty"` + // UsageLogs holds the value of the usage_logs edge. + UsageLogs []*UsageLog `json:"usage_logs,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [2]bool + loadedTypes [3]bool } // UserOrErr returns the User value or an error if the edge @@ -74,6 +76,15 @@ func (e ApiKeyEdges) GroupOrErr() (*Group, error) { return nil, &NotLoadedError{edge: "group"} } +// UsageLogsOrErr returns the UsageLogs value or an error if the edge +// was not loaded in eager-loading. +func (e ApiKeyEdges) UsageLogsOrErr() ([]*UsageLog, error) { + if e.loadedTypes[2] { + return e.UsageLogs, nil + } + return nil, &NotLoadedError{edge: "usage_logs"} +} + // scanValues returns the types for scanning values from sql.Rows. func (*ApiKey) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) @@ -179,6 +190,11 @@ func (_m *ApiKey) QueryGroup() *GroupQuery { return NewApiKeyClient(_m.config).QueryGroup(_m) } +// QueryUsageLogs queries the "usage_logs" edge of the ApiKey entity. +func (_m *ApiKey) QueryUsageLogs() *UsageLogQuery { + return NewApiKeyClient(_m.config).QueryUsageLogs(_m) +} + // Update returns a builder for updating this ApiKey. // Note that you need to call ApiKey.Unwrap() before calling this method if this ApiKey // was returned from a transaction, and the transaction was committed or rolled back. diff --git a/backend/ent/apikey/apikey.go b/backend/ent/apikey/apikey.go index 4eba5f53..f03b2daa 100644 --- a/backend/ent/apikey/apikey.go +++ b/backend/ent/apikey/apikey.go @@ -35,6 +35,8 @@ const ( EdgeUser = "user" // EdgeGroup holds the string denoting the group edge name in mutations. EdgeGroup = "group" + // EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations. + EdgeUsageLogs = "usage_logs" // Table holds the table name of the apikey in the database. Table = "api_keys" // UserTable is the table that holds the user relation/edge. @@ -51,6 +53,13 @@ const ( GroupInverseTable = "groups" // GroupColumn is the table column denoting the group relation/edge. GroupColumn = "group_id" + // UsageLogsTable is the table that holds the usage_logs relation/edge. + UsageLogsTable = "usage_logs" + // UsageLogsInverseTable is the table name for the UsageLog entity. + // It exists in this package in order to avoid circular dependency with the "usagelog" package. + UsageLogsInverseTable = "usage_logs" + // UsageLogsColumn is the table column denoting the usage_logs relation/edge. + UsageLogsColumn = "api_key_id" ) // Columns holds all SQL columns for apikey fields. @@ -161,6 +170,20 @@ func ByGroupField(field string, opts ...sql.OrderTermOption) OrderOption { sqlgraph.OrderByNeighborTerms(s, newGroupStep(), sql.OrderByField(field, opts...)) } } + +// ByUsageLogsCount orders the results by usage_logs count. +func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...) + } +} + +// ByUsageLogs orders the results by usage_logs terms. +func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} func newUserStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), @@ -175,3 +198,10 @@ func newGroupStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.M2O, true, GroupTable, GroupColumn), ) } +func newUsageLogsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UsageLogsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) +} diff --git a/backend/ent/apikey/where.go b/backend/ent/apikey/where.go index 11cabd3f..95bc4e2a 100644 --- a/backend/ent/apikey/where.go +++ b/backend/ent/apikey/where.go @@ -516,6 +516,29 @@ func HasGroupWith(preds ...predicate.Group) predicate.ApiKey { }) } +// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge. +func HasUsageLogs() predicate.ApiKey { + return predicate.ApiKey(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates). +func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.ApiKey { + return predicate.ApiKey(func(s *sql.Selector) { + step := newUsageLogsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.ApiKey) predicate.ApiKey { return predicate.ApiKey(sql.AndPredicates(predicates...)) diff --git a/backend/ent/apikey_create.go b/backend/ent/apikey_create.go index 8d7ddb69..5b984b21 100644 --- a/backend/ent/apikey_create.go +++ b/backend/ent/apikey_create.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" ) @@ -122,6 +123,21 @@ func (_c *ApiKeyCreate) SetGroup(v *Group) *ApiKeyCreate { return _c.SetGroupID(v.ID) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_c *ApiKeyCreate) AddUsageLogIDs(ids ...int64) *ApiKeyCreate { + _c.mutation.AddUsageLogIDs(ids...) + return _c +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_c *ApiKeyCreate) AddUsageLogs(v ...*UsageLog) *ApiKeyCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddUsageLogIDs(ids...) +} + // Mutation returns the ApiKeyMutation object of the builder. func (_c *ApiKeyCreate) Mutation() *ApiKeyMutation { return _c.mutation @@ -303,6 +319,22 @@ func (_c *ApiKeyCreate) createSpec() (*ApiKey, *sqlgraph.CreateSpec) { _node.GroupID = &nodes[0] _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: apikey.UsageLogsTable, + Columns: []string{apikey.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } diff --git a/backend/ent/apikey_query.go b/backend/ent/apikey_query.go index 86051a60..d4029feb 100644 --- a/backend/ent/apikey_query.go +++ b/backend/ent/apikey_query.go @@ -4,6 +4,7 @@ package ent import ( "context" + "database/sql/driver" "fmt" "math" @@ -14,18 +15,20 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" ) // ApiKeyQuery is the builder for querying ApiKey entities. type ApiKeyQuery struct { config - ctx *QueryContext - order []apikey.OrderOption - inters []Interceptor - predicates []predicate.ApiKey - withUser *UserQuery - withGroup *GroupQuery + ctx *QueryContext + order []apikey.OrderOption + inters []Interceptor + predicates []predicate.ApiKey + withUser *UserQuery + withGroup *GroupQuery + withUsageLogs *UsageLogQuery // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -106,6 +109,28 @@ func (_q *ApiKeyQuery) QueryGroup() *GroupQuery { return query } +// QueryUsageLogs chains the current query on the "usage_logs" edge. +func (_q *ApiKeyQuery) QueryUsageLogs() *UsageLogQuery { + query := (&UsageLogClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(apikey.Table, apikey.FieldID, selector), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, apikey.UsageLogsTable, apikey.UsageLogsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // First returns the first ApiKey entity from the query. // Returns a *NotFoundError when no ApiKey was found. func (_q *ApiKeyQuery) First(ctx context.Context) (*ApiKey, error) { @@ -293,13 +318,14 @@ func (_q *ApiKeyQuery) Clone() *ApiKeyQuery { return nil } return &ApiKeyQuery{ - config: _q.config, - ctx: _q.ctx.Clone(), - order: append([]apikey.OrderOption{}, _q.order...), - inters: append([]Interceptor{}, _q.inters...), - predicates: append([]predicate.ApiKey{}, _q.predicates...), - withUser: _q.withUser.Clone(), - withGroup: _q.withGroup.Clone(), + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]apikey.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.ApiKey{}, _q.predicates...), + withUser: _q.withUser.Clone(), + withGroup: _q.withGroup.Clone(), + withUsageLogs: _q.withUsageLogs.Clone(), // clone intermediate query. sql: _q.sql.Clone(), path: _q.path, @@ -328,6 +354,17 @@ func (_q *ApiKeyQuery) WithGroup(opts ...func(*GroupQuery)) *ApiKeyQuery { return _q } +// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to +// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *ApiKeyQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *ApiKeyQuery { + query := (&UsageLogClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUsageLogs = query + return _q +} + // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. // @@ -406,9 +443,10 @@ func (_q *ApiKeyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ApiKe var ( nodes = []*ApiKey{} _spec = _q.querySpec() - loadedTypes = [2]bool{ + loadedTypes = [3]bool{ _q.withUser != nil, _q.withGroup != nil, + _q.withUsageLogs != nil, } ) _spec.ScanValues = func(columns []string) ([]any, error) { @@ -441,6 +479,13 @@ func (_q *ApiKeyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ApiKe return nil, err } } + if query := _q.withUsageLogs; query != nil { + if err := _q.loadUsageLogs(ctx, query, nodes, + func(n *ApiKey) { n.Edges.UsageLogs = []*UsageLog{} }, + func(n *ApiKey, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil { + return nil, err + } + } return nodes, nil } @@ -505,6 +550,36 @@ func (_q *ApiKeyQuery) loadGroup(ctx context.Context, query *GroupQuery, nodes [ } return nil } +func (_q *ApiKeyQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*ApiKey, init func(*ApiKey), assign func(*ApiKey, *UsageLog)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*ApiKey) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(usagelog.FieldAPIKeyID) + } + query.Where(predicate.UsageLog(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(apikey.UsageLogsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.APIKeyID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "api_key_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *ApiKeyQuery) sqlCount(ctx context.Context) (int, error) { _spec := _q.querySpec() diff --git a/backend/ent/apikey_update.go b/backend/ent/apikey_update.go index 3917d068..3259bfd9 100644 --- a/backend/ent/apikey_update.go +++ b/backend/ent/apikey_update.go @@ -14,6 +14,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" ) @@ -142,6 +143,21 @@ func (_u *ApiKeyUpdate) SetGroup(v *Group) *ApiKeyUpdate { return _u.SetGroupID(v.ID) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *ApiKeyUpdate) AddUsageLogIDs(ids ...int64) *ApiKeyUpdate { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *ApiKeyUpdate) AddUsageLogs(v ...*UsageLog) *ApiKeyUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // Mutation returns the ApiKeyMutation object of the builder. func (_u *ApiKeyUpdate) Mutation() *ApiKeyMutation { return _u.mutation @@ -159,6 +175,27 @@ func (_u *ApiKeyUpdate) ClearGroup() *ApiKeyUpdate { return _u } +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *ApiKeyUpdate) ClearUsageLogs() *ApiKeyUpdate { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *ApiKeyUpdate) RemoveUsageLogIDs(ids ...int64) *ApiKeyUpdate { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *ApiKeyUpdate) RemoveUsageLogs(v ...*UsageLog) *ApiKeyUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *ApiKeyUpdate) Save(ctx context.Context) (int, error) { if err := _u.defaults(); err != nil { @@ -312,6 +349,51 @@ func (_u *ApiKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: apikey.UsageLogsTable, + Columns: []string{apikey.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: apikey.UsageLogsTable, + Columns: []string{apikey.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: apikey.UsageLogsTable, + Columns: []string{apikey.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{apikey.Label} @@ -444,6 +526,21 @@ func (_u *ApiKeyUpdateOne) SetGroup(v *Group) *ApiKeyUpdateOne { return _u.SetGroupID(v.ID) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *ApiKeyUpdateOne) AddUsageLogIDs(ids ...int64) *ApiKeyUpdateOne { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *ApiKeyUpdateOne) AddUsageLogs(v ...*UsageLog) *ApiKeyUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // Mutation returns the ApiKeyMutation object of the builder. func (_u *ApiKeyUpdateOne) Mutation() *ApiKeyMutation { return _u.mutation @@ -461,6 +558,27 @@ func (_u *ApiKeyUpdateOne) ClearGroup() *ApiKeyUpdateOne { return _u } +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *ApiKeyUpdateOne) ClearUsageLogs() *ApiKeyUpdateOne { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *ApiKeyUpdateOne) RemoveUsageLogIDs(ids ...int64) *ApiKeyUpdateOne { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *ApiKeyUpdateOne) RemoveUsageLogs(v ...*UsageLog) *ApiKeyUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // Where appends a list predicates to the ApiKeyUpdate builder. func (_u *ApiKeyUpdateOne) Where(ps ...predicate.ApiKey) *ApiKeyUpdateOne { _u.mutation.Where(ps...) @@ -644,6 +762,51 @@ func (_u *ApiKeyUpdateOne) sqlSave(ctx context.Context) (_node *ApiKey, err erro } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: apikey.UsageLogsTable, + Columns: []string{apikey.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: apikey.UsageLogsTable, + Columns: []string{apikey.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: apikey.UsageLogsTable, + Columns: []string{apikey.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &ApiKey{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/ent/client.go b/backend/ent/client.go index 113dc7ff..909226fa 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -22,6 +22,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -48,6 +49,8 @@ type Client struct { RedeemCode *RedeemCodeClient // Setting is the client for interacting with the Setting builders. Setting *SettingClient + // UsageLog is the client for interacting with the UsageLog builders. + UsageLog *UsageLogClient // User is the client for interacting with the User builders. User *UserClient // UserAllowedGroup is the client for interacting with the UserAllowedGroup builders. @@ -72,6 +75,7 @@ func (c *Client) init() { c.Proxy = NewProxyClient(c.config) c.RedeemCode = NewRedeemCodeClient(c.config) c.Setting = NewSettingClient(c.config) + c.UsageLog = NewUsageLogClient(c.config) c.User = NewUserClient(c.config) c.UserAllowedGroup = NewUserAllowedGroupClient(c.config) c.UserSubscription = NewUserSubscriptionClient(c.config) @@ -174,6 +178,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { Proxy: NewProxyClient(cfg), RedeemCode: NewRedeemCodeClient(cfg), Setting: NewSettingClient(cfg), + UsageLog: NewUsageLogClient(cfg), User: NewUserClient(cfg), UserAllowedGroup: NewUserAllowedGroupClient(cfg), UserSubscription: NewUserSubscriptionClient(cfg), @@ -203,6 +208,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) Proxy: NewProxyClient(cfg), RedeemCode: NewRedeemCodeClient(cfg), Setting: NewSettingClient(cfg), + UsageLog: NewUsageLogClient(cfg), User: NewUserClient(cfg), UserAllowedGroup: NewUserAllowedGroupClient(cfg), UserSubscription: NewUserSubscriptionClient(cfg), @@ -236,7 +242,7 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.Account, c.AccountGroup, c.ApiKey, c.Group, c.Proxy, c.RedeemCode, c.Setting, - c.User, c.UserAllowedGroup, c.UserSubscription, + c.UsageLog, c.User, c.UserAllowedGroup, c.UserSubscription, } { n.Use(hooks...) } @@ -247,7 +253,7 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.Account, c.AccountGroup, c.ApiKey, c.Group, c.Proxy, c.RedeemCode, c.Setting, - c.User, c.UserAllowedGroup, c.UserSubscription, + c.UsageLog, c.User, c.UserAllowedGroup, c.UserSubscription, } { n.Intercept(interceptors...) } @@ -270,6 +276,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.RedeemCode.mutate(ctx, m) case *SettingMutation: return c.Setting.mutate(ctx, m) + case *UsageLogMutation: + return c.UsageLog.mutate(ctx, m) case *UserMutation: return c.User.mutate(ctx, m) case *UserAllowedGroupMutation: @@ -405,6 +413,38 @@ func (c *AccountClient) QueryGroups(_m *Account) *GroupQuery { return query } +// QueryProxy queries the proxy edge of a Account. +func (c *AccountClient) QueryProxy(_m *Account) *ProxyQuery { + query := (&ProxyClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(account.Table, account.FieldID, id), + sqlgraph.To(proxy.Table, proxy.FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, account.ProxyTable, account.ProxyColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryUsageLogs queries the usage_logs edge of a Account. +func (c *AccountClient) QueryUsageLogs(_m *Account) *UsageLogQuery { + query := (&UsageLogClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(account.Table, account.FieldID, id), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, account.UsageLogsTable, account.UsageLogsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // QueryAccountGroups queries the account_groups edge of a Account. func (c *AccountClient) QueryAccountGroups(_m *Account) *AccountGroupQuery { query := (&AccountGroupClient{config: c.config}).Query() @@ -704,6 +744,22 @@ func (c *ApiKeyClient) QueryGroup(_m *ApiKey) *GroupQuery { return query } +// QueryUsageLogs queries the usage_logs edge of a ApiKey. +func (c *ApiKeyClient) QueryUsageLogs(_m *ApiKey) *UsageLogQuery { + query := (&UsageLogClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(apikey.Table, apikey.FieldID, id), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, apikey.UsageLogsTable, apikey.UsageLogsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // Hooks returns the client hooks. func (c *ApiKeyClient) Hooks() []Hook { hooks := c.hooks.ApiKey @@ -887,6 +943,22 @@ func (c *GroupClient) QuerySubscriptions(_m *Group) *UserSubscriptionQuery { return query } +// QueryUsageLogs queries the usage_logs edge of a Group. +func (c *GroupClient) QueryUsageLogs(_m *Group) *UsageLogQuery { + query := (&UsageLogClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(group.Table, group.FieldID, id), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, group.UsageLogsTable, group.UsageLogsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // QueryAccounts queries the accounts edge of a Group. func (c *GroupClient) QueryAccounts(_m *Group) *AccountQuery { query := (&AccountClient{config: c.config}).Query() @@ -1086,6 +1158,22 @@ func (c *ProxyClient) GetX(ctx context.Context, id int64) *Proxy { return obj } +// QueryAccounts queries the accounts edge of a Proxy. +func (c *ProxyClient) QueryAccounts(_m *Proxy) *AccountQuery { + query := (&AccountClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(proxy.Table, proxy.FieldID, id), + sqlgraph.To(account.Table, account.FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, proxy.AccountsTable, proxy.AccountsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // Hooks returns the client hooks. func (c *ProxyClient) Hooks() []Hook { hooks := c.hooks.Proxy @@ -1411,6 +1499,219 @@ func (c *SettingClient) mutate(ctx context.Context, m *SettingMutation) (Value, } } +// UsageLogClient is a client for the UsageLog schema. +type UsageLogClient struct { + config +} + +// NewUsageLogClient returns a client for the UsageLog from the given config. +func NewUsageLogClient(c config) *UsageLogClient { + return &UsageLogClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `usagelog.Hooks(f(g(h())))`. +func (c *UsageLogClient) Use(hooks ...Hook) { + c.hooks.UsageLog = append(c.hooks.UsageLog, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `usagelog.Intercept(f(g(h())))`. +func (c *UsageLogClient) Intercept(interceptors ...Interceptor) { + c.inters.UsageLog = append(c.inters.UsageLog, interceptors...) +} + +// Create returns a builder for creating a UsageLog entity. +func (c *UsageLogClient) Create() *UsageLogCreate { + mutation := newUsageLogMutation(c.config, OpCreate) + return &UsageLogCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of UsageLog entities. +func (c *UsageLogClient) CreateBulk(builders ...*UsageLogCreate) *UsageLogCreateBulk { + return &UsageLogCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *UsageLogClient) MapCreateBulk(slice any, setFunc func(*UsageLogCreate, int)) *UsageLogCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &UsageLogCreateBulk{err: fmt.Errorf("calling to UsageLogClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*UsageLogCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &UsageLogCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for UsageLog. +func (c *UsageLogClient) Update() *UsageLogUpdate { + mutation := newUsageLogMutation(c.config, OpUpdate) + return &UsageLogUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *UsageLogClient) UpdateOne(_m *UsageLog) *UsageLogUpdateOne { + mutation := newUsageLogMutation(c.config, OpUpdateOne, withUsageLog(_m)) + return &UsageLogUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *UsageLogClient) UpdateOneID(id int64) *UsageLogUpdateOne { + mutation := newUsageLogMutation(c.config, OpUpdateOne, withUsageLogID(id)) + return &UsageLogUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for UsageLog. +func (c *UsageLogClient) Delete() *UsageLogDelete { + mutation := newUsageLogMutation(c.config, OpDelete) + return &UsageLogDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *UsageLogClient) DeleteOne(_m *UsageLog) *UsageLogDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *UsageLogClient) DeleteOneID(id int64) *UsageLogDeleteOne { + builder := c.Delete().Where(usagelog.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &UsageLogDeleteOne{builder} +} + +// Query returns a query builder for UsageLog. +func (c *UsageLogClient) Query() *UsageLogQuery { + return &UsageLogQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeUsageLog}, + inters: c.Interceptors(), + } +} + +// Get returns a UsageLog entity by its id. +func (c *UsageLogClient) Get(ctx context.Context, id int64) (*UsageLog, error) { + return c.Query().Where(usagelog.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *UsageLogClient) GetX(ctx context.Context, id int64) *UsageLog { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryUser queries the user edge of a UsageLog. +func (c *UsageLogClient) QueryUser(_m *UsageLog) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.UserTable, usagelog.UserColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryAPIKey queries the api_key edge of a UsageLog. +func (c *UsageLogClient) QueryAPIKey(_m *UsageLog) *ApiKeyQuery { + query := (&ApiKeyClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, id), + sqlgraph.To(apikey.Table, apikey.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.APIKeyTable, usagelog.APIKeyColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryAccount queries the account edge of a UsageLog. +func (c *UsageLogClient) QueryAccount(_m *UsageLog) *AccountQuery { + query := (&AccountClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, id), + sqlgraph.To(account.Table, account.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.AccountTable, usagelog.AccountColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryGroup queries the group edge of a UsageLog. +func (c *UsageLogClient) QueryGroup(_m *UsageLog) *GroupQuery { + query := (&GroupClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, id), + sqlgraph.To(group.Table, group.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.GroupTable, usagelog.GroupColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QuerySubscription queries the subscription edge of a UsageLog. +func (c *UsageLogClient) QuerySubscription(_m *UsageLog) *UserSubscriptionQuery { + query := (&UserSubscriptionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, id), + sqlgraph.To(usersubscription.Table, usersubscription.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.SubscriptionTable, usagelog.SubscriptionColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *UsageLogClient) Hooks() []Hook { + return c.hooks.UsageLog +} + +// Interceptors returns the client interceptors. +func (c *UsageLogClient) Interceptors() []Interceptor { + return c.inters.UsageLog +} + +func (c *UsageLogClient) mutate(ctx context.Context, m *UsageLogMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&UsageLogCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&UsageLogUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&UsageLogUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&UsageLogDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown UsageLog mutation op: %q", m.Op()) + } +} + // UserClient is a client for the User schema. type UserClient struct { config @@ -1599,6 +1900,22 @@ func (c *UserClient) QueryAllowedGroups(_m *User) *GroupQuery { return query } +// QueryUsageLogs queries the usage_logs edge of a User. +func (c *UserClient) QueryUsageLogs(_m *User) *UsageLogQuery { + query := (&UsageLogClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.UsageLogsTable, user.UsageLogsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // QueryUserAllowedGroups queries the user_allowed_groups edge of a User. func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: c.config}).Query() @@ -1914,14 +2231,32 @@ func (c *UserSubscriptionClient) QueryAssignedByUser(_m *UserSubscription) *User return query } +// QueryUsageLogs queries the usage_logs edge of a UserSubscription. +func (c *UserSubscriptionClient) QueryUsageLogs(_m *UserSubscription) *UsageLogQuery { + query := (&UsageLogClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(usersubscription.Table, usersubscription.FieldID, id), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, usersubscription.UsageLogsTable, usersubscription.UsageLogsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // Hooks returns the client hooks. func (c *UserSubscriptionClient) Hooks() []Hook { - return c.hooks.UserSubscription + hooks := c.hooks.UserSubscription + return append(hooks[:len(hooks):len(hooks)], usersubscription.Hooks[:]...) } // Interceptors returns the client interceptors. func (c *UserSubscriptionClient) Interceptors() []Interceptor { - return c.inters.UserSubscription + inters := c.inters.UserSubscription + return append(inters[:len(inters):len(inters)], usersubscription.Interceptors[:]...) } func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscriptionMutation) (Value, error) { @@ -1942,16 +2277,15 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription // hooks and interceptors per client, for fast access. type ( hooks struct { - Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, User, - UserAllowedGroup, UserSubscription []ent.Hook + Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, UsageLog, + User, UserAllowedGroup, UserSubscription []ent.Hook } inters struct { - Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, User, - UserAllowedGroup, UserSubscription []ent.Interceptor + Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, UsageLog, + User, UserAllowedGroup, UserSubscription []ent.Interceptor } ) -// ExecContext 透传到底层 driver,用于在 ent 事务中执行原生 SQL(例如同步 legacy 字段)。 // ExecContext allows calling the underlying ExecContext method of the driver if it is supported by it. // See, database/sql#DB.ExecContext for more information. func (c *config) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) { @@ -1964,7 +2298,6 @@ func (c *config) ExecContext(ctx context.Context, query string, args ...any) (st return ex.ExecContext(ctx, query, args...) } -// QueryContext 透传到底层 driver,用于在事务内执行原生查询并共享锁/一致性语义。 // QueryContext allows calling the underlying QueryContext method of the driver if it is supported by it. // See, database/sql#DB.QueryContext for more information. func (c *config) QueryContext(ctx context.Context, query string, args ...any) (*stdsql.Rows, error) { diff --git a/backend/ent/ent.go b/backend/ent/ent.go index e2c8b56c..29890206 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -19,6 +19,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -89,6 +90,7 @@ func checkColumn(t, c string) error { proxy.Table: proxy.ValidColumn, redeemcode.Table: redeemcode.ValidColumn, setting.Table: setting.ValidColumn, + usagelog.Table: usagelog.ValidColumn, user.Table: user.ValidColumn, userallowedgroup.Table: userallowedgroup.ValidColumn, usersubscription.Table: usersubscription.ValidColumn, diff --git a/backend/ent/group.go b/backend/ent/group.go index fecb202a..9b1e8604 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -43,6 +43,8 @@ type Group struct { WeeklyLimitUsd *float64 `json:"weekly_limit_usd,omitempty"` // MonthlyLimitUsd holds the value of the "monthly_limit_usd" field. MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"` + // DefaultValidityDays holds the value of the "default_validity_days" field. + DefaultValidityDays int `json:"default_validity_days,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -57,6 +59,8 @@ type GroupEdges struct { RedeemCodes []*RedeemCode `json:"redeem_codes,omitempty"` // Subscriptions holds the value of the subscriptions edge. Subscriptions []*UserSubscription `json:"subscriptions,omitempty"` + // UsageLogs holds the value of the usage_logs edge. + UsageLogs []*UsageLog `json:"usage_logs,omitempty"` // Accounts holds the value of the accounts edge. Accounts []*Account `json:"accounts,omitempty"` // AllowedUsers holds the value of the allowed_users edge. @@ -67,7 +71,7 @@ type GroupEdges struct { UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [7]bool + loadedTypes [8]bool } // APIKeysOrErr returns the APIKeys value or an error if the edge @@ -97,10 +101,19 @@ func (e GroupEdges) SubscriptionsOrErr() ([]*UserSubscription, error) { return nil, &NotLoadedError{edge: "subscriptions"} } +// UsageLogsOrErr returns the UsageLogs value or an error if the edge +// was not loaded in eager-loading. +func (e GroupEdges) UsageLogsOrErr() ([]*UsageLog, error) { + if e.loadedTypes[3] { + return e.UsageLogs, nil + } + return nil, &NotLoadedError{edge: "usage_logs"} +} + // AccountsOrErr returns the Accounts value or an error if the edge // was not loaded in eager-loading. func (e GroupEdges) AccountsOrErr() ([]*Account, error) { - if e.loadedTypes[3] { + if e.loadedTypes[4] { return e.Accounts, nil } return nil, &NotLoadedError{edge: "accounts"} @@ -109,7 +122,7 @@ func (e GroupEdges) AccountsOrErr() ([]*Account, error) { // AllowedUsersOrErr returns the AllowedUsers value or an error if the edge // was not loaded in eager-loading. func (e GroupEdges) AllowedUsersOrErr() ([]*User, error) { - if e.loadedTypes[4] { + if e.loadedTypes[5] { return e.AllowedUsers, nil } return nil, &NotLoadedError{edge: "allowed_users"} @@ -118,7 +131,7 @@ func (e GroupEdges) AllowedUsersOrErr() ([]*User, error) { // AccountGroupsOrErr returns the AccountGroups value or an error if the edge // was not loaded in eager-loading. func (e GroupEdges) AccountGroupsOrErr() ([]*AccountGroup, error) { - if e.loadedTypes[5] { + if e.loadedTypes[6] { return e.AccountGroups, nil } return nil, &NotLoadedError{edge: "account_groups"} @@ -127,7 +140,7 @@ func (e GroupEdges) AccountGroupsOrErr() ([]*AccountGroup, error) { // UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge // was not loaded in eager-loading. func (e GroupEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) { - if e.loadedTypes[6] { + if e.loadedTypes[7] { return e.UserAllowedGroups, nil } return nil, &NotLoadedError{edge: "user_allowed_groups"} @@ -142,7 +155,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd: values[i] = new(sql.NullFloat64) - case group.FieldID: + case group.FieldID, group.FieldDefaultValidityDays: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: values[i] = new(sql.NullString) @@ -252,6 +265,12 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.MonthlyLimitUsd = new(float64) *_m.MonthlyLimitUsd = value.Float64 } + case group.FieldDefaultValidityDays: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field default_validity_days", values[i]) + } else if value.Valid { + _m.DefaultValidityDays = int(value.Int64) + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -280,6 +299,11 @@ func (_m *Group) QuerySubscriptions() *UserSubscriptionQuery { return NewGroupClient(_m.config).QuerySubscriptions(_m) } +// QueryUsageLogs queries the "usage_logs" edge of the Group entity. +func (_m *Group) QueryUsageLogs() *UsageLogQuery { + return NewGroupClient(_m.config).QueryUsageLogs(_m) +} + // QueryAccounts queries the "accounts" edge of the Group entity. func (_m *Group) QueryAccounts() *AccountQuery { return NewGroupClient(_m.config).QueryAccounts(_m) @@ -371,6 +395,9 @@ func (_m *Group) String() string { builder.WriteString("monthly_limit_usd=") builder.WriteString(fmt.Sprintf("%v", *v)) } + builder.WriteString(", ") + builder.WriteString("default_validity_days=") + builder.WriteString(fmt.Sprintf("%v", _m.DefaultValidityDays)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 05a5673d..8dc53c49 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -41,12 +41,16 @@ const ( FieldWeeklyLimitUsd = "weekly_limit_usd" // FieldMonthlyLimitUsd holds the string denoting the monthly_limit_usd field in the database. FieldMonthlyLimitUsd = "monthly_limit_usd" + // FieldDefaultValidityDays holds the string denoting the default_validity_days field in the database. + FieldDefaultValidityDays = "default_validity_days" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. EdgeRedeemCodes = "redeem_codes" // EdgeSubscriptions holds the string denoting the subscriptions edge name in mutations. EdgeSubscriptions = "subscriptions" + // EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations. + EdgeUsageLogs = "usage_logs" // EdgeAccounts holds the string denoting the accounts edge name in mutations. EdgeAccounts = "accounts" // EdgeAllowedUsers holds the string denoting the allowed_users edge name in mutations. @@ -78,6 +82,13 @@ const ( SubscriptionsInverseTable = "user_subscriptions" // SubscriptionsColumn is the table column denoting the subscriptions relation/edge. SubscriptionsColumn = "group_id" + // UsageLogsTable is the table that holds the usage_logs relation/edge. + UsageLogsTable = "usage_logs" + // UsageLogsInverseTable is the table name for the UsageLog entity. + // It exists in this package in order to avoid circular dependency with the "usagelog" package. + UsageLogsInverseTable = "usage_logs" + // UsageLogsColumn is the table column denoting the usage_logs relation/edge. + UsageLogsColumn = "group_id" // AccountsTable is the table that holds the accounts relation/edge. The primary key declared below. AccountsTable = "account_groups" // AccountsInverseTable is the table name for the Account entity. @@ -120,6 +131,7 @@ var Columns = []string{ FieldDailyLimitUsd, FieldWeeklyLimitUsd, FieldMonthlyLimitUsd, + FieldDefaultValidityDays, } var ( @@ -173,6 +185,8 @@ var ( DefaultSubscriptionType string // SubscriptionTypeValidator is a validator for the "subscription_type" field. It is called by the builders before save. SubscriptionTypeValidator func(string) error + // DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field. + DefaultDefaultValidityDays int ) // OrderOption defines the ordering options for the Group queries. @@ -248,6 +262,11 @@ func ByMonthlyLimitUsd(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldMonthlyLimitUsd, opts...).ToFunc() } +// ByDefaultValidityDays orders the results by the default_validity_days field. +func ByDefaultValidityDays(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDefaultValidityDays, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -290,6 +309,20 @@ func BySubscriptions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { } } +// ByUsageLogsCount orders the results by usage_logs count. +func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...) + } +} + +// ByUsageLogs orders the results by usage_logs terms. +func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + // ByAccountsCount orders the results by accounts count. func ByAccountsCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -366,6 +399,13 @@ func newSubscriptionsStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.O2M, false, SubscriptionsTable, SubscriptionsColumn), ) } +func newUsageLogsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UsageLogsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) +} func newAccountsStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index fd597be9..ac18a418 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -120,6 +120,11 @@ func MonthlyLimitUsd(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldMonthlyLimitUsd, v)) } +// DefaultValidityDays applies equality check predicate on the "default_validity_days" field. It's identical to DefaultValidityDaysEQ. +func DefaultValidityDays(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldDefaultValidityDays, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Group { return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) @@ -785,6 +790,46 @@ func MonthlyLimitUsdNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldMonthlyLimitUsd)) } +// DefaultValidityDaysEQ applies the EQ predicate on the "default_validity_days" field. +func DefaultValidityDaysEQ(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldDefaultValidityDays, v)) +} + +// DefaultValidityDaysNEQ applies the NEQ predicate on the "default_validity_days" field. +func DefaultValidityDaysNEQ(v int) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldDefaultValidityDays, v)) +} + +// DefaultValidityDaysIn applies the In predicate on the "default_validity_days" field. +func DefaultValidityDaysIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldIn(FieldDefaultValidityDays, vs...)) +} + +// DefaultValidityDaysNotIn applies the NotIn predicate on the "default_validity_days" field. +func DefaultValidityDaysNotIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldDefaultValidityDays, vs...)) +} + +// DefaultValidityDaysGT applies the GT predicate on the "default_validity_days" field. +func DefaultValidityDaysGT(v int) predicate.Group { + return predicate.Group(sql.FieldGT(FieldDefaultValidityDays, v)) +} + +// DefaultValidityDaysGTE applies the GTE predicate on the "default_validity_days" field. +func DefaultValidityDaysGTE(v int) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldDefaultValidityDays, v)) +} + +// DefaultValidityDaysLT applies the LT predicate on the "default_validity_days" field. +func DefaultValidityDaysLT(v int) predicate.Group { + return predicate.Group(sql.FieldLT(FieldDefaultValidityDays, v)) +} + +// DefaultValidityDaysLTE applies the LTE predicate on the "default_validity_days" field. +func DefaultValidityDaysLTE(v int) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldDefaultValidityDays, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.Group { return predicate.Group(func(s *sql.Selector) { @@ -854,6 +899,29 @@ func HasSubscriptionsWith(preds ...predicate.UserSubscription) predicate.Group { }) } +// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge. +func HasUsageLogs() predicate.Group { + return predicate.Group(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates). +func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.Group { + return predicate.Group(func(s *sql.Selector) { + step := newUsageLogsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // HasAccounts applies the HasEdge predicate on the "accounts" edge. func HasAccounts() predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 873cf84c..383a1352 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -15,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -201,6 +202,20 @@ func (_c *GroupCreate) SetNillableMonthlyLimitUsd(v *float64) *GroupCreate { return _c } +// SetDefaultValidityDays sets the "default_validity_days" field. +func (_c *GroupCreate) SetDefaultValidityDays(v int) *GroupCreate { + _c.mutation.SetDefaultValidityDays(v) + return _c +} + +// SetNillableDefaultValidityDays sets the "default_validity_days" field if the given value is not nil. +func (_c *GroupCreate) SetNillableDefaultValidityDays(v *int) *GroupCreate { + if v != nil { + _c.SetDefaultValidityDays(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the ApiKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -246,6 +261,21 @@ func (_c *GroupCreate) AddSubscriptions(v ...*UserSubscription) *GroupCreate { return _c.AddSubscriptionIDs(ids...) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_c *GroupCreate) AddUsageLogIDs(ids ...int64) *GroupCreate { + _c.mutation.AddUsageLogIDs(ids...) + return _c +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_c *GroupCreate) AddUsageLogs(v ...*UsageLog) *GroupCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddUsageLogIDs(ids...) +} + // AddAccountIDs adds the "accounts" edge to the Account entity by IDs. func (_c *GroupCreate) AddAccountIDs(ids ...int64) *GroupCreate { _c.mutation.AddAccountIDs(ids...) @@ -347,6 +377,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultSubscriptionType _c.mutation.SetSubscriptionType(v) } + if _, ok := _c.mutation.DefaultValidityDays(); !ok { + v := group.DefaultDefaultValidityDays + _c.mutation.SetDefaultValidityDays(v) + } return nil } @@ -396,6 +430,9 @@ func (_c *GroupCreate) check() error { return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)} } } + if _, ok := _c.mutation.DefaultValidityDays(); !ok { + return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)} + } return nil } @@ -475,6 +512,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldMonthlyLimitUsd, field.TypeFloat64, value) _node.MonthlyLimitUsd = &value } + if value, ok := _c.mutation.DefaultValidityDays(); ok { + _spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value) + _node.DefaultValidityDays = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -523,6 +564,22 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { } _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsageLogsTable, + Columns: []string{group.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } if nodes := _c.mutation.AccountsIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2M, @@ -813,6 +870,24 @@ func (u *GroupUpsert) ClearMonthlyLimitUsd() *GroupUpsert { return u } +// SetDefaultValidityDays sets the "default_validity_days" field. +func (u *GroupUpsert) SetDefaultValidityDays(v int) *GroupUpsert { + u.Set(group.FieldDefaultValidityDays, v) + return u +} + +// UpdateDefaultValidityDays sets the "default_validity_days" field to the value that was provided on create. +func (u *GroupUpsert) UpdateDefaultValidityDays() *GroupUpsert { + u.SetExcluded(group.FieldDefaultValidityDays) + return u +} + +// AddDefaultValidityDays adds v to the "default_validity_days" field. +func (u *GroupUpsert) AddDefaultValidityDays(v int) *GroupUpsert { + u.Add(group.FieldDefaultValidityDays, v) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1089,6 +1164,27 @@ func (u *GroupUpsertOne) ClearMonthlyLimitUsd() *GroupUpsertOne { }) } +// SetDefaultValidityDays sets the "default_validity_days" field. +func (u *GroupUpsertOne) SetDefaultValidityDays(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetDefaultValidityDays(v) + }) +} + +// AddDefaultValidityDays adds v to the "default_validity_days" field. +func (u *GroupUpsertOne) AddDefaultValidityDays(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddDefaultValidityDays(v) + }) +} + +// UpdateDefaultValidityDays sets the "default_validity_days" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateDefaultValidityDays() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateDefaultValidityDays() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1531,6 +1627,27 @@ func (u *GroupUpsertBulk) ClearMonthlyLimitUsd() *GroupUpsertBulk { }) } +// SetDefaultValidityDays sets the "default_validity_days" field. +func (u *GroupUpsertBulk) SetDefaultValidityDays(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetDefaultValidityDays(v) + }) +} + +// AddDefaultValidityDays adds v to the "default_validity_days" field. +func (u *GroupUpsertBulk) AddDefaultValidityDays(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddDefaultValidityDays(v) + }) +} + +// UpdateDefaultValidityDays sets the "default_validity_days" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateDefaultValidityDays() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateDefaultValidityDays() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_query.go b/backend/ent/group_query.go index 0b86e069..93a8d8c2 100644 --- a/backend/ent/group_query.go +++ b/backend/ent/group_query.go @@ -18,6 +18,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -33,6 +34,7 @@ type GroupQuery struct { withAPIKeys *ApiKeyQuery withRedeemCodes *RedeemCodeQuery withSubscriptions *UserSubscriptionQuery + withUsageLogs *UsageLogQuery withAccounts *AccountQuery withAllowedUsers *UserQuery withAccountGroups *AccountGroupQuery @@ -139,6 +141,28 @@ func (_q *GroupQuery) QuerySubscriptions() *UserSubscriptionQuery { return query } +// QueryUsageLogs chains the current query on the "usage_logs" edge. +func (_q *GroupQuery) QueryUsageLogs() *UsageLogQuery { + query := (&UsageLogClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(group.Table, group.FieldID, selector), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, group.UsageLogsTable, group.UsageLogsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // QueryAccounts chains the current query on the "accounts" edge. func (_q *GroupQuery) QueryAccounts() *AccountQuery { query := (&AccountClient{config: _q.config}).Query() @@ -422,6 +446,7 @@ func (_q *GroupQuery) Clone() *GroupQuery { withAPIKeys: _q.withAPIKeys.Clone(), withRedeemCodes: _q.withRedeemCodes.Clone(), withSubscriptions: _q.withSubscriptions.Clone(), + withUsageLogs: _q.withUsageLogs.Clone(), withAccounts: _q.withAccounts.Clone(), withAllowedUsers: _q.withAllowedUsers.Clone(), withAccountGroups: _q.withAccountGroups.Clone(), @@ -465,6 +490,17 @@ func (_q *GroupQuery) WithSubscriptions(opts ...func(*UserSubscriptionQuery)) *G return _q } +// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to +// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *GroupQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *GroupQuery { + query := (&UsageLogClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUsageLogs = query + return _q +} + // WithAccounts tells the query-builder to eager-load the nodes that are connected to // the "accounts" edge. The optional arguments are used to configure the query builder of the edge. func (_q *GroupQuery) WithAccounts(opts ...func(*AccountQuery)) *GroupQuery { @@ -587,10 +623,11 @@ func (_q *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, var ( nodes = []*Group{} _spec = _q.querySpec() - loadedTypes = [7]bool{ + loadedTypes = [8]bool{ _q.withAPIKeys != nil, _q.withRedeemCodes != nil, _q.withSubscriptions != nil, + _q.withUsageLogs != nil, _q.withAccounts != nil, _q.withAllowedUsers != nil, _q.withAccountGroups != nil, @@ -636,6 +673,13 @@ func (_q *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, return nil, err } } + if query := _q.withUsageLogs; query != nil { + if err := _q.loadUsageLogs(ctx, query, nodes, + func(n *Group) { n.Edges.UsageLogs = []*UsageLog{} }, + func(n *Group, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil { + return nil, err + } + } if query := _q.withAccounts; query != nil { if err := _q.loadAccounts(ctx, query, nodes, func(n *Group) { n.Edges.Accounts = []*Account{} }, @@ -763,6 +807,39 @@ func (_q *GroupQuery) loadSubscriptions(ctx context.Context, query *UserSubscrip } return nil } +func (_q *GroupQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*Group, init func(*Group), assign func(*Group, *UsageLog)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*Group) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(usagelog.FieldGroupID) + } + query.Where(predicate.UsageLog(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(group.UsageLogsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.GroupID + if fk == nil { + return fmt.Errorf(`foreign-key "group_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "group_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *GroupQuery) loadAccounts(ctx context.Context, query *AccountQuery, nodes []*Group, init func(*Group), assign func(*Group, *Account)) error { edgeIDs := make([]driver.Value, len(nodes)) byID := make(map[int64]*Group) diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index 0ed1e3fd..1825a892 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -16,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -251,6 +252,27 @@ func (_u *GroupUpdate) ClearMonthlyLimitUsd() *GroupUpdate { return _u } +// SetDefaultValidityDays sets the "default_validity_days" field. +func (_u *GroupUpdate) SetDefaultValidityDays(v int) *GroupUpdate { + _u.mutation.ResetDefaultValidityDays() + _u.mutation.SetDefaultValidityDays(v) + return _u +} + +// SetNillableDefaultValidityDays sets the "default_validity_days" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableDefaultValidityDays(v *int) *GroupUpdate { + if v != nil { + _u.SetDefaultValidityDays(*v) + } + return _u +} + +// AddDefaultValidityDays adds value to the "default_validity_days" field. +func (_u *GroupUpdate) AddDefaultValidityDays(v int) *GroupUpdate { + _u.mutation.AddDefaultValidityDays(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the ApiKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -296,6 +318,21 @@ func (_u *GroupUpdate) AddSubscriptions(v ...*UserSubscription) *GroupUpdate { return _u.AddSubscriptionIDs(ids...) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *GroupUpdate) AddUsageLogIDs(ids ...int64) *GroupUpdate { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *GroupUpdate) AddUsageLogs(v ...*UsageLog) *GroupUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // AddAccountIDs adds the "accounts" edge to the Account entity by IDs. func (_u *GroupUpdate) AddAccountIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAccountIDs(ids...) @@ -394,6 +431,27 @@ func (_u *GroupUpdate) RemoveSubscriptions(v ...*UserSubscription) *GroupUpdate return _u.RemoveSubscriptionIDs(ids...) } +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *GroupUpdate) ClearUsageLogs() *GroupUpdate { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *GroupUpdate) RemoveUsageLogIDs(ids ...int64) *GroupUpdate { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *GroupUpdate) RemoveUsageLogs(v ...*UsageLog) *GroupUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // ClearAccounts clears all "accounts" edges to the Account entity. func (_u *GroupUpdate) ClearAccounts() *GroupUpdate { _u.mutation.ClearAccounts() @@ -578,6 +636,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.MonthlyLimitUsdCleared() { _spec.ClearField(group.FieldMonthlyLimitUsd, field.TypeFloat64) } + if value, ok := _u.mutation.DefaultValidityDays(); ok { + _spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedDefaultValidityDays(); ok { + _spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -713,6 +777,51 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsageLogsTable, + Columns: []string{group.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsageLogsTable, + Columns: []string{group.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsageLogsTable, + Columns: []string{group.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _u.mutation.AccountsCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2M, @@ -1065,6 +1174,27 @@ func (_u *GroupUpdateOne) ClearMonthlyLimitUsd() *GroupUpdateOne { return _u } +// SetDefaultValidityDays sets the "default_validity_days" field. +func (_u *GroupUpdateOne) SetDefaultValidityDays(v int) *GroupUpdateOne { + _u.mutation.ResetDefaultValidityDays() + _u.mutation.SetDefaultValidityDays(v) + return _u +} + +// SetNillableDefaultValidityDays sets the "default_validity_days" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableDefaultValidityDays(v *int) *GroupUpdateOne { + if v != nil { + _u.SetDefaultValidityDays(*v) + } + return _u +} + +// AddDefaultValidityDays adds value to the "default_validity_days" field. +func (_u *GroupUpdateOne) AddDefaultValidityDays(v int) *GroupUpdateOne { + _u.mutation.AddDefaultValidityDays(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the ApiKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -1110,6 +1240,21 @@ func (_u *GroupUpdateOne) AddSubscriptions(v ...*UserSubscription) *GroupUpdateO return _u.AddSubscriptionIDs(ids...) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *GroupUpdateOne) AddUsageLogIDs(ids ...int64) *GroupUpdateOne { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *GroupUpdateOne) AddUsageLogs(v ...*UsageLog) *GroupUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // AddAccountIDs adds the "accounts" edge to the Account entity by IDs. func (_u *GroupUpdateOne) AddAccountIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAccountIDs(ids...) @@ -1208,6 +1353,27 @@ func (_u *GroupUpdateOne) RemoveSubscriptions(v ...*UserSubscription) *GroupUpda return _u.RemoveSubscriptionIDs(ids...) } +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *GroupUpdateOne) ClearUsageLogs() *GroupUpdateOne { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *GroupUpdateOne) RemoveUsageLogIDs(ids ...int64) *GroupUpdateOne { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *GroupUpdateOne) RemoveUsageLogs(v ...*UsageLog) *GroupUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // ClearAccounts clears all "accounts" edges to the Account entity. func (_u *GroupUpdateOne) ClearAccounts() *GroupUpdateOne { _u.mutation.ClearAccounts() @@ -1422,6 +1588,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.MonthlyLimitUsdCleared() { _spec.ClearField(group.FieldMonthlyLimitUsd, field.TypeFloat64) } + if value, ok := _u.mutation.DefaultValidityDays(); ok { + _spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedDefaultValidityDays(); ok { + _spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1557,6 +1729,51 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsageLogsTable, + Columns: []string{group.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsageLogsTable, + Columns: []string{group.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsageLogsTable, + Columns: []string{group.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _u.mutation.AccountsCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2M, diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 46933bb0..33955cbb 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -93,6 +93,18 @@ func (f SettingFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, err return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SettingMutation", m) } +// The UsageLogFunc type is an adapter to allow the use of ordinary +// function as UsageLog mutator. +type UsageLogFunc func(context.Context, *ent.UsageLogMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f UsageLogFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.UsageLogMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UsageLogMutation", m) +} + // The UserFunc type is an adapter to allow the use of ordinary // function as User mutator. type UserFunc func(context.Context, *ent.UserMutation) (ent.Value, error) diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index ab5f5554..9815f477 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -16,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -266,6 +267,33 @@ func (f TraverseSetting) Traverse(ctx context.Context, q ent.Query) error { return fmt.Errorf("unexpected query type %T. expect *ent.SettingQuery", q) } +// The UsageLogFunc type is an adapter to allow the use of ordinary function as a Querier. +type UsageLogFunc func(context.Context, *ent.UsageLogQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f UsageLogFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.UsageLogQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.UsageLogQuery", q) +} + +// The TraverseUsageLog type is an adapter to allow the use of ordinary function as Traverser. +type TraverseUsageLog func(context.Context, *ent.UsageLogQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseUsageLog) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseUsageLog) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.UsageLogQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.UsageLogQuery", q) +} + // The UserFunc type is an adapter to allow the use of ordinary function as a Querier. type UserFunc func(context.Context, *ent.UserQuery) (ent.Value, error) @@ -364,6 +392,8 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.RedeemCodeQuery, predicate.RedeemCode, redeemcode.OrderOption]{typ: ent.TypeRedeemCode, tq: q}, nil case *ent.SettingQuery: return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil + case *ent.UsageLogQuery: + return &query[*ent.UsageLogQuery, predicate.UsageLog, usagelog.OrderOption]{typ: ent.TypeUsageLog, tq: q}, nil case *ent.UserQuery: return &query[*ent.UserQuery, predicate.User, user.OrderOption]{typ: ent.TypeUser, tq: q}, nil case *ent.UserAllowedGroupQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 45408760..848ac74c 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -20,7 +20,6 @@ var ( {Name: "type", Type: field.TypeString, Size: 20}, {Name: "credentials", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, - {Name: "proxy_id", Type: field.TypeInt64, Nullable: true}, {Name: "concurrency", Type: field.TypeInt, Default: 3}, {Name: "priority", Type: field.TypeInt, Default: 50}, {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, @@ -33,12 +32,21 @@ var ( {Name: "session_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "session_window_end", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "session_window_status", Type: field.TypeString, Nullable: true, Size: 20}, + {Name: "proxy_id", Type: field.TypeInt64, Nullable: true}, } // AccountsTable holds the schema information for the "accounts" table. AccountsTable = &schema.Table{ Name: "accounts", Columns: AccountsColumns, PrimaryKey: []*schema.Column{AccountsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "accounts_proxies_proxy", + Columns: []*schema.Column{AccountsColumns[21]}, + RefColumns: []*schema.Column{ProxiesColumns[0]}, + OnDelete: schema.SetNull, + }, + }, Indexes: []*schema.Index{ { Name: "account_platform", @@ -53,42 +61,42 @@ var ( { Name: "account_status", Unique: false, - Columns: []*schema.Column{AccountsColumns[12]}, + Columns: []*schema.Column{AccountsColumns[11]}, }, { Name: "account_proxy_id", Unique: false, - Columns: []*schema.Column{AccountsColumns[9]}, + Columns: []*schema.Column{AccountsColumns[21]}, }, { Name: "account_priority", Unique: false, - Columns: []*schema.Column{AccountsColumns[11]}, + Columns: []*schema.Column{AccountsColumns[10]}, }, { Name: "account_last_used_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[14]}, + Columns: []*schema.Column{AccountsColumns[13]}, }, { Name: "account_schedulable", Unique: false, - Columns: []*schema.Column{AccountsColumns[15]}, + Columns: []*schema.Column{AccountsColumns[14]}, }, { Name: "account_rate_limited_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[16]}, + Columns: []*schema.Column{AccountsColumns[15]}, }, { Name: "account_rate_limit_reset_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[17]}, + Columns: []*schema.Column{AccountsColumns[16]}, }, { Name: "account_overload_until", Unique: false, - Columns: []*schema.Column{AccountsColumns[18]}, + Columns: []*schema.Column{AccountsColumns[17]}, }, { Name: "account_deleted_at", @@ -100,7 +108,7 @@ var ( // AccountGroupsColumns holds the columns for the "account_groups" table. AccountGroupsColumns = []*schema.Column{ {Name: "priority", Type: field.TypeInt, Default: 50}, - {Name: "created_at", Type: field.TypeTime}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "account_id", Type: field.TypeInt64}, {Name: "group_id", Type: field.TypeInt64}, } @@ -168,11 +176,6 @@ var ( }, }, Indexes: []*schema.Index{ - { - Name: "apikey_key", - Unique: true, - Columns: []*schema.Column{APIKeysColumns[4]}, - }, { Name: "apikey_user_id", Unique: false, @@ -211,6 +214,7 @@ var ( {Name: "daily_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "default_validity_days", Type: field.TypeInt, Default: 30}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ @@ -218,11 +222,6 @@ var ( Columns: GroupsColumns, PrimaryKey: []*schema.Column{GroupsColumns[0]}, Indexes: []*schema.Index{ - { - Name: "group_name", - Unique: true, - Columns: []*schema.Column{GroupsColumns[4]}, - }, { Name: "group_status", Unique: false, @@ -316,11 +315,6 @@ var ( }, }, Indexes: []*schema.Index{ - { - Name: "redeemcode_code", - Unique: true, - Columns: []*schema.Column{RedeemCodesColumns[1]}, - }, { Name: "redeemcode_status", Unique: false, @@ -350,11 +344,123 @@ var ( Name: "settings", Columns: SettingsColumns, PrimaryKey: []*schema.Column{SettingsColumns[0]}, + } + // UsageLogsColumns holds the columns for the "usage_logs" table. + UsageLogsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "request_id", Type: field.TypeString, Size: 64}, + {Name: "model", Type: field.TypeString, Size: 100}, + {Name: "input_tokens", Type: field.TypeInt, Default: 0}, + {Name: "output_tokens", Type: field.TypeInt, Default: 0}, + {Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0}, + {Name: "cache_read_tokens", Type: field.TypeInt, Default: 0}, + {Name: "cache_creation_5m_tokens", Type: field.TypeInt, Default: 0}, + {Name: "cache_creation_1h_tokens", Type: field.TypeInt, Default: 0}, + {Name: "input_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "output_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "cache_creation_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "cache_read_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "total_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "actual_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}}, + {Name: "billing_type", Type: field.TypeInt8, Default: 0}, + {Name: "stream", Type: field.TypeBool, Default: false}, + {Name: "duration_ms", Type: field.TypeInt, Nullable: true}, + {Name: "first_token_ms", Type: field.TypeInt, Nullable: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "account_id", Type: field.TypeInt64}, + {Name: "api_key_id", Type: field.TypeInt64}, + {Name: "group_id", Type: field.TypeInt64, Nullable: true}, + {Name: "user_id", Type: field.TypeInt64}, + {Name: "subscription_id", Type: field.TypeInt64, Nullable: true}, + } + // UsageLogsTable holds the schema information for the "usage_logs" table. + UsageLogsTable = &schema.Table{ + Name: "usage_logs", + Columns: UsageLogsColumns, + PrimaryKey: []*schema.Column{UsageLogsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "usage_logs_accounts_usage_logs", + Columns: []*schema.Column{UsageLogsColumns[21]}, + RefColumns: []*schema.Column{AccountsColumns[0]}, + OnDelete: schema.NoAction, + }, + { + Symbol: "usage_logs_api_keys_usage_logs", + Columns: []*schema.Column{UsageLogsColumns[22]}, + RefColumns: []*schema.Column{APIKeysColumns[0]}, + OnDelete: schema.NoAction, + }, + { + Symbol: "usage_logs_groups_usage_logs", + Columns: []*schema.Column{UsageLogsColumns[23]}, + RefColumns: []*schema.Column{GroupsColumns[0]}, + OnDelete: schema.SetNull, + }, + { + Symbol: "usage_logs_users_usage_logs", + Columns: []*schema.Column{UsageLogsColumns[24]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.NoAction, + }, + { + Symbol: "usage_logs_user_subscriptions_usage_logs", + Columns: []*schema.Column{UsageLogsColumns[25]}, + RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, + OnDelete: schema.SetNull, + }, + }, Indexes: []*schema.Index{ { - Name: "setting_key", - Unique: true, - Columns: []*schema.Column{SettingsColumns[1]}, + Name: "usagelog_user_id", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[24]}, + }, + { + Name: "usagelog_api_key_id", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[22]}, + }, + { + Name: "usagelog_account_id", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[21]}, + }, + { + Name: "usagelog_group_id", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[23]}, + }, + { + Name: "usagelog_subscription_id", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[25]}, + }, + { + Name: "usagelog_created_at", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[20]}, + }, + { + Name: "usagelog_model", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[2]}, + }, + { + Name: "usagelog_request_id", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[1]}, + }, + { + Name: "usagelog_user_id_created_at", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[24], UsageLogsColumns[20]}, + }, + { + Name: "usagelog_api_key_id_created_at", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[22], UsageLogsColumns[20]}, }, }, } @@ -380,11 +486,6 @@ var ( Columns: UsersColumns, PrimaryKey: []*schema.Column{UsersColumns[0]}, Indexes: []*schema.Index{ - { - Name: "user_email", - Unique: true, - Columns: []*schema.Column{UsersColumns[4]}, - }, { Name: "user_status", Unique: false, @@ -399,7 +500,7 @@ var ( } // UserAllowedGroupsColumns holds the columns for the "user_allowed_groups" table. UserAllowedGroupsColumns = []*schema.Column{ - {Name: "created_at", Type: field.TypeTime}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "user_id", Type: field.TypeInt64}, {Name: "group_id", Type: field.TypeInt64}, } @@ -435,6 +536,7 @@ var ( {Name: "id", Type: field.TypeInt64, Increment: true}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "starts_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "expires_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, @@ -458,19 +560,19 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "user_subscriptions_groups_subscriptions", - Columns: []*schema.Column{UserSubscriptionsColumns[14]}, + Columns: []*schema.Column{UserSubscriptionsColumns[15]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "user_subscriptions_users_subscriptions", - Columns: []*schema.Column{UserSubscriptionsColumns[15]}, + Columns: []*schema.Column{UserSubscriptionsColumns[16]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "user_subscriptions_users_assigned_subscriptions", - Columns: []*schema.Column{UserSubscriptionsColumns[16]}, + Columns: []*schema.Column{UserSubscriptionsColumns[17]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.SetNull, }, @@ -479,32 +581,37 @@ var ( { Name: "usersubscription_user_id", Unique: false, - Columns: []*schema.Column{UserSubscriptionsColumns[15]}, + Columns: []*schema.Column{UserSubscriptionsColumns[16]}, }, { Name: "usersubscription_group_id", Unique: false, - Columns: []*schema.Column{UserSubscriptionsColumns[14]}, + Columns: []*schema.Column{UserSubscriptionsColumns[15]}, }, { Name: "usersubscription_status", Unique: false, - Columns: []*schema.Column{UserSubscriptionsColumns[5]}, + Columns: []*schema.Column{UserSubscriptionsColumns[6]}, }, { Name: "usersubscription_expires_at", Unique: false, - Columns: []*schema.Column{UserSubscriptionsColumns[4]}, + Columns: []*schema.Column{UserSubscriptionsColumns[5]}, }, { Name: "usersubscription_assigned_by", Unique: false, - Columns: []*schema.Column{UserSubscriptionsColumns[16]}, + Columns: []*schema.Column{UserSubscriptionsColumns[17]}, }, { Name: "usersubscription_user_id_group_id", Unique: true, - Columns: []*schema.Column{UserSubscriptionsColumns[15], UserSubscriptionsColumns[14]}, + Columns: []*schema.Column{UserSubscriptionsColumns[16], UserSubscriptionsColumns[15]}, + }, + { + Name: "usersubscription_deleted_at", + Unique: false, + Columns: []*schema.Column{UserSubscriptionsColumns[3]}, }, }, } @@ -517,6 +624,7 @@ var ( ProxiesTable, RedeemCodesTable, SettingsTable, + UsageLogsTable, UsersTable, UserAllowedGroupsTable, UserSubscriptionsTable, @@ -524,6 +632,7 @@ var ( ) func init() { + AccountsTable.ForeignKeys[0].RefTable = ProxiesTable AccountsTable.Annotation = &entsql.Annotation{ Table: "accounts", } @@ -551,6 +660,14 @@ func init() { SettingsTable.Annotation = &entsql.Annotation{ Table: "settings", } + UsageLogsTable.ForeignKeys[0].RefTable = AccountsTable + UsageLogsTable.ForeignKeys[1].RefTable = APIKeysTable + UsageLogsTable.ForeignKeys[2].RefTable = GroupsTable + UsageLogsTable.ForeignKeys[3].RefTable = UsersTable + UsageLogsTable.ForeignKeys[4].RefTable = UserSubscriptionsTable + UsageLogsTable.Annotation = &entsql.Annotation{ + Table: "usage_logs", + } UsersTable.Annotation = &entsql.Annotation{ Table: "users", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 45a6f5a7..9e4359ab 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -19,6 +19,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -40,6 +41,7 @@ const ( TypeProxy = "Proxy" TypeRedeemCode = "RedeemCode" TypeSetting = "Setting" + TypeUsageLog = "UsageLog" TypeUser = "User" TypeUserAllowedGroup = "UserAllowedGroup" TypeUserSubscription = "UserSubscription" @@ -59,8 +61,6 @@ type AccountMutation struct { _type *string credentials *map[string]interface{} extra *map[string]interface{} - proxy_id *int64 - addproxy_id *int64 concurrency *int addconcurrency *int priority *int @@ -79,6 +79,11 @@ type AccountMutation struct { groups map[int64]struct{} removedgroups map[int64]struct{} clearedgroups bool + proxy *int64 + clearedproxy bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool done bool oldValue func(context.Context) (*Account, error) predicates []predicate.Account @@ -485,13 +490,12 @@ func (m *AccountMutation) ResetExtra() { // SetProxyID sets the "proxy_id" field. func (m *AccountMutation) SetProxyID(i int64) { - m.proxy_id = &i - m.addproxy_id = nil + m.proxy = &i } // ProxyID returns the value of the "proxy_id" field in the mutation. func (m *AccountMutation) ProxyID() (r int64, exists bool) { - v := m.proxy_id + v := m.proxy if v == nil { return } @@ -515,28 +519,9 @@ func (m *AccountMutation) OldProxyID(ctx context.Context) (v *int64, err error) return oldValue.ProxyID, nil } -// AddProxyID adds i to the "proxy_id" field. -func (m *AccountMutation) AddProxyID(i int64) { - if m.addproxy_id != nil { - *m.addproxy_id += i - } else { - m.addproxy_id = &i - } -} - -// AddedProxyID returns the value that was added to the "proxy_id" field in this mutation. -func (m *AccountMutation) AddedProxyID() (r int64, exists bool) { - v := m.addproxy_id - if v == nil { - return - } - return *v, true -} - // ClearProxyID clears the value of the "proxy_id" field. func (m *AccountMutation) ClearProxyID() { - m.proxy_id = nil - m.addproxy_id = nil + m.proxy = nil m.clearedFields[account.FieldProxyID] = struct{}{} } @@ -548,8 +533,7 @@ func (m *AccountMutation) ProxyIDCleared() bool { // ResetProxyID resets all changes to the "proxy_id" field. func (m *AccountMutation) ResetProxyID() { - m.proxy_id = nil - m.addproxy_id = nil + m.proxy = nil delete(m.clearedFields, account.FieldProxyID) } @@ -1183,6 +1167,87 @@ func (m *AccountMutation) ResetGroups() { m.removedgroups = nil } +// ClearProxy clears the "proxy" edge to the Proxy entity. +func (m *AccountMutation) ClearProxy() { + m.clearedproxy = true + m.clearedFields[account.FieldProxyID] = struct{}{} +} + +// ProxyCleared reports if the "proxy" edge to the Proxy entity was cleared. +func (m *AccountMutation) ProxyCleared() bool { + return m.ProxyIDCleared() || m.clearedproxy +} + +// ProxyIDs returns the "proxy" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// ProxyID instead. It exists only for internal usage by the builders. +func (m *AccountMutation) ProxyIDs() (ids []int64) { + if id := m.proxy; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetProxy resets all changes to the "proxy" edge. +func (m *AccountMutation) ResetProxy() { + m.proxy = nil + m.clearedproxy = false +} + +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids. +func (m *AccountMutation) AddUsageLogIDs(ids ...int64) { + if m.usage_logs == nil { + m.usage_logs = make(map[int64]struct{}) + } + for i := range ids { + m.usage_logs[ids[i]] = struct{}{} + } +} + +// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity. +func (m *AccountMutation) ClearUsageLogs() { + m.clearedusage_logs = true +} + +// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared. +func (m *AccountMutation) UsageLogsCleared() bool { + return m.clearedusage_logs +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs. +func (m *AccountMutation) RemoveUsageLogIDs(ids ...int64) { + if m.removedusage_logs == nil { + m.removedusage_logs = make(map[int64]struct{}) + } + for i := range ids { + delete(m.usage_logs, ids[i]) + m.removedusage_logs[ids[i]] = struct{}{} + } +} + +// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity. +func (m *AccountMutation) RemovedUsageLogsIDs() (ids []int64) { + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return +} + +// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation. +func (m *AccountMutation) UsageLogsIDs() (ids []int64) { + for id := range m.usage_logs { + ids = append(ids, id) + } + return +} + +// ResetUsageLogs resets all changes to the "usage_logs" edge. +func (m *AccountMutation) ResetUsageLogs() { + m.usage_logs = nil + m.clearedusage_logs = false + m.removedusage_logs = nil +} + // Where appends a list predicates to the AccountMutation builder. func (m *AccountMutation) Where(ps ...predicate.Account) { m.predicates = append(m.predicates, ps...) @@ -1242,7 +1307,7 @@ func (m *AccountMutation) Fields() []string { if m.extra != nil { fields = append(fields, account.FieldExtra) } - if m.proxy_id != nil { + if m.proxy != nil { fields = append(fields, account.FieldProxyID) } if m.concurrency != nil { @@ -1546,9 +1611,6 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error { // this mutation. func (m *AccountMutation) AddedFields() []string { var fields []string - if m.addproxy_id != nil { - fields = append(fields, account.FieldProxyID) - } if m.addconcurrency != nil { fields = append(fields, account.FieldConcurrency) } @@ -1563,8 +1625,6 @@ func (m *AccountMutation) AddedFields() []string { // was not set, or was not defined in the schema. func (m *AccountMutation) AddedField(name string) (ent.Value, bool) { switch name { - case account.FieldProxyID: - return m.AddedProxyID() case account.FieldConcurrency: return m.AddedConcurrency() case account.FieldPriority: @@ -1578,13 +1638,6 @@ func (m *AccountMutation) AddedField(name string) (ent.Value, bool) { // type. func (m *AccountMutation) AddField(name string, value ent.Value) error { switch name { - case account.FieldProxyID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddProxyID(v) - return nil case account.FieldConcurrency: v, ok := value.(int) if !ok { @@ -1758,10 +1811,16 @@ func (m *AccountMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *AccountMutation) AddedEdges() []string { - edges := make([]string, 0, 1) + edges := make([]string, 0, 3) if m.groups != nil { edges = append(edges, account.EdgeGroups) } + if m.proxy != nil { + edges = append(edges, account.EdgeProxy) + } + if m.usage_logs != nil { + edges = append(edges, account.EdgeUsageLogs) + } return edges } @@ -1775,16 +1834,29 @@ func (m *AccountMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case account.EdgeProxy: + if id := m.proxy; id != nil { + return []ent.Value{*id} + } + case account.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.usage_logs)) + for id := range m.usage_logs { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *AccountMutation) RemovedEdges() []string { - edges := make([]string, 0, 1) + edges := make([]string, 0, 3) if m.removedgroups != nil { edges = append(edges, account.EdgeGroups) } + if m.removedusage_logs != nil { + edges = append(edges, account.EdgeUsageLogs) + } return edges } @@ -1798,16 +1870,28 @@ func (m *AccountMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case account.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.removedusage_logs)) + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return ids } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *AccountMutation) ClearedEdges() []string { - edges := make([]string, 0, 1) + edges := make([]string, 0, 3) if m.clearedgroups { edges = append(edges, account.EdgeGroups) } + if m.clearedproxy { + edges = append(edges, account.EdgeProxy) + } + if m.clearedusage_logs { + edges = append(edges, account.EdgeUsageLogs) + } return edges } @@ -1817,6 +1901,10 @@ func (m *AccountMutation) EdgeCleared(name string) bool { switch name { case account.EdgeGroups: return m.clearedgroups + case account.EdgeProxy: + return m.clearedproxy + case account.EdgeUsageLogs: + return m.clearedusage_logs } return false } @@ -1825,6 +1913,9 @@ func (m *AccountMutation) EdgeCleared(name string) bool { // if that edge is not defined in the schema. func (m *AccountMutation) ClearEdge(name string) error { switch name { + case account.EdgeProxy: + m.ClearProxy() + return nil } return fmt.Errorf("unknown Account unique edge %s", name) } @@ -1836,6 +1927,12 @@ func (m *AccountMutation) ResetEdge(name string) error { case account.EdgeGroups: m.ResetGroups() return nil + case account.EdgeProxy: + m.ResetProxy() + return nil + case account.EdgeUsageLogs: + m.ResetUsageLogs() + return nil } return fmt.Errorf("unknown Account edge %s", name) } @@ -2328,23 +2425,26 @@ func (m *AccountGroupMutation) ResetEdge(name string) error { // ApiKeyMutation represents an operation that mutates the ApiKey nodes in the graph. type ApiKeyMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - key *string - name *string - status *string - clearedFields map[string]struct{} - user *int64 - cleareduser bool - group *int64 - clearedgroup bool - done bool - oldValue func(context.Context) (*ApiKey, error) - predicates []predicate.ApiKey + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + key *string + name *string + status *string + clearedFields map[string]struct{} + user *int64 + cleareduser bool + group *int64 + clearedgroup bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + done bool + oldValue func(context.Context) (*ApiKey, error) + predicates []predicate.ApiKey } var _ ent.Mutation = (*ApiKeyMutation)(nil) @@ -2813,6 +2913,60 @@ func (m *ApiKeyMutation) ResetGroup() { m.clearedgroup = false } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids. +func (m *ApiKeyMutation) AddUsageLogIDs(ids ...int64) { + if m.usage_logs == nil { + m.usage_logs = make(map[int64]struct{}) + } + for i := range ids { + m.usage_logs[ids[i]] = struct{}{} + } +} + +// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity. +func (m *ApiKeyMutation) ClearUsageLogs() { + m.clearedusage_logs = true +} + +// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared. +func (m *ApiKeyMutation) UsageLogsCleared() bool { + return m.clearedusage_logs +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs. +func (m *ApiKeyMutation) RemoveUsageLogIDs(ids ...int64) { + if m.removedusage_logs == nil { + m.removedusage_logs = make(map[int64]struct{}) + } + for i := range ids { + delete(m.usage_logs, ids[i]) + m.removedusage_logs[ids[i]] = struct{}{} + } +} + +// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity. +func (m *ApiKeyMutation) RemovedUsageLogsIDs() (ids []int64) { + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return +} + +// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation. +func (m *ApiKeyMutation) UsageLogsIDs() (ids []int64) { + for id := range m.usage_logs { + ids = append(ids, id) + } + return +} + +// ResetUsageLogs resets all changes to the "usage_logs" edge. +func (m *ApiKeyMutation) ResetUsageLogs() { + m.usage_logs = nil + m.clearedusage_logs = false + m.removedusage_logs = nil +} + // Where appends a list predicates to the ApiKeyMutation builder. func (m *ApiKeyMutation) Where(ps ...predicate.ApiKey) { m.predicates = append(m.predicates, ps...) @@ -3083,13 +3237,16 @@ func (m *ApiKeyMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *ApiKeyMutation) AddedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) if m.user != nil { edges = append(edges, apikey.EdgeUser) } if m.group != nil { edges = append(edges, apikey.EdgeGroup) } + if m.usage_logs != nil { + edges = append(edges, apikey.EdgeUsageLogs) + } return edges } @@ -3105,31 +3262,51 @@ func (m *ApiKeyMutation) AddedIDs(name string) []ent.Value { if id := m.group; id != nil { return []ent.Value{*id} } + case apikey.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.usage_logs)) + for id := range m.usage_logs { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *ApiKeyMutation) RemovedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) + if m.removedusage_logs != nil { + edges = append(edges, apikey.EdgeUsageLogs) + } return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. func (m *ApiKeyMutation) RemovedIDs(name string) []ent.Value { + switch name { + case apikey.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.removedusage_logs)) + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return ids + } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *ApiKeyMutation) ClearedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) if m.cleareduser { edges = append(edges, apikey.EdgeUser) } if m.clearedgroup { edges = append(edges, apikey.EdgeGroup) } + if m.clearedusage_logs { + edges = append(edges, apikey.EdgeUsageLogs) + } return edges } @@ -3141,6 +3318,8 @@ func (m *ApiKeyMutation) EdgeCleared(name string) bool { return m.cleareduser case apikey.EdgeGroup: return m.clearedgroup + case apikey.EdgeUsageLogs: + return m.clearedusage_logs } return false } @@ -3169,6 +3348,9 @@ func (m *ApiKeyMutation) ResetEdge(name string) error { case apikey.EdgeGroup: m.ResetGroup() return nil + case apikey.EdgeUsageLogs: + m.ResetUsageLogs() + return nil } return fmt.Errorf("unknown ApiKey edge %s", name) } @@ -3176,45 +3358,50 @@ func (m *ApiKeyMutation) ResetEdge(name string) error { // GroupMutation represents an operation that mutates the Group nodes in the graph. type GroupMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - description *string - rate_multiplier *float64 - addrate_multiplier *float64 - is_exclusive *bool - status *string - platform *string - subscription_type *string - daily_limit_usd *float64 - adddaily_limit_usd *float64 - weekly_limit_usd *float64 - addweekly_limit_usd *float64 - monthly_limit_usd *float64 - addmonthly_limit_usd *float64 - clearedFields map[string]struct{} - api_keys map[int64]struct{} - removedapi_keys map[int64]struct{} - clearedapi_keys bool - redeem_codes map[int64]struct{} - removedredeem_codes map[int64]struct{} - clearedredeem_codes bool - subscriptions map[int64]struct{} - removedsubscriptions map[int64]struct{} - clearedsubscriptions bool - accounts map[int64]struct{} - removedaccounts map[int64]struct{} - clearedaccounts bool - allowed_users map[int64]struct{} - removedallowed_users map[int64]struct{} - clearedallowed_users bool - done bool - oldValue func(context.Context) (*Group, error) - predicates []predicate.Group + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + description *string + rate_multiplier *float64 + addrate_multiplier *float64 + is_exclusive *bool + status *string + platform *string + subscription_type *string + daily_limit_usd *float64 + adddaily_limit_usd *float64 + weekly_limit_usd *float64 + addweekly_limit_usd *float64 + monthly_limit_usd *float64 + addmonthly_limit_usd *float64 + default_validity_days *int + adddefault_validity_days *int + clearedFields map[string]struct{} + api_keys map[int64]struct{} + removedapi_keys map[int64]struct{} + clearedapi_keys bool + redeem_codes map[int64]struct{} + removedredeem_codes map[int64]struct{} + clearedredeem_codes bool + subscriptions map[int64]struct{} + removedsubscriptions map[int64]struct{} + clearedsubscriptions bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + accounts map[int64]struct{} + removedaccounts map[int64]struct{} + clearedaccounts bool + allowed_users map[int64]struct{} + removedallowed_users map[int64]struct{} + clearedallowed_users bool + done bool + oldValue func(context.Context) (*Group, error) + predicates []predicate.Group } var _ ent.Mutation = (*GroupMutation)(nil) @@ -3931,6 +4118,62 @@ func (m *GroupMutation) ResetMonthlyLimitUsd() { delete(m.clearedFields, group.FieldMonthlyLimitUsd) } +// SetDefaultValidityDays sets the "default_validity_days" field. +func (m *GroupMutation) SetDefaultValidityDays(i int) { + m.default_validity_days = &i + m.adddefault_validity_days = nil +} + +// DefaultValidityDays returns the value of the "default_validity_days" field in the mutation. +func (m *GroupMutation) DefaultValidityDays() (r int, exists bool) { + v := m.default_validity_days + if v == nil { + return + } + return *v, true +} + +// OldDefaultValidityDays returns the old "default_validity_days" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldDefaultValidityDays(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDefaultValidityDays is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDefaultValidityDays requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDefaultValidityDays: %w", err) + } + return oldValue.DefaultValidityDays, nil +} + +// AddDefaultValidityDays adds i to the "default_validity_days" field. +func (m *GroupMutation) AddDefaultValidityDays(i int) { + if m.adddefault_validity_days != nil { + *m.adddefault_validity_days += i + } else { + m.adddefault_validity_days = &i + } +} + +// AddedDefaultValidityDays returns the value that was added to the "default_validity_days" field in this mutation. +func (m *GroupMutation) AddedDefaultValidityDays() (r int, exists bool) { + v := m.adddefault_validity_days + if v == nil { + return + } + return *v, true +} + +// ResetDefaultValidityDays resets all changes to the "default_validity_days" field. +func (m *GroupMutation) ResetDefaultValidityDays() { + m.default_validity_days = nil + m.adddefault_validity_days = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the ApiKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -4093,6 +4336,60 @@ func (m *GroupMutation) ResetSubscriptions() { m.removedsubscriptions = nil } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids. +func (m *GroupMutation) AddUsageLogIDs(ids ...int64) { + if m.usage_logs == nil { + m.usage_logs = make(map[int64]struct{}) + } + for i := range ids { + m.usage_logs[ids[i]] = struct{}{} + } +} + +// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity. +func (m *GroupMutation) ClearUsageLogs() { + m.clearedusage_logs = true +} + +// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared. +func (m *GroupMutation) UsageLogsCleared() bool { + return m.clearedusage_logs +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs. +func (m *GroupMutation) RemoveUsageLogIDs(ids ...int64) { + if m.removedusage_logs == nil { + m.removedusage_logs = make(map[int64]struct{}) + } + for i := range ids { + delete(m.usage_logs, ids[i]) + m.removedusage_logs[ids[i]] = struct{}{} + } +} + +// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity. +func (m *GroupMutation) RemovedUsageLogsIDs() (ids []int64) { + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return +} + +// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation. +func (m *GroupMutation) UsageLogsIDs() (ids []int64) { + for id := range m.usage_logs { + ids = append(ids, id) + } + return +} + +// ResetUsageLogs resets all changes to the "usage_logs" edge. +func (m *GroupMutation) ResetUsageLogs() { + m.usage_logs = nil + m.clearedusage_logs = false + m.removedusage_logs = nil +} + // AddAccountIDs adds the "accounts" edge to the Account entity by ids. func (m *GroupMutation) AddAccountIDs(ids ...int64) { if m.accounts == nil { @@ -4235,7 +4532,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 13) + fields := make([]string, 0, 14) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -4275,6 +4572,9 @@ func (m *GroupMutation) Fields() []string { if m.monthly_limit_usd != nil { fields = append(fields, group.FieldMonthlyLimitUsd) } + if m.default_validity_days != nil { + fields = append(fields, group.FieldDefaultValidityDays) + } return fields } @@ -4309,6 +4609,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.WeeklyLimitUsd() case group.FieldMonthlyLimitUsd: return m.MonthlyLimitUsd() + case group.FieldDefaultValidityDays: + return m.DefaultValidityDays() } return nil, false } @@ -4344,6 +4646,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldWeeklyLimitUsd(ctx) case group.FieldMonthlyLimitUsd: return m.OldMonthlyLimitUsd(ctx) + case group.FieldDefaultValidityDays: + return m.OldDefaultValidityDays(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -4444,6 +4748,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetMonthlyLimitUsd(v) return nil + case group.FieldDefaultValidityDays: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDefaultValidityDays(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -4464,6 +4775,9 @@ func (m *GroupMutation) AddedFields() []string { if m.addmonthly_limit_usd != nil { fields = append(fields, group.FieldMonthlyLimitUsd) } + if m.adddefault_validity_days != nil { + fields = append(fields, group.FieldDefaultValidityDays) + } return fields } @@ -4480,6 +4794,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedWeeklyLimitUsd() case group.FieldMonthlyLimitUsd: return m.AddedMonthlyLimitUsd() + case group.FieldDefaultValidityDays: + return m.AddedDefaultValidityDays() } return nil, false } @@ -4517,6 +4833,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddMonthlyLimitUsd(v) return nil + case group.FieldDefaultValidityDays: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDefaultValidityDays(v) + return nil } return fmt.Errorf("unknown Group numeric field %s", name) } @@ -4616,13 +4939,16 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldMonthlyLimitUsd: m.ResetMonthlyLimitUsd() return nil + case group.FieldDefaultValidityDays: + m.ResetDefaultValidityDays() + return nil } return fmt.Errorf("unknown Group field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. func (m *GroupMutation) AddedEdges() []string { - edges := make([]string, 0, 5) + edges := make([]string, 0, 6) if m.api_keys != nil { edges = append(edges, group.EdgeAPIKeys) } @@ -4632,6 +4958,9 @@ func (m *GroupMutation) AddedEdges() []string { if m.subscriptions != nil { edges = append(edges, group.EdgeSubscriptions) } + if m.usage_logs != nil { + edges = append(edges, group.EdgeUsageLogs) + } if m.accounts != nil { edges = append(edges, group.EdgeAccounts) } @@ -4663,6 +4992,12 @@ func (m *GroupMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case group.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.usage_logs)) + for id := range m.usage_logs { + ids = append(ids, id) + } + return ids case group.EdgeAccounts: ids := make([]ent.Value, 0, len(m.accounts)) for id := range m.accounts { @@ -4681,7 +5016,7 @@ func (m *GroupMutation) AddedIDs(name string) []ent.Value { // RemovedEdges returns all edge names that were removed in this mutation. func (m *GroupMutation) RemovedEdges() []string { - edges := make([]string, 0, 5) + edges := make([]string, 0, 6) if m.removedapi_keys != nil { edges = append(edges, group.EdgeAPIKeys) } @@ -4691,6 +5026,9 @@ func (m *GroupMutation) RemovedEdges() []string { if m.removedsubscriptions != nil { edges = append(edges, group.EdgeSubscriptions) } + if m.removedusage_logs != nil { + edges = append(edges, group.EdgeUsageLogs) + } if m.removedaccounts != nil { edges = append(edges, group.EdgeAccounts) } @@ -4722,6 +5060,12 @@ func (m *GroupMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case group.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.removedusage_logs)) + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return ids case group.EdgeAccounts: ids := make([]ent.Value, 0, len(m.removedaccounts)) for id := range m.removedaccounts { @@ -4740,7 +5084,7 @@ func (m *GroupMutation) RemovedIDs(name string) []ent.Value { // ClearedEdges returns all edge names that were cleared in this mutation. func (m *GroupMutation) ClearedEdges() []string { - edges := make([]string, 0, 5) + edges := make([]string, 0, 6) if m.clearedapi_keys { edges = append(edges, group.EdgeAPIKeys) } @@ -4750,6 +5094,9 @@ func (m *GroupMutation) ClearedEdges() []string { if m.clearedsubscriptions { edges = append(edges, group.EdgeSubscriptions) } + if m.clearedusage_logs { + edges = append(edges, group.EdgeUsageLogs) + } if m.clearedaccounts { edges = append(edges, group.EdgeAccounts) } @@ -4769,6 +5116,8 @@ func (m *GroupMutation) EdgeCleared(name string) bool { return m.clearedredeem_codes case group.EdgeSubscriptions: return m.clearedsubscriptions + case group.EdgeUsageLogs: + return m.clearedusage_logs case group.EdgeAccounts: return m.clearedaccounts case group.EdgeAllowedUsers: @@ -4798,6 +5147,9 @@ func (m *GroupMutation) ResetEdge(name string) error { case group.EdgeSubscriptions: m.ResetSubscriptions() return nil + case group.EdgeUsageLogs: + m.ResetUsageLogs() + return nil case group.EdgeAccounts: m.ResetAccounts() return nil @@ -4811,24 +5163,27 @@ func (m *GroupMutation) ResetEdge(name string) error { // ProxyMutation represents an operation that mutates the Proxy nodes in the graph. type ProxyMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - protocol *string - host *string - port *int - addport *int - username *string - password *string - status *string - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*Proxy, error) - predicates []predicate.Proxy + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + protocol *string + host *string + port *int + addport *int + username *string + password *string + status *string + clearedFields map[string]struct{} + accounts map[int64]struct{} + removedaccounts map[int64]struct{} + clearedaccounts bool + done bool + oldValue func(context.Context) (*Proxy, error) + predicates []predicate.Proxy } var _ ent.Mutation = (*ProxyMutation)(nil) @@ -5348,6 +5703,60 @@ func (m *ProxyMutation) ResetStatus() { m.status = nil } +// AddAccountIDs adds the "accounts" edge to the Account entity by ids. +func (m *ProxyMutation) AddAccountIDs(ids ...int64) { + if m.accounts == nil { + m.accounts = make(map[int64]struct{}) + } + for i := range ids { + m.accounts[ids[i]] = struct{}{} + } +} + +// ClearAccounts clears the "accounts" edge to the Account entity. +func (m *ProxyMutation) ClearAccounts() { + m.clearedaccounts = true +} + +// AccountsCleared reports if the "accounts" edge to the Account entity was cleared. +func (m *ProxyMutation) AccountsCleared() bool { + return m.clearedaccounts +} + +// RemoveAccountIDs removes the "accounts" edge to the Account entity by IDs. +func (m *ProxyMutation) RemoveAccountIDs(ids ...int64) { + if m.removedaccounts == nil { + m.removedaccounts = make(map[int64]struct{}) + } + for i := range ids { + delete(m.accounts, ids[i]) + m.removedaccounts[ids[i]] = struct{}{} + } +} + +// RemovedAccounts returns the removed IDs of the "accounts" edge to the Account entity. +func (m *ProxyMutation) RemovedAccountsIDs() (ids []int64) { + for id := range m.removedaccounts { + ids = append(ids, id) + } + return +} + +// AccountsIDs returns the "accounts" edge IDs in the mutation. +func (m *ProxyMutation) AccountsIDs() (ids []int64) { + for id := range m.accounts { + ids = append(ids, id) + } + return +} + +// ResetAccounts resets all changes to the "accounts" edge. +func (m *ProxyMutation) ResetAccounts() { + m.accounts = nil + m.clearedaccounts = false + m.removedaccounts = nil +} + // Where appends a list predicates to the ProxyMutation builder. func (m *ProxyMutation) Where(ps ...predicate.Proxy) { m.predicates = append(m.predicates, ps...) @@ -5670,49 +6079,85 @@ func (m *ProxyMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *ProxyMutation) AddedEdges() []string { - edges := make([]string, 0, 0) + edges := make([]string, 0, 1) + if m.accounts != nil { + edges = append(edges, proxy.EdgeAccounts) + } return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. func (m *ProxyMutation) AddedIDs(name string) []ent.Value { + switch name { + case proxy.EdgeAccounts: + ids := make([]ent.Value, 0, len(m.accounts)) + for id := range m.accounts { + ids = append(ids, id) + } + return ids + } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *ProxyMutation) RemovedEdges() []string { - edges := make([]string, 0, 0) + edges := make([]string, 0, 1) + if m.removedaccounts != nil { + edges = append(edges, proxy.EdgeAccounts) + } return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. func (m *ProxyMutation) RemovedIDs(name string) []ent.Value { + switch name { + case proxy.EdgeAccounts: + ids := make([]ent.Value, 0, len(m.removedaccounts)) + for id := range m.removedaccounts { + ids = append(ids, id) + } + return ids + } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *ProxyMutation) ClearedEdges() []string { - edges := make([]string, 0, 0) + edges := make([]string, 0, 1) + if m.clearedaccounts { + edges = append(edges, proxy.EdgeAccounts) + } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. func (m *ProxyMutation) EdgeCleared(name string) bool { + switch name { + case proxy.EdgeAccounts: + return m.clearedaccounts + } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. func (m *ProxyMutation) ClearEdge(name string) error { + switch name { + } return fmt.Errorf("unknown Proxy unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. func (m *ProxyMutation) ResetEdge(name string) error { + switch name { + case proxy.EdgeAccounts: + m.ResetAccounts() + return nil + } return fmt.Errorf("unknown Proxy edge %s", name) } @@ -7223,6 +7668,2478 @@ func (m *SettingMutation) ResetEdge(name string) error { return fmt.Errorf("unknown Setting edge %s", name) } +// UsageLogMutation represents an operation that mutates the UsageLog nodes in the graph. +type UsageLogMutation struct { + config + op Op + typ string + id *int64 + request_id *string + model *string + input_tokens *int + addinput_tokens *int + output_tokens *int + addoutput_tokens *int + cache_creation_tokens *int + addcache_creation_tokens *int + cache_read_tokens *int + addcache_read_tokens *int + cache_creation_5m_tokens *int + addcache_creation_5m_tokens *int + cache_creation_1h_tokens *int + addcache_creation_1h_tokens *int + input_cost *float64 + addinput_cost *float64 + output_cost *float64 + addoutput_cost *float64 + cache_creation_cost *float64 + addcache_creation_cost *float64 + cache_read_cost *float64 + addcache_read_cost *float64 + total_cost *float64 + addtotal_cost *float64 + actual_cost *float64 + addactual_cost *float64 + rate_multiplier *float64 + addrate_multiplier *float64 + billing_type *int8 + addbilling_type *int8 + stream *bool + duration_ms *int + addduration_ms *int + first_token_ms *int + addfirst_token_ms *int + created_at *time.Time + clearedFields map[string]struct{} + user *int64 + cleareduser bool + api_key *int64 + clearedapi_key bool + account *int64 + clearedaccount bool + group *int64 + clearedgroup bool + subscription *int64 + clearedsubscription bool + done bool + oldValue func(context.Context) (*UsageLog, error) + predicates []predicate.UsageLog +} + +var _ ent.Mutation = (*UsageLogMutation)(nil) + +// usagelogOption allows management of the mutation configuration using functional options. +type usagelogOption func(*UsageLogMutation) + +// newUsageLogMutation creates new mutation for the UsageLog entity. +func newUsageLogMutation(c config, op Op, opts ...usagelogOption) *UsageLogMutation { + m := &UsageLogMutation{ + config: c, + op: op, + typ: TypeUsageLog, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withUsageLogID sets the ID field of the mutation. +func withUsageLogID(id int64) usagelogOption { + return func(m *UsageLogMutation) { + var ( + err error + once sync.Once + value *UsageLog + ) + m.oldValue = func(ctx context.Context) (*UsageLog, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().UsageLog.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withUsageLog sets the old UsageLog of the mutation. +func withUsageLog(node *UsageLog) usagelogOption { + return func(m *UsageLogMutation) { + m.oldValue = func(context.Context) (*UsageLog, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m UsageLogMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m UsageLogMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *UsageLogMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *UsageLogMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().UsageLog.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetUserID sets the "user_id" field. +func (m *UsageLogMutation) SetUserID(i int64) { + m.user = &i +} + +// UserID returns the value of the "user_id" field in the mutation. +func (m *UsageLogMutation) UserID() (r int64, exists bool) { + v := m.user + if v == nil { + return + } + return *v, true +} + +// OldUserID returns the old "user_id" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldUserID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUserID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUserID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUserID: %w", err) + } + return oldValue.UserID, nil +} + +// ResetUserID resets all changes to the "user_id" field. +func (m *UsageLogMutation) ResetUserID() { + m.user = nil +} + +// SetAPIKeyID sets the "api_key_id" field. +func (m *UsageLogMutation) SetAPIKeyID(i int64) { + m.api_key = &i +} + +// APIKeyID returns the value of the "api_key_id" field in the mutation. +func (m *UsageLogMutation) APIKeyID() (r int64, exists bool) { + v := m.api_key + if v == nil { + return + } + return *v, true +} + +// OldAPIKeyID returns the old "api_key_id" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldAPIKeyID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAPIKeyID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAPIKeyID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAPIKeyID: %w", err) + } + return oldValue.APIKeyID, nil +} + +// ResetAPIKeyID resets all changes to the "api_key_id" field. +func (m *UsageLogMutation) ResetAPIKeyID() { + m.api_key = nil +} + +// SetAccountID sets the "account_id" field. +func (m *UsageLogMutation) SetAccountID(i int64) { + m.account = &i +} + +// AccountID returns the value of the "account_id" field in the mutation. +func (m *UsageLogMutation) AccountID() (r int64, exists bool) { + v := m.account + if v == nil { + return + } + return *v, true +} + +// OldAccountID returns the old "account_id" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldAccountID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAccountID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAccountID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAccountID: %w", err) + } + return oldValue.AccountID, nil +} + +// ResetAccountID resets all changes to the "account_id" field. +func (m *UsageLogMutation) ResetAccountID() { + m.account = nil +} + +// SetRequestID sets the "request_id" field. +func (m *UsageLogMutation) SetRequestID(s string) { + m.request_id = &s +} + +// RequestID returns the value of the "request_id" field in the mutation. +func (m *UsageLogMutation) RequestID() (r string, exists bool) { + v := m.request_id + if v == nil { + return + } + return *v, true +} + +// OldRequestID returns the old "request_id" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldRequestID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequestID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequestID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequestID: %w", err) + } + return oldValue.RequestID, nil +} + +// ResetRequestID resets all changes to the "request_id" field. +func (m *UsageLogMutation) ResetRequestID() { + m.request_id = nil +} + +// SetModel sets the "model" field. +func (m *UsageLogMutation) SetModel(s string) { + m.model = &s +} + +// Model returns the value of the "model" field in the mutation. +func (m *UsageLogMutation) Model() (r string, exists bool) { + v := m.model + if v == nil { + return + } + return *v, true +} + +// OldModel returns the old "model" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldModel(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldModel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldModel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldModel: %w", err) + } + return oldValue.Model, nil +} + +// ResetModel resets all changes to the "model" field. +func (m *UsageLogMutation) ResetModel() { + m.model = nil +} + +// SetGroupID sets the "group_id" field. +func (m *UsageLogMutation) SetGroupID(i int64) { + m.group = &i +} + +// GroupID returns the value of the "group_id" field in the mutation. +func (m *UsageLogMutation) GroupID() (r int64, exists bool) { + v := m.group + if v == nil { + return + } + return *v, true +} + +// OldGroupID returns the old "group_id" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldGroupID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldGroupID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldGroupID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldGroupID: %w", err) + } + return oldValue.GroupID, nil +} + +// ClearGroupID clears the value of the "group_id" field. +func (m *UsageLogMutation) ClearGroupID() { + m.group = nil + m.clearedFields[usagelog.FieldGroupID] = struct{}{} +} + +// GroupIDCleared returns if the "group_id" field was cleared in this mutation. +func (m *UsageLogMutation) GroupIDCleared() bool { + _, ok := m.clearedFields[usagelog.FieldGroupID] + return ok +} + +// ResetGroupID resets all changes to the "group_id" field. +func (m *UsageLogMutation) ResetGroupID() { + m.group = nil + delete(m.clearedFields, usagelog.FieldGroupID) +} + +// SetSubscriptionID sets the "subscription_id" field. +func (m *UsageLogMutation) SetSubscriptionID(i int64) { + m.subscription = &i +} + +// SubscriptionID returns the value of the "subscription_id" field in the mutation. +func (m *UsageLogMutation) SubscriptionID() (r int64, exists bool) { + v := m.subscription + if v == nil { + return + } + return *v, true +} + +// OldSubscriptionID returns the old "subscription_id" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldSubscriptionID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSubscriptionID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSubscriptionID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSubscriptionID: %w", err) + } + return oldValue.SubscriptionID, nil +} + +// ClearSubscriptionID clears the value of the "subscription_id" field. +func (m *UsageLogMutation) ClearSubscriptionID() { + m.subscription = nil + m.clearedFields[usagelog.FieldSubscriptionID] = struct{}{} +} + +// SubscriptionIDCleared returns if the "subscription_id" field was cleared in this mutation. +func (m *UsageLogMutation) SubscriptionIDCleared() bool { + _, ok := m.clearedFields[usagelog.FieldSubscriptionID] + return ok +} + +// ResetSubscriptionID resets all changes to the "subscription_id" field. +func (m *UsageLogMutation) ResetSubscriptionID() { + m.subscription = nil + delete(m.clearedFields, usagelog.FieldSubscriptionID) +} + +// SetInputTokens sets the "input_tokens" field. +func (m *UsageLogMutation) SetInputTokens(i int) { + m.input_tokens = &i + m.addinput_tokens = nil +} + +// InputTokens returns the value of the "input_tokens" field in the mutation. +func (m *UsageLogMutation) InputTokens() (r int, exists bool) { + v := m.input_tokens + if v == nil { + return + } + return *v, true +} + +// OldInputTokens returns the old "input_tokens" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldInputTokens(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldInputTokens is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldInputTokens requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldInputTokens: %w", err) + } + return oldValue.InputTokens, nil +} + +// AddInputTokens adds i to the "input_tokens" field. +func (m *UsageLogMutation) AddInputTokens(i int) { + if m.addinput_tokens != nil { + *m.addinput_tokens += i + } else { + m.addinput_tokens = &i + } +} + +// AddedInputTokens returns the value that was added to the "input_tokens" field in this mutation. +func (m *UsageLogMutation) AddedInputTokens() (r int, exists bool) { + v := m.addinput_tokens + if v == nil { + return + } + return *v, true +} + +// ResetInputTokens resets all changes to the "input_tokens" field. +func (m *UsageLogMutation) ResetInputTokens() { + m.input_tokens = nil + m.addinput_tokens = nil +} + +// SetOutputTokens sets the "output_tokens" field. +func (m *UsageLogMutation) SetOutputTokens(i int) { + m.output_tokens = &i + m.addoutput_tokens = nil +} + +// OutputTokens returns the value of the "output_tokens" field in the mutation. +func (m *UsageLogMutation) OutputTokens() (r int, exists bool) { + v := m.output_tokens + if v == nil { + return + } + return *v, true +} + +// OldOutputTokens returns the old "output_tokens" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldOutputTokens(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOutputTokens is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOutputTokens requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOutputTokens: %w", err) + } + return oldValue.OutputTokens, nil +} + +// AddOutputTokens adds i to the "output_tokens" field. +func (m *UsageLogMutation) AddOutputTokens(i int) { + if m.addoutput_tokens != nil { + *m.addoutput_tokens += i + } else { + m.addoutput_tokens = &i + } +} + +// AddedOutputTokens returns the value that was added to the "output_tokens" field in this mutation. +func (m *UsageLogMutation) AddedOutputTokens() (r int, exists bool) { + v := m.addoutput_tokens + if v == nil { + return + } + return *v, true +} + +// ResetOutputTokens resets all changes to the "output_tokens" field. +func (m *UsageLogMutation) ResetOutputTokens() { + m.output_tokens = nil + m.addoutput_tokens = nil +} + +// SetCacheCreationTokens sets the "cache_creation_tokens" field. +func (m *UsageLogMutation) SetCacheCreationTokens(i int) { + m.cache_creation_tokens = &i + m.addcache_creation_tokens = nil +} + +// CacheCreationTokens returns the value of the "cache_creation_tokens" field in the mutation. +func (m *UsageLogMutation) CacheCreationTokens() (r int, exists bool) { + v := m.cache_creation_tokens + if v == nil { + return + } + return *v, true +} + +// OldCacheCreationTokens returns the old "cache_creation_tokens" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCacheCreationTokens(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheCreationTokens is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheCreationTokens requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheCreationTokens: %w", err) + } + return oldValue.CacheCreationTokens, nil +} + +// AddCacheCreationTokens adds i to the "cache_creation_tokens" field. +func (m *UsageLogMutation) AddCacheCreationTokens(i int) { + if m.addcache_creation_tokens != nil { + *m.addcache_creation_tokens += i + } else { + m.addcache_creation_tokens = &i + } +} + +// AddedCacheCreationTokens returns the value that was added to the "cache_creation_tokens" field in this mutation. +func (m *UsageLogMutation) AddedCacheCreationTokens() (r int, exists bool) { + v := m.addcache_creation_tokens + if v == nil { + return + } + return *v, true +} + +// ResetCacheCreationTokens resets all changes to the "cache_creation_tokens" field. +func (m *UsageLogMutation) ResetCacheCreationTokens() { + m.cache_creation_tokens = nil + m.addcache_creation_tokens = nil +} + +// SetCacheReadTokens sets the "cache_read_tokens" field. +func (m *UsageLogMutation) SetCacheReadTokens(i int) { + m.cache_read_tokens = &i + m.addcache_read_tokens = nil +} + +// CacheReadTokens returns the value of the "cache_read_tokens" field in the mutation. +func (m *UsageLogMutation) CacheReadTokens() (r int, exists bool) { + v := m.cache_read_tokens + if v == nil { + return + } + return *v, true +} + +// OldCacheReadTokens returns the old "cache_read_tokens" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCacheReadTokens(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheReadTokens is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheReadTokens requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheReadTokens: %w", err) + } + return oldValue.CacheReadTokens, nil +} + +// AddCacheReadTokens adds i to the "cache_read_tokens" field. +func (m *UsageLogMutation) AddCacheReadTokens(i int) { + if m.addcache_read_tokens != nil { + *m.addcache_read_tokens += i + } else { + m.addcache_read_tokens = &i + } +} + +// AddedCacheReadTokens returns the value that was added to the "cache_read_tokens" field in this mutation. +func (m *UsageLogMutation) AddedCacheReadTokens() (r int, exists bool) { + v := m.addcache_read_tokens + if v == nil { + return + } + return *v, true +} + +// ResetCacheReadTokens resets all changes to the "cache_read_tokens" field. +func (m *UsageLogMutation) ResetCacheReadTokens() { + m.cache_read_tokens = nil + m.addcache_read_tokens = nil +} + +// SetCacheCreation5mTokens sets the "cache_creation_5m_tokens" field. +func (m *UsageLogMutation) SetCacheCreation5mTokens(i int) { + m.cache_creation_5m_tokens = &i + m.addcache_creation_5m_tokens = nil +} + +// CacheCreation5mTokens returns the value of the "cache_creation_5m_tokens" field in the mutation. +func (m *UsageLogMutation) CacheCreation5mTokens() (r int, exists bool) { + v := m.cache_creation_5m_tokens + if v == nil { + return + } + return *v, true +} + +// OldCacheCreation5mTokens returns the old "cache_creation_5m_tokens" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCacheCreation5mTokens(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheCreation5mTokens is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheCreation5mTokens requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheCreation5mTokens: %w", err) + } + return oldValue.CacheCreation5mTokens, nil +} + +// AddCacheCreation5mTokens adds i to the "cache_creation_5m_tokens" field. +func (m *UsageLogMutation) AddCacheCreation5mTokens(i int) { + if m.addcache_creation_5m_tokens != nil { + *m.addcache_creation_5m_tokens += i + } else { + m.addcache_creation_5m_tokens = &i + } +} + +// AddedCacheCreation5mTokens returns the value that was added to the "cache_creation_5m_tokens" field in this mutation. +func (m *UsageLogMutation) AddedCacheCreation5mTokens() (r int, exists bool) { + v := m.addcache_creation_5m_tokens + if v == nil { + return + } + return *v, true +} + +// ResetCacheCreation5mTokens resets all changes to the "cache_creation_5m_tokens" field. +func (m *UsageLogMutation) ResetCacheCreation5mTokens() { + m.cache_creation_5m_tokens = nil + m.addcache_creation_5m_tokens = nil +} + +// SetCacheCreation1hTokens sets the "cache_creation_1h_tokens" field. +func (m *UsageLogMutation) SetCacheCreation1hTokens(i int) { + m.cache_creation_1h_tokens = &i + m.addcache_creation_1h_tokens = nil +} + +// CacheCreation1hTokens returns the value of the "cache_creation_1h_tokens" field in the mutation. +func (m *UsageLogMutation) CacheCreation1hTokens() (r int, exists bool) { + v := m.cache_creation_1h_tokens + if v == nil { + return + } + return *v, true +} + +// OldCacheCreation1hTokens returns the old "cache_creation_1h_tokens" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCacheCreation1hTokens(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheCreation1hTokens is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheCreation1hTokens requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheCreation1hTokens: %w", err) + } + return oldValue.CacheCreation1hTokens, nil +} + +// AddCacheCreation1hTokens adds i to the "cache_creation_1h_tokens" field. +func (m *UsageLogMutation) AddCacheCreation1hTokens(i int) { + if m.addcache_creation_1h_tokens != nil { + *m.addcache_creation_1h_tokens += i + } else { + m.addcache_creation_1h_tokens = &i + } +} + +// AddedCacheCreation1hTokens returns the value that was added to the "cache_creation_1h_tokens" field in this mutation. +func (m *UsageLogMutation) AddedCacheCreation1hTokens() (r int, exists bool) { + v := m.addcache_creation_1h_tokens + if v == nil { + return + } + return *v, true +} + +// ResetCacheCreation1hTokens resets all changes to the "cache_creation_1h_tokens" field. +func (m *UsageLogMutation) ResetCacheCreation1hTokens() { + m.cache_creation_1h_tokens = nil + m.addcache_creation_1h_tokens = nil +} + +// SetInputCost sets the "input_cost" field. +func (m *UsageLogMutation) SetInputCost(f float64) { + m.input_cost = &f + m.addinput_cost = nil +} + +// InputCost returns the value of the "input_cost" field in the mutation. +func (m *UsageLogMutation) InputCost() (r float64, exists bool) { + v := m.input_cost + if v == nil { + return + } + return *v, true +} + +// OldInputCost returns the old "input_cost" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldInputCost(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldInputCost is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldInputCost requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldInputCost: %w", err) + } + return oldValue.InputCost, nil +} + +// AddInputCost adds f to the "input_cost" field. +func (m *UsageLogMutation) AddInputCost(f float64) { + if m.addinput_cost != nil { + *m.addinput_cost += f + } else { + m.addinput_cost = &f + } +} + +// AddedInputCost returns the value that was added to the "input_cost" field in this mutation. +func (m *UsageLogMutation) AddedInputCost() (r float64, exists bool) { + v := m.addinput_cost + if v == nil { + return + } + return *v, true +} + +// ResetInputCost resets all changes to the "input_cost" field. +func (m *UsageLogMutation) ResetInputCost() { + m.input_cost = nil + m.addinput_cost = nil +} + +// SetOutputCost sets the "output_cost" field. +func (m *UsageLogMutation) SetOutputCost(f float64) { + m.output_cost = &f + m.addoutput_cost = nil +} + +// OutputCost returns the value of the "output_cost" field in the mutation. +func (m *UsageLogMutation) OutputCost() (r float64, exists bool) { + v := m.output_cost + if v == nil { + return + } + return *v, true +} + +// OldOutputCost returns the old "output_cost" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldOutputCost(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOutputCost is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOutputCost requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOutputCost: %w", err) + } + return oldValue.OutputCost, nil +} + +// AddOutputCost adds f to the "output_cost" field. +func (m *UsageLogMutation) AddOutputCost(f float64) { + if m.addoutput_cost != nil { + *m.addoutput_cost += f + } else { + m.addoutput_cost = &f + } +} + +// AddedOutputCost returns the value that was added to the "output_cost" field in this mutation. +func (m *UsageLogMutation) AddedOutputCost() (r float64, exists bool) { + v := m.addoutput_cost + if v == nil { + return + } + return *v, true +} + +// ResetOutputCost resets all changes to the "output_cost" field. +func (m *UsageLogMutation) ResetOutputCost() { + m.output_cost = nil + m.addoutput_cost = nil +} + +// SetCacheCreationCost sets the "cache_creation_cost" field. +func (m *UsageLogMutation) SetCacheCreationCost(f float64) { + m.cache_creation_cost = &f + m.addcache_creation_cost = nil +} + +// CacheCreationCost returns the value of the "cache_creation_cost" field in the mutation. +func (m *UsageLogMutation) CacheCreationCost() (r float64, exists bool) { + v := m.cache_creation_cost + if v == nil { + return + } + return *v, true +} + +// OldCacheCreationCost returns the old "cache_creation_cost" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCacheCreationCost(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheCreationCost is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheCreationCost requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheCreationCost: %w", err) + } + return oldValue.CacheCreationCost, nil +} + +// AddCacheCreationCost adds f to the "cache_creation_cost" field. +func (m *UsageLogMutation) AddCacheCreationCost(f float64) { + if m.addcache_creation_cost != nil { + *m.addcache_creation_cost += f + } else { + m.addcache_creation_cost = &f + } +} + +// AddedCacheCreationCost returns the value that was added to the "cache_creation_cost" field in this mutation. +func (m *UsageLogMutation) AddedCacheCreationCost() (r float64, exists bool) { + v := m.addcache_creation_cost + if v == nil { + return + } + return *v, true +} + +// ResetCacheCreationCost resets all changes to the "cache_creation_cost" field. +func (m *UsageLogMutation) ResetCacheCreationCost() { + m.cache_creation_cost = nil + m.addcache_creation_cost = nil +} + +// SetCacheReadCost sets the "cache_read_cost" field. +func (m *UsageLogMutation) SetCacheReadCost(f float64) { + m.cache_read_cost = &f + m.addcache_read_cost = nil +} + +// CacheReadCost returns the value of the "cache_read_cost" field in the mutation. +func (m *UsageLogMutation) CacheReadCost() (r float64, exists bool) { + v := m.cache_read_cost + if v == nil { + return + } + return *v, true +} + +// OldCacheReadCost returns the old "cache_read_cost" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCacheReadCost(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheReadCost is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheReadCost requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheReadCost: %w", err) + } + return oldValue.CacheReadCost, nil +} + +// AddCacheReadCost adds f to the "cache_read_cost" field. +func (m *UsageLogMutation) AddCacheReadCost(f float64) { + if m.addcache_read_cost != nil { + *m.addcache_read_cost += f + } else { + m.addcache_read_cost = &f + } +} + +// AddedCacheReadCost returns the value that was added to the "cache_read_cost" field in this mutation. +func (m *UsageLogMutation) AddedCacheReadCost() (r float64, exists bool) { + v := m.addcache_read_cost + if v == nil { + return + } + return *v, true +} + +// ResetCacheReadCost resets all changes to the "cache_read_cost" field. +func (m *UsageLogMutation) ResetCacheReadCost() { + m.cache_read_cost = nil + m.addcache_read_cost = nil +} + +// SetTotalCost sets the "total_cost" field. +func (m *UsageLogMutation) SetTotalCost(f float64) { + m.total_cost = &f + m.addtotal_cost = nil +} + +// TotalCost returns the value of the "total_cost" field in the mutation. +func (m *UsageLogMutation) TotalCost() (r float64, exists bool) { + v := m.total_cost + if v == nil { + return + } + return *v, true +} + +// OldTotalCost returns the old "total_cost" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldTotalCost(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTotalCost is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTotalCost requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTotalCost: %w", err) + } + return oldValue.TotalCost, nil +} + +// AddTotalCost adds f to the "total_cost" field. +func (m *UsageLogMutation) AddTotalCost(f float64) { + if m.addtotal_cost != nil { + *m.addtotal_cost += f + } else { + m.addtotal_cost = &f + } +} + +// AddedTotalCost returns the value that was added to the "total_cost" field in this mutation. +func (m *UsageLogMutation) AddedTotalCost() (r float64, exists bool) { + v := m.addtotal_cost + if v == nil { + return + } + return *v, true +} + +// ResetTotalCost resets all changes to the "total_cost" field. +func (m *UsageLogMutation) ResetTotalCost() { + m.total_cost = nil + m.addtotal_cost = nil +} + +// SetActualCost sets the "actual_cost" field. +func (m *UsageLogMutation) SetActualCost(f float64) { + m.actual_cost = &f + m.addactual_cost = nil +} + +// ActualCost returns the value of the "actual_cost" field in the mutation. +func (m *UsageLogMutation) ActualCost() (r float64, exists bool) { + v := m.actual_cost + if v == nil { + return + } + return *v, true +} + +// OldActualCost returns the old "actual_cost" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldActualCost(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldActualCost is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldActualCost requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldActualCost: %w", err) + } + return oldValue.ActualCost, nil +} + +// AddActualCost adds f to the "actual_cost" field. +func (m *UsageLogMutation) AddActualCost(f float64) { + if m.addactual_cost != nil { + *m.addactual_cost += f + } else { + m.addactual_cost = &f + } +} + +// AddedActualCost returns the value that was added to the "actual_cost" field in this mutation. +func (m *UsageLogMutation) AddedActualCost() (r float64, exists bool) { + v := m.addactual_cost + if v == nil { + return + } + return *v, true +} + +// ResetActualCost resets all changes to the "actual_cost" field. +func (m *UsageLogMutation) ResetActualCost() { + m.actual_cost = nil + m.addactual_cost = nil +} + +// SetRateMultiplier sets the "rate_multiplier" field. +func (m *UsageLogMutation) SetRateMultiplier(f float64) { + m.rate_multiplier = &f + m.addrate_multiplier = nil +} + +// RateMultiplier returns the value of the "rate_multiplier" field in the mutation. +func (m *UsageLogMutation) RateMultiplier() (r float64, exists bool) { + v := m.rate_multiplier + if v == nil { + return + } + return *v, true +} + +// OldRateMultiplier returns the old "rate_multiplier" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldRateMultiplier(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRateMultiplier is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRateMultiplier requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRateMultiplier: %w", err) + } + return oldValue.RateMultiplier, nil +} + +// AddRateMultiplier adds f to the "rate_multiplier" field. +func (m *UsageLogMutation) AddRateMultiplier(f float64) { + if m.addrate_multiplier != nil { + *m.addrate_multiplier += f + } else { + m.addrate_multiplier = &f + } +} + +// AddedRateMultiplier returns the value that was added to the "rate_multiplier" field in this mutation. +func (m *UsageLogMutation) AddedRateMultiplier() (r float64, exists bool) { + v := m.addrate_multiplier + if v == nil { + return + } + return *v, true +} + +// ResetRateMultiplier resets all changes to the "rate_multiplier" field. +func (m *UsageLogMutation) ResetRateMultiplier() { + m.rate_multiplier = nil + m.addrate_multiplier = nil +} + +// SetBillingType sets the "billing_type" field. +func (m *UsageLogMutation) SetBillingType(i int8) { + m.billing_type = &i + m.addbilling_type = nil +} + +// BillingType returns the value of the "billing_type" field in the mutation. +func (m *UsageLogMutation) BillingType() (r int8, exists bool) { + v := m.billing_type + if v == nil { + return + } + return *v, true +} + +// OldBillingType returns the old "billing_type" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldBillingType(ctx context.Context) (v int8, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBillingType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBillingType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBillingType: %w", err) + } + return oldValue.BillingType, nil +} + +// AddBillingType adds i to the "billing_type" field. +func (m *UsageLogMutation) AddBillingType(i int8) { + if m.addbilling_type != nil { + *m.addbilling_type += i + } else { + m.addbilling_type = &i + } +} + +// AddedBillingType returns the value that was added to the "billing_type" field in this mutation. +func (m *UsageLogMutation) AddedBillingType() (r int8, exists bool) { + v := m.addbilling_type + if v == nil { + return + } + return *v, true +} + +// ResetBillingType resets all changes to the "billing_type" field. +func (m *UsageLogMutation) ResetBillingType() { + m.billing_type = nil + m.addbilling_type = nil +} + +// SetStream sets the "stream" field. +func (m *UsageLogMutation) SetStream(b bool) { + m.stream = &b +} + +// Stream returns the value of the "stream" field in the mutation. +func (m *UsageLogMutation) Stream() (r bool, exists bool) { + v := m.stream + if v == nil { + return + } + return *v, true +} + +// OldStream returns the old "stream" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldStream(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStream is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStream requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStream: %w", err) + } + return oldValue.Stream, nil +} + +// ResetStream resets all changes to the "stream" field. +func (m *UsageLogMutation) ResetStream() { + m.stream = nil +} + +// SetDurationMs sets the "duration_ms" field. +func (m *UsageLogMutation) SetDurationMs(i int) { + m.duration_ms = &i + m.addduration_ms = nil +} + +// DurationMs returns the value of the "duration_ms" field in the mutation. +func (m *UsageLogMutation) DurationMs() (r int, exists bool) { + v := m.duration_ms + if v == nil { + return + } + return *v, true +} + +// OldDurationMs returns the old "duration_ms" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldDurationMs(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDurationMs is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDurationMs requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDurationMs: %w", err) + } + return oldValue.DurationMs, nil +} + +// AddDurationMs adds i to the "duration_ms" field. +func (m *UsageLogMutation) AddDurationMs(i int) { + if m.addduration_ms != nil { + *m.addduration_ms += i + } else { + m.addduration_ms = &i + } +} + +// AddedDurationMs returns the value that was added to the "duration_ms" field in this mutation. +func (m *UsageLogMutation) AddedDurationMs() (r int, exists bool) { + v := m.addduration_ms + if v == nil { + return + } + return *v, true +} + +// ClearDurationMs clears the value of the "duration_ms" field. +func (m *UsageLogMutation) ClearDurationMs() { + m.duration_ms = nil + m.addduration_ms = nil + m.clearedFields[usagelog.FieldDurationMs] = struct{}{} +} + +// DurationMsCleared returns if the "duration_ms" field was cleared in this mutation. +func (m *UsageLogMutation) DurationMsCleared() bool { + _, ok := m.clearedFields[usagelog.FieldDurationMs] + return ok +} + +// ResetDurationMs resets all changes to the "duration_ms" field. +func (m *UsageLogMutation) ResetDurationMs() { + m.duration_ms = nil + m.addduration_ms = nil + delete(m.clearedFields, usagelog.FieldDurationMs) +} + +// SetFirstTokenMs sets the "first_token_ms" field. +func (m *UsageLogMutation) SetFirstTokenMs(i int) { + m.first_token_ms = &i + m.addfirst_token_ms = nil +} + +// FirstTokenMs returns the value of the "first_token_ms" field in the mutation. +func (m *UsageLogMutation) FirstTokenMs() (r int, exists bool) { + v := m.first_token_ms + if v == nil { + return + } + return *v, true +} + +// OldFirstTokenMs returns the old "first_token_ms" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldFirstTokenMs(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFirstTokenMs is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFirstTokenMs requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFirstTokenMs: %w", err) + } + return oldValue.FirstTokenMs, nil +} + +// AddFirstTokenMs adds i to the "first_token_ms" field. +func (m *UsageLogMutation) AddFirstTokenMs(i int) { + if m.addfirst_token_ms != nil { + *m.addfirst_token_ms += i + } else { + m.addfirst_token_ms = &i + } +} + +// AddedFirstTokenMs returns the value that was added to the "first_token_ms" field in this mutation. +func (m *UsageLogMutation) AddedFirstTokenMs() (r int, exists bool) { + v := m.addfirst_token_ms + if v == nil { + return + } + return *v, true +} + +// ClearFirstTokenMs clears the value of the "first_token_ms" field. +func (m *UsageLogMutation) ClearFirstTokenMs() { + m.first_token_ms = nil + m.addfirst_token_ms = nil + m.clearedFields[usagelog.FieldFirstTokenMs] = struct{}{} +} + +// FirstTokenMsCleared returns if the "first_token_ms" field was cleared in this mutation. +func (m *UsageLogMutation) FirstTokenMsCleared() bool { + _, ok := m.clearedFields[usagelog.FieldFirstTokenMs] + return ok +} + +// ResetFirstTokenMs resets all changes to the "first_token_ms" field. +func (m *UsageLogMutation) ResetFirstTokenMs() { + m.first_token_ms = nil + m.addfirst_token_ms = nil + delete(m.clearedFields, usagelog.FieldFirstTokenMs) +} + +// SetCreatedAt sets the "created_at" field. +func (m *UsageLogMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *UsageLogMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *UsageLogMutation) ResetCreatedAt() { + m.created_at = nil +} + +// ClearUser clears the "user" edge to the User entity. +func (m *UsageLogMutation) ClearUser() { + m.cleareduser = true + m.clearedFields[usagelog.FieldUserID] = struct{}{} +} + +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *UsageLogMutation) UserCleared() bool { + return m.cleareduser +} + +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *UsageLogMutation) UserIDs() (ids []int64) { + if id := m.user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetUser resets all changes to the "user" edge. +func (m *UsageLogMutation) ResetUser() { + m.user = nil + m.cleareduser = false +} + +// ClearAPIKey clears the "api_key" edge to the ApiKey entity. +func (m *UsageLogMutation) ClearAPIKey() { + m.clearedapi_key = true + m.clearedFields[usagelog.FieldAPIKeyID] = struct{}{} +} + +// APIKeyCleared reports if the "api_key" edge to the ApiKey entity was cleared. +func (m *UsageLogMutation) APIKeyCleared() bool { + return m.clearedapi_key +} + +// APIKeyIDs returns the "api_key" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// APIKeyID instead. It exists only for internal usage by the builders. +func (m *UsageLogMutation) APIKeyIDs() (ids []int64) { + if id := m.api_key; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetAPIKey resets all changes to the "api_key" edge. +func (m *UsageLogMutation) ResetAPIKey() { + m.api_key = nil + m.clearedapi_key = false +} + +// ClearAccount clears the "account" edge to the Account entity. +func (m *UsageLogMutation) ClearAccount() { + m.clearedaccount = true + m.clearedFields[usagelog.FieldAccountID] = struct{}{} +} + +// AccountCleared reports if the "account" edge to the Account entity was cleared. +func (m *UsageLogMutation) AccountCleared() bool { + return m.clearedaccount +} + +// AccountIDs returns the "account" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// AccountID instead. It exists only for internal usage by the builders. +func (m *UsageLogMutation) AccountIDs() (ids []int64) { + if id := m.account; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetAccount resets all changes to the "account" edge. +func (m *UsageLogMutation) ResetAccount() { + m.account = nil + m.clearedaccount = false +} + +// ClearGroup clears the "group" edge to the Group entity. +func (m *UsageLogMutation) ClearGroup() { + m.clearedgroup = true + m.clearedFields[usagelog.FieldGroupID] = struct{}{} +} + +// GroupCleared reports if the "group" edge to the Group entity was cleared. +func (m *UsageLogMutation) GroupCleared() bool { + return m.GroupIDCleared() || m.clearedgroup +} + +// GroupIDs returns the "group" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// GroupID instead. It exists only for internal usage by the builders. +func (m *UsageLogMutation) GroupIDs() (ids []int64) { + if id := m.group; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetGroup resets all changes to the "group" edge. +func (m *UsageLogMutation) ResetGroup() { + m.group = nil + m.clearedgroup = false +} + +// ClearSubscription clears the "subscription" edge to the UserSubscription entity. +func (m *UsageLogMutation) ClearSubscription() { + m.clearedsubscription = true + m.clearedFields[usagelog.FieldSubscriptionID] = struct{}{} +} + +// SubscriptionCleared reports if the "subscription" edge to the UserSubscription entity was cleared. +func (m *UsageLogMutation) SubscriptionCleared() bool { + return m.SubscriptionIDCleared() || m.clearedsubscription +} + +// SubscriptionIDs returns the "subscription" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// SubscriptionID instead. It exists only for internal usage by the builders. +func (m *UsageLogMutation) SubscriptionIDs() (ids []int64) { + if id := m.subscription; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetSubscription resets all changes to the "subscription" edge. +func (m *UsageLogMutation) ResetSubscription() { + m.subscription = nil + m.clearedsubscription = false +} + +// Where appends a list predicates to the UsageLogMutation builder. +func (m *UsageLogMutation) Where(ps ...predicate.UsageLog) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the UsageLogMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *UsageLogMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.UsageLog, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *UsageLogMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *UsageLogMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (UsageLog). +func (m *UsageLogMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *UsageLogMutation) Fields() []string { + fields := make([]string, 0, 25) + if m.user != nil { + fields = append(fields, usagelog.FieldUserID) + } + if m.api_key != nil { + fields = append(fields, usagelog.FieldAPIKeyID) + } + if m.account != nil { + fields = append(fields, usagelog.FieldAccountID) + } + if m.request_id != nil { + fields = append(fields, usagelog.FieldRequestID) + } + if m.model != nil { + fields = append(fields, usagelog.FieldModel) + } + if m.group != nil { + fields = append(fields, usagelog.FieldGroupID) + } + if m.subscription != nil { + fields = append(fields, usagelog.FieldSubscriptionID) + } + if m.input_tokens != nil { + fields = append(fields, usagelog.FieldInputTokens) + } + if m.output_tokens != nil { + fields = append(fields, usagelog.FieldOutputTokens) + } + if m.cache_creation_tokens != nil { + fields = append(fields, usagelog.FieldCacheCreationTokens) + } + if m.cache_read_tokens != nil { + fields = append(fields, usagelog.FieldCacheReadTokens) + } + if m.cache_creation_5m_tokens != nil { + fields = append(fields, usagelog.FieldCacheCreation5mTokens) + } + if m.cache_creation_1h_tokens != nil { + fields = append(fields, usagelog.FieldCacheCreation1hTokens) + } + if m.input_cost != nil { + fields = append(fields, usagelog.FieldInputCost) + } + if m.output_cost != nil { + fields = append(fields, usagelog.FieldOutputCost) + } + if m.cache_creation_cost != nil { + fields = append(fields, usagelog.FieldCacheCreationCost) + } + if m.cache_read_cost != nil { + fields = append(fields, usagelog.FieldCacheReadCost) + } + if m.total_cost != nil { + fields = append(fields, usagelog.FieldTotalCost) + } + if m.actual_cost != nil { + fields = append(fields, usagelog.FieldActualCost) + } + if m.rate_multiplier != nil { + fields = append(fields, usagelog.FieldRateMultiplier) + } + if m.billing_type != nil { + fields = append(fields, usagelog.FieldBillingType) + } + if m.stream != nil { + fields = append(fields, usagelog.FieldStream) + } + if m.duration_ms != nil { + fields = append(fields, usagelog.FieldDurationMs) + } + if m.first_token_ms != nil { + fields = append(fields, usagelog.FieldFirstTokenMs) + } + if m.created_at != nil { + fields = append(fields, usagelog.FieldCreatedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { + switch name { + case usagelog.FieldUserID: + return m.UserID() + case usagelog.FieldAPIKeyID: + return m.APIKeyID() + case usagelog.FieldAccountID: + return m.AccountID() + case usagelog.FieldRequestID: + return m.RequestID() + case usagelog.FieldModel: + return m.Model() + case usagelog.FieldGroupID: + return m.GroupID() + case usagelog.FieldSubscriptionID: + return m.SubscriptionID() + case usagelog.FieldInputTokens: + return m.InputTokens() + case usagelog.FieldOutputTokens: + return m.OutputTokens() + case usagelog.FieldCacheCreationTokens: + return m.CacheCreationTokens() + case usagelog.FieldCacheReadTokens: + return m.CacheReadTokens() + case usagelog.FieldCacheCreation5mTokens: + return m.CacheCreation5mTokens() + case usagelog.FieldCacheCreation1hTokens: + return m.CacheCreation1hTokens() + case usagelog.FieldInputCost: + return m.InputCost() + case usagelog.FieldOutputCost: + return m.OutputCost() + case usagelog.FieldCacheCreationCost: + return m.CacheCreationCost() + case usagelog.FieldCacheReadCost: + return m.CacheReadCost() + case usagelog.FieldTotalCost: + return m.TotalCost() + case usagelog.FieldActualCost: + return m.ActualCost() + case usagelog.FieldRateMultiplier: + return m.RateMultiplier() + case usagelog.FieldBillingType: + return m.BillingType() + case usagelog.FieldStream: + return m.Stream() + case usagelog.FieldDurationMs: + return m.DurationMs() + case usagelog.FieldFirstTokenMs: + return m.FirstTokenMs() + case usagelog.FieldCreatedAt: + return m.CreatedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case usagelog.FieldUserID: + return m.OldUserID(ctx) + case usagelog.FieldAPIKeyID: + return m.OldAPIKeyID(ctx) + case usagelog.FieldAccountID: + return m.OldAccountID(ctx) + case usagelog.FieldRequestID: + return m.OldRequestID(ctx) + case usagelog.FieldModel: + return m.OldModel(ctx) + case usagelog.FieldGroupID: + return m.OldGroupID(ctx) + case usagelog.FieldSubscriptionID: + return m.OldSubscriptionID(ctx) + case usagelog.FieldInputTokens: + return m.OldInputTokens(ctx) + case usagelog.FieldOutputTokens: + return m.OldOutputTokens(ctx) + case usagelog.FieldCacheCreationTokens: + return m.OldCacheCreationTokens(ctx) + case usagelog.FieldCacheReadTokens: + return m.OldCacheReadTokens(ctx) + case usagelog.FieldCacheCreation5mTokens: + return m.OldCacheCreation5mTokens(ctx) + case usagelog.FieldCacheCreation1hTokens: + return m.OldCacheCreation1hTokens(ctx) + case usagelog.FieldInputCost: + return m.OldInputCost(ctx) + case usagelog.FieldOutputCost: + return m.OldOutputCost(ctx) + case usagelog.FieldCacheCreationCost: + return m.OldCacheCreationCost(ctx) + case usagelog.FieldCacheReadCost: + return m.OldCacheReadCost(ctx) + case usagelog.FieldTotalCost: + return m.OldTotalCost(ctx) + case usagelog.FieldActualCost: + return m.OldActualCost(ctx) + case usagelog.FieldRateMultiplier: + return m.OldRateMultiplier(ctx) + case usagelog.FieldBillingType: + return m.OldBillingType(ctx) + case usagelog.FieldStream: + return m.OldStream(ctx) + case usagelog.FieldDurationMs: + return m.OldDurationMs(ctx) + case usagelog.FieldFirstTokenMs: + return m.OldFirstTokenMs(ctx) + case usagelog.FieldCreatedAt: + return m.OldCreatedAt(ctx) + } + return nil, fmt.Errorf("unknown UsageLog field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *UsageLogMutation) SetField(name string, value ent.Value) error { + switch name { + case usagelog.FieldUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserID(v) + return nil + case usagelog.FieldAPIKeyID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAPIKeyID(v) + return nil + case usagelog.FieldAccountID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAccountID(v) + return nil + case usagelog.FieldRequestID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequestID(v) + return nil + case usagelog.FieldModel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetModel(v) + return nil + case usagelog.FieldGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetGroupID(v) + return nil + case usagelog.FieldSubscriptionID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSubscriptionID(v) + return nil + case usagelog.FieldInputTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetInputTokens(v) + return nil + case usagelog.FieldOutputTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOutputTokens(v) + return nil + case usagelog.FieldCacheCreationTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheCreationTokens(v) + return nil + case usagelog.FieldCacheReadTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheReadTokens(v) + return nil + case usagelog.FieldCacheCreation5mTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheCreation5mTokens(v) + return nil + case usagelog.FieldCacheCreation1hTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheCreation1hTokens(v) + return nil + case usagelog.FieldInputCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetInputCost(v) + return nil + case usagelog.FieldOutputCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOutputCost(v) + return nil + case usagelog.FieldCacheCreationCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheCreationCost(v) + return nil + case usagelog.FieldCacheReadCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheReadCost(v) + return nil + case usagelog.FieldTotalCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotalCost(v) + return nil + case usagelog.FieldActualCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetActualCost(v) + return nil + case usagelog.FieldRateMultiplier: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRateMultiplier(v) + return nil + case usagelog.FieldBillingType: + v, ok := value.(int8) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBillingType(v) + return nil + case usagelog.FieldStream: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStream(v) + return nil + case usagelog.FieldDurationMs: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDurationMs(v) + return nil + case usagelog.FieldFirstTokenMs: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFirstTokenMs(v) + return nil + case usagelog.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + } + return fmt.Errorf("unknown UsageLog field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *UsageLogMutation) AddedFields() []string { + var fields []string + if m.addinput_tokens != nil { + fields = append(fields, usagelog.FieldInputTokens) + } + if m.addoutput_tokens != nil { + fields = append(fields, usagelog.FieldOutputTokens) + } + if m.addcache_creation_tokens != nil { + fields = append(fields, usagelog.FieldCacheCreationTokens) + } + if m.addcache_read_tokens != nil { + fields = append(fields, usagelog.FieldCacheReadTokens) + } + if m.addcache_creation_5m_tokens != nil { + fields = append(fields, usagelog.FieldCacheCreation5mTokens) + } + if m.addcache_creation_1h_tokens != nil { + fields = append(fields, usagelog.FieldCacheCreation1hTokens) + } + if m.addinput_cost != nil { + fields = append(fields, usagelog.FieldInputCost) + } + if m.addoutput_cost != nil { + fields = append(fields, usagelog.FieldOutputCost) + } + if m.addcache_creation_cost != nil { + fields = append(fields, usagelog.FieldCacheCreationCost) + } + if m.addcache_read_cost != nil { + fields = append(fields, usagelog.FieldCacheReadCost) + } + if m.addtotal_cost != nil { + fields = append(fields, usagelog.FieldTotalCost) + } + if m.addactual_cost != nil { + fields = append(fields, usagelog.FieldActualCost) + } + if m.addrate_multiplier != nil { + fields = append(fields, usagelog.FieldRateMultiplier) + } + if m.addbilling_type != nil { + fields = append(fields, usagelog.FieldBillingType) + } + if m.addduration_ms != nil { + fields = append(fields, usagelog.FieldDurationMs) + } + if m.addfirst_token_ms != nil { + fields = append(fields, usagelog.FieldFirstTokenMs) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case usagelog.FieldInputTokens: + return m.AddedInputTokens() + case usagelog.FieldOutputTokens: + return m.AddedOutputTokens() + case usagelog.FieldCacheCreationTokens: + return m.AddedCacheCreationTokens() + case usagelog.FieldCacheReadTokens: + return m.AddedCacheReadTokens() + case usagelog.FieldCacheCreation5mTokens: + return m.AddedCacheCreation5mTokens() + case usagelog.FieldCacheCreation1hTokens: + return m.AddedCacheCreation1hTokens() + case usagelog.FieldInputCost: + return m.AddedInputCost() + case usagelog.FieldOutputCost: + return m.AddedOutputCost() + case usagelog.FieldCacheCreationCost: + return m.AddedCacheCreationCost() + case usagelog.FieldCacheReadCost: + return m.AddedCacheReadCost() + case usagelog.FieldTotalCost: + return m.AddedTotalCost() + case usagelog.FieldActualCost: + return m.AddedActualCost() + case usagelog.FieldRateMultiplier: + return m.AddedRateMultiplier() + case usagelog.FieldBillingType: + return m.AddedBillingType() + case usagelog.FieldDurationMs: + return m.AddedDurationMs() + case usagelog.FieldFirstTokenMs: + return m.AddedFirstTokenMs() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *UsageLogMutation) AddField(name string, value ent.Value) error { + switch name { + case usagelog.FieldInputTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddInputTokens(v) + return nil + case usagelog.FieldOutputTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddOutputTokens(v) + return nil + case usagelog.FieldCacheCreationTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCacheCreationTokens(v) + return nil + case usagelog.FieldCacheReadTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCacheReadTokens(v) + return nil + case usagelog.FieldCacheCreation5mTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCacheCreation5mTokens(v) + return nil + case usagelog.FieldCacheCreation1hTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCacheCreation1hTokens(v) + return nil + case usagelog.FieldInputCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddInputCost(v) + return nil + case usagelog.FieldOutputCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddOutputCost(v) + return nil + case usagelog.FieldCacheCreationCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCacheCreationCost(v) + return nil + case usagelog.FieldCacheReadCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCacheReadCost(v) + return nil + case usagelog.FieldTotalCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddTotalCost(v) + return nil + case usagelog.FieldActualCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddActualCost(v) + return nil + case usagelog.FieldRateMultiplier: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRateMultiplier(v) + return nil + case usagelog.FieldBillingType: + v, ok := value.(int8) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddBillingType(v) + return nil + case usagelog.FieldDurationMs: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDurationMs(v) + return nil + case usagelog.FieldFirstTokenMs: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFirstTokenMs(v) + return nil + } + return fmt.Errorf("unknown UsageLog numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *UsageLogMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(usagelog.FieldGroupID) { + fields = append(fields, usagelog.FieldGroupID) + } + if m.FieldCleared(usagelog.FieldSubscriptionID) { + fields = append(fields, usagelog.FieldSubscriptionID) + } + if m.FieldCleared(usagelog.FieldDurationMs) { + fields = append(fields, usagelog.FieldDurationMs) + } + if m.FieldCleared(usagelog.FieldFirstTokenMs) { + fields = append(fields, usagelog.FieldFirstTokenMs) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *UsageLogMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *UsageLogMutation) ClearField(name string) error { + switch name { + case usagelog.FieldGroupID: + m.ClearGroupID() + return nil + case usagelog.FieldSubscriptionID: + m.ClearSubscriptionID() + return nil + case usagelog.FieldDurationMs: + m.ClearDurationMs() + return nil + case usagelog.FieldFirstTokenMs: + m.ClearFirstTokenMs() + return nil + } + return fmt.Errorf("unknown UsageLog nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *UsageLogMutation) ResetField(name string) error { + switch name { + case usagelog.FieldUserID: + m.ResetUserID() + return nil + case usagelog.FieldAPIKeyID: + m.ResetAPIKeyID() + return nil + case usagelog.FieldAccountID: + m.ResetAccountID() + return nil + case usagelog.FieldRequestID: + m.ResetRequestID() + return nil + case usagelog.FieldModel: + m.ResetModel() + return nil + case usagelog.FieldGroupID: + m.ResetGroupID() + return nil + case usagelog.FieldSubscriptionID: + m.ResetSubscriptionID() + return nil + case usagelog.FieldInputTokens: + m.ResetInputTokens() + return nil + case usagelog.FieldOutputTokens: + m.ResetOutputTokens() + return nil + case usagelog.FieldCacheCreationTokens: + m.ResetCacheCreationTokens() + return nil + case usagelog.FieldCacheReadTokens: + m.ResetCacheReadTokens() + return nil + case usagelog.FieldCacheCreation5mTokens: + m.ResetCacheCreation5mTokens() + return nil + case usagelog.FieldCacheCreation1hTokens: + m.ResetCacheCreation1hTokens() + return nil + case usagelog.FieldInputCost: + m.ResetInputCost() + return nil + case usagelog.FieldOutputCost: + m.ResetOutputCost() + return nil + case usagelog.FieldCacheCreationCost: + m.ResetCacheCreationCost() + return nil + case usagelog.FieldCacheReadCost: + m.ResetCacheReadCost() + return nil + case usagelog.FieldTotalCost: + m.ResetTotalCost() + return nil + case usagelog.FieldActualCost: + m.ResetActualCost() + return nil + case usagelog.FieldRateMultiplier: + m.ResetRateMultiplier() + return nil + case usagelog.FieldBillingType: + m.ResetBillingType() + return nil + case usagelog.FieldStream: + m.ResetStream() + return nil + case usagelog.FieldDurationMs: + m.ResetDurationMs() + return nil + case usagelog.FieldFirstTokenMs: + m.ResetFirstTokenMs() + return nil + case usagelog.FieldCreatedAt: + m.ResetCreatedAt() + return nil + } + return fmt.Errorf("unknown UsageLog field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *UsageLogMutation) AddedEdges() []string { + edges := make([]string, 0, 5) + if m.user != nil { + edges = append(edges, usagelog.EdgeUser) + } + if m.api_key != nil { + edges = append(edges, usagelog.EdgeAPIKey) + } + if m.account != nil { + edges = append(edges, usagelog.EdgeAccount) + } + if m.group != nil { + edges = append(edges, usagelog.EdgeGroup) + } + if m.subscription != nil { + edges = append(edges, usagelog.EdgeSubscription) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *UsageLogMutation) AddedIDs(name string) []ent.Value { + switch name { + case usagelog.EdgeUser: + if id := m.user; id != nil { + return []ent.Value{*id} + } + case usagelog.EdgeAPIKey: + if id := m.api_key; id != nil { + return []ent.Value{*id} + } + case usagelog.EdgeAccount: + if id := m.account; id != nil { + return []ent.Value{*id} + } + case usagelog.EdgeGroup: + if id := m.group; id != nil { + return []ent.Value{*id} + } + case usagelog.EdgeSubscription: + if id := m.subscription; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *UsageLogMutation) RemovedEdges() []string { + edges := make([]string, 0, 5) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *UsageLogMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *UsageLogMutation) ClearedEdges() []string { + edges := make([]string, 0, 5) + if m.cleareduser { + edges = append(edges, usagelog.EdgeUser) + } + if m.clearedapi_key { + edges = append(edges, usagelog.EdgeAPIKey) + } + if m.clearedaccount { + edges = append(edges, usagelog.EdgeAccount) + } + if m.clearedgroup { + edges = append(edges, usagelog.EdgeGroup) + } + if m.clearedsubscription { + edges = append(edges, usagelog.EdgeSubscription) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *UsageLogMutation) EdgeCleared(name string) bool { + switch name { + case usagelog.EdgeUser: + return m.cleareduser + case usagelog.EdgeAPIKey: + return m.clearedapi_key + case usagelog.EdgeAccount: + return m.clearedaccount + case usagelog.EdgeGroup: + return m.clearedgroup + case usagelog.EdgeSubscription: + return m.clearedsubscription + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *UsageLogMutation) ClearEdge(name string) error { + switch name { + case usagelog.EdgeUser: + m.ClearUser() + return nil + case usagelog.EdgeAPIKey: + m.ClearAPIKey() + return nil + case usagelog.EdgeAccount: + m.ClearAccount() + return nil + case usagelog.EdgeGroup: + m.ClearGroup() + return nil + case usagelog.EdgeSubscription: + m.ClearSubscription() + return nil + } + return fmt.Errorf("unknown UsageLog unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *UsageLogMutation) ResetEdge(name string) error { + switch name { + case usagelog.EdgeUser: + m.ResetUser() + return nil + case usagelog.EdgeAPIKey: + m.ResetAPIKey() + return nil + case usagelog.EdgeAccount: + m.ResetAccount() + return nil + case usagelog.EdgeGroup: + m.ResetGroup() + return nil + case usagelog.EdgeSubscription: + m.ResetSubscription() + return nil + } + return fmt.Errorf("unknown UsageLog edge %s", name) +} + // UserMutation represents an operation that mutates the User nodes in the graph. type UserMutation struct { config @@ -7259,6 +10176,9 @@ type UserMutation struct { allowed_groups map[int64]struct{} removedallowed_groups map[int64]struct{} clearedallowed_groups bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool done bool oldValue func(context.Context) (*User, error) predicates []predicate.User @@ -8117,6 +11037,60 @@ func (m *UserMutation) ResetAllowedGroups() { m.removedallowed_groups = nil } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids. +func (m *UserMutation) AddUsageLogIDs(ids ...int64) { + if m.usage_logs == nil { + m.usage_logs = make(map[int64]struct{}) + } + for i := range ids { + m.usage_logs[ids[i]] = struct{}{} + } +} + +// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity. +func (m *UserMutation) ClearUsageLogs() { + m.clearedusage_logs = true +} + +// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared. +func (m *UserMutation) UsageLogsCleared() bool { + return m.clearedusage_logs +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs. +func (m *UserMutation) RemoveUsageLogIDs(ids ...int64) { + if m.removedusage_logs == nil { + m.removedusage_logs = make(map[int64]struct{}) + } + for i := range ids { + delete(m.usage_logs, ids[i]) + m.removedusage_logs[ids[i]] = struct{}{} + } +} + +// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity. +func (m *UserMutation) RemovedUsageLogsIDs() (ids []int64) { + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return +} + +// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation. +func (m *UserMutation) UsageLogsIDs() (ids []int64) { + for id := range m.usage_logs { + ids = append(ids, id) + } + return +} + +// ResetUsageLogs resets all changes to the "usage_logs" edge. +func (m *UserMutation) ResetUsageLogs() { + m.usage_logs = nil + m.clearedusage_logs = false + m.removedusage_logs = nil +} + // Where appends a list predicates to the UserMutation builder. func (m *UserMutation) Where(ps ...predicate.User) { m.predicates = append(m.predicates, ps...) @@ -8473,7 +11447,7 @@ func (m *UserMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *UserMutation) AddedEdges() []string { - edges := make([]string, 0, 5) + edges := make([]string, 0, 6) if m.api_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -8489,6 +11463,9 @@ func (m *UserMutation) AddedEdges() []string { if m.allowed_groups != nil { edges = append(edges, user.EdgeAllowedGroups) } + if m.usage_logs != nil { + edges = append(edges, user.EdgeUsageLogs) + } return edges } @@ -8526,13 +11503,19 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.usage_logs)) + for id := range m.usage_logs { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *UserMutation) RemovedEdges() []string { - edges := make([]string, 0, 5) + edges := make([]string, 0, 6) if m.removedapi_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -8548,6 +11531,9 @@ func (m *UserMutation) RemovedEdges() []string { if m.removedallowed_groups != nil { edges = append(edges, user.EdgeAllowedGroups) } + if m.removedusage_logs != nil { + edges = append(edges, user.EdgeUsageLogs) + } return edges } @@ -8585,13 +11571,19 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.removedusage_logs)) + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return ids } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *UserMutation) ClearedEdges() []string { - edges := make([]string, 0, 5) + edges := make([]string, 0, 6) if m.clearedapi_keys { edges = append(edges, user.EdgeAPIKeys) } @@ -8607,6 +11599,9 @@ func (m *UserMutation) ClearedEdges() []string { if m.clearedallowed_groups { edges = append(edges, user.EdgeAllowedGroups) } + if m.clearedusage_logs { + edges = append(edges, user.EdgeUsageLogs) + } return edges } @@ -8624,6 +11619,8 @@ func (m *UserMutation) EdgeCleared(name string) bool { return m.clearedassigned_subscriptions case user.EdgeAllowedGroups: return m.clearedallowed_groups + case user.EdgeUsageLogs: + return m.clearedusage_logs } return false } @@ -8655,6 +11652,9 @@ func (m *UserMutation) ResetEdge(name string) error { case user.EdgeAllowedGroups: m.ResetAllowedGroups() return nil + case user.EdgeUsageLogs: + m.ResetUsageLogs() + return nil } return fmt.Errorf("unknown User edge %s", name) } @@ -9084,6 +12084,7 @@ type UserSubscriptionMutation struct { id *int64 created_at *time.Time updated_at *time.Time + deleted_at *time.Time starts_at *time.Time expires_at *time.Time status *string @@ -9105,6 +12106,9 @@ type UserSubscriptionMutation struct { clearedgroup bool assigned_by_user *int64 clearedassigned_by_user bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool done bool oldValue func(context.Context) (*UserSubscription, error) predicates []predicate.UserSubscription @@ -9280,6 +12284,55 @@ func (m *UserSubscriptionMutation) ResetUpdatedAt() { m.updated_at = nil } +// SetDeletedAt sets the "deleted_at" field. +func (m *UserSubscriptionMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *UserSubscriptionMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the UserSubscription entity. +// If the UserSubscription object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserSubscriptionMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *UserSubscriptionMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[usersubscription.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *UserSubscriptionMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[usersubscription.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *UserSubscriptionMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, usersubscription.FieldDeletedAt) +} + // SetUserID sets the "user_id" field. func (m *UserSubscriptionMutation) SetUserID(i int64) { m.user = &i @@ -10003,6 +13056,60 @@ func (m *UserSubscriptionMutation) ResetAssignedByUser() { m.clearedassigned_by_user = false } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids. +func (m *UserSubscriptionMutation) AddUsageLogIDs(ids ...int64) { + if m.usage_logs == nil { + m.usage_logs = make(map[int64]struct{}) + } + for i := range ids { + m.usage_logs[ids[i]] = struct{}{} + } +} + +// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity. +func (m *UserSubscriptionMutation) ClearUsageLogs() { + m.clearedusage_logs = true +} + +// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared. +func (m *UserSubscriptionMutation) UsageLogsCleared() bool { + return m.clearedusage_logs +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs. +func (m *UserSubscriptionMutation) RemoveUsageLogIDs(ids ...int64) { + if m.removedusage_logs == nil { + m.removedusage_logs = make(map[int64]struct{}) + } + for i := range ids { + delete(m.usage_logs, ids[i]) + m.removedusage_logs[ids[i]] = struct{}{} + } +} + +// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity. +func (m *UserSubscriptionMutation) RemovedUsageLogsIDs() (ids []int64) { + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return +} + +// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation. +func (m *UserSubscriptionMutation) UsageLogsIDs() (ids []int64) { + for id := range m.usage_logs { + ids = append(ids, id) + } + return +} + +// ResetUsageLogs resets all changes to the "usage_logs" edge. +func (m *UserSubscriptionMutation) ResetUsageLogs() { + m.usage_logs = nil + m.clearedusage_logs = false + m.removedusage_logs = nil +} + // Where appends a list predicates to the UserSubscriptionMutation builder. func (m *UserSubscriptionMutation) Where(ps ...predicate.UserSubscription) { m.predicates = append(m.predicates, ps...) @@ -10037,13 +13144,16 @@ func (m *UserSubscriptionMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserSubscriptionMutation) Fields() []string { - fields := make([]string, 0, 16) + fields := make([]string, 0, 17) if m.created_at != nil { fields = append(fields, usersubscription.FieldCreatedAt) } if m.updated_at != nil { fields = append(fields, usersubscription.FieldUpdatedAt) } + if m.deleted_at != nil { + fields = append(fields, usersubscription.FieldDeletedAt) + } if m.user != nil { fields = append(fields, usersubscription.FieldUserID) } @@ -10098,6 +13208,8 @@ func (m *UserSubscriptionMutation) Field(name string) (ent.Value, bool) { return m.CreatedAt() case usersubscription.FieldUpdatedAt: return m.UpdatedAt() + case usersubscription.FieldDeletedAt: + return m.DeletedAt() case usersubscription.FieldUserID: return m.UserID() case usersubscription.FieldGroupID: @@ -10139,6 +13251,8 @@ func (m *UserSubscriptionMutation) OldField(ctx context.Context, name string) (e return m.OldCreatedAt(ctx) case usersubscription.FieldUpdatedAt: return m.OldUpdatedAt(ctx) + case usersubscription.FieldDeletedAt: + return m.OldDeletedAt(ctx) case usersubscription.FieldUserID: return m.OldUserID(ctx) case usersubscription.FieldGroupID: @@ -10190,6 +13304,13 @@ func (m *UserSubscriptionMutation) SetField(name string, value ent.Value) error } m.SetUpdatedAt(v) return nil + case usersubscription.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil case usersubscription.FieldUserID: v, ok := value.(int64) if !ok { @@ -10357,6 +13478,9 @@ func (m *UserSubscriptionMutation) AddField(name string, value ent.Value) error // mutation. func (m *UserSubscriptionMutation) ClearedFields() []string { var fields []string + if m.FieldCleared(usersubscription.FieldDeletedAt) { + fields = append(fields, usersubscription.FieldDeletedAt) + } if m.FieldCleared(usersubscription.FieldDailyWindowStart) { fields = append(fields, usersubscription.FieldDailyWindowStart) } @@ -10386,6 +13510,9 @@ func (m *UserSubscriptionMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *UserSubscriptionMutation) ClearField(name string) error { switch name { + case usersubscription.FieldDeletedAt: + m.ClearDeletedAt() + return nil case usersubscription.FieldDailyWindowStart: m.ClearDailyWindowStart() return nil @@ -10415,6 +13542,9 @@ func (m *UserSubscriptionMutation) ResetField(name string) error { case usersubscription.FieldUpdatedAt: m.ResetUpdatedAt() return nil + case usersubscription.FieldDeletedAt: + m.ResetDeletedAt() + return nil case usersubscription.FieldUserID: m.ResetUserID() return nil @@ -10463,7 +13593,7 @@ func (m *UserSubscriptionMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *UserSubscriptionMutation) AddedEdges() []string { - edges := make([]string, 0, 3) + edges := make([]string, 0, 4) if m.user != nil { edges = append(edges, usersubscription.EdgeUser) } @@ -10473,6 +13603,9 @@ func (m *UserSubscriptionMutation) AddedEdges() []string { if m.assigned_by_user != nil { edges = append(edges, usersubscription.EdgeAssignedByUser) } + if m.usage_logs != nil { + edges = append(edges, usersubscription.EdgeUsageLogs) + } return edges } @@ -10492,25 +13625,42 @@ func (m *UserSubscriptionMutation) AddedIDs(name string) []ent.Value { if id := m.assigned_by_user; id != nil { return []ent.Value{*id} } + case usersubscription.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.usage_logs)) + for id := range m.usage_logs { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *UserSubscriptionMutation) RemovedEdges() []string { - edges := make([]string, 0, 3) + edges := make([]string, 0, 4) + if m.removedusage_logs != nil { + edges = append(edges, usersubscription.EdgeUsageLogs) + } return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. func (m *UserSubscriptionMutation) RemovedIDs(name string) []ent.Value { + switch name { + case usersubscription.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.removedusage_logs)) + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return ids + } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *UserSubscriptionMutation) ClearedEdges() []string { - edges := make([]string, 0, 3) + edges := make([]string, 0, 4) if m.cleareduser { edges = append(edges, usersubscription.EdgeUser) } @@ -10520,6 +13670,9 @@ func (m *UserSubscriptionMutation) ClearedEdges() []string { if m.clearedassigned_by_user { edges = append(edges, usersubscription.EdgeAssignedByUser) } + if m.clearedusage_logs { + edges = append(edges, usersubscription.EdgeUsageLogs) + } return edges } @@ -10533,6 +13686,8 @@ func (m *UserSubscriptionMutation) EdgeCleared(name string) bool { return m.clearedgroup case usersubscription.EdgeAssignedByUser: return m.clearedassigned_by_user + case usersubscription.EdgeUsageLogs: + return m.clearedusage_logs } return false } @@ -10567,6 +13722,9 @@ func (m *UserSubscriptionMutation) ResetEdge(name string) error { case usersubscription.EdgeAssignedByUser: m.ResetAssignedByUser() return nil + case usersubscription.EdgeUsageLogs: + m.ResetUsageLogs() + return nil } return fmt.Errorf("unknown UserSubscription edge %s", name) } diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index 467dad7b..f6bdf466 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -27,6 +27,9 @@ type RedeemCode func(*sql.Selector) // Setting is the predicate function for setting builders. type Setting func(*sql.Selector) +// UsageLog is the predicate function for usagelog builders. +type UsageLog func(*sql.Selector) + // User is the predicate function for user builders. type User func(*sql.Selector) diff --git a/backend/ent/proxy.go b/backend/ent/proxy.go index eb271c7a..5228b73e 100644 --- a/backend/ent/proxy.go +++ b/backend/ent/proxy.go @@ -36,10 +36,31 @@ type Proxy struct { // Password holds the value of the "password" field. Password *string `json:"password,omitempty"` // Status holds the value of the "status" field. - Status string `json:"status,omitempty"` + Status string `json:"status,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the ProxyQuery when eager-loading is set. + Edges ProxyEdges `json:"edges"` selectValues sql.SelectValues } +// ProxyEdges holds the relations/edges for other nodes in the graph. +type ProxyEdges struct { + // Accounts holds the value of the accounts edge. + Accounts []*Account `json:"accounts,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// AccountsOrErr returns the Accounts value or an error if the edge +// was not loaded in eager-loading. +func (e ProxyEdges) AccountsOrErr() ([]*Account, error) { + if e.loadedTypes[0] { + return e.Accounts, nil + } + return nil, &NotLoadedError{edge: "accounts"} +} + // scanValues returns the types for scanning values from sql.Rows. func (*Proxy) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) @@ -148,6 +169,11 @@ func (_m *Proxy) Value(name string) (ent.Value, error) { return _m.selectValues.Get(name) } +// QueryAccounts queries the "accounts" edge of the Proxy entity. +func (_m *Proxy) QueryAccounts() *AccountQuery { + return NewProxyClient(_m.config).QueryAccounts(_m) +} + // Update returns a builder for updating this Proxy. // Note that you need to call Proxy.Unwrap() before calling this method if this Proxy // was returned from a transaction, and the transaction was committed or rolled back. diff --git a/backend/ent/proxy/proxy.go b/backend/ent/proxy/proxy.go index e5e1067c..db7abcda 100644 --- a/backend/ent/proxy/proxy.go +++ b/backend/ent/proxy/proxy.go @@ -7,6 +7,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -34,8 +35,17 @@ const ( FieldPassword = "password" // FieldStatus holds the string denoting the status field in the database. FieldStatus = "status" + // EdgeAccounts holds the string denoting the accounts edge name in mutations. + EdgeAccounts = "accounts" // Table holds the table name of the proxy in the database. Table = "proxies" + // AccountsTable is the table that holds the accounts relation/edge. + AccountsTable = "accounts" + // AccountsInverseTable is the table name for the Account entity. + // It exists in this package in order to avoid circular dependency with the "account" package. + AccountsInverseTable = "accounts" + // AccountsColumn is the table column denoting the accounts relation/edge. + AccountsColumn = "proxy_id" ) // Columns holds all SQL columns for proxy fields. @@ -150,3 +160,24 @@ func ByPassword(opts ...sql.OrderTermOption) OrderOption { func ByStatus(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStatus, opts...).ToFunc() } + +// ByAccountsCount orders the results by accounts count. +func ByAccountsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAccountsStep(), opts...) + } +} + +// ByAccounts orders the results by accounts terms. +func ByAccounts(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAccountsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newAccountsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AccountsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, AccountsTable, AccountsColumn), + ) +} diff --git a/backend/ent/proxy/where.go b/backend/ent/proxy/where.go index ad92cee6..0a31ad7e 100644 --- a/backend/ent/proxy/where.go +++ b/backend/ent/proxy/where.go @@ -6,6 +6,7 @@ import ( "time" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "github.com/Wei-Shaw/sub2api/ent/predicate" ) @@ -684,6 +685,29 @@ func StatusContainsFold(v string) predicate.Proxy { return predicate.Proxy(sql.FieldContainsFold(FieldStatus, v)) } +// HasAccounts applies the HasEdge predicate on the "accounts" edge. +func HasAccounts() predicate.Proxy { + return predicate.Proxy(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, AccountsTable, AccountsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAccountsWith applies the HasEdge predicate on the "accounts" edge with a given conditions (other predicates). +func HasAccountsWith(preds ...predicate.Account) predicate.Proxy { + return predicate.Proxy(func(s *sql.Selector) { + step := newAccountsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.Proxy) predicate.Proxy { return predicate.Proxy(sql.AndPredicates(predicates...)) diff --git a/backend/ent/proxy_create.go b/backend/ent/proxy_create.go index 386abaec..9687aaa2 100644 --- a/backend/ent/proxy_create.go +++ b/backend/ent/proxy_create.go @@ -11,6 +11,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/account" "github.com/Wei-Shaw/sub2api/ent/proxy" ) @@ -130,6 +131,21 @@ func (_c *ProxyCreate) SetNillableStatus(v *string) *ProxyCreate { return _c } +// AddAccountIDs adds the "accounts" edge to the Account entity by IDs. +func (_c *ProxyCreate) AddAccountIDs(ids ...int64) *ProxyCreate { + _c.mutation.AddAccountIDs(ids...) + return _c +} + +// AddAccounts adds the "accounts" edges to the Account entity. +func (_c *ProxyCreate) AddAccounts(v ...*Account) *ProxyCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddAccountIDs(ids...) +} + // Mutation returns the ProxyMutation object of the builder. func (_c *ProxyCreate) Mutation() *ProxyMutation { return _c.mutation @@ -308,6 +324,22 @@ func (_c *ProxyCreate) createSpec() (*Proxy, *sqlgraph.CreateSpec) { _spec.SetField(proxy.FieldStatus, field.TypeString, value) _node.Status = value } + if nodes := _c.mutation.AccountsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: proxy.AccountsTable, + Columns: []string{proxy.AccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } diff --git a/backend/ent/proxy_query.go b/backend/ent/proxy_query.go index b0599553..1358eed2 100644 --- a/backend/ent/proxy_query.go +++ b/backend/ent/proxy_query.go @@ -4,6 +4,7 @@ package ent import ( "context" + "database/sql/driver" "fmt" "math" @@ -11,6 +12,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/account" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/proxy" ) @@ -18,10 +20,11 @@ import ( // ProxyQuery is the builder for querying Proxy entities. type ProxyQuery struct { config - ctx *QueryContext - order []proxy.OrderOption - inters []Interceptor - predicates []predicate.Proxy + ctx *QueryContext + order []proxy.OrderOption + inters []Interceptor + predicates []predicate.Proxy + withAccounts *AccountQuery // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -58,6 +61,28 @@ func (_q *ProxyQuery) Order(o ...proxy.OrderOption) *ProxyQuery { return _q } +// QueryAccounts chains the current query on the "accounts" edge. +func (_q *ProxyQuery) QueryAccounts() *AccountQuery { + query := (&AccountClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(proxy.Table, proxy.FieldID, selector), + sqlgraph.To(account.Table, account.FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, proxy.AccountsTable, proxy.AccountsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // First returns the first Proxy entity from the query. // Returns a *NotFoundError when no Proxy was found. func (_q *ProxyQuery) First(ctx context.Context) (*Proxy, error) { @@ -245,17 +270,29 @@ func (_q *ProxyQuery) Clone() *ProxyQuery { return nil } return &ProxyQuery{ - config: _q.config, - ctx: _q.ctx.Clone(), - order: append([]proxy.OrderOption{}, _q.order...), - inters: append([]Interceptor{}, _q.inters...), - predicates: append([]predicate.Proxy{}, _q.predicates...), + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]proxy.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.Proxy{}, _q.predicates...), + withAccounts: _q.withAccounts.Clone(), // clone intermediate query. sql: _q.sql.Clone(), path: _q.path, } } +// WithAccounts tells the query-builder to eager-load the nodes that are connected to +// the "accounts" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *ProxyQuery) WithAccounts(opts ...func(*AccountQuery)) *ProxyQuery { + query := (&AccountClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAccounts = query + return _q +} + // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. // @@ -332,8 +369,11 @@ func (_q *ProxyQuery) prepareQuery(ctx context.Context) error { func (_q *ProxyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Proxy, error) { var ( - nodes = []*Proxy{} - _spec = _q.querySpec() + nodes = []*Proxy{} + _spec = _q.querySpec() + loadedTypes = [1]bool{ + _q.withAccounts != nil, + } ) _spec.ScanValues = func(columns []string) ([]any, error) { return (*Proxy).scanValues(nil, columns) @@ -341,6 +381,7 @@ func (_q *ProxyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Proxy, _spec.Assign = func(columns []string, values []any) error { node := &Proxy{config: _q.config} nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } for i := range hooks { @@ -352,9 +393,50 @@ func (_q *ProxyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Proxy, if len(nodes) == 0 { return nodes, nil } + if query := _q.withAccounts; query != nil { + if err := _q.loadAccounts(ctx, query, nodes, + func(n *Proxy) { n.Edges.Accounts = []*Account{} }, + func(n *Proxy, e *Account) { n.Edges.Accounts = append(n.Edges.Accounts, e) }); err != nil { + return nil, err + } + } return nodes, nil } +func (_q *ProxyQuery) loadAccounts(ctx context.Context, query *AccountQuery, nodes []*Proxy, init func(*Proxy), assign func(*Proxy, *Account)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*Proxy) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(account.FieldProxyID) + } + query.Where(predicate.Account(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(proxy.AccountsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.ProxyID + if fk == nil { + return fmt.Errorf(`foreign-key "proxy_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "proxy_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} + func (_q *ProxyQuery) sqlCount(ctx context.Context) (int, error) { _spec := _q.querySpec() _spec.Node.Columns = _q.ctx.Fields diff --git a/backend/ent/proxy_update.go b/backend/ent/proxy_update.go index 3f5e1a7f..d487857f 100644 --- a/backend/ent/proxy_update.go +++ b/backend/ent/proxy_update.go @@ -11,6 +11,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/account" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/proxy" ) @@ -171,11 +172,47 @@ func (_u *ProxyUpdate) SetNillableStatus(v *string) *ProxyUpdate { return _u } +// AddAccountIDs adds the "accounts" edge to the Account entity by IDs. +func (_u *ProxyUpdate) AddAccountIDs(ids ...int64) *ProxyUpdate { + _u.mutation.AddAccountIDs(ids...) + return _u +} + +// AddAccounts adds the "accounts" edges to the Account entity. +func (_u *ProxyUpdate) AddAccounts(v ...*Account) *ProxyUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAccountIDs(ids...) +} + // Mutation returns the ProxyMutation object of the builder. func (_u *ProxyUpdate) Mutation() *ProxyMutation { return _u.mutation } +// ClearAccounts clears all "accounts" edges to the Account entity. +func (_u *ProxyUpdate) ClearAccounts() *ProxyUpdate { + _u.mutation.ClearAccounts() + return _u +} + +// RemoveAccountIDs removes the "accounts" edge to Account entities by IDs. +func (_u *ProxyUpdate) RemoveAccountIDs(ids ...int64) *ProxyUpdate { + _u.mutation.RemoveAccountIDs(ids...) + return _u +} + +// RemoveAccounts removes "accounts" edges to Account entities. +func (_u *ProxyUpdate) RemoveAccounts(v ...*Account) *ProxyUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAccountIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *ProxyUpdate) Save(ctx context.Context) (int, error) { if err := _u.defaults(); err != nil { @@ -304,6 +341,51 @@ func (_u *ProxyUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.Status(); ok { _spec.SetField(proxy.FieldStatus, field.TypeString, value) } + if _u.mutation.AccountsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: proxy.AccountsTable, + Columns: []string{proxy.AccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAccountsIDs(); len(nodes) > 0 && !_u.mutation.AccountsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: proxy.AccountsTable, + Columns: []string{proxy.AccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AccountsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: proxy.AccountsTable, + Columns: []string{proxy.AccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{proxy.Label} @@ -467,11 +549,47 @@ func (_u *ProxyUpdateOne) SetNillableStatus(v *string) *ProxyUpdateOne { return _u } +// AddAccountIDs adds the "accounts" edge to the Account entity by IDs. +func (_u *ProxyUpdateOne) AddAccountIDs(ids ...int64) *ProxyUpdateOne { + _u.mutation.AddAccountIDs(ids...) + return _u +} + +// AddAccounts adds the "accounts" edges to the Account entity. +func (_u *ProxyUpdateOne) AddAccounts(v ...*Account) *ProxyUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAccountIDs(ids...) +} + // Mutation returns the ProxyMutation object of the builder. func (_u *ProxyUpdateOne) Mutation() *ProxyMutation { return _u.mutation } +// ClearAccounts clears all "accounts" edges to the Account entity. +func (_u *ProxyUpdateOne) ClearAccounts() *ProxyUpdateOne { + _u.mutation.ClearAccounts() + return _u +} + +// RemoveAccountIDs removes the "accounts" edge to Account entities by IDs. +func (_u *ProxyUpdateOne) RemoveAccountIDs(ids ...int64) *ProxyUpdateOne { + _u.mutation.RemoveAccountIDs(ids...) + return _u +} + +// RemoveAccounts removes "accounts" edges to Account entities. +func (_u *ProxyUpdateOne) RemoveAccounts(v ...*Account) *ProxyUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAccountIDs(ids...) +} + // Where appends a list predicates to the ProxyUpdate builder. func (_u *ProxyUpdateOne) Where(ps ...predicate.Proxy) *ProxyUpdateOne { _u.mutation.Where(ps...) @@ -630,6 +748,51 @@ func (_u *ProxyUpdateOne) sqlSave(ctx context.Context) (_node *Proxy, err error) if value, ok := _u.mutation.Status(); ok { _spec.SetField(proxy.FieldStatus, field.TypeString, value) } + if _u.mutation.AccountsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: proxy.AccountsTable, + Columns: []string{proxy.AccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAccountsIDs(); len(nodes) > 0 && !_u.mutation.AccountsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: proxy.AccountsTable, + Columns: []string{proxy.AccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AccountsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: proxy.AccountsTable, + Columns: []string{proxy.AccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &Proxy{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index ef5e6bec..da0accd7 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -13,6 +13,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/schema" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -259,6 +260,10 @@ func init() { group.DefaultSubscriptionType = groupDescSubscriptionType.Default.(string) // group.SubscriptionTypeValidator is a validator for the "subscription_type" field. It is called by the builders before save. group.SubscriptionTypeValidator = groupDescSubscriptionType.Validators[0].(func(string) error) + // groupDescDefaultValidityDays is the schema descriptor for default_validity_days field. + groupDescDefaultValidityDays := groupFields[10].Descriptor() + // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. + group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) proxyMixin := schema.Proxy{}.Mixin() proxyMixinHooks1 := proxyMixin[1].Hooks() proxy.Hooks[0] = proxyMixinHooks1[0] @@ -420,6 +425,108 @@ func init() { setting.DefaultUpdatedAt = settingDescUpdatedAt.Default.(func() time.Time) // setting.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. setting.UpdateDefaultUpdatedAt = settingDescUpdatedAt.UpdateDefault.(func() time.Time) + usagelogFields := schema.UsageLog{}.Fields() + _ = usagelogFields + // usagelogDescRequestID is the schema descriptor for request_id field. + usagelogDescRequestID := usagelogFields[3].Descriptor() + // usagelog.RequestIDValidator is a validator for the "request_id" field. It is called by the builders before save. + usagelog.RequestIDValidator = func() func(string) error { + validators := usagelogDescRequestID.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(request_id string) error { + for _, fn := range fns { + if err := fn(request_id); err != nil { + return err + } + } + return nil + } + }() + // usagelogDescModel is the schema descriptor for model field. + usagelogDescModel := usagelogFields[4].Descriptor() + // usagelog.ModelValidator is a validator for the "model" field. It is called by the builders before save. + usagelog.ModelValidator = func() func(string) error { + validators := usagelogDescModel.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(model string) error { + for _, fn := range fns { + if err := fn(model); err != nil { + return err + } + } + return nil + } + }() + // usagelogDescInputTokens is the schema descriptor for input_tokens field. + usagelogDescInputTokens := usagelogFields[7].Descriptor() + // usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field. + usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int) + // usagelogDescOutputTokens is the schema descriptor for output_tokens field. + usagelogDescOutputTokens := usagelogFields[8].Descriptor() + // usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field. + usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int) + // usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field. + usagelogDescCacheCreationTokens := usagelogFields[9].Descriptor() + // usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field. + usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int) + // usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field. + usagelogDescCacheReadTokens := usagelogFields[10].Descriptor() + // usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field. + usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int) + // usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field. + usagelogDescCacheCreation5mTokens := usagelogFields[11].Descriptor() + // usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field. + usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int) + // usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field. + usagelogDescCacheCreation1hTokens := usagelogFields[12].Descriptor() + // usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field. + usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int) + // usagelogDescInputCost is the schema descriptor for input_cost field. + usagelogDescInputCost := usagelogFields[13].Descriptor() + // usagelog.DefaultInputCost holds the default value on creation for the input_cost field. + usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64) + // usagelogDescOutputCost is the schema descriptor for output_cost field. + usagelogDescOutputCost := usagelogFields[14].Descriptor() + // usagelog.DefaultOutputCost holds the default value on creation for the output_cost field. + usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64) + // usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field. + usagelogDescCacheCreationCost := usagelogFields[15].Descriptor() + // usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field. + usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64) + // usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field. + usagelogDescCacheReadCost := usagelogFields[16].Descriptor() + // usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field. + usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64) + // usagelogDescTotalCost is the schema descriptor for total_cost field. + usagelogDescTotalCost := usagelogFields[17].Descriptor() + // usagelog.DefaultTotalCost holds the default value on creation for the total_cost field. + usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64) + // usagelogDescActualCost is the schema descriptor for actual_cost field. + usagelogDescActualCost := usagelogFields[18].Descriptor() + // usagelog.DefaultActualCost holds the default value on creation for the actual_cost field. + usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64) + // usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field. + usagelogDescRateMultiplier := usagelogFields[19].Descriptor() + // usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field. + usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64) + // usagelogDescBillingType is the schema descriptor for billing_type field. + usagelogDescBillingType := usagelogFields[20].Descriptor() + // usagelog.DefaultBillingType holds the default value on creation for the billing_type field. + usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8) + // usagelogDescStream is the schema descriptor for stream field. + usagelogDescStream := usagelogFields[21].Descriptor() + // usagelog.DefaultStream holds the default value on creation for the stream field. + usagelog.DefaultStream = usagelogDescStream.Default.(bool) + // usagelogDescCreatedAt is the schema descriptor for created_at field. + usagelogDescCreatedAt := usagelogFields[24].Descriptor() + // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. + usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() userMixinHooks1 := userMixin[1].Hooks() user.Hooks[0] = userMixinHooks1[0] @@ -518,6 +625,10 @@ func init() { // userallowedgroup.DefaultCreatedAt holds the default value on creation for the created_at field. userallowedgroup.DefaultCreatedAt = userallowedgroupDescCreatedAt.Default.(func() time.Time) usersubscriptionMixin := schema.UserSubscription{}.Mixin() + usersubscriptionMixinHooks1 := usersubscriptionMixin[1].Hooks() + usersubscription.Hooks[0] = usersubscriptionMixinHooks1[0] + usersubscriptionMixinInters1 := usersubscriptionMixin[1].Interceptors() + usersubscription.Interceptors[0] = usersubscriptionMixinInters1[0] usersubscriptionMixinFields0 := usersubscriptionMixin[0].Fields() _ = usersubscriptionMixinFields0 usersubscriptionFields := schema.UserSubscription{}.Fields() diff --git a/backend/ent/schema/account.go b/backend/ent/schema/account.go index c1dd64af..2561dc17 100644 --- a/backend/ent/schema/account.go +++ b/backend/ent/schema/account.go @@ -168,6 +168,13 @@ func (Account) Edges() []ent.Edge { // 一个账户可以属于多个分组,一个分组可以包含多个账户 edge.To("groups", Group.Type). Through("account_groups", AccountGroup.Type), + // proxy: 账户使用的代理配置(可选的一对一关系) + // 使用已有的 proxy_id 外键字段 + edge.To("proxy", Proxy.Type). + Field("proxy_id"). + Unique(), + // usage_logs: 该账户的使用日志 + edge.To("usage_logs", UsageLog.Type), } } diff --git a/backend/ent/schema/account_group.go b/backend/ent/schema/account_group.go index 66729752..aa270f08 100644 --- a/backend/ent/schema/account_group.go +++ b/backend/ent/schema/account_group.go @@ -4,6 +4,7 @@ import ( "time" "entgo.io/ent" + "entgo.io/ent/dialect" "entgo.io/ent/dialect/entsql" "entgo.io/ent/schema" "entgo.io/ent/schema/edge" @@ -33,7 +34,8 @@ func (AccountGroup) Fields() []ent.Field { Default(50), field.Time("created_at"). Immutable(). - Default(time.Now), + Default(time.Now). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), } } diff --git a/backend/ent/schema/api_key.go b/backend/ent/schema/api_key.go index 0f0f830e..f9ece05e 100644 --- a/backend/ent/schema/api_key.go +++ b/backend/ent/schema/api_key.go @@ -60,12 +60,13 @@ func (ApiKey) Edges() []ent.Edge { Ref("api_keys"). Field("group_id"). Unique(), + edge.To("usage_logs", UsageLog.Type), } } func (ApiKey) Indexes() []ent.Index { return []ent.Index{ - index.Fields("key").Unique(), + // key 字段已在 Fields() 中声明 Unique(),无需重复索引 index.Fields("user_id"), index.Fields("group_id"), index.Fields("status"), diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 2c30c979..7f3ed167 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -69,6 +69,8 @@ func (Group) Fields() []ent.Field { Optional(). Nillable(). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Int("default_validity_days"). + Default(30), } } @@ -77,6 +79,7 @@ func (Group) Edges() []ent.Edge { edge.To("api_keys", ApiKey.Type), edge.To("redeem_codes", RedeemCode.Type), edge.To("subscriptions", UserSubscription.Type), + edge.To("usage_logs", UsageLog.Type), edge.From("accounts", Account.Type). Ref("groups"). Through("account_groups", AccountGroup.Type), @@ -88,7 +91,7 @@ func (Group) Edges() []ent.Edge { func (Group) Indexes() []ent.Index { return []ent.Index{ - index.Fields("name").Unique(), + // name 字段已在 Fields() 中声明 Unique(),无需重复索引 index.Fields("status"), index.Fields("platform"), index.Fields("subscription_type"), diff --git a/backend/ent/schema/proxy.go b/backend/ent/schema/proxy.go index 45608c96..46d657d3 100644 --- a/backend/ent/schema/proxy.go +++ b/backend/ent/schema/proxy.go @@ -6,6 +6,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/entsql" "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" "entgo.io/ent/schema/index" ) @@ -54,6 +55,15 @@ func (Proxy) Fields() []ent.Field { } } +// Edges 定义代理实体的关联关系。 +func (Proxy) Edges() []ent.Edge { + return []ent.Edge{ + // accounts: 使用此代理的账户(反向边) + edge.From("accounts", Account.Type). + Ref("proxy"), + } +} + func (Proxy) Indexes() []ent.Index { return []ent.Index{ index.Fields("status"), diff --git a/backend/ent/schema/redeem_code.go b/backend/ent/schema/redeem_code.go index 0ecb48b7..b4664e06 100644 --- a/backend/ent/schema/redeem_code.go +++ b/backend/ent/schema/redeem_code.go @@ -15,6 +15,14 @@ import ( ) // RedeemCode holds the schema definition for the RedeemCode entity. +// +// 删除策略:硬删除 +// RedeemCode 使用硬删除而非软删除,原因如下: +// - 兑换码具有一次性使用特性,删除后无需保留历史记录 +// - 已使用的兑换码通过 status 和 used_at 字段追踪,无需依赖软删除 +// - 减少数据库存储压力和查询复杂度 +// +// 如需审计已删除的兑换码,建议在删除前将关键信息写入审计日志表。 type RedeemCode struct { ent.Schema } @@ -78,7 +86,7 @@ func (RedeemCode) Edges() []ent.Edge { func (RedeemCode) Indexes() []ent.Index { return []ent.Index{ - index.Fields("code").Unique(), + // code 字段已在 Fields() 中声明 Unique(),无需重复索引 index.Fields("status"), index.Fields("used_by"), index.Fields("group_id"), diff --git a/backend/ent/schema/setting.go b/backend/ent/schema/setting.go index f31f2a41..3f896fab 100644 --- a/backend/ent/schema/setting.go +++ b/backend/ent/schema/setting.go @@ -8,10 +8,17 @@ import ( "entgo.io/ent/dialect/entsql" "entgo.io/ent/schema" "entgo.io/ent/schema/field" - "entgo.io/ent/schema/index" ) // Setting holds the schema definition for the Setting entity. +// +// 删除策略:硬删除 +// Setting 使用硬删除而非软删除,原因如下: +// - 系统设置是简单的键值对,删除即意味着恢复默认值 +// - 设置变更通常通过应用日志追踪,无需在数据库层面保留历史 +// - 保持表结构简洁,避免无效数据积累 +// +// 如需设置变更审计,建议在更新/删除前将变更记录写入审计日志表。 type Setting struct { ent.Schema } @@ -43,7 +50,6 @@ func (Setting) Fields() []ent.Field { } func (Setting) Indexes() []ent.Index { - return []ent.Index{ - index.Fields("key").Unique(), - } + // key 字段已在 Fields() 中声明 Unique(),无需额外索引 + return nil } diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go new file mode 100644 index 00000000..6f78e8a9 --- /dev/null +++ b/backend/ent/schema/usage_log.go @@ -0,0 +1,152 @@ +// Package schema 定义 Ent ORM 的数据库 schema。 +package schema + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// UsageLog 定义使用日志实体的 schema。 +// +// 使用日志记录每次 API 调用的详细信息,包括 token 使用量、成本计算等。 +// 这是一个只追加的表,不支持更新和删除。 +type UsageLog struct { + ent.Schema +} + +// Annotations 返回 schema 的注解配置。 +func (UsageLog) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "usage_logs"}, + } +} + +// Fields 定义使用日志实体的所有字段。 +func (UsageLog) Fields() []ent.Field { + return []ent.Field{ + // 关联字段 + field.Int64("user_id"), + field.Int64("api_key_id"), + field.Int64("account_id"), + field.String("request_id"). + MaxLen(64). + NotEmpty(), + field.String("model"). + MaxLen(100). + NotEmpty(), + field.Int64("group_id"). + Optional(). + Nillable(), + field.Int64("subscription_id"). + Optional(). + Nillable(), + + // Token 计数字段 + field.Int("input_tokens"). + Default(0), + field.Int("output_tokens"). + Default(0), + field.Int("cache_creation_tokens"). + Default(0), + field.Int("cache_read_tokens"). + Default(0), + field.Int("cache_creation_5m_tokens"). + Default(0), + field.Int("cache_creation_1h_tokens"). + Default(0), + + // 成本字段 + field.Float("input_cost"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("output_cost"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("cache_creation_cost"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("cache_read_cost"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("total_cost"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("actual_cost"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("rate_multiplier"). + Default(1). + SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}), + + // 其他字段 + field.Int8("billing_type"). + Default(0), + field.Bool("stream"). + Default(false), + field.Int("duration_ms"). + Optional(). + Nillable(), + field.Int("first_token_ms"). + Optional(). + Nillable(), + + // 时间戳(只有 created_at,日志不可修改) + field.Time("created_at"). + Default(time.Now). + Immutable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + } +} + +// Edges 定义使用日志实体的关联关系。 +func (UsageLog) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("user", User.Type). + Ref("usage_logs"). + Field("user_id"). + Required(). + Unique(), + edge.From("api_key", ApiKey.Type). + Ref("usage_logs"). + Field("api_key_id"). + Required(). + Unique(), + edge.From("account", Account.Type). + Ref("usage_logs"). + Field("account_id"). + Required(). + Unique(), + edge.From("group", Group.Type). + Ref("usage_logs"). + Field("group_id"). + Unique(), + edge.From("subscription", UserSubscription.Type). + Ref("usage_logs"). + Field("subscription_id"). + Unique(), + } +} + +// Indexes 定义数据库索引,优化查询性能。 +func (UsageLog) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("user_id"), + index.Fields("api_key_id"), + index.Fields("account_id"), + index.Fields("group_id"), + index.Fields("subscription_id"), + index.Fields("created_at"), + index.Fields("model"), + index.Fields("request_id"), + // 复合索引用于时间范围查询 + index.Fields("user_id", "created_at"), + index.Fields("api_key_id", "created_at"), + } +} diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index e76799ed..ba7f0ce7 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -73,12 +73,13 @@ func (User) Edges() []ent.Edge { edge.To("assigned_subscriptions", UserSubscription.Type), edge.To("allowed_groups", Group.Type). Through("user_allowed_groups", UserAllowedGroup.Type), + edge.To("usage_logs", UsageLog.Type), } } func (User) Indexes() []ent.Index { return []ent.Index{ - index.Fields("email").Unique(), + // email 字段已在 Fields() 中声明 Unique(),无需重复索引 index.Fields("status"), index.Fields("deleted_at"), } diff --git a/backend/ent/schema/user_allowed_group.go b/backend/ent/schema/user_allowed_group.go index 8fce97c2..94156219 100644 --- a/backend/ent/schema/user_allowed_group.go +++ b/backend/ent/schema/user_allowed_group.go @@ -4,6 +4,7 @@ import ( "time" "entgo.io/ent" + "entgo.io/ent/dialect" "entgo.io/ent/dialect/entsql" "entgo.io/ent/schema" "entgo.io/ent/schema/edge" @@ -31,7 +32,8 @@ func (UserAllowedGroup) Fields() []ent.Field { field.Int64("group_id"), field.Time("created_at"). Immutable(). - Default(time.Now), + Default(time.Now). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), } } diff --git a/backend/ent/schema/user_subscription.go b/backend/ent/schema/user_subscription.go index bcb0da71..88c4ea8f 100644 --- a/backend/ent/schema/user_subscription.go +++ b/backend/ent/schema/user_subscription.go @@ -29,6 +29,7 @@ func (UserSubscription) Annotations() []schema.Annotation { func (UserSubscription) Mixin() []ent.Mixin { return []ent.Mixin{ mixins.TimeMixin{}, + mixins.SoftDeleteMixin{}, } } @@ -97,6 +98,7 @@ func (UserSubscription) Edges() []ent.Edge { Ref("assigned_subscriptions"). Field("assigned_by"). Unique(), + edge.To("usage_logs", UsageLog.Type), } } @@ -108,5 +110,6 @@ func (UserSubscription) Indexes() []ent.Index { index.Fields("expires_at"), index.Fields("assigned_by"), index.Fields("user_id", "group_id").Unique(), + index.Fields("deleted_at"), } } diff --git a/backend/ent/tx.go b/backend/ent/tx.go index fbb68edf..ecb0409d 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -28,6 +28,8 @@ type Tx struct { RedeemCode *RedeemCodeClient // Setting is the client for interacting with the Setting builders. Setting *SettingClient + // UsageLog is the client for interacting with the UsageLog builders. + UsageLog *UsageLogClient // User is the client for interacting with the User builders. User *UserClient // UserAllowedGroup is the client for interacting with the UserAllowedGroup builders. @@ -172,6 +174,7 @@ func (tx *Tx) init() { tx.Proxy = NewProxyClient(tx.config) tx.RedeemCode = NewRedeemCodeClient(tx.config) tx.Setting = NewSettingClient(tx.config) + tx.UsageLog = NewUsageLogClient(tx.config) tx.User = NewUserClient(tx.config) tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config) tx.UserSubscription = NewUserSubscriptionClient(tx.config) @@ -238,7 +241,6 @@ func (tx *txDriver) Query(ctx context.Context, query string, args, v any) error var _ dialect.Driver = (*txDriver)(nil) -// ExecContext 透传到底层事务,用于在 ent 事务中执行原生 SQL(与 ent 写入保持同一事务)。 // ExecContext allows calling the underlying ExecContext method of the transaction if it is supported by it. // See, database/sql#Tx.ExecContext for more information. func (tx *txDriver) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) { @@ -251,7 +253,6 @@ func (tx *txDriver) ExecContext(ctx context.Context, query string, args ...any) return ex.ExecContext(ctx, query, args...) } -// QueryContext 透传到底层事务,用于在 ent 事务中执行原生查询并共享锁语义。 // QueryContext allows calling the underlying QueryContext method of the transaction if it is supported by it. // See, database/sql#Tx.QueryContext for more information. func (tx *txDriver) QueryContext(ctx context.Context, query string, args ...any) (*stdsql.Rows, error) { diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go new file mode 100644 index 00000000..e01780fe --- /dev/null +++ b/backend/ent/usagelog.go @@ -0,0 +1,491 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/account" + "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/usagelog" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/ent/usersubscription" +) + +// UsageLog is the model entity for the UsageLog schema. +type UsageLog struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // UserID holds the value of the "user_id" field. + UserID int64 `json:"user_id,omitempty"` + // APIKeyID holds the value of the "api_key_id" field. + APIKeyID int64 `json:"api_key_id,omitempty"` + // AccountID holds the value of the "account_id" field. + AccountID int64 `json:"account_id,omitempty"` + // RequestID holds the value of the "request_id" field. + RequestID string `json:"request_id,omitempty"` + // Model holds the value of the "model" field. + Model string `json:"model,omitempty"` + // GroupID holds the value of the "group_id" field. + GroupID *int64 `json:"group_id,omitempty"` + // SubscriptionID holds the value of the "subscription_id" field. + SubscriptionID *int64 `json:"subscription_id,omitempty"` + // InputTokens holds the value of the "input_tokens" field. + InputTokens int `json:"input_tokens,omitempty"` + // OutputTokens holds the value of the "output_tokens" field. + OutputTokens int `json:"output_tokens,omitempty"` + // CacheCreationTokens holds the value of the "cache_creation_tokens" field. + CacheCreationTokens int `json:"cache_creation_tokens,omitempty"` + // CacheReadTokens holds the value of the "cache_read_tokens" field. + CacheReadTokens int `json:"cache_read_tokens,omitempty"` + // CacheCreation5mTokens holds the value of the "cache_creation_5m_tokens" field. + CacheCreation5mTokens int `json:"cache_creation_5m_tokens,omitempty"` + // CacheCreation1hTokens holds the value of the "cache_creation_1h_tokens" field. + CacheCreation1hTokens int `json:"cache_creation_1h_tokens,omitempty"` + // InputCost holds the value of the "input_cost" field. + InputCost float64 `json:"input_cost,omitempty"` + // OutputCost holds the value of the "output_cost" field. + OutputCost float64 `json:"output_cost,omitempty"` + // CacheCreationCost holds the value of the "cache_creation_cost" field. + CacheCreationCost float64 `json:"cache_creation_cost,omitempty"` + // CacheReadCost holds the value of the "cache_read_cost" field. + CacheReadCost float64 `json:"cache_read_cost,omitempty"` + // TotalCost holds the value of the "total_cost" field. + TotalCost float64 `json:"total_cost,omitempty"` + // ActualCost holds the value of the "actual_cost" field. + ActualCost float64 `json:"actual_cost,omitempty"` + // RateMultiplier holds the value of the "rate_multiplier" field. + RateMultiplier float64 `json:"rate_multiplier,omitempty"` + // BillingType holds the value of the "billing_type" field. + BillingType int8 `json:"billing_type,omitempty"` + // Stream holds the value of the "stream" field. + Stream bool `json:"stream,omitempty"` + // DurationMs holds the value of the "duration_ms" field. + DurationMs *int `json:"duration_ms,omitempty"` + // FirstTokenMs holds the value of the "first_token_ms" field. + FirstTokenMs *int `json:"first_token_ms,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the UsageLogQuery when eager-loading is set. + Edges UsageLogEdges `json:"edges"` + selectValues sql.SelectValues +} + +// UsageLogEdges holds the relations/edges for other nodes in the graph. +type UsageLogEdges struct { + // User holds the value of the user edge. + User *User `json:"user,omitempty"` + // APIKey holds the value of the api_key edge. + APIKey *ApiKey `json:"api_key,omitempty"` + // Account holds the value of the account edge. + Account *Account `json:"account,omitempty"` + // Group holds the value of the group edge. + Group *Group `json:"group,omitempty"` + // Subscription holds the value of the subscription edge. + Subscription *UserSubscription `json:"subscription,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [5]bool +} + +// UserOrErr returns the User value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e UsageLogEdges) UserOrErr() (*User, error) { + if e.User != nil { + return e.User, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: user.Label} + } + return nil, &NotLoadedError{edge: "user"} +} + +// APIKeyOrErr returns the APIKey value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e UsageLogEdges) APIKeyOrErr() (*ApiKey, error) { + if e.APIKey != nil { + return e.APIKey, nil + } else if e.loadedTypes[1] { + return nil, &NotFoundError{label: apikey.Label} + } + return nil, &NotLoadedError{edge: "api_key"} +} + +// AccountOrErr returns the Account value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e UsageLogEdges) AccountOrErr() (*Account, error) { + if e.Account != nil { + return e.Account, nil + } else if e.loadedTypes[2] { + return nil, &NotFoundError{label: account.Label} + } + return nil, &NotLoadedError{edge: "account"} +} + +// GroupOrErr returns the Group value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e UsageLogEdges) GroupOrErr() (*Group, error) { + if e.Group != nil { + return e.Group, nil + } else if e.loadedTypes[3] { + return nil, &NotFoundError{label: group.Label} + } + return nil, &NotLoadedError{edge: "group"} +} + +// SubscriptionOrErr returns the Subscription value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e UsageLogEdges) SubscriptionOrErr() (*UserSubscription, error) { + if e.Subscription != nil { + return e.Subscription, nil + } else if e.loadedTypes[4] { + return nil, &NotFoundError{label: usersubscription.Label} + } + return nil, &NotLoadedError{edge: "subscription"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*UsageLog) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case usagelog.FieldStream: + values[i] = new(sql.NullBool) + case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier: + values[i] = new(sql.NullFloat64) + case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs: + values[i] = new(sql.NullInt64) + case usagelog.FieldRequestID, usagelog.FieldModel: + values[i] = new(sql.NullString) + case usagelog.FieldCreatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the UsageLog fields. +func (_m *UsageLog) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case usagelog.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case usagelog.FieldUserID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field user_id", values[i]) + } else if value.Valid { + _m.UserID = value.Int64 + } + case usagelog.FieldAPIKeyID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field api_key_id", values[i]) + } else if value.Valid { + _m.APIKeyID = value.Int64 + } + case usagelog.FieldAccountID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field account_id", values[i]) + } else if value.Valid { + _m.AccountID = value.Int64 + } + case usagelog.FieldRequestID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field request_id", values[i]) + } else if value.Valid { + _m.RequestID = value.String + } + case usagelog.FieldModel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field model", values[i]) + } else if value.Valid { + _m.Model = value.String + } + case usagelog.FieldGroupID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field group_id", values[i]) + } else if value.Valid { + _m.GroupID = new(int64) + *_m.GroupID = value.Int64 + } + case usagelog.FieldSubscriptionID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field subscription_id", values[i]) + } else if value.Valid { + _m.SubscriptionID = new(int64) + *_m.SubscriptionID = value.Int64 + } + case usagelog.FieldInputTokens: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field input_tokens", values[i]) + } else if value.Valid { + _m.InputTokens = int(value.Int64) + } + case usagelog.FieldOutputTokens: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field output_tokens", values[i]) + } else if value.Valid { + _m.OutputTokens = int(value.Int64) + } + case usagelog.FieldCacheCreationTokens: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field cache_creation_tokens", values[i]) + } else if value.Valid { + _m.CacheCreationTokens = int(value.Int64) + } + case usagelog.FieldCacheReadTokens: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field cache_read_tokens", values[i]) + } else if value.Valid { + _m.CacheReadTokens = int(value.Int64) + } + case usagelog.FieldCacheCreation5mTokens: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field cache_creation_5m_tokens", values[i]) + } else if value.Valid { + _m.CacheCreation5mTokens = int(value.Int64) + } + case usagelog.FieldCacheCreation1hTokens: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field cache_creation_1h_tokens", values[i]) + } else if value.Valid { + _m.CacheCreation1hTokens = int(value.Int64) + } + case usagelog.FieldInputCost: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field input_cost", values[i]) + } else if value.Valid { + _m.InputCost = value.Float64 + } + case usagelog.FieldOutputCost: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field output_cost", values[i]) + } else if value.Valid { + _m.OutputCost = value.Float64 + } + case usagelog.FieldCacheCreationCost: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field cache_creation_cost", values[i]) + } else if value.Valid { + _m.CacheCreationCost = value.Float64 + } + case usagelog.FieldCacheReadCost: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field cache_read_cost", values[i]) + } else if value.Valid { + _m.CacheReadCost = value.Float64 + } + case usagelog.FieldTotalCost: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field total_cost", values[i]) + } else if value.Valid { + _m.TotalCost = value.Float64 + } + case usagelog.FieldActualCost: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field actual_cost", values[i]) + } else if value.Valid { + _m.ActualCost = value.Float64 + } + case usagelog.FieldRateMultiplier: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field rate_multiplier", values[i]) + } else if value.Valid { + _m.RateMultiplier = value.Float64 + } + case usagelog.FieldBillingType: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field billing_type", values[i]) + } else if value.Valid { + _m.BillingType = int8(value.Int64) + } + case usagelog.FieldStream: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field stream", values[i]) + } else if value.Valid { + _m.Stream = value.Bool + } + case usagelog.FieldDurationMs: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field duration_ms", values[i]) + } else if value.Valid { + _m.DurationMs = new(int) + *_m.DurationMs = int(value.Int64) + } + case usagelog.FieldFirstTokenMs: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field first_token_ms", values[i]) + } else if value.Valid { + _m.FirstTokenMs = new(int) + *_m.FirstTokenMs = int(value.Int64) + } + case usagelog.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the UsageLog. +// This includes values selected through modifiers, order, etc. +func (_m *UsageLog) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryUser queries the "user" edge of the UsageLog entity. +func (_m *UsageLog) QueryUser() *UserQuery { + return NewUsageLogClient(_m.config).QueryUser(_m) +} + +// QueryAPIKey queries the "api_key" edge of the UsageLog entity. +func (_m *UsageLog) QueryAPIKey() *ApiKeyQuery { + return NewUsageLogClient(_m.config).QueryAPIKey(_m) +} + +// QueryAccount queries the "account" edge of the UsageLog entity. +func (_m *UsageLog) QueryAccount() *AccountQuery { + return NewUsageLogClient(_m.config).QueryAccount(_m) +} + +// QueryGroup queries the "group" edge of the UsageLog entity. +func (_m *UsageLog) QueryGroup() *GroupQuery { + return NewUsageLogClient(_m.config).QueryGroup(_m) +} + +// QuerySubscription queries the "subscription" edge of the UsageLog entity. +func (_m *UsageLog) QuerySubscription() *UserSubscriptionQuery { + return NewUsageLogClient(_m.config).QuerySubscription(_m) +} + +// Update returns a builder for updating this UsageLog. +// Note that you need to call UsageLog.Unwrap() before calling this method if this UsageLog +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *UsageLog) Update() *UsageLogUpdateOne { + return NewUsageLogClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the UsageLog entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *UsageLog) Unwrap() *UsageLog { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: UsageLog is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *UsageLog) String() string { + var builder strings.Builder + builder.WriteString("UsageLog(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("user_id=") + builder.WriteString(fmt.Sprintf("%v", _m.UserID)) + builder.WriteString(", ") + builder.WriteString("api_key_id=") + builder.WriteString(fmt.Sprintf("%v", _m.APIKeyID)) + builder.WriteString(", ") + builder.WriteString("account_id=") + builder.WriteString(fmt.Sprintf("%v", _m.AccountID)) + builder.WriteString(", ") + builder.WriteString("request_id=") + builder.WriteString(_m.RequestID) + builder.WriteString(", ") + builder.WriteString("model=") + builder.WriteString(_m.Model) + builder.WriteString(", ") + if v := _m.GroupID; v != nil { + builder.WriteString("group_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SubscriptionID; v != nil { + builder.WriteString("subscription_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("input_tokens=") + builder.WriteString(fmt.Sprintf("%v", _m.InputTokens)) + builder.WriteString(", ") + builder.WriteString("output_tokens=") + builder.WriteString(fmt.Sprintf("%v", _m.OutputTokens)) + builder.WriteString(", ") + builder.WriteString("cache_creation_tokens=") + builder.WriteString(fmt.Sprintf("%v", _m.CacheCreationTokens)) + builder.WriteString(", ") + builder.WriteString("cache_read_tokens=") + builder.WriteString(fmt.Sprintf("%v", _m.CacheReadTokens)) + builder.WriteString(", ") + builder.WriteString("cache_creation_5m_tokens=") + builder.WriteString(fmt.Sprintf("%v", _m.CacheCreation5mTokens)) + builder.WriteString(", ") + builder.WriteString("cache_creation_1h_tokens=") + builder.WriteString(fmt.Sprintf("%v", _m.CacheCreation1hTokens)) + builder.WriteString(", ") + builder.WriteString("input_cost=") + builder.WriteString(fmt.Sprintf("%v", _m.InputCost)) + builder.WriteString(", ") + builder.WriteString("output_cost=") + builder.WriteString(fmt.Sprintf("%v", _m.OutputCost)) + builder.WriteString(", ") + builder.WriteString("cache_creation_cost=") + builder.WriteString(fmt.Sprintf("%v", _m.CacheCreationCost)) + builder.WriteString(", ") + builder.WriteString("cache_read_cost=") + builder.WriteString(fmt.Sprintf("%v", _m.CacheReadCost)) + builder.WriteString(", ") + builder.WriteString("total_cost=") + builder.WriteString(fmt.Sprintf("%v", _m.TotalCost)) + builder.WriteString(", ") + builder.WriteString("actual_cost=") + builder.WriteString(fmt.Sprintf("%v", _m.ActualCost)) + builder.WriteString(", ") + builder.WriteString("rate_multiplier=") + builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier)) + builder.WriteString(", ") + builder.WriteString("billing_type=") + builder.WriteString(fmt.Sprintf("%v", _m.BillingType)) + builder.WriteString(", ") + builder.WriteString("stream=") + builder.WriteString(fmt.Sprintf("%v", _m.Stream)) + builder.WriteString(", ") + if v := _m.DurationMs; v != nil { + builder.WriteString("duration_ms=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.FirstTokenMs; v != nil { + builder.WriteString("first_token_ms=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// UsageLogs is a parsable slice of UsageLog. +type UsageLogs []*UsageLog diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go new file mode 100644 index 00000000..bdc6f7e6 --- /dev/null +++ b/backend/ent/usagelog/usagelog.go @@ -0,0 +1,396 @@ +// Code generated by ent, DO NOT EDIT. + +package usagelog + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the usagelog type in the database. + Label = "usage_log" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldUserID holds the string denoting the user_id field in the database. + FieldUserID = "user_id" + // FieldAPIKeyID holds the string denoting the api_key_id field in the database. + FieldAPIKeyID = "api_key_id" + // FieldAccountID holds the string denoting the account_id field in the database. + FieldAccountID = "account_id" + // FieldRequestID holds the string denoting the request_id field in the database. + FieldRequestID = "request_id" + // FieldModel holds the string denoting the model field in the database. + FieldModel = "model" + // FieldGroupID holds the string denoting the group_id field in the database. + FieldGroupID = "group_id" + // FieldSubscriptionID holds the string denoting the subscription_id field in the database. + FieldSubscriptionID = "subscription_id" + // FieldInputTokens holds the string denoting the input_tokens field in the database. + FieldInputTokens = "input_tokens" + // FieldOutputTokens holds the string denoting the output_tokens field in the database. + FieldOutputTokens = "output_tokens" + // FieldCacheCreationTokens holds the string denoting the cache_creation_tokens field in the database. + FieldCacheCreationTokens = "cache_creation_tokens" + // FieldCacheReadTokens holds the string denoting the cache_read_tokens field in the database. + FieldCacheReadTokens = "cache_read_tokens" + // FieldCacheCreation5mTokens holds the string denoting the cache_creation_5m_tokens field in the database. + FieldCacheCreation5mTokens = "cache_creation_5m_tokens" + // FieldCacheCreation1hTokens holds the string denoting the cache_creation_1h_tokens field in the database. + FieldCacheCreation1hTokens = "cache_creation_1h_tokens" + // FieldInputCost holds the string denoting the input_cost field in the database. + FieldInputCost = "input_cost" + // FieldOutputCost holds the string denoting the output_cost field in the database. + FieldOutputCost = "output_cost" + // FieldCacheCreationCost holds the string denoting the cache_creation_cost field in the database. + FieldCacheCreationCost = "cache_creation_cost" + // FieldCacheReadCost holds the string denoting the cache_read_cost field in the database. + FieldCacheReadCost = "cache_read_cost" + // FieldTotalCost holds the string denoting the total_cost field in the database. + FieldTotalCost = "total_cost" + // FieldActualCost holds the string denoting the actual_cost field in the database. + FieldActualCost = "actual_cost" + // FieldRateMultiplier holds the string denoting the rate_multiplier field in the database. + FieldRateMultiplier = "rate_multiplier" + // FieldBillingType holds the string denoting the billing_type field in the database. + FieldBillingType = "billing_type" + // FieldStream holds the string denoting the stream field in the database. + FieldStream = "stream" + // FieldDurationMs holds the string denoting the duration_ms field in the database. + FieldDurationMs = "duration_ms" + // FieldFirstTokenMs holds the string denoting the first_token_ms field in the database. + FieldFirstTokenMs = "first_token_ms" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // EdgeUser holds the string denoting the user edge name in mutations. + EdgeUser = "user" + // EdgeAPIKey holds the string denoting the api_key edge name in mutations. + EdgeAPIKey = "api_key" + // EdgeAccount holds the string denoting the account edge name in mutations. + EdgeAccount = "account" + // EdgeGroup holds the string denoting the group edge name in mutations. + EdgeGroup = "group" + // EdgeSubscription holds the string denoting the subscription edge name in mutations. + EdgeSubscription = "subscription" + // Table holds the table name of the usagelog in the database. + Table = "usage_logs" + // UserTable is the table that holds the user relation/edge. + UserTable = "usage_logs" + // UserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + UserInverseTable = "users" + // UserColumn is the table column denoting the user relation/edge. + UserColumn = "user_id" + // APIKeyTable is the table that holds the api_key relation/edge. + APIKeyTable = "usage_logs" + // APIKeyInverseTable is the table name for the ApiKey entity. + // It exists in this package in order to avoid circular dependency with the "apikey" package. + APIKeyInverseTable = "api_keys" + // APIKeyColumn is the table column denoting the api_key relation/edge. + APIKeyColumn = "api_key_id" + // AccountTable is the table that holds the account relation/edge. + AccountTable = "usage_logs" + // AccountInverseTable is the table name for the Account entity. + // It exists in this package in order to avoid circular dependency with the "account" package. + AccountInverseTable = "accounts" + // AccountColumn is the table column denoting the account relation/edge. + AccountColumn = "account_id" + // GroupTable is the table that holds the group relation/edge. + GroupTable = "usage_logs" + // GroupInverseTable is the table name for the Group entity. + // It exists in this package in order to avoid circular dependency with the "group" package. + GroupInverseTable = "groups" + // GroupColumn is the table column denoting the group relation/edge. + GroupColumn = "group_id" + // SubscriptionTable is the table that holds the subscription relation/edge. + SubscriptionTable = "usage_logs" + // SubscriptionInverseTable is the table name for the UserSubscription entity. + // It exists in this package in order to avoid circular dependency with the "usersubscription" package. + SubscriptionInverseTable = "user_subscriptions" + // SubscriptionColumn is the table column denoting the subscription relation/edge. + SubscriptionColumn = "subscription_id" +) + +// Columns holds all SQL columns for usagelog fields. +var Columns = []string{ + FieldID, + FieldUserID, + FieldAPIKeyID, + FieldAccountID, + FieldRequestID, + FieldModel, + FieldGroupID, + FieldSubscriptionID, + FieldInputTokens, + FieldOutputTokens, + FieldCacheCreationTokens, + FieldCacheReadTokens, + FieldCacheCreation5mTokens, + FieldCacheCreation1hTokens, + FieldInputCost, + FieldOutputCost, + FieldCacheCreationCost, + FieldCacheReadCost, + FieldTotalCost, + FieldActualCost, + FieldRateMultiplier, + FieldBillingType, + FieldStream, + FieldDurationMs, + FieldFirstTokenMs, + FieldCreatedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // RequestIDValidator is a validator for the "request_id" field. It is called by the builders before save. + RequestIDValidator func(string) error + // ModelValidator is a validator for the "model" field. It is called by the builders before save. + ModelValidator func(string) error + // DefaultInputTokens holds the default value on creation for the "input_tokens" field. + DefaultInputTokens int + // DefaultOutputTokens holds the default value on creation for the "output_tokens" field. + DefaultOutputTokens int + // DefaultCacheCreationTokens holds the default value on creation for the "cache_creation_tokens" field. + DefaultCacheCreationTokens int + // DefaultCacheReadTokens holds the default value on creation for the "cache_read_tokens" field. + DefaultCacheReadTokens int + // DefaultCacheCreation5mTokens holds the default value on creation for the "cache_creation_5m_tokens" field. + DefaultCacheCreation5mTokens int + // DefaultCacheCreation1hTokens holds the default value on creation for the "cache_creation_1h_tokens" field. + DefaultCacheCreation1hTokens int + // DefaultInputCost holds the default value on creation for the "input_cost" field. + DefaultInputCost float64 + // DefaultOutputCost holds the default value on creation for the "output_cost" field. + DefaultOutputCost float64 + // DefaultCacheCreationCost holds the default value on creation for the "cache_creation_cost" field. + DefaultCacheCreationCost float64 + // DefaultCacheReadCost holds the default value on creation for the "cache_read_cost" field. + DefaultCacheReadCost float64 + // DefaultTotalCost holds the default value on creation for the "total_cost" field. + DefaultTotalCost float64 + // DefaultActualCost holds the default value on creation for the "actual_cost" field. + DefaultActualCost float64 + // DefaultRateMultiplier holds the default value on creation for the "rate_multiplier" field. + DefaultRateMultiplier float64 + // DefaultBillingType holds the default value on creation for the "billing_type" field. + DefaultBillingType int8 + // DefaultStream holds the default value on creation for the "stream" field. + DefaultStream bool + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time +) + +// OrderOption defines the ordering options for the UsageLog queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByUserID orders the results by the user_id field. +func ByUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUserID, opts...).ToFunc() +} + +// ByAPIKeyID orders the results by the api_key_id field. +func ByAPIKeyID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAPIKeyID, opts...).ToFunc() +} + +// ByAccountID orders the results by the account_id field. +func ByAccountID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAccountID, opts...).ToFunc() +} + +// ByRequestID orders the results by the request_id field. +func ByRequestID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRequestID, opts...).ToFunc() +} + +// ByModel orders the results by the model field. +func ByModel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldModel, opts...).ToFunc() +} + +// ByGroupID orders the results by the group_id field. +func ByGroupID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldGroupID, opts...).ToFunc() +} + +// BySubscriptionID orders the results by the subscription_id field. +func BySubscriptionID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSubscriptionID, opts...).ToFunc() +} + +// ByInputTokens orders the results by the input_tokens field. +func ByInputTokens(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldInputTokens, opts...).ToFunc() +} + +// ByOutputTokens orders the results by the output_tokens field. +func ByOutputTokens(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOutputTokens, opts...).ToFunc() +} + +// ByCacheCreationTokens orders the results by the cache_creation_tokens field. +func ByCacheCreationTokens(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheCreationTokens, opts...).ToFunc() +} + +// ByCacheReadTokens orders the results by the cache_read_tokens field. +func ByCacheReadTokens(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheReadTokens, opts...).ToFunc() +} + +// ByCacheCreation5mTokens orders the results by the cache_creation_5m_tokens field. +func ByCacheCreation5mTokens(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheCreation5mTokens, opts...).ToFunc() +} + +// ByCacheCreation1hTokens orders the results by the cache_creation_1h_tokens field. +func ByCacheCreation1hTokens(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheCreation1hTokens, opts...).ToFunc() +} + +// ByInputCost orders the results by the input_cost field. +func ByInputCost(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldInputCost, opts...).ToFunc() +} + +// ByOutputCost orders the results by the output_cost field. +func ByOutputCost(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOutputCost, opts...).ToFunc() +} + +// ByCacheCreationCost orders the results by the cache_creation_cost field. +func ByCacheCreationCost(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheCreationCost, opts...).ToFunc() +} + +// ByCacheReadCost orders the results by the cache_read_cost field. +func ByCacheReadCost(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheReadCost, opts...).ToFunc() +} + +// ByTotalCost orders the results by the total_cost field. +func ByTotalCost(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotalCost, opts...).ToFunc() +} + +// ByActualCost orders the results by the actual_cost field. +func ByActualCost(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldActualCost, opts...).ToFunc() +} + +// ByRateMultiplier orders the results by the rate_multiplier field. +func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc() +} + +// ByBillingType orders the results by the billing_type field. +func ByBillingType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBillingType, opts...).ToFunc() +} + +// ByStream orders the results by the stream field. +func ByStream(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStream, opts...).ToFunc() +} + +// ByDurationMs orders the results by the duration_ms field. +func ByDurationMs(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDurationMs, opts...).ToFunc() +} + +// ByFirstTokenMs orders the results by the first_token_ms field. +func ByFirstTokenMs(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFirstTokenMs, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUserField orders the results by user field. +func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...)) + } +} + +// ByAPIKeyField orders the results by api_key field. +func ByAPIKeyField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAPIKeyStep(), sql.OrderByField(field, opts...)) + } +} + +// ByAccountField orders the results by account field. +func ByAccountField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAccountStep(), sql.OrderByField(field, opts...)) + } +} + +// ByGroupField orders the results by group field. +func ByGroupField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newGroupStep(), sql.OrderByField(field, opts...)) + } +} + +// BySubscriptionField orders the results by subscription field. +func BySubscriptionField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newSubscriptionStep(), sql.OrderByField(field, opts...)) + } +} +func newUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) +} +func newAPIKeyStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(APIKeyInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, APIKeyTable, APIKeyColumn), + ) +} +func newAccountStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AccountInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, AccountTable, AccountColumn), + ) +} +func newGroupStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(GroupInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, GroupTable, GroupColumn), + ) +} +func newSubscriptionStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(SubscriptionInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, SubscriptionTable, SubscriptionColumn), + ) +} diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go new file mode 100644 index 00000000..9c260433 --- /dev/null +++ b/backend/ent/usagelog/where.go @@ -0,0 +1,1271 @@ +// Code generated by ent, DO NOT EDIT. + +package usagelog + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldID, id)) +} + +// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ. +func UserID(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldUserID, v)) +} + +// APIKeyID applies equality check predicate on the "api_key_id" field. It's identical to APIKeyIDEQ. +func APIKeyID(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldAPIKeyID, v)) +} + +// AccountID applies equality check predicate on the "account_id" field. It's identical to AccountIDEQ. +func AccountID(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldAccountID, v)) +} + +// RequestID applies equality check predicate on the "request_id" field. It's identical to RequestIDEQ. +func RequestID(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldRequestID, v)) +} + +// Model applies equality check predicate on the "model" field. It's identical to ModelEQ. +func Model(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldModel, v)) +} + +// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ. +func GroupID(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) +} + +// SubscriptionID applies equality check predicate on the "subscription_id" field. It's identical to SubscriptionIDEQ. +func SubscriptionID(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldSubscriptionID, v)) +} + +// InputTokens applies equality check predicate on the "input_tokens" field. It's identical to InputTokensEQ. +func InputTokens(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldInputTokens, v)) +} + +// OutputTokens applies equality check predicate on the "output_tokens" field. It's identical to OutputTokensEQ. +func OutputTokens(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldOutputTokens, v)) +} + +// CacheCreationTokens applies equality check predicate on the "cache_creation_tokens" field. It's identical to CacheCreationTokensEQ. +func CacheCreationTokens(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheCreationTokens, v)) +} + +// CacheReadTokens applies equality check predicate on the "cache_read_tokens" field. It's identical to CacheReadTokensEQ. +func CacheReadTokens(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheReadTokens, v)) +} + +// CacheCreation5mTokens applies equality check predicate on the "cache_creation_5m_tokens" field. It's identical to CacheCreation5mTokensEQ. +func CacheCreation5mTokens(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheCreation5mTokens, v)) +} + +// CacheCreation1hTokens applies equality check predicate on the "cache_creation_1h_tokens" field. It's identical to CacheCreation1hTokensEQ. +func CacheCreation1hTokens(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheCreation1hTokens, v)) +} + +// InputCost applies equality check predicate on the "input_cost" field. It's identical to InputCostEQ. +func InputCost(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldInputCost, v)) +} + +// OutputCost applies equality check predicate on the "output_cost" field. It's identical to OutputCostEQ. +func OutputCost(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldOutputCost, v)) +} + +// CacheCreationCost applies equality check predicate on the "cache_creation_cost" field. It's identical to CacheCreationCostEQ. +func CacheCreationCost(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheCreationCost, v)) +} + +// CacheReadCost applies equality check predicate on the "cache_read_cost" field. It's identical to CacheReadCostEQ. +func CacheReadCost(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheReadCost, v)) +} + +// TotalCost applies equality check predicate on the "total_cost" field. It's identical to TotalCostEQ. +func TotalCost(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldTotalCost, v)) +} + +// ActualCost applies equality check predicate on the "actual_cost" field. It's identical to ActualCostEQ. +func ActualCost(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldActualCost, v)) +} + +// RateMultiplier applies equality check predicate on the "rate_multiplier" field. It's identical to RateMultiplierEQ. +func RateMultiplier(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldRateMultiplier, v)) +} + +// BillingType applies equality check predicate on the "billing_type" field. It's identical to BillingTypeEQ. +func BillingType(v int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v)) +} + +// Stream applies equality check predicate on the "stream" field. It's identical to StreamEQ. +func Stream(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldStream, v)) +} + +// DurationMs applies equality check predicate on the "duration_ms" field. It's identical to DurationMsEQ. +func DurationMs(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldDurationMs, v)) +} + +// FirstTokenMs applies equality check predicate on the "first_token_ms" field. It's identical to FirstTokenMsEQ. +func FirstTokenMs(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldFirstTokenMs, v)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UserIDEQ applies the EQ predicate on the "user_id" field. +func UserIDEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldUserID, v)) +} + +// UserIDNEQ applies the NEQ predicate on the "user_id" field. +func UserIDNEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldUserID, v)) +} + +// UserIDIn applies the In predicate on the "user_id" field. +func UserIDIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldUserID, vs...)) +} + +// UserIDNotIn applies the NotIn predicate on the "user_id" field. +func UserIDNotIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldUserID, vs...)) +} + +// APIKeyIDEQ applies the EQ predicate on the "api_key_id" field. +func APIKeyIDEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldAPIKeyID, v)) +} + +// APIKeyIDNEQ applies the NEQ predicate on the "api_key_id" field. +func APIKeyIDNEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldAPIKeyID, v)) +} + +// APIKeyIDIn applies the In predicate on the "api_key_id" field. +func APIKeyIDIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldAPIKeyID, vs...)) +} + +// APIKeyIDNotIn applies the NotIn predicate on the "api_key_id" field. +func APIKeyIDNotIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldAPIKeyID, vs...)) +} + +// AccountIDEQ applies the EQ predicate on the "account_id" field. +func AccountIDEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldAccountID, v)) +} + +// AccountIDNEQ applies the NEQ predicate on the "account_id" field. +func AccountIDNEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldAccountID, v)) +} + +// AccountIDIn applies the In predicate on the "account_id" field. +func AccountIDIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldAccountID, vs...)) +} + +// AccountIDNotIn applies the NotIn predicate on the "account_id" field. +func AccountIDNotIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldAccountID, vs...)) +} + +// RequestIDEQ applies the EQ predicate on the "request_id" field. +func RequestIDEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldRequestID, v)) +} + +// RequestIDNEQ applies the NEQ predicate on the "request_id" field. +func RequestIDNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldRequestID, v)) +} + +// RequestIDIn applies the In predicate on the "request_id" field. +func RequestIDIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldRequestID, vs...)) +} + +// RequestIDNotIn applies the NotIn predicate on the "request_id" field. +func RequestIDNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldRequestID, vs...)) +} + +// RequestIDGT applies the GT predicate on the "request_id" field. +func RequestIDGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldRequestID, v)) +} + +// RequestIDGTE applies the GTE predicate on the "request_id" field. +func RequestIDGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldRequestID, v)) +} + +// RequestIDLT applies the LT predicate on the "request_id" field. +func RequestIDLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldRequestID, v)) +} + +// RequestIDLTE applies the LTE predicate on the "request_id" field. +func RequestIDLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldRequestID, v)) +} + +// RequestIDContains applies the Contains predicate on the "request_id" field. +func RequestIDContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldRequestID, v)) +} + +// RequestIDHasPrefix applies the HasPrefix predicate on the "request_id" field. +func RequestIDHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldRequestID, v)) +} + +// RequestIDHasSuffix applies the HasSuffix predicate on the "request_id" field. +func RequestIDHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldRequestID, v)) +} + +// RequestIDEqualFold applies the EqualFold predicate on the "request_id" field. +func RequestIDEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldRequestID, v)) +} + +// RequestIDContainsFold applies the ContainsFold predicate on the "request_id" field. +func RequestIDContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldRequestID, v)) +} + +// ModelEQ applies the EQ predicate on the "model" field. +func ModelEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldModel, v)) +} + +// ModelNEQ applies the NEQ predicate on the "model" field. +func ModelNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldModel, v)) +} + +// ModelIn applies the In predicate on the "model" field. +func ModelIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldModel, vs...)) +} + +// ModelNotIn applies the NotIn predicate on the "model" field. +func ModelNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldModel, vs...)) +} + +// ModelGT applies the GT predicate on the "model" field. +func ModelGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldModel, v)) +} + +// ModelGTE applies the GTE predicate on the "model" field. +func ModelGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldModel, v)) +} + +// ModelLT applies the LT predicate on the "model" field. +func ModelLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldModel, v)) +} + +// ModelLTE applies the LTE predicate on the "model" field. +func ModelLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldModel, v)) +} + +// ModelContains applies the Contains predicate on the "model" field. +func ModelContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldModel, v)) +} + +// ModelHasPrefix applies the HasPrefix predicate on the "model" field. +func ModelHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldModel, v)) +} + +// ModelHasSuffix applies the HasSuffix predicate on the "model" field. +func ModelHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldModel, v)) +} + +// ModelEqualFold applies the EqualFold predicate on the "model" field. +func ModelEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldModel, v)) +} + +// ModelContainsFold applies the ContainsFold predicate on the "model" field. +func ModelContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v)) +} + +// GroupIDEQ applies the EQ predicate on the "group_id" field. +func GroupIDEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) +} + +// GroupIDNEQ applies the NEQ predicate on the "group_id" field. +func GroupIDNEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldGroupID, v)) +} + +// GroupIDIn applies the In predicate on the "group_id" field. +func GroupIDIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldGroupID, vs...)) +} + +// GroupIDNotIn applies the NotIn predicate on the "group_id" field. +func GroupIDNotIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldGroupID, vs...)) +} + +// GroupIDIsNil applies the IsNil predicate on the "group_id" field. +func GroupIDIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldGroupID)) +} + +// GroupIDNotNil applies the NotNil predicate on the "group_id" field. +func GroupIDNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldGroupID)) +} + +// SubscriptionIDEQ applies the EQ predicate on the "subscription_id" field. +func SubscriptionIDEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldSubscriptionID, v)) +} + +// SubscriptionIDNEQ applies the NEQ predicate on the "subscription_id" field. +func SubscriptionIDNEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldSubscriptionID, v)) +} + +// SubscriptionIDIn applies the In predicate on the "subscription_id" field. +func SubscriptionIDIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldSubscriptionID, vs...)) +} + +// SubscriptionIDNotIn applies the NotIn predicate on the "subscription_id" field. +func SubscriptionIDNotIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldSubscriptionID, vs...)) +} + +// SubscriptionIDIsNil applies the IsNil predicate on the "subscription_id" field. +func SubscriptionIDIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldSubscriptionID)) +} + +// SubscriptionIDNotNil applies the NotNil predicate on the "subscription_id" field. +func SubscriptionIDNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldSubscriptionID)) +} + +// InputTokensEQ applies the EQ predicate on the "input_tokens" field. +func InputTokensEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldInputTokens, v)) +} + +// InputTokensNEQ applies the NEQ predicate on the "input_tokens" field. +func InputTokensNEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldInputTokens, v)) +} + +// InputTokensIn applies the In predicate on the "input_tokens" field. +func InputTokensIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldInputTokens, vs...)) +} + +// InputTokensNotIn applies the NotIn predicate on the "input_tokens" field. +func InputTokensNotIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldInputTokens, vs...)) +} + +// InputTokensGT applies the GT predicate on the "input_tokens" field. +func InputTokensGT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldInputTokens, v)) +} + +// InputTokensGTE applies the GTE predicate on the "input_tokens" field. +func InputTokensGTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldInputTokens, v)) +} + +// InputTokensLT applies the LT predicate on the "input_tokens" field. +func InputTokensLT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldInputTokens, v)) +} + +// InputTokensLTE applies the LTE predicate on the "input_tokens" field. +func InputTokensLTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldInputTokens, v)) +} + +// OutputTokensEQ applies the EQ predicate on the "output_tokens" field. +func OutputTokensEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldOutputTokens, v)) +} + +// OutputTokensNEQ applies the NEQ predicate on the "output_tokens" field. +func OutputTokensNEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldOutputTokens, v)) +} + +// OutputTokensIn applies the In predicate on the "output_tokens" field. +func OutputTokensIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldOutputTokens, vs...)) +} + +// OutputTokensNotIn applies the NotIn predicate on the "output_tokens" field. +func OutputTokensNotIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldOutputTokens, vs...)) +} + +// OutputTokensGT applies the GT predicate on the "output_tokens" field. +func OutputTokensGT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldOutputTokens, v)) +} + +// OutputTokensGTE applies the GTE predicate on the "output_tokens" field. +func OutputTokensGTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldOutputTokens, v)) +} + +// OutputTokensLT applies the LT predicate on the "output_tokens" field. +func OutputTokensLT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldOutputTokens, v)) +} + +// OutputTokensLTE applies the LTE predicate on the "output_tokens" field. +func OutputTokensLTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldOutputTokens, v)) +} + +// CacheCreationTokensEQ applies the EQ predicate on the "cache_creation_tokens" field. +func CacheCreationTokensEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheCreationTokens, v)) +} + +// CacheCreationTokensNEQ applies the NEQ predicate on the "cache_creation_tokens" field. +func CacheCreationTokensNEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCacheCreationTokens, v)) +} + +// CacheCreationTokensIn applies the In predicate on the "cache_creation_tokens" field. +func CacheCreationTokensIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldCacheCreationTokens, vs...)) +} + +// CacheCreationTokensNotIn applies the NotIn predicate on the "cache_creation_tokens" field. +func CacheCreationTokensNotIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldCacheCreationTokens, vs...)) +} + +// CacheCreationTokensGT applies the GT predicate on the "cache_creation_tokens" field. +func CacheCreationTokensGT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldCacheCreationTokens, v)) +} + +// CacheCreationTokensGTE applies the GTE predicate on the "cache_creation_tokens" field. +func CacheCreationTokensGTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldCacheCreationTokens, v)) +} + +// CacheCreationTokensLT applies the LT predicate on the "cache_creation_tokens" field. +func CacheCreationTokensLT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldCacheCreationTokens, v)) +} + +// CacheCreationTokensLTE applies the LTE predicate on the "cache_creation_tokens" field. +func CacheCreationTokensLTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldCacheCreationTokens, v)) +} + +// CacheReadTokensEQ applies the EQ predicate on the "cache_read_tokens" field. +func CacheReadTokensEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheReadTokens, v)) +} + +// CacheReadTokensNEQ applies the NEQ predicate on the "cache_read_tokens" field. +func CacheReadTokensNEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCacheReadTokens, v)) +} + +// CacheReadTokensIn applies the In predicate on the "cache_read_tokens" field. +func CacheReadTokensIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldCacheReadTokens, vs...)) +} + +// CacheReadTokensNotIn applies the NotIn predicate on the "cache_read_tokens" field. +func CacheReadTokensNotIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldCacheReadTokens, vs...)) +} + +// CacheReadTokensGT applies the GT predicate on the "cache_read_tokens" field. +func CacheReadTokensGT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldCacheReadTokens, v)) +} + +// CacheReadTokensGTE applies the GTE predicate on the "cache_read_tokens" field. +func CacheReadTokensGTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldCacheReadTokens, v)) +} + +// CacheReadTokensLT applies the LT predicate on the "cache_read_tokens" field. +func CacheReadTokensLT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldCacheReadTokens, v)) +} + +// CacheReadTokensLTE applies the LTE predicate on the "cache_read_tokens" field. +func CacheReadTokensLTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldCacheReadTokens, v)) +} + +// CacheCreation5mTokensEQ applies the EQ predicate on the "cache_creation_5m_tokens" field. +func CacheCreation5mTokensEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheCreation5mTokens, v)) +} + +// CacheCreation5mTokensNEQ applies the NEQ predicate on the "cache_creation_5m_tokens" field. +func CacheCreation5mTokensNEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCacheCreation5mTokens, v)) +} + +// CacheCreation5mTokensIn applies the In predicate on the "cache_creation_5m_tokens" field. +func CacheCreation5mTokensIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldCacheCreation5mTokens, vs...)) +} + +// CacheCreation5mTokensNotIn applies the NotIn predicate on the "cache_creation_5m_tokens" field. +func CacheCreation5mTokensNotIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldCacheCreation5mTokens, vs...)) +} + +// CacheCreation5mTokensGT applies the GT predicate on the "cache_creation_5m_tokens" field. +func CacheCreation5mTokensGT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldCacheCreation5mTokens, v)) +} + +// CacheCreation5mTokensGTE applies the GTE predicate on the "cache_creation_5m_tokens" field. +func CacheCreation5mTokensGTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldCacheCreation5mTokens, v)) +} + +// CacheCreation5mTokensLT applies the LT predicate on the "cache_creation_5m_tokens" field. +func CacheCreation5mTokensLT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldCacheCreation5mTokens, v)) +} + +// CacheCreation5mTokensLTE applies the LTE predicate on the "cache_creation_5m_tokens" field. +func CacheCreation5mTokensLTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldCacheCreation5mTokens, v)) +} + +// CacheCreation1hTokensEQ applies the EQ predicate on the "cache_creation_1h_tokens" field. +func CacheCreation1hTokensEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheCreation1hTokens, v)) +} + +// CacheCreation1hTokensNEQ applies the NEQ predicate on the "cache_creation_1h_tokens" field. +func CacheCreation1hTokensNEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCacheCreation1hTokens, v)) +} + +// CacheCreation1hTokensIn applies the In predicate on the "cache_creation_1h_tokens" field. +func CacheCreation1hTokensIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldCacheCreation1hTokens, vs...)) +} + +// CacheCreation1hTokensNotIn applies the NotIn predicate on the "cache_creation_1h_tokens" field. +func CacheCreation1hTokensNotIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldCacheCreation1hTokens, vs...)) +} + +// CacheCreation1hTokensGT applies the GT predicate on the "cache_creation_1h_tokens" field. +func CacheCreation1hTokensGT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldCacheCreation1hTokens, v)) +} + +// CacheCreation1hTokensGTE applies the GTE predicate on the "cache_creation_1h_tokens" field. +func CacheCreation1hTokensGTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldCacheCreation1hTokens, v)) +} + +// CacheCreation1hTokensLT applies the LT predicate on the "cache_creation_1h_tokens" field. +func CacheCreation1hTokensLT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldCacheCreation1hTokens, v)) +} + +// CacheCreation1hTokensLTE applies the LTE predicate on the "cache_creation_1h_tokens" field. +func CacheCreation1hTokensLTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldCacheCreation1hTokens, v)) +} + +// InputCostEQ applies the EQ predicate on the "input_cost" field. +func InputCostEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldInputCost, v)) +} + +// InputCostNEQ applies the NEQ predicate on the "input_cost" field. +func InputCostNEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldInputCost, v)) +} + +// InputCostIn applies the In predicate on the "input_cost" field. +func InputCostIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldInputCost, vs...)) +} + +// InputCostNotIn applies the NotIn predicate on the "input_cost" field. +func InputCostNotIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldInputCost, vs...)) +} + +// InputCostGT applies the GT predicate on the "input_cost" field. +func InputCostGT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldInputCost, v)) +} + +// InputCostGTE applies the GTE predicate on the "input_cost" field. +func InputCostGTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldInputCost, v)) +} + +// InputCostLT applies the LT predicate on the "input_cost" field. +func InputCostLT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldInputCost, v)) +} + +// InputCostLTE applies the LTE predicate on the "input_cost" field. +func InputCostLTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldInputCost, v)) +} + +// OutputCostEQ applies the EQ predicate on the "output_cost" field. +func OutputCostEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldOutputCost, v)) +} + +// OutputCostNEQ applies the NEQ predicate on the "output_cost" field. +func OutputCostNEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldOutputCost, v)) +} + +// OutputCostIn applies the In predicate on the "output_cost" field. +func OutputCostIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldOutputCost, vs...)) +} + +// OutputCostNotIn applies the NotIn predicate on the "output_cost" field. +func OutputCostNotIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldOutputCost, vs...)) +} + +// OutputCostGT applies the GT predicate on the "output_cost" field. +func OutputCostGT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldOutputCost, v)) +} + +// OutputCostGTE applies the GTE predicate on the "output_cost" field. +func OutputCostGTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldOutputCost, v)) +} + +// OutputCostLT applies the LT predicate on the "output_cost" field. +func OutputCostLT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldOutputCost, v)) +} + +// OutputCostLTE applies the LTE predicate on the "output_cost" field. +func OutputCostLTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldOutputCost, v)) +} + +// CacheCreationCostEQ applies the EQ predicate on the "cache_creation_cost" field. +func CacheCreationCostEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheCreationCost, v)) +} + +// CacheCreationCostNEQ applies the NEQ predicate on the "cache_creation_cost" field. +func CacheCreationCostNEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCacheCreationCost, v)) +} + +// CacheCreationCostIn applies the In predicate on the "cache_creation_cost" field. +func CacheCreationCostIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldCacheCreationCost, vs...)) +} + +// CacheCreationCostNotIn applies the NotIn predicate on the "cache_creation_cost" field. +func CacheCreationCostNotIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldCacheCreationCost, vs...)) +} + +// CacheCreationCostGT applies the GT predicate on the "cache_creation_cost" field. +func CacheCreationCostGT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldCacheCreationCost, v)) +} + +// CacheCreationCostGTE applies the GTE predicate on the "cache_creation_cost" field. +func CacheCreationCostGTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldCacheCreationCost, v)) +} + +// CacheCreationCostLT applies the LT predicate on the "cache_creation_cost" field. +func CacheCreationCostLT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldCacheCreationCost, v)) +} + +// CacheCreationCostLTE applies the LTE predicate on the "cache_creation_cost" field. +func CacheCreationCostLTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldCacheCreationCost, v)) +} + +// CacheReadCostEQ applies the EQ predicate on the "cache_read_cost" field. +func CacheReadCostEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheReadCost, v)) +} + +// CacheReadCostNEQ applies the NEQ predicate on the "cache_read_cost" field. +func CacheReadCostNEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCacheReadCost, v)) +} + +// CacheReadCostIn applies the In predicate on the "cache_read_cost" field. +func CacheReadCostIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldCacheReadCost, vs...)) +} + +// CacheReadCostNotIn applies the NotIn predicate on the "cache_read_cost" field. +func CacheReadCostNotIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldCacheReadCost, vs...)) +} + +// CacheReadCostGT applies the GT predicate on the "cache_read_cost" field. +func CacheReadCostGT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldCacheReadCost, v)) +} + +// CacheReadCostGTE applies the GTE predicate on the "cache_read_cost" field. +func CacheReadCostGTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldCacheReadCost, v)) +} + +// CacheReadCostLT applies the LT predicate on the "cache_read_cost" field. +func CacheReadCostLT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldCacheReadCost, v)) +} + +// CacheReadCostLTE applies the LTE predicate on the "cache_read_cost" field. +func CacheReadCostLTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldCacheReadCost, v)) +} + +// TotalCostEQ applies the EQ predicate on the "total_cost" field. +func TotalCostEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldTotalCost, v)) +} + +// TotalCostNEQ applies the NEQ predicate on the "total_cost" field. +func TotalCostNEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldTotalCost, v)) +} + +// TotalCostIn applies the In predicate on the "total_cost" field. +func TotalCostIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldTotalCost, vs...)) +} + +// TotalCostNotIn applies the NotIn predicate on the "total_cost" field. +func TotalCostNotIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldTotalCost, vs...)) +} + +// TotalCostGT applies the GT predicate on the "total_cost" field. +func TotalCostGT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldTotalCost, v)) +} + +// TotalCostGTE applies the GTE predicate on the "total_cost" field. +func TotalCostGTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldTotalCost, v)) +} + +// TotalCostLT applies the LT predicate on the "total_cost" field. +func TotalCostLT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldTotalCost, v)) +} + +// TotalCostLTE applies the LTE predicate on the "total_cost" field. +func TotalCostLTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldTotalCost, v)) +} + +// ActualCostEQ applies the EQ predicate on the "actual_cost" field. +func ActualCostEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldActualCost, v)) +} + +// ActualCostNEQ applies the NEQ predicate on the "actual_cost" field. +func ActualCostNEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldActualCost, v)) +} + +// ActualCostIn applies the In predicate on the "actual_cost" field. +func ActualCostIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldActualCost, vs...)) +} + +// ActualCostNotIn applies the NotIn predicate on the "actual_cost" field. +func ActualCostNotIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldActualCost, vs...)) +} + +// ActualCostGT applies the GT predicate on the "actual_cost" field. +func ActualCostGT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldActualCost, v)) +} + +// ActualCostGTE applies the GTE predicate on the "actual_cost" field. +func ActualCostGTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldActualCost, v)) +} + +// ActualCostLT applies the LT predicate on the "actual_cost" field. +func ActualCostLT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldActualCost, v)) +} + +// ActualCostLTE applies the LTE predicate on the "actual_cost" field. +func ActualCostLTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldActualCost, v)) +} + +// RateMultiplierEQ applies the EQ predicate on the "rate_multiplier" field. +func RateMultiplierEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldRateMultiplier, v)) +} + +// RateMultiplierNEQ applies the NEQ predicate on the "rate_multiplier" field. +func RateMultiplierNEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldRateMultiplier, v)) +} + +// RateMultiplierIn applies the In predicate on the "rate_multiplier" field. +func RateMultiplierIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldRateMultiplier, vs...)) +} + +// RateMultiplierNotIn applies the NotIn predicate on the "rate_multiplier" field. +func RateMultiplierNotIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldRateMultiplier, vs...)) +} + +// RateMultiplierGT applies the GT predicate on the "rate_multiplier" field. +func RateMultiplierGT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldRateMultiplier, v)) +} + +// RateMultiplierGTE applies the GTE predicate on the "rate_multiplier" field. +func RateMultiplierGTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldRateMultiplier, v)) +} + +// RateMultiplierLT applies the LT predicate on the "rate_multiplier" field. +func RateMultiplierLT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldRateMultiplier, v)) +} + +// RateMultiplierLTE applies the LTE predicate on the "rate_multiplier" field. +func RateMultiplierLTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldRateMultiplier, v)) +} + +// BillingTypeEQ applies the EQ predicate on the "billing_type" field. +func BillingTypeEQ(v int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v)) +} + +// BillingTypeNEQ applies the NEQ predicate on the "billing_type" field. +func BillingTypeNEQ(v int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldBillingType, v)) +} + +// BillingTypeIn applies the In predicate on the "billing_type" field. +func BillingTypeIn(vs ...int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldBillingType, vs...)) +} + +// BillingTypeNotIn applies the NotIn predicate on the "billing_type" field. +func BillingTypeNotIn(vs ...int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldBillingType, vs...)) +} + +// BillingTypeGT applies the GT predicate on the "billing_type" field. +func BillingTypeGT(v int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldBillingType, v)) +} + +// BillingTypeGTE applies the GTE predicate on the "billing_type" field. +func BillingTypeGTE(v int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldBillingType, v)) +} + +// BillingTypeLT applies the LT predicate on the "billing_type" field. +func BillingTypeLT(v int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldBillingType, v)) +} + +// BillingTypeLTE applies the LTE predicate on the "billing_type" field. +func BillingTypeLTE(v int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldBillingType, v)) +} + +// StreamEQ applies the EQ predicate on the "stream" field. +func StreamEQ(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldStream, v)) +} + +// StreamNEQ applies the NEQ predicate on the "stream" field. +func StreamNEQ(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldStream, v)) +} + +// DurationMsEQ applies the EQ predicate on the "duration_ms" field. +func DurationMsEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldDurationMs, v)) +} + +// DurationMsNEQ applies the NEQ predicate on the "duration_ms" field. +func DurationMsNEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldDurationMs, v)) +} + +// DurationMsIn applies the In predicate on the "duration_ms" field. +func DurationMsIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldDurationMs, vs...)) +} + +// DurationMsNotIn applies the NotIn predicate on the "duration_ms" field. +func DurationMsNotIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldDurationMs, vs...)) +} + +// DurationMsGT applies the GT predicate on the "duration_ms" field. +func DurationMsGT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldDurationMs, v)) +} + +// DurationMsGTE applies the GTE predicate on the "duration_ms" field. +func DurationMsGTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldDurationMs, v)) +} + +// DurationMsLT applies the LT predicate on the "duration_ms" field. +func DurationMsLT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldDurationMs, v)) +} + +// DurationMsLTE applies the LTE predicate on the "duration_ms" field. +func DurationMsLTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldDurationMs, v)) +} + +// DurationMsIsNil applies the IsNil predicate on the "duration_ms" field. +func DurationMsIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldDurationMs)) +} + +// DurationMsNotNil applies the NotNil predicate on the "duration_ms" field. +func DurationMsNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldDurationMs)) +} + +// FirstTokenMsEQ applies the EQ predicate on the "first_token_ms" field. +func FirstTokenMsEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldFirstTokenMs, v)) +} + +// FirstTokenMsNEQ applies the NEQ predicate on the "first_token_ms" field. +func FirstTokenMsNEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldFirstTokenMs, v)) +} + +// FirstTokenMsIn applies the In predicate on the "first_token_ms" field. +func FirstTokenMsIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldFirstTokenMs, vs...)) +} + +// FirstTokenMsNotIn applies the NotIn predicate on the "first_token_ms" field. +func FirstTokenMsNotIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldFirstTokenMs, vs...)) +} + +// FirstTokenMsGT applies the GT predicate on the "first_token_ms" field. +func FirstTokenMsGT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldFirstTokenMs, v)) +} + +// FirstTokenMsGTE applies the GTE predicate on the "first_token_ms" field. +func FirstTokenMsGTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldFirstTokenMs, v)) +} + +// FirstTokenMsLT applies the LT predicate on the "first_token_ms" field. +func FirstTokenMsLT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldFirstTokenMs, v)) +} + +// FirstTokenMsLTE applies the LTE predicate on the "first_token_ms" field. +func FirstTokenMsLTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldFirstTokenMs, v)) +} + +// FirstTokenMsIsNil applies the IsNil predicate on the "first_token_ms" field. +func FirstTokenMsIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldFirstTokenMs)) +} + +// FirstTokenMsNotNil applies the NotNil predicate on the "first_token_ms" field. +func FirstTokenMsNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldFirstTokenMs)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldCreatedAt, v)) +} + +// HasUser applies the HasEdge predicate on the "user" edge. +func HasUser() predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates). +func HasUserWith(preds ...predicate.User) predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := newUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasAPIKey applies the HasEdge predicate on the "api_key" edge. +func HasAPIKey() predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, APIKeyTable, APIKeyColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAPIKeyWith applies the HasEdge predicate on the "api_key" edge with a given conditions (other predicates). +func HasAPIKeyWith(preds ...predicate.ApiKey) predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := newAPIKeyStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasAccount applies the HasEdge predicate on the "account" edge. +func HasAccount() predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, AccountTable, AccountColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAccountWith applies the HasEdge predicate on the "account" edge with a given conditions (other predicates). +func HasAccountWith(preds ...predicate.Account) predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := newAccountStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasGroup applies the HasEdge predicate on the "group" edge. +func HasGroup() predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, GroupTable, GroupColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasGroupWith applies the HasEdge predicate on the "group" edge with a given conditions (other predicates). +func HasGroupWith(preds ...predicate.Group) predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := newGroupStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasSubscription applies the HasEdge predicate on the "subscription" edge. +func HasSubscription() predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, SubscriptionTable, SubscriptionColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasSubscriptionWith applies the HasEdge predicate on the "subscription" edge with a given conditions (other predicates). +func HasSubscriptionWith(preds ...predicate.UserSubscription) predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := newSubscriptionStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.UsageLog) predicate.UsageLog { + return predicate.UsageLog(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.UsageLog) predicate.UsageLog { + return predicate.UsageLog(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.UsageLog) predicate.UsageLog { + return predicate.UsageLog(sql.NotPredicates(p)) +} diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go new file mode 100644 index 00000000..bcba64b1 --- /dev/null +++ b/backend/ent/usagelog_create.go @@ -0,0 +1,2431 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/account" + "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/usagelog" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/ent/usersubscription" +) + +// UsageLogCreate is the builder for creating a UsageLog entity. +type UsageLogCreate struct { + config + mutation *UsageLogMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetUserID sets the "user_id" field. +func (_c *UsageLogCreate) SetUserID(v int64) *UsageLogCreate { + _c.mutation.SetUserID(v) + return _c +} + +// SetAPIKeyID sets the "api_key_id" field. +func (_c *UsageLogCreate) SetAPIKeyID(v int64) *UsageLogCreate { + _c.mutation.SetAPIKeyID(v) + return _c +} + +// SetAccountID sets the "account_id" field. +func (_c *UsageLogCreate) SetAccountID(v int64) *UsageLogCreate { + _c.mutation.SetAccountID(v) + return _c +} + +// SetRequestID sets the "request_id" field. +func (_c *UsageLogCreate) SetRequestID(v string) *UsageLogCreate { + _c.mutation.SetRequestID(v) + return _c +} + +// SetModel sets the "model" field. +func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate { + _c.mutation.SetModel(v) + return _c +} + +// SetGroupID sets the "group_id" field. +func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate { + _c.mutation.SetGroupID(v) + return _c +} + +// SetNillableGroupID sets the "group_id" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableGroupID(v *int64) *UsageLogCreate { + if v != nil { + _c.SetGroupID(*v) + } + return _c +} + +// SetSubscriptionID sets the "subscription_id" field. +func (_c *UsageLogCreate) SetSubscriptionID(v int64) *UsageLogCreate { + _c.mutation.SetSubscriptionID(v) + return _c +} + +// SetNillableSubscriptionID sets the "subscription_id" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableSubscriptionID(v *int64) *UsageLogCreate { + if v != nil { + _c.SetSubscriptionID(*v) + } + return _c +} + +// SetInputTokens sets the "input_tokens" field. +func (_c *UsageLogCreate) SetInputTokens(v int) *UsageLogCreate { + _c.mutation.SetInputTokens(v) + return _c +} + +// SetNillableInputTokens sets the "input_tokens" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableInputTokens(v *int) *UsageLogCreate { + if v != nil { + _c.SetInputTokens(*v) + } + return _c +} + +// SetOutputTokens sets the "output_tokens" field. +func (_c *UsageLogCreate) SetOutputTokens(v int) *UsageLogCreate { + _c.mutation.SetOutputTokens(v) + return _c +} + +// SetNillableOutputTokens sets the "output_tokens" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableOutputTokens(v *int) *UsageLogCreate { + if v != nil { + _c.SetOutputTokens(*v) + } + return _c +} + +// SetCacheCreationTokens sets the "cache_creation_tokens" field. +func (_c *UsageLogCreate) SetCacheCreationTokens(v int) *UsageLogCreate { + _c.mutation.SetCacheCreationTokens(v) + return _c +} + +// SetNillableCacheCreationTokens sets the "cache_creation_tokens" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCacheCreationTokens(v *int) *UsageLogCreate { + if v != nil { + _c.SetCacheCreationTokens(*v) + } + return _c +} + +// SetCacheReadTokens sets the "cache_read_tokens" field. +func (_c *UsageLogCreate) SetCacheReadTokens(v int) *UsageLogCreate { + _c.mutation.SetCacheReadTokens(v) + return _c +} + +// SetNillableCacheReadTokens sets the "cache_read_tokens" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCacheReadTokens(v *int) *UsageLogCreate { + if v != nil { + _c.SetCacheReadTokens(*v) + } + return _c +} + +// SetCacheCreation5mTokens sets the "cache_creation_5m_tokens" field. +func (_c *UsageLogCreate) SetCacheCreation5mTokens(v int) *UsageLogCreate { + _c.mutation.SetCacheCreation5mTokens(v) + return _c +} + +// SetNillableCacheCreation5mTokens sets the "cache_creation_5m_tokens" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCacheCreation5mTokens(v *int) *UsageLogCreate { + if v != nil { + _c.SetCacheCreation5mTokens(*v) + } + return _c +} + +// SetCacheCreation1hTokens sets the "cache_creation_1h_tokens" field. +func (_c *UsageLogCreate) SetCacheCreation1hTokens(v int) *UsageLogCreate { + _c.mutation.SetCacheCreation1hTokens(v) + return _c +} + +// SetNillableCacheCreation1hTokens sets the "cache_creation_1h_tokens" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCacheCreation1hTokens(v *int) *UsageLogCreate { + if v != nil { + _c.SetCacheCreation1hTokens(*v) + } + return _c +} + +// SetInputCost sets the "input_cost" field. +func (_c *UsageLogCreate) SetInputCost(v float64) *UsageLogCreate { + _c.mutation.SetInputCost(v) + return _c +} + +// SetNillableInputCost sets the "input_cost" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableInputCost(v *float64) *UsageLogCreate { + if v != nil { + _c.SetInputCost(*v) + } + return _c +} + +// SetOutputCost sets the "output_cost" field. +func (_c *UsageLogCreate) SetOutputCost(v float64) *UsageLogCreate { + _c.mutation.SetOutputCost(v) + return _c +} + +// SetNillableOutputCost sets the "output_cost" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableOutputCost(v *float64) *UsageLogCreate { + if v != nil { + _c.SetOutputCost(*v) + } + return _c +} + +// SetCacheCreationCost sets the "cache_creation_cost" field. +func (_c *UsageLogCreate) SetCacheCreationCost(v float64) *UsageLogCreate { + _c.mutation.SetCacheCreationCost(v) + return _c +} + +// SetNillableCacheCreationCost sets the "cache_creation_cost" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCacheCreationCost(v *float64) *UsageLogCreate { + if v != nil { + _c.SetCacheCreationCost(*v) + } + return _c +} + +// SetCacheReadCost sets the "cache_read_cost" field. +func (_c *UsageLogCreate) SetCacheReadCost(v float64) *UsageLogCreate { + _c.mutation.SetCacheReadCost(v) + return _c +} + +// SetNillableCacheReadCost sets the "cache_read_cost" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCacheReadCost(v *float64) *UsageLogCreate { + if v != nil { + _c.SetCacheReadCost(*v) + } + return _c +} + +// SetTotalCost sets the "total_cost" field. +func (_c *UsageLogCreate) SetTotalCost(v float64) *UsageLogCreate { + _c.mutation.SetTotalCost(v) + return _c +} + +// SetNillableTotalCost sets the "total_cost" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableTotalCost(v *float64) *UsageLogCreate { + if v != nil { + _c.SetTotalCost(*v) + } + return _c +} + +// SetActualCost sets the "actual_cost" field. +func (_c *UsageLogCreate) SetActualCost(v float64) *UsageLogCreate { + _c.mutation.SetActualCost(v) + return _c +} + +// SetNillableActualCost sets the "actual_cost" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableActualCost(v *float64) *UsageLogCreate { + if v != nil { + _c.SetActualCost(*v) + } + return _c +} + +// SetRateMultiplier sets the "rate_multiplier" field. +func (_c *UsageLogCreate) SetRateMultiplier(v float64) *UsageLogCreate { + _c.mutation.SetRateMultiplier(v) + return _c +} + +// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableRateMultiplier(v *float64) *UsageLogCreate { + if v != nil { + _c.SetRateMultiplier(*v) + } + return _c +} + +// SetBillingType sets the "billing_type" field. +func (_c *UsageLogCreate) SetBillingType(v int8) *UsageLogCreate { + _c.mutation.SetBillingType(v) + return _c +} + +// SetNillableBillingType sets the "billing_type" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableBillingType(v *int8) *UsageLogCreate { + if v != nil { + _c.SetBillingType(*v) + } + return _c +} + +// SetStream sets the "stream" field. +func (_c *UsageLogCreate) SetStream(v bool) *UsageLogCreate { + _c.mutation.SetStream(v) + return _c +} + +// SetNillableStream sets the "stream" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableStream(v *bool) *UsageLogCreate { + if v != nil { + _c.SetStream(*v) + } + return _c +} + +// SetDurationMs sets the "duration_ms" field. +func (_c *UsageLogCreate) SetDurationMs(v int) *UsageLogCreate { + _c.mutation.SetDurationMs(v) + return _c +} + +// SetNillableDurationMs sets the "duration_ms" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableDurationMs(v *int) *UsageLogCreate { + if v != nil { + _c.SetDurationMs(*v) + } + return _c +} + +// SetFirstTokenMs sets the "first_token_ms" field. +func (_c *UsageLogCreate) SetFirstTokenMs(v int) *UsageLogCreate { + _c.mutation.SetFirstTokenMs(v) + return _c +} + +// SetNillableFirstTokenMs sets the "first_token_ms" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableFirstTokenMs(v *int) *UsageLogCreate { + if v != nil { + _c.SetFirstTokenMs(*v) + } + return _c +} + +// SetCreatedAt sets the "created_at" field. +func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCreatedAt(v *time.Time) *UsageLogCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUser sets the "user" edge to the User entity. +func (_c *UsageLogCreate) SetUser(v *User) *UsageLogCreate { + return _c.SetUserID(v.ID) +} + +// SetAPIKey sets the "api_key" edge to the ApiKey entity. +func (_c *UsageLogCreate) SetAPIKey(v *ApiKey) *UsageLogCreate { + return _c.SetAPIKeyID(v.ID) +} + +// SetAccount sets the "account" edge to the Account entity. +func (_c *UsageLogCreate) SetAccount(v *Account) *UsageLogCreate { + return _c.SetAccountID(v.ID) +} + +// SetGroup sets the "group" edge to the Group entity. +func (_c *UsageLogCreate) SetGroup(v *Group) *UsageLogCreate { + return _c.SetGroupID(v.ID) +} + +// SetSubscription sets the "subscription" edge to the UserSubscription entity. +func (_c *UsageLogCreate) SetSubscription(v *UserSubscription) *UsageLogCreate { + return _c.SetSubscriptionID(v.ID) +} + +// Mutation returns the UsageLogMutation object of the builder. +func (_c *UsageLogCreate) Mutation() *UsageLogMutation { + return _c.mutation +} + +// Save creates the UsageLog in the database. +func (_c *UsageLogCreate) Save(ctx context.Context) (*UsageLog, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *UsageLogCreate) SaveX(ctx context.Context) *UsageLog { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *UsageLogCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *UsageLogCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *UsageLogCreate) defaults() { + if _, ok := _c.mutation.InputTokens(); !ok { + v := usagelog.DefaultInputTokens + _c.mutation.SetInputTokens(v) + } + if _, ok := _c.mutation.OutputTokens(); !ok { + v := usagelog.DefaultOutputTokens + _c.mutation.SetOutputTokens(v) + } + if _, ok := _c.mutation.CacheCreationTokens(); !ok { + v := usagelog.DefaultCacheCreationTokens + _c.mutation.SetCacheCreationTokens(v) + } + if _, ok := _c.mutation.CacheReadTokens(); !ok { + v := usagelog.DefaultCacheReadTokens + _c.mutation.SetCacheReadTokens(v) + } + if _, ok := _c.mutation.CacheCreation5mTokens(); !ok { + v := usagelog.DefaultCacheCreation5mTokens + _c.mutation.SetCacheCreation5mTokens(v) + } + if _, ok := _c.mutation.CacheCreation1hTokens(); !ok { + v := usagelog.DefaultCacheCreation1hTokens + _c.mutation.SetCacheCreation1hTokens(v) + } + if _, ok := _c.mutation.InputCost(); !ok { + v := usagelog.DefaultInputCost + _c.mutation.SetInputCost(v) + } + if _, ok := _c.mutation.OutputCost(); !ok { + v := usagelog.DefaultOutputCost + _c.mutation.SetOutputCost(v) + } + if _, ok := _c.mutation.CacheCreationCost(); !ok { + v := usagelog.DefaultCacheCreationCost + _c.mutation.SetCacheCreationCost(v) + } + if _, ok := _c.mutation.CacheReadCost(); !ok { + v := usagelog.DefaultCacheReadCost + _c.mutation.SetCacheReadCost(v) + } + if _, ok := _c.mutation.TotalCost(); !ok { + v := usagelog.DefaultTotalCost + _c.mutation.SetTotalCost(v) + } + if _, ok := _c.mutation.ActualCost(); !ok { + v := usagelog.DefaultActualCost + _c.mutation.SetActualCost(v) + } + if _, ok := _c.mutation.RateMultiplier(); !ok { + v := usagelog.DefaultRateMultiplier + _c.mutation.SetRateMultiplier(v) + } + if _, ok := _c.mutation.BillingType(); !ok { + v := usagelog.DefaultBillingType + _c.mutation.SetBillingType(v) + } + if _, ok := _c.mutation.Stream(); !ok { + v := usagelog.DefaultStream + _c.mutation.SetStream(v) + } + if _, ok := _c.mutation.CreatedAt(); !ok { + v := usagelog.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *UsageLogCreate) check() error { + if _, ok := _c.mutation.UserID(); !ok { + return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "UsageLog.user_id"`)} + } + if _, ok := _c.mutation.APIKeyID(); !ok { + return &ValidationError{Name: "api_key_id", err: errors.New(`ent: missing required field "UsageLog.api_key_id"`)} + } + if _, ok := _c.mutation.AccountID(); !ok { + return &ValidationError{Name: "account_id", err: errors.New(`ent: missing required field "UsageLog.account_id"`)} + } + if _, ok := _c.mutation.RequestID(); !ok { + return &ValidationError{Name: "request_id", err: errors.New(`ent: missing required field "UsageLog.request_id"`)} + } + if v, ok := _c.mutation.RequestID(); ok { + if err := usagelog.RequestIDValidator(v); err != nil { + return &ValidationError{Name: "request_id", err: fmt.Errorf(`ent: validator failed for field "UsageLog.request_id": %w`, err)} + } + } + if _, ok := _c.mutation.Model(); !ok { + return &ValidationError{Name: "model", err: errors.New(`ent: missing required field "UsageLog.model"`)} + } + if v, ok := _c.mutation.Model(); ok { + if err := usagelog.ModelValidator(v); err != nil { + return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} + } + } + if _, ok := _c.mutation.InputTokens(); !ok { + return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)} + } + if _, ok := _c.mutation.OutputTokens(); !ok { + return &ValidationError{Name: "output_tokens", err: errors.New(`ent: missing required field "UsageLog.output_tokens"`)} + } + if _, ok := _c.mutation.CacheCreationTokens(); !ok { + return &ValidationError{Name: "cache_creation_tokens", err: errors.New(`ent: missing required field "UsageLog.cache_creation_tokens"`)} + } + if _, ok := _c.mutation.CacheReadTokens(); !ok { + return &ValidationError{Name: "cache_read_tokens", err: errors.New(`ent: missing required field "UsageLog.cache_read_tokens"`)} + } + if _, ok := _c.mutation.CacheCreation5mTokens(); !ok { + return &ValidationError{Name: "cache_creation_5m_tokens", err: errors.New(`ent: missing required field "UsageLog.cache_creation_5m_tokens"`)} + } + if _, ok := _c.mutation.CacheCreation1hTokens(); !ok { + return &ValidationError{Name: "cache_creation_1h_tokens", err: errors.New(`ent: missing required field "UsageLog.cache_creation_1h_tokens"`)} + } + if _, ok := _c.mutation.InputCost(); !ok { + return &ValidationError{Name: "input_cost", err: errors.New(`ent: missing required field "UsageLog.input_cost"`)} + } + if _, ok := _c.mutation.OutputCost(); !ok { + return &ValidationError{Name: "output_cost", err: errors.New(`ent: missing required field "UsageLog.output_cost"`)} + } + if _, ok := _c.mutation.CacheCreationCost(); !ok { + return &ValidationError{Name: "cache_creation_cost", err: errors.New(`ent: missing required field "UsageLog.cache_creation_cost"`)} + } + if _, ok := _c.mutation.CacheReadCost(); !ok { + return &ValidationError{Name: "cache_read_cost", err: errors.New(`ent: missing required field "UsageLog.cache_read_cost"`)} + } + if _, ok := _c.mutation.TotalCost(); !ok { + return &ValidationError{Name: "total_cost", err: errors.New(`ent: missing required field "UsageLog.total_cost"`)} + } + if _, ok := _c.mutation.ActualCost(); !ok { + return &ValidationError{Name: "actual_cost", err: errors.New(`ent: missing required field "UsageLog.actual_cost"`)} + } + if _, ok := _c.mutation.RateMultiplier(); !ok { + return &ValidationError{Name: "rate_multiplier", err: errors.New(`ent: missing required field "UsageLog.rate_multiplier"`)} + } + if _, ok := _c.mutation.BillingType(); !ok { + return &ValidationError{Name: "billing_type", err: errors.New(`ent: missing required field "UsageLog.billing_type"`)} + } + if _, ok := _c.mutation.Stream(); !ok { + return &ValidationError{Name: "stream", err: errors.New(`ent: missing required field "UsageLog.stream"`)} + } + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)} + } + if len(_c.mutation.UserIDs()) == 0 { + return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "UsageLog.user"`)} + } + if len(_c.mutation.APIKeyIDs()) == 0 { + return &ValidationError{Name: "api_key", err: errors.New(`ent: missing required edge "UsageLog.api_key"`)} + } + if len(_c.mutation.AccountIDs()) == 0 { + return &ValidationError{Name: "account", err: errors.New(`ent: missing required edge "UsageLog.account"`)} + } + return nil +} + +func (_c *UsageLogCreate) sqlSave(ctx context.Context) (*UsageLog, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { + var ( + _node = &UsageLog{config: _c.config} + _spec = sqlgraph.NewCreateSpec(usagelog.Table, sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.RequestID(); ok { + _spec.SetField(usagelog.FieldRequestID, field.TypeString, value) + _node.RequestID = value + } + if value, ok := _c.mutation.Model(); ok { + _spec.SetField(usagelog.FieldModel, field.TypeString, value) + _node.Model = value + } + if value, ok := _c.mutation.InputTokens(); ok { + _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) + _node.InputTokens = value + } + if value, ok := _c.mutation.OutputTokens(); ok { + _spec.SetField(usagelog.FieldOutputTokens, field.TypeInt, value) + _node.OutputTokens = value + } + if value, ok := _c.mutation.CacheCreationTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreationTokens, field.TypeInt, value) + _node.CacheCreationTokens = value + } + if value, ok := _c.mutation.CacheReadTokens(); ok { + _spec.SetField(usagelog.FieldCacheReadTokens, field.TypeInt, value) + _node.CacheReadTokens = value + } + if value, ok := _c.mutation.CacheCreation5mTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreation5mTokens, field.TypeInt, value) + _node.CacheCreation5mTokens = value + } + if value, ok := _c.mutation.CacheCreation1hTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreation1hTokens, field.TypeInt, value) + _node.CacheCreation1hTokens = value + } + if value, ok := _c.mutation.InputCost(); ok { + _spec.SetField(usagelog.FieldInputCost, field.TypeFloat64, value) + _node.InputCost = value + } + if value, ok := _c.mutation.OutputCost(); ok { + _spec.SetField(usagelog.FieldOutputCost, field.TypeFloat64, value) + _node.OutputCost = value + } + if value, ok := _c.mutation.CacheCreationCost(); ok { + _spec.SetField(usagelog.FieldCacheCreationCost, field.TypeFloat64, value) + _node.CacheCreationCost = value + } + if value, ok := _c.mutation.CacheReadCost(); ok { + _spec.SetField(usagelog.FieldCacheReadCost, field.TypeFloat64, value) + _node.CacheReadCost = value + } + if value, ok := _c.mutation.TotalCost(); ok { + _spec.SetField(usagelog.FieldTotalCost, field.TypeFloat64, value) + _node.TotalCost = value + } + if value, ok := _c.mutation.ActualCost(); ok { + _spec.SetField(usagelog.FieldActualCost, field.TypeFloat64, value) + _node.ActualCost = value + } + if value, ok := _c.mutation.RateMultiplier(); ok { + _spec.SetField(usagelog.FieldRateMultiplier, field.TypeFloat64, value) + _node.RateMultiplier = value + } + if value, ok := _c.mutation.BillingType(); ok { + _spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value) + _node.BillingType = value + } + if value, ok := _c.mutation.Stream(); ok { + _spec.SetField(usagelog.FieldStream, field.TypeBool, value) + _node.Stream = value + } + if value, ok := _c.mutation.DurationMs(); ok { + _spec.SetField(usagelog.FieldDurationMs, field.TypeInt, value) + _node.DurationMs = &value + } + if value, ok := _c.mutation.FirstTokenMs(); ok { + _spec.SetField(usagelog.FieldFirstTokenMs, field.TypeInt, value) + _node.FirstTokenMs = &value + } + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.UserTable, + Columns: []string{usagelog.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.UserID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.APIKeyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.APIKeyTable, + Columns: []string{usagelog.APIKeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(apikey.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.APIKeyID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.AccountIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.AccountTable, + Columns: []string{usagelog.AccountColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.AccountID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.GroupIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.GroupTable, + Columns: []string{usagelog.GroupColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.GroupID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.SubscriptionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.SubscriptionTable, + Columns: []string{usagelog.SubscriptionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.SubscriptionID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.UsageLog.Create(). +// SetUserID(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.UsageLogUpsert) { +// SetUserID(v+v). +// }). +// Exec(ctx) +func (_c *UsageLogCreate) OnConflict(opts ...sql.ConflictOption) *UsageLogUpsertOne { + _c.conflict = opts + return &UsageLogUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.UsageLog.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *UsageLogCreate) OnConflictColumns(columns ...string) *UsageLogUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &UsageLogUpsertOne{ + create: _c, + } +} + +type ( + // UsageLogUpsertOne is the builder for "upsert"-ing + // one UsageLog node. + UsageLogUpsertOne struct { + create *UsageLogCreate + } + + // UsageLogUpsert is the "OnConflict" setter. + UsageLogUpsert struct { + *sql.UpdateSet + } +) + +// SetUserID sets the "user_id" field. +func (u *UsageLogUpsert) SetUserID(v int64) *UsageLogUpsert { + u.Set(usagelog.FieldUserID, v) + return u +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateUserID() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldUserID) + return u +} + +// SetAPIKeyID sets the "api_key_id" field. +func (u *UsageLogUpsert) SetAPIKeyID(v int64) *UsageLogUpsert { + u.Set(usagelog.FieldAPIKeyID, v) + return u +} + +// UpdateAPIKeyID sets the "api_key_id" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateAPIKeyID() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldAPIKeyID) + return u +} + +// SetAccountID sets the "account_id" field. +func (u *UsageLogUpsert) SetAccountID(v int64) *UsageLogUpsert { + u.Set(usagelog.FieldAccountID, v) + return u +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateAccountID() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldAccountID) + return u +} + +// SetRequestID sets the "request_id" field. +func (u *UsageLogUpsert) SetRequestID(v string) *UsageLogUpsert { + u.Set(usagelog.FieldRequestID, v) + return u +} + +// UpdateRequestID sets the "request_id" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateRequestID() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldRequestID) + return u +} + +// SetModel sets the "model" field. +func (u *UsageLogUpsert) SetModel(v string) *UsageLogUpsert { + u.Set(usagelog.FieldModel, v) + return u +} + +// UpdateModel sets the "model" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldModel) + return u +} + +// SetGroupID sets the "group_id" field. +func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert { + u.Set(usagelog.FieldGroupID, v) + return u +} + +// UpdateGroupID sets the "group_id" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateGroupID() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldGroupID) + return u +} + +// ClearGroupID clears the value of the "group_id" field. +func (u *UsageLogUpsert) ClearGroupID() *UsageLogUpsert { + u.SetNull(usagelog.FieldGroupID) + return u +} + +// SetSubscriptionID sets the "subscription_id" field. +func (u *UsageLogUpsert) SetSubscriptionID(v int64) *UsageLogUpsert { + u.Set(usagelog.FieldSubscriptionID, v) + return u +} + +// UpdateSubscriptionID sets the "subscription_id" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateSubscriptionID() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldSubscriptionID) + return u +} + +// ClearSubscriptionID clears the value of the "subscription_id" field. +func (u *UsageLogUpsert) ClearSubscriptionID() *UsageLogUpsert { + u.SetNull(usagelog.FieldSubscriptionID) + return u +} + +// SetInputTokens sets the "input_tokens" field. +func (u *UsageLogUpsert) SetInputTokens(v int) *UsageLogUpsert { + u.Set(usagelog.FieldInputTokens, v) + return u +} + +// UpdateInputTokens sets the "input_tokens" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateInputTokens() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldInputTokens) + return u +} + +// AddInputTokens adds v to the "input_tokens" field. +func (u *UsageLogUpsert) AddInputTokens(v int) *UsageLogUpsert { + u.Add(usagelog.FieldInputTokens, v) + return u +} + +// SetOutputTokens sets the "output_tokens" field. +func (u *UsageLogUpsert) SetOutputTokens(v int) *UsageLogUpsert { + u.Set(usagelog.FieldOutputTokens, v) + return u +} + +// UpdateOutputTokens sets the "output_tokens" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateOutputTokens() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldOutputTokens) + return u +} + +// AddOutputTokens adds v to the "output_tokens" field. +func (u *UsageLogUpsert) AddOutputTokens(v int) *UsageLogUpsert { + u.Add(usagelog.FieldOutputTokens, v) + return u +} + +// SetCacheCreationTokens sets the "cache_creation_tokens" field. +func (u *UsageLogUpsert) SetCacheCreationTokens(v int) *UsageLogUpsert { + u.Set(usagelog.FieldCacheCreationTokens, v) + return u +} + +// UpdateCacheCreationTokens sets the "cache_creation_tokens" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateCacheCreationTokens() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldCacheCreationTokens) + return u +} + +// AddCacheCreationTokens adds v to the "cache_creation_tokens" field. +func (u *UsageLogUpsert) AddCacheCreationTokens(v int) *UsageLogUpsert { + u.Add(usagelog.FieldCacheCreationTokens, v) + return u +} + +// SetCacheReadTokens sets the "cache_read_tokens" field. +func (u *UsageLogUpsert) SetCacheReadTokens(v int) *UsageLogUpsert { + u.Set(usagelog.FieldCacheReadTokens, v) + return u +} + +// UpdateCacheReadTokens sets the "cache_read_tokens" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateCacheReadTokens() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldCacheReadTokens) + return u +} + +// AddCacheReadTokens adds v to the "cache_read_tokens" field. +func (u *UsageLogUpsert) AddCacheReadTokens(v int) *UsageLogUpsert { + u.Add(usagelog.FieldCacheReadTokens, v) + return u +} + +// SetCacheCreation5mTokens sets the "cache_creation_5m_tokens" field. +func (u *UsageLogUpsert) SetCacheCreation5mTokens(v int) *UsageLogUpsert { + u.Set(usagelog.FieldCacheCreation5mTokens, v) + return u +} + +// UpdateCacheCreation5mTokens sets the "cache_creation_5m_tokens" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateCacheCreation5mTokens() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldCacheCreation5mTokens) + return u +} + +// AddCacheCreation5mTokens adds v to the "cache_creation_5m_tokens" field. +func (u *UsageLogUpsert) AddCacheCreation5mTokens(v int) *UsageLogUpsert { + u.Add(usagelog.FieldCacheCreation5mTokens, v) + return u +} + +// SetCacheCreation1hTokens sets the "cache_creation_1h_tokens" field. +func (u *UsageLogUpsert) SetCacheCreation1hTokens(v int) *UsageLogUpsert { + u.Set(usagelog.FieldCacheCreation1hTokens, v) + return u +} + +// UpdateCacheCreation1hTokens sets the "cache_creation_1h_tokens" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateCacheCreation1hTokens() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldCacheCreation1hTokens) + return u +} + +// AddCacheCreation1hTokens adds v to the "cache_creation_1h_tokens" field. +func (u *UsageLogUpsert) AddCacheCreation1hTokens(v int) *UsageLogUpsert { + u.Add(usagelog.FieldCacheCreation1hTokens, v) + return u +} + +// SetInputCost sets the "input_cost" field. +func (u *UsageLogUpsert) SetInputCost(v float64) *UsageLogUpsert { + u.Set(usagelog.FieldInputCost, v) + return u +} + +// UpdateInputCost sets the "input_cost" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateInputCost() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldInputCost) + return u +} + +// AddInputCost adds v to the "input_cost" field. +func (u *UsageLogUpsert) AddInputCost(v float64) *UsageLogUpsert { + u.Add(usagelog.FieldInputCost, v) + return u +} + +// SetOutputCost sets the "output_cost" field. +func (u *UsageLogUpsert) SetOutputCost(v float64) *UsageLogUpsert { + u.Set(usagelog.FieldOutputCost, v) + return u +} + +// UpdateOutputCost sets the "output_cost" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateOutputCost() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldOutputCost) + return u +} + +// AddOutputCost adds v to the "output_cost" field. +func (u *UsageLogUpsert) AddOutputCost(v float64) *UsageLogUpsert { + u.Add(usagelog.FieldOutputCost, v) + return u +} + +// SetCacheCreationCost sets the "cache_creation_cost" field. +func (u *UsageLogUpsert) SetCacheCreationCost(v float64) *UsageLogUpsert { + u.Set(usagelog.FieldCacheCreationCost, v) + return u +} + +// UpdateCacheCreationCost sets the "cache_creation_cost" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateCacheCreationCost() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldCacheCreationCost) + return u +} + +// AddCacheCreationCost adds v to the "cache_creation_cost" field. +func (u *UsageLogUpsert) AddCacheCreationCost(v float64) *UsageLogUpsert { + u.Add(usagelog.FieldCacheCreationCost, v) + return u +} + +// SetCacheReadCost sets the "cache_read_cost" field. +func (u *UsageLogUpsert) SetCacheReadCost(v float64) *UsageLogUpsert { + u.Set(usagelog.FieldCacheReadCost, v) + return u +} + +// UpdateCacheReadCost sets the "cache_read_cost" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateCacheReadCost() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldCacheReadCost) + return u +} + +// AddCacheReadCost adds v to the "cache_read_cost" field. +func (u *UsageLogUpsert) AddCacheReadCost(v float64) *UsageLogUpsert { + u.Add(usagelog.FieldCacheReadCost, v) + return u +} + +// SetTotalCost sets the "total_cost" field. +func (u *UsageLogUpsert) SetTotalCost(v float64) *UsageLogUpsert { + u.Set(usagelog.FieldTotalCost, v) + return u +} + +// UpdateTotalCost sets the "total_cost" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateTotalCost() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldTotalCost) + return u +} + +// AddTotalCost adds v to the "total_cost" field. +func (u *UsageLogUpsert) AddTotalCost(v float64) *UsageLogUpsert { + u.Add(usagelog.FieldTotalCost, v) + return u +} + +// SetActualCost sets the "actual_cost" field. +func (u *UsageLogUpsert) SetActualCost(v float64) *UsageLogUpsert { + u.Set(usagelog.FieldActualCost, v) + return u +} + +// UpdateActualCost sets the "actual_cost" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateActualCost() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldActualCost) + return u +} + +// AddActualCost adds v to the "actual_cost" field. +func (u *UsageLogUpsert) AddActualCost(v float64) *UsageLogUpsert { + u.Add(usagelog.FieldActualCost, v) + return u +} + +// SetRateMultiplier sets the "rate_multiplier" field. +func (u *UsageLogUpsert) SetRateMultiplier(v float64) *UsageLogUpsert { + u.Set(usagelog.FieldRateMultiplier, v) + return u +} + +// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateRateMultiplier() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldRateMultiplier) + return u +} + +// AddRateMultiplier adds v to the "rate_multiplier" field. +func (u *UsageLogUpsert) AddRateMultiplier(v float64) *UsageLogUpsert { + u.Add(usagelog.FieldRateMultiplier, v) + return u +} + +// SetBillingType sets the "billing_type" field. +func (u *UsageLogUpsert) SetBillingType(v int8) *UsageLogUpsert { + u.Set(usagelog.FieldBillingType, v) + return u +} + +// UpdateBillingType sets the "billing_type" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateBillingType() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldBillingType) + return u +} + +// AddBillingType adds v to the "billing_type" field. +func (u *UsageLogUpsert) AddBillingType(v int8) *UsageLogUpsert { + u.Add(usagelog.FieldBillingType, v) + return u +} + +// SetStream sets the "stream" field. +func (u *UsageLogUpsert) SetStream(v bool) *UsageLogUpsert { + u.Set(usagelog.FieldStream, v) + return u +} + +// UpdateStream sets the "stream" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateStream() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldStream) + return u +} + +// SetDurationMs sets the "duration_ms" field. +func (u *UsageLogUpsert) SetDurationMs(v int) *UsageLogUpsert { + u.Set(usagelog.FieldDurationMs, v) + return u +} + +// UpdateDurationMs sets the "duration_ms" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateDurationMs() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldDurationMs) + return u +} + +// AddDurationMs adds v to the "duration_ms" field. +func (u *UsageLogUpsert) AddDurationMs(v int) *UsageLogUpsert { + u.Add(usagelog.FieldDurationMs, v) + return u +} + +// ClearDurationMs clears the value of the "duration_ms" field. +func (u *UsageLogUpsert) ClearDurationMs() *UsageLogUpsert { + u.SetNull(usagelog.FieldDurationMs) + return u +} + +// SetFirstTokenMs sets the "first_token_ms" field. +func (u *UsageLogUpsert) SetFirstTokenMs(v int) *UsageLogUpsert { + u.Set(usagelog.FieldFirstTokenMs, v) + return u +} + +// UpdateFirstTokenMs sets the "first_token_ms" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateFirstTokenMs() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldFirstTokenMs) + return u +} + +// AddFirstTokenMs adds v to the "first_token_ms" field. +func (u *UsageLogUpsert) AddFirstTokenMs(v int) *UsageLogUpsert { + u.Add(usagelog.FieldFirstTokenMs, v) + return u +} + +// ClearFirstTokenMs clears the value of the "first_token_ms" field. +func (u *UsageLogUpsert) ClearFirstTokenMs() *UsageLogUpsert { + u.SetNull(usagelog.FieldFirstTokenMs) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.UsageLog.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *UsageLogUpsertOne) UpdateNewValues() *UsageLogUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(usagelog.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.UsageLog.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *UsageLogUpsertOne) Ignore() *UsageLogUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *UsageLogUpsertOne) DoNothing() *UsageLogUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the UsageLogCreate.OnConflict +// documentation for more info. +func (u *UsageLogUpsertOne) Update(set func(*UsageLogUpsert)) *UsageLogUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&UsageLogUpsert{UpdateSet: update}) + })) + return u +} + +// SetUserID sets the "user_id" field. +func (u *UsageLogUpsertOne) SetUserID(v int64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateUserID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateUserID() + }) +} + +// SetAPIKeyID sets the "api_key_id" field. +func (u *UsageLogUpsertOne) SetAPIKeyID(v int64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetAPIKeyID(v) + }) +} + +// UpdateAPIKeyID sets the "api_key_id" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateAPIKeyID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateAPIKeyID() + }) +} + +// SetAccountID sets the "account_id" field. +func (u *UsageLogUpsertOne) SetAccountID(v int64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetAccountID(v) + }) +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateAccountID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateAccountID() + }) +} + +// SetRequestID sets the "request_id" field. +func (u *UsageLogUpsertOne) SetRequestID(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetRequestID(v) + }) +} + +// UpdateRequestID sets the "request_id" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateRequestID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateRequestID() + }) +} + +// SetModel sets the "model" field. +func (u *UsageLogUpsertOne) SetModel(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetModel(v) + }) +} + +// UpdateModel sets the "model" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateModel() + }) +} + +// SetGroupID sets the "group_id" field. +func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetGroupID(v) + }) +} + +// UpdateGroupID sets the "group_id" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateGroupID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateGroupID() + }) +} + +// ClearGroupID clears the value of the "group_id" field. +func (u *UsageLogUpsertOne) ClearGroupID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearGroupID() + }) +} + +// SetSubscriptionID sets the "subscription_id" field. +func (u *UsageLogUpsertOne) SetSubscriptionID(v int64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetSubscriptionID(v) + }) +} + +// UpdateSubscriptionID sets the "subscription_id" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateSubscriptionID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateSubscriptionID() + }) +} + +// ClearSubscriptionID clears the value of the "subscription_id" field. +func (u *UsageLogUpsertOne) ClearSubscriptionID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearSubscriptionID() + }) +} + +// SetInputTokens sets the "input_tokens" field. +func (u *UsageLogUpsertOne) SetInputTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetInputTokens(v) + }) +} + +// AddInputTokens adds v to the "input_tokens" field. +func (u *UsageLogUpsertOne) AddInputTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddInputTokens(v) + }) +} + +// UpdateInputTokens sets the "input_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateInputTokens() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateInputTokens() + }) +} + +// SetOutputTokens sets the "output_tokens" field. +func (u *UsageLogUpsertOne) SetOutputTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetOutputTokens(v) + }) +} + +// AddOutputTokens adds v to the "output_tokens" field. +func (u *UsageLogUpsertOne) AddOutputTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddOutputTokens(v) + }) +} + +// UpdateOutputTokens sets the "output_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateOutputTokens() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateOutputTokens() + }) +} + +// SetCacheCreationTokens sets the "cache_creation_tokens" field. +func (u *UsageLogUpsertOne) SetCacheCreationTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheCreationTokens(v) + }) +} + +// AddCacheCreationTokens adds v to the "cache_creation_tokens" field. +func (u *UsageLogUpsertOne) AddCacheCreationTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheCreationTokens(v) + }) +} + +// UpdateCacheCreationTokens sets the "cache_creation_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateCacheCreationTokens() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheCreationTokens() + }) +} + +// SetCacheReadTokens sets the "cache_read_tokens" field. +func (u *UsageLogUpsertOne) SetCacheReadTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheReadTokens(v) + }) +} + +// AddCacheReadTokens adds v to the "cache_read_tokens" field. +func (u *UsageLogUpsertOne) AddCacheReadTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheReadTokens(v) + }) +} + +// UpdateCacheReadTokens sets the "cache_read_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateCacheReadTokens() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheReadTokens() + }) +} + +// SetCacheCreation5mTokens sets the "cache_creation_5m_tokens" field. +func (u *UsageLogUpsertOne) SetCacheCreation5mTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheCreation5mTokens(v) + }) +} + +// AddCacheCreation5mTokens adds v to the "cache_creation_5m_tokens" field. +func (u *UsageLogUpsertOne) AddCacheCreation5mTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheCreation5mTokens(v) + }) +} + +// UpdateCacheCreation5mTokens sets the "cache_creation_5m_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateCacheCreation5mTokens() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheCreation5mTokens() + }) +} + +// SetCacheCreation1hTokens sets the "cache_creation_1h_tokens" field. +func (u *UsageLogUpsertOne) SetCacheCreation1hTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheCreation1hTokens(v) + }) +} + +// AddCacheCreation1hTokens adds v to the "cache_creation_1h_tokens" field. +func (u *UsageLogUpsertOne) AddCacheCreation1hTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheCreation1hTokens(v) + }) +} + +// UpdateCacheCreation1hTokens sets the "cache_creation_1h_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateCacheCreation1hTokens() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheCreation1hTokens() + }) +} + +// SetInputCost sets the "input_cost" field. +func (u *UsageLogUpsertOne) SetInputCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetInputCost(v) + }) +} + +// AddInputCost adds v to the "input_cost" field. +func (u *UsageLogUpsertOne) AddInputCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddInputCost(v) + }) +} + +// UpdateInputCost sets the "input_cost" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateInputCost() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateInputCost() + }) +} + +// SetOutputCost sets the "output_cost" field. +func (u *UsageLogUpsertOne) SetOutputCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetOutputCost(v) + }) +} + +// AddOutputCost adds v to the "output_cost" field. +func (u *UsageLogUpsertOne) AddOutputCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddOutputCost(v) + }) +} + +// UpdateOutputCost sets the "output_cost" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateOutputCost() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateOutputCost() + }) +} + +// SetCacheCreationCost sets the "cache_creation_cost" field. +func (u *UsageLogUpsertOne) SetCacheCreationCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheCreationCost(v) + }) +} + +// AddCacheCreationCost adds v to the "cache_creation_cost" field. +func (u *UsageLogUpsertOne) AddCacheCreationCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheCreationCost(v) + }) +} + +// UpdateCacheCreationCost sets the "cache_creation_cost" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateCacheCreationCost() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheCreationCost() + }) +} + +// SetCacheReadCost sets the "cache_read_cost" field. +func (u *UsageLogUpsertOne) SetCacheReadCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheReadCost(v) + }) +} + +// AddCacheReadCost adds v to the "cache_read_cost" field. +func (u *UsageLogUpsertOne) AddCacheReadCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheReadCost(v) + }) +} + +// UpdateCacheReadCost sets the "cache_read_cost" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateCacheReadCost() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheReadCost() + }) +} + +// SetTotalCost sets the "total_cost" field. +func (u *UsageLogUpsertOne) SetTotalCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetTotalCost(v) + }) +} + +// AddTotalCost adds v to the "total_cost" field. +func (u *UsageLogUpsertOne) AddTotalCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddTotalCost(v) + }) +} + +// UpdateTotalCost sets the "total_cost" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateTotalCost() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateTotalCost() + }) +} + +// SetActualCost sets the "actual_cost" field. +func (u *UsageLogUpsertOne) SetActualCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetActualCost(v) + }) +} + +// AddActualCost adds v to the "actual_cost" field. +func (u *UsageLogUpsertOne) AddActualCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddActualCost(v) + }) +} + +// UpdateActualCost sets the "actual_cost" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateActualCost() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateActualCost() + }) +} + +// SetRateMultiplier sets the "rate_multiplier" field. +func (u *UsageLogUpsertOne) SetRateMultiplier(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetRateMultiplier(v) + }) +} + +// AddRateMultiplier adds v to the "rate_multiplier" field. +func (u *UsageLogUpsertOne) AddRateMultiplier(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddRateMultiplier(v) + }) +} + +// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateRateMultiplier() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateRateMultiplier() + }) +} + +// SetBillingType sets the "billing_type" field. +func (u *UsageLogUpsertOne) SetBillingType(v int8) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetBillingType(v) + }) +} + +// AddBillingType adds v to the "billing_type" field. +func (u *UsageLogUpsertOne) AddBillingType(v int8) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddBillingType(v) + }) +} + +// UpdateBillingType sets the "billing_type" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateBillingType() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateBillingType() + }) +} + +// SetStream sets the "stream" field. +func (u *UsageLogUpsertOne) SetStream(v bool) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetStream(v) + }) +} + +// UpdateStream sets the "stream" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateStream() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateStream() + }) +} + +// SetDurationMs sets the "duration_ms" field. +func (u *UsageLogUpsertOne) SetDurationMs(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetDurationMs(v) + }) +} + +// AddDurationMs adds v to the "duration_ms" field. +func (u *UsageLogUpsertOne) AddDurationMs(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddDurationMs(v) + }) +} + +// UpdateDurationMs sets the "duration_ms" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateDurationMs() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateDurationMs() + }) +} + +// ClearDurationMs clears the value of the "duration_ms" field. +func (u *UsageLogUpsertOne) ClearDurationMs() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearDurationMs() + }) +} + +// SetFirstTokenMs sets the "first_token_ms" field. +func (u *UsageLogUpsertOne) SetFirstTokenMs(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetFirstTokenMs(v) + }) +} + +// AddFirstTokenMs adds v to the "first_token_ms" field. +func (u *UsageLogUpsertOne) AddFirstTokenMs(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddFirstTokenMs(v) + }) +} + +// UpdateFirstTokenMs sets the "first_token_ms" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateFirstTokenMs() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateFirstTokenMs() + }) +} + +// ClearFirstTokenMs clears the value of the "first_token_ms" field. +func (u *UsageLogUpsertOne) ClearFirstTokenMs() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearFirstTokenMs() + }) +} + +// Exec executes the query. +func (u *UsageLogUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for UsageLogCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *UsageLogUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *UsageLogUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *UsageLogUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// UsageLogCreateBulk is the builder for creating many UsageLog entities in bulk. +type UsageLogCreateBulk struct { + config + err error + builders []*UsageLogCreate + conflict []sql.ConflictOption +} + +// Save creates the UsageLog entities in the database. +func (_c *UsageLogCreateBulk) Save(ctx context.Context) ([]*UsageLog, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*UsageLog, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*UsageLogMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *UsageLogCreateBulk) SaveX(ctx context.Context) []*UsageLog { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *UsageLogCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *UsageLogCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.UsageLog.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.UsageLogUpsert) { +// SetUserID(v+v). +// }). +// Exec(ctx) +func (_c *UsageLogCreateBulk) OnConflict(opts ...sql.ConflictOption) *UsageLogUpsertBulk { + _c.conflict = opts + return &UsageLogUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.UsageLog.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *UsageLogCreateBulk) OnConflictColumns(columns ...string) *UsageLogUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &UsageLogUpsertBulk{ + create: _c, + } +} + +// UsageLogUpsertBulk is the builder for "upsert"-ing +// a bulk of UsageLog nodes. +type UsageLogUpsertBulk struct { + create *UsageLogCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.UsageLog.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *UsageLogUpsertBulk) UpdateNewValues() *UsageLogUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(usagelog.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.UsageLog.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *UsageLogUpsertBulk) Ignore() *UsageLogUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *UsageLogUpsertBulk) DoNothing() *UsageLogUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the UsageLogCreateBulk.OnConflict +// documentation for more info. +func (u *UsageLogUpsertBulk) Update(set func(*UsageLogUpsert)) *UsageLogUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&UsageLogUpsert{UpdateSet: update}) + })) + return u +} + +// SetUserID sets the "user_id" field. +func (u *UsageLogUpsertBulk) SetUserID(v int64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateUserID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateUserID() + }) +} + +// SetAPIKeyID sets the "api_key_id" field. +func (u *UsageLogUpsertBulk) SetAPIKeyID(v int64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetAPIKeyID(v) + }) +} + +// UpdateAPIKeyID sets the "api_key_id" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateAPIKeyID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateAPIKeyID() + }) +} + +// SetAccountID sets the "account_id" field. +func (u *UsageLogUpsertBulk) SetAccountID(v int64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetAccountID(v) + }) +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateAccountID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateAccountID() + }) +} + +// SetRequestID sets the "request_id" field. +func (u *UsageLogUpsertBulk) SetRequestID(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetRequestID(v) + }) +} + +// UpdateRequestID sets the "request_id" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateRequestID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateRequestID() + }) +} + +// SetModel sets the "model" field. +func (u *UsageLogUpsertBulk) SetModel(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetModel(v) + }) +} + +// UpdateModel sets the "model" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateModel() + }) +} + +// SetGroupID sets the "group_id" field. +func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetGroupID(v) + }) +} + +// UpdateGroupID sets the "group_id" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateGroupID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateGroupID() + }) +} + +// ClearGroupID clears the value of the "group_id" field. +func (u *UsageLogUpsertBulk) ClearGroupID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearGroupID() + }) +} + +// SetSubscriptionID sets the "subscription_id" field. +func (u *UsageLogUpsertBulk) SetSubscriptionID(v int64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetSubscriptionID(v) + }) +} + +// UpdateSubscriptionID sets the "subscription_id" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateSubscriptionID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateSubscriptionID() + }) +} + +// ClearSubscriptionID clears the value of the "subscription_id" field. +func (u *UsageLogUpsertBulk) ClearSubscriptionID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearSubscriptionID() + }) +} + +// SetInputTokens sets the "input_tokens" field. +func (u *UsageLogUpsertBulk) SetInputTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetInputTokens(v) + }) +} + +// AddInputTokens adds v to the "input_tokens" field. +func (u *UsageLogUpsertBulk) AddInputTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddInputTokens(v) + }) +} + +// UpdateInputTokens sets the "input_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateInputTokens() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateInputTokens() + }) +} + +// SetOutputTokens sets the "output_tokens" field. +func (u *UsageLogUpsertBulk) SetOutputTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetOutputTokens(v) + }) +} + +// AddOutputTokens adds v to the "output_tokens" field. +func (u *UsageLogUpsertBulk) AddOutputTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddOutputTokens(v) + }) +} + +// UpdateOutputTokens sets the "output_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateOutputTokens() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateOutputTokens() + }) +} + +// SetCacheCreationTokens sets the "cache_creation_tokens" field. +func (u *UsageLogUpsertBulk) SetCacheCreationTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheCreationTokens(v) + }) +} + +// AddCacheCreationTokens adds v to the "cache_creation_tokens" field. +func (u *UsageLogUpsertBulk) AddCacheCreationTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheCreationTokens(v) + }) +} + +// UpdateCacheCreationTokens sets the "cache_creation_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateCacheCreationTokens() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheCreationTokens() + }) +} + +// SetCacheReadTokens sets the "cache_read_tokens" field. +func (u *UsageLogUpsertBulk) SetCacheReadTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheReadTokens(v) + }) +} + +// AddCacheReadTokens adds v to the "cache_read_tokens" field. +func (u *UsageLogUpsertBulk) AddCacheReadTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheReadTokens(v) + }) +} + +// UpdateCacheReadTokens sets the "cache_read_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateCacheReadTokens() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheReadTokens() + }) +} + +// SetCacheCreation5mTokens sets the "cache_creation_5m_tokens" field. +func (u *UsageLogUpsertBulk) SetCacheCreation5mTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheCreation5mTokens(v) + }) +} + +// AddCacheCreation5mTokens adds v to the "cache_creation_5m_tokens" field. +func (u *UsageLogUpsertBulk) AddCacheCreation5mTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheCreation5mTokens(v) + }) +} + +// UpdateCacheCreation5mTokens sets the "cache_creation_5m_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateCacheCreation5mTokens() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheCreation5mTokens() + }) +} + +// SetCacheCreation1hTokens sets the "cache_creation_1h_tokens" field. +func (u *UsageLogUpsertBulk) SetCacheCreation1hTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheCreation1hTokens(v) + }) +} + +// AddCacheCreation1hTokens adds v to the "cache_creation_1h_tokens" field. +func (u *UsageLogUpsertBulk) AddCacheCreation1hTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheCreation1hTokens(v) + }) +} + +// UpdateCacheCreation1hTokens sets the "cache_creation_1h_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateCacheCreation1hTokens() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheCreation1hTokens() + }) +} + +// SetInputCost sets the "input_cost" field. +func (u *UsageLogUpsertBulk) SetInputCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetInputCost(v) + }) +} + +// AddInputCost adds v to the "input_cost" field. +func (u *UsageLogUpsertBulk) AddInputCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddInputCost(v) + }) +} + +// UpdateInputCost sets the "input_cost" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateInputCost() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateInputCost() + }) +} + +// SetOutputCost sets the "output_cost" field. +func (u *UsageLogUpsertBulk) SetOutputCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetOutputCost(v) + }) +} + +// AddOutputCost adds v to the "output_cost" field. +func (u *UsageLogUpsertBulk) AddOutputCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddOutputCost(v) + }) +} + +// UpdateOutputCost sets the "output_cost" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateOutputCost() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateOutputCost() + }) +} + +// SetCacheCreationCost sets the "cache_creation_cost" field. +func (u *UsageLogUpsertBulk) SetCacheCreationCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheCreationCost(v) + }) +} + +// AddCacheCreationCost adds v to the "cache_creation_cost" field. +func (u *UsageLogUpsertBulk) AddCacheCreationCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheCreationCost(v) + }) +} + +// UpdateCacheCreationCost sets the "cache_creation_cost" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateCacheCreationCost() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheCreationCost() + }) +} + +// SetCacheReadCost sets the "cache_read_cost" field. +func (u *UsageLogUpsertBulk) SetCacheReadCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheReadCost(v) + }) +} + +// AddCacheReadCost adds v to the "cache_read_cost" field. +func (u *UsageLogUpsertBulk) AddCacheReadCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheReadCost(v) + }) +} + +// UpdateCacheReadCost sets the "cache_read_cost" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateCacheReadCost() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheReadCost() + }) +} + +// SetTotalCost sets the "total_cost" field. +func (u *UsageLogUpsertBulk) SetTotalCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetTotalCost(v) + }) +} + +// AddTotalCost adds v to the "total_cost" field. +func (u *UsageLogUpsertBulk) AddTotalCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddTotalCost(v) + }) +} + +// UpdateTotalCost sets the "total_cost" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateTotalCost() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateTotalCost() + }) +} + +// SetActualCost sets the "actual_cost" field. +func (u *UsageLogUpsertBulk) SetActualCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetActualCost(v) + }) +} + +// AddActualCost adds v to the "actual_cost" field. +func (u *UsageLogUpsertBulk) AddActualCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddActualCost(v) + }) +} + +// UpdateActualCost sets the "actual_cost" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateActualCost() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateActualCost() + }) +} + +// SetRateMultiplier sets the "rate_multiplier" field. +func (u *UsageLogUpsertBulk) SetRateMultiplier(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetRateMultiplier(v) + }) +} + +// AddRateMultiplier adds v to the "rate_multiplier" field. +func (u *UsageLogUpsertBulk) AddRateMultiplier(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddRateMultiplier(v) + }) +} + +// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateRateMultiplier() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateRateMultiplier() + }) +} + +// SetBillingType sets the "billing_type" field. +func (u *UsageLogUpsertBulk) SetBillingType(v int8) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetBillingType(v) + }) +} + +// AddBillingType adds v to the "billing_type" field. +func (u *UsageLogUpsertBulk) AddBillingType(v int8) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddBillingType(v) + }) +} + +// UpdateBillingType sets the "billing_type" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateBillingType() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateBillingType() + }) +} + +// SetStream sets the "stream" field. +func (u *UsageLogUpsertBulk) SetStream(v bool) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetStream(v) + }) +} + +// UpdateStream sets the "stream" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateStream() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateStream() + }) +} + +// SetDurationMs sets the "duration_ms" field. +func (u *UsageLogUpsertBulk) SetDurationMs(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetDurationMs(v) + }) +} + +// AddDurationMs adds v to the "duration_ms" field. +func (u *UsageLogUpsertBulk) AddDurationMs(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddDurationMs(v) + }) +} + +// UpdateDurationMs sets the "duration_ms" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateDurationMs() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateDurationMs() + }) +} + +// ClearDurationMs clears the value of the "duration_ms" field. +func (u *UsageLogUpsertBulk) ClearDurationMs() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearDurationMs() + }) +} + +// SetFirstTokenMs sets the "first_token_ms" field. +func (u *UsageLogUpsertBulk) SetFirstTokenMs(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetFirstTokenMs(v) + }) +} + +// AddFirstTokenMs adds v to the "first_token_ms" field. +func (u *UsageLogUpsertBulk) AddFirstTokenMs(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddFirstTokenMs(v) + }) +} + +// UpdateFirstTokenMs sets the "first_token_ms" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateFirstTokenMs() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateFirstTokenMs() + }) +} + +// ClearFirstTokenMs clears the value of the "first_token_ms" field. +func (u *UsageLogUpsertBulk) ClearFirstTokenMs() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearFirstTokenMs() + }) +} + +// Exec executes the query. +func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the UsageLogCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for UsageLogCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *UsageLogUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/usagelog_delete.go b/backend/ent/usagelog_delete.go new file mode 100644 index 00000000..73450fda --- /dev/null +++ b/backend/ent/usagelog_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagelog" +) + +// UsageLogDelete is the builder for deleting a UsageLog entity. +type UsageLogDelete struct { + config + hooks []Hook + mutation *UsageLogMutation +} + +// Where appends a list predicates to the UsageLogDelete builder. +func (_d *UsageLogDelete) Where(ps ...predicate.UsageLog) *UsageLogDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *UsageLogDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *UsageLogDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *UsageLogDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(usagelog.Table, sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// UsageLogDeleteOne is the builder for deleting a single UsageLog entity. +type UsageLogDeleteOne struct { + _d *UsageLogDelete +} + +// Where appends a list predicates to the UsageLogDelete builder. +func (_d *UsageLogDeleteOne) Where(ps ...predicate.UsageLog) *UsageLogDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *UsageLogDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{usagelog.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *UsageLogDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/usagelog_query.go b/backend/ent/usagelog_query.go new file mode 100644 index 00000000..8e5013cc --- /dev/null +++ b/backend/ent/usagelog_query.go @@ -0,0 +1,912 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/account" + "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagelog" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/ent/usersubscription" +) + +// UsageLogQuery is the builder for querying UsageLog entities. +type UsageLogQuery struct { + config + ctx *QueryContext + order []usagelog.OrderOption + inters []Interceptor + predicates []predicate.UsageLog + withUser *UserQuery + withAPIKey *ApiKeyQuery + withAccount *AccountQuery + withGroup *GroupQuery + withSubscription *UserSubscriptionQuery + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the UsageLogQuery builder. +func (_q *UsageLogQuery) Where(ps ...predicate.UsageLog) *UsageLogQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *UsageLogQuery) Limit(limit int) *UsageLogQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *UsageLogQuery) Offset(offset int) *UsageLogQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *UsageLogQuery) Unique(unique bool) *UsageLogQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *UsageLogQuery) Order(o ...usagelog.OrderOption) *UsageLogQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryUser chains the current query on the "user" edge. +func (_q *UsageLogQuery) QueryUser() *UserQuery { + query := (&UserClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.UserTable, usagelog.UserColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryAPIKey chains the current query on the "api_key" edge. +func (_q *UsageLogQuery) QueryAPIKey() *ApiKeyQuery { + query := (&ApiKeyClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, selector), + sqlgraph.To(apikey.Table, apikey.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.APIKeyTable, usagelog.APIKeyColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryAccount chains the current query on the "account" edge. +func (_q *UsageLogQuery) QueryAccount() *AccountQuery { + query := (&AccountClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, selector), + sqlgraph.To(account.Table, account.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.AccountTable, usagelog.AccountColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryGroup chains the current query on the "group" edge. +func (_q *UsageLogQuery) QueryGroup() *GroupQuery { + query := (&GroupClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, selector), + sqlgraph.To(group.Table, group.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.GroupTable, usagelog.GroupColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QuerySubscription chains the current query on the "subscription" edge. +func (_q *UsageLogQuery) QuerySubscription() *UserSubscriptionQuery { + query := (&UserSubscriptionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, selector), + sqlgraph.To(usersubscription.Table, usersubscription.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.SubscriptionTable, usagelog.SubscriptionColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first UsageLog entity from the query. +// Returns a *NotFoundError when no UsageLog was found. +func (_q *UsageLogQuery) First(ctx context.Context) (*UsageLog, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{usagelog.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *UsageLogQuery) FirstX(ctx context.Context) *UsageLog { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first UsageLog ID from the query. +// Returns a *NotFoundError when no UsageLog ID was found. +func (_q *UsageLogQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{usagelog.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *UsageLogQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single UsageLog entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one UsageLog entity is found. +// Returns a *NotFoundError when no UsageLog entities are found. +func (_q *UsageLogQuery) Only(ctx context.Context) (*UsageLog, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{usagelog.Label} + default: + return nil, &NotSingularError{usagelog.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *UsageLogQuery) OnlyX(ctx context.Context) *UsageLog { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only UsageLog ID in the query. +// Returns a *NotSingularError when more than one UsageLog ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *UsageLogQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{usagelog.Label} + default: + err = &NotSingularError{usagelog.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *UsageLogQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of UsageLogs. +func (_q *UsageLogQuery) All(ctx context.Context) ([]*UsageLog, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*UsageLog, *UsageLogQuery]() + return withInterceptors[[]*UsageLog](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *UsageLogQuery) AllX(ctx context.Context) []*UsageLog { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of UsageLog IDs. +func (_q *UsageLogQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(usagelog.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *UsageLogQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *UsageLogQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*UsageLogQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *UsageLogQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *UsageLogQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *UsageLogQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the UsageLogQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *UsageLogQuery) Clone() *UsageLogQuery { + if _q == nil { + return nil + } + return &UsageLogQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]usagelog.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.UsageLog{}, _q.predicates...), + withUser: _q.withUser.Clone(), + withAPIKey: _q.withAPIKey.Clone(), + withAccount: _q.withAccount.Clone(), + withGroup: _q.withGroup.Clone(), + withSubscription: _q.withSubscription.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithUser tells the query-builder to eager-load the nodes that are connected to +// the "user" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UsageLogQuery) WithUser(opts ...func(*UserQuery)) *UsageLogQuery { + query := (&UserClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUser = query + return _q +} + +// WithAPIKey tells the query-builder to eager-load the nodes that are connected to +// the "api_key" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UsageLogQuery) WithAPIKey(opts ...func(*ApiKeyQuery)) *UsageLogQuery { + query := (&ApiKeyClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAPIKey = query + return _q +} + +// WithAccount tells the query-builder to eager-load the nodes that are connected to +// the "account" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UsageLogQuery) WithAccount(opts ...func(*AccountQuery)) *UsageLogQuery { + query := (&AccountClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAccount = query + return _q +} + +// WithGroup tells the query-builder to eager-load the nodes that are connected to +// the "group" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UsageLogQuery) WithGroup(opts ...func(*GroupQuery)) *UsageLogQuery { + query := (&GroupClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withGroup = query + return _q +} + +// WithSubscription tells the query-builder to eager-load the nodes that are connected to +// the "subscription" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UsageLogQuery) WithSubscription(opts ...func(*UserSubscriptionQuery)) *UsageLogQuery { + query := (&UserSubscriptionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withSubscription = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// UserID int64 `json:"user_id,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.UsageLog.Query(). +// GroupBy(usagelog.FieldUserID). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *UsageLogQuery) GroupBy(field string, fields ...string) *UsageLogGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &UsageLogGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = usagelog.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// UserID int64 `json:"user_id,omitempty"` +// } +// +// client.UsageLog.Query(). +// Select(usagelog.FieldUserID). +// Scan(ctx, &v) +func (_q *UsageLogQuery) Select(fields ...string) *UsageLogSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &UsageLogSelect{UsageLogQuery: _q} + sbuild.label = usagelog.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a UsageLogSelect configured with the given aggregations. +func (_q *UsageLogQuery) Aggregate(fns ...AggregateFunc) *UsageLogSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *UsageLogQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !usagelog.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *UsageLogQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UsageLog, error) { + var ( + nodes = []*UsageLog{} + _spec = _q.querySpec() + loadedTypes = [5]bool{ + _q.withUser != nil, + _q.withAPIKey != nil, + _q.withAccount != nil, + _q.withGroup != nil, + _q.withSubscription != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*UsageLog).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &UsageLog{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withUser; query != nil { + if err := _q.loadUser(ctx, query, nodes, nil, + func(n *UsageLog, e *User) { n.Edges.User = e }); err != nil { + return nil, err + } + } + if query := _q.withAPIKey; query != nil { + if err := _q.loadAPIKey(ctx, query, nodes, nil, + func(n *UsageLog, e *ApiKey) { n.Edges.APIKey = e }); err != nil { + return nil, err + } + } + if query := _q.withAccount; query != nil { + if err := _q.loadAccount(ctx, query, nodes, nil, + func(n *UsageLog, e *Account) { n.Edges.Account = e }); err != nil { + return nil, err + } + } + if query := _q.withGroup; query != nil { + if err := _q.loadGroup(ctx, query, nodes, nil, + func(n *UsageLog, e *Group) { n.Edges.Group = e }); err != nil { + return nil, err + } + } + if query := _q.withSubscription; query != nil { + if err := _q.loadSubscription(ctx, query, nodes, nil, + func(n *UsageLog, e *UserSubscription) { n.Edges.Subscription = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *UsageLogQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*UsageLog, init func(*UsageLog), assign func(*UsageLog, *User)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*UsageLog) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *UsageLogQuery) loadAPIKey(ctx context.Context, query *ApiKeyQuery, nodes []*UsageLog, init func(*UsageLog), assign func(*UsageLog, *ApiKey)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*UsageLog) + for i := range nodes { + fk := nodes[i].APIKeyID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(apikey.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "api_key_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *UsageLogQuery) loadAccount(ctx context.Context, query *AccountQuery, nodes []*UsageLog, init func(*UsageLog), assign func(*UsageLog, *Account)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*UsageLog) + for i := range nodes { + fk := nodes[i].AccountID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(account.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "account_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *UsageLogQuery) loadGroup(ctx context.Context, query *GroupQuery, nodes []*UsageLog, init func(*UsageLog), assign func(*UsageLog, *Group)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*UsageLog) + for i := range nodes { + if nodes[i].GroupID == nil { + continue + } + fk := *nodes[i].GroupID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(group.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "group_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *UsageLogQuery) loadSubscription(ctx context.Context, query *UserSubscriptionQuery, nodes []*UsageLog, init func(*UsageLog), assign func(*UsageLog, *UserSubscription)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*UsageLog) + for i := range nodes { + if nodes[i].SubscriptionID == nil { + continue + } + fk := *nodes[i].SubscriptionID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(usersubscription.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "subscription_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *UsageLogQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *UsageLogQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(usagelog.Table, usagelog.Columns, sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, usagelog.FieldID) + for i := range fields { + if fields[i] != usagelog.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withUser != nil { + _spec.Node.AddColumnOnce(usagelog.FieldUserID) + } + if _q.withAPIKey != nil { + _spec.Node.AddColumnOnce(usagelog.FieldAPIKeyID) + } + if _q.withAccount != nil { + _spec.Node.AddColumnOnce(usagelog.FieldAccountID) + } + if _q.withGroup != nil { + _spec.Node.AddColumnOnce(usagelog.FieldGroupID) + } + if _q.withSubscription != nil { + _spec.Node.AddColumnOnce(usagelog.FieldSubscriptionID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *UsageLogQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(usagelog.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = usagelog.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// UsageLogGroupBy is the group-by builder for UsageLog entities. +type UsageLogGroupBy struct { + selector + build *UsageLogQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *UsageLogGroupBy) Aggregate(fns ...AggregateFunc) *UsageLogGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *UsageLogGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*UsageLogQuery, *UsageLogGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *UsageLogGroupBy) sqlScan(ctx context.Context, root *UsageLogQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// UsageLogSelect is the builder for selecting fields of UsageLog entities. +type UsageLogSelect struct { + *UsageLogQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *UsageLogSelect) Aggregate(fns ...AggregateFunc) *UsageLogSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *UsageLogSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*UsageLogQuery, *UsageLogSelect](ctx, _s.UsageLogQuery, _s, _s.inters, v) +} + +func (_s *UsageLogSelect) sqlScan(ctx context.Context, root *UsageLogQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go new file mode 100644 index 00000000..55b8e234 --- /dev/null +++ b/backend/ent/usagelog_update.go @@ -0,0 +1,1800 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/account" + "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagelog" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/ent/usersubscription" +) + +// UsageLogUpdate is the builder for updating UsageLog entities. +type UsageLogUpdate struct { + config + hooks []Hook + mutation *UsageLogMutation +} + +// Where appends a list predicates to the UsageLogUpdate builder. +func (_u *UsageLogUpdate) Where(ps ...predicate.UsageLog) *UsageLogUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *UsageLogUpdate) SetUserID(v int64) *UsageLogUpdate { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableUserID(v *int64) *UsageLogUpdate { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetAPIKeyID sets the "api_key_id" field. +func (_u *UsageLogUpdate) SetAPIKeyID(v int64) *UsageLogUpdate { + _u.mutation.SetAPIKeyID(v) + return _u +} + +// SetNillableAPIKeyID sets the "api_key_id" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableAPIKeyID(v *int64) *UsageLogUpdate { + if v != nil { + _u.SetAPIKeyID(*v) + } + return _u +} + +// SetAccountID sets the "account_id" field. +func (_u *UsageLogUpdate) SetAccountID(v int64) *UsageLogUpdate { + _u.mutation.SetAccountID(v) + return _u +} + +// SetNillableAccountID sets the "account_id" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableAccountID(v *int64) *UsageLogUpdate { + if v != nil { + _u.SetAccountID(*v) + } + return _u +} + +// SetRequestID sets the "request_id" field. +func (_u *UsageLogUpdate) SetRequestID(v string) *UsageLogUpdate { + _u.mutation.SetRequestID(v) + return _u +} + +// SetNillableRequestID sets the "request_id" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableRequestID(v *string) *UsageLogUpdate { + if v != nil { + _u.SetRequestID(*v) + } + return _u +} + +// SetModel sets the "model" field. +func (_u *UsageLogUpdate) SetModel(v string) *UsageLogUpdate { + _u.mutation.SetModel(v) + return _u +} + +// SetNillableModel sets the "model" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate { + if v != nil { + _u.SetModel(*v) + } + return _u +} + +// SetGroupID sets the "group_id" field. +func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate { + _u.mutation.SetGroupID(v) + return _u +} + +// SetNillableGroupID sets the "group_id" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableGroupID(v *int64) *UsageLogUpdate { + if v != nil { + _u.SetGroupID(*v) + } + return _u +} + +// ClearGroupID clears the value of the "group_id" field. +func (_u *UsageLogUpdate) ClearGroupID() *UsageLogUpdate { + _u.mutation.ClearGroupID() + return _u +} + +// SetSubscriptionID sets the "subscription_id" field. +func (_u *UsageLogUpdate) SetSubscriptionID(v int64) *UsageLogUpdate { + _u.mutation.SetSubscriptionID(v) + return _u +} + +// SetNillableSubscriptionID sets the "subscription_id" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableSubscriptionID(v *int64) *UsageLogUpdate { + if v != nil { + _u.SetSubscriptionID(*v) + } + return _u +} + +// ClearSubscriptionID clears the value of the "subscription_id" field. +func (_u *UsageLogUpdate) ClearSubscriptionID() *UsageLogUpdate { + _u.mutation.ClearSubscriptionID() + return _u +} + +// SetInputTokens sets the "input_tokens" field. +func (_u *UsageLogUpdate) SetInputTokens(v int) *UsageLogUpdate { + _u.mutation.ResetInputTokens() + _u.mutation.SetInputTokens(v) + return _u +} + +// SetNillableInputTokens sets the "input_tokens" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableInputTokens(v *int) *UsageLogUpdate { + if v != nil { + _u.SetInputTokens(*v) + } + return _u +} + +// AddInputTokens adds value to the "input_tokens" field. +func (_u *UsageLogUpdate) AddInputTokens(v int) *UsageLogUpdate { + _u.mutation.AddInputTokens(v) + return _u +} + +// SetOutputTokens sets the "output_tokens" field. +func (_u *UsageLogUpdate) SetOutputTokens(v int) *UsageLogUpdate { + _u.mutation.ResetOutputTokens() + _u.mutation.SetOutputTokens(v) + return _u +} + +// SetNillableOutputTokens sets the "output_tokens" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableOutputTokens(v *int) *UsageLogUpdate { + if v != nil { + _u.SetOutputTokens(*v) + } + return _u +} + +// AddOutputTokens adds value to the "output_tokens" field. +func (_u *UsageLogUpdate) AddOutputTokens(v int) *UsageLogUpdate { + _u.mutation.AddOutputTokens(v) + return _u +} + +// SetCacheCreationTokens sets the "cache_creation_tokens" field. +func (_u *UsageLogUpdate) SetCacheCreationTokens(v int) *UsageLogUpdate { + _u.mutation.ResetCacheCreationTokens() + _u.mutation.SetCacheCreationTokens(v) + return _u +} + +// SetNillableCacheCreationTokens sets the "cache_creation_tokens" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableCacheCreationTokens(v *int) *UsageLogUpdate { + if v != nil { + _u.SetCacheCreationTokens(*v) + } + return _u +} + +// AddCacheCreationTokens adds value to the "cache_creation_tokens" field. +func (_u *UsageLogUpdate) AddCacheCreationTokens(v int) *UsageLogUpdate { + _u.mutation.AddCacheCreationTokens(v) + return _u +} + +// SetCacheReadTokens sets the "cache_read_tokens" field. +func (_u *UsageLogUpdate) SetCacheReadTokens(v int) *UsageLogUpdate { + _u.mutation.ResetCacheReadTokens() + _u.mutation.SetCacheReadTokens(v) + return _u +} + +// SetNillableCacheReadTokens sets the "cache_read_tokens" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableCacheReadTokens(v *int) *UsageLogUpdate { + if v != nil { + _u.SetCacheReadTokens(*v) + } + return _u +} + +// AddCacheReadTokens adds value to the "cache_read_tokens" field. +func (_u *UsageLogUpdate) AddCacheReadTokens(v int) *UsageLogUpdate { + _u.mutation.AddCacheReadTokens(v) + return _u +} + +// SetCacheCreation5mTokens sets the "cache_creation_5m_tokens" field. +func (_u *UsageLogUpdate) SetCacheCreation5mTokens(v int) *UsageLogUpdate { + _u.mutation.ResetCacheCreation5mTokens() + _u.mutation.SetCacheCreation5mTokens(v) + return _u +} + +// SetNillableCacheCreation5mTokens sets the "cache_creation_5m_tokens" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableCacheCreation5mTokens(v *int) *UsageLogUpdate { + if v != nil { + _u.SetCacheCreation5mTokens(*v) + } + return _u +} + +// AddCacheCreation5mTokens adds value to the "cache_creation_5m_tokens" field. +func (_u *UsageLogUpdate) AddCacheCreation5mTokens(v int) *UsageLogUpdate { + _u.mutation.AddCacheCreation5mTokens(v) + return _u +} + +// SetCacheCreation1hTokens sets the "cache_creation_1h_tokens" field. +func (_u *UsageLogUpdate) SetCacheCreation1hTokens(v int) *UsageLogUpdate { + _u.mutation.ResetCacheCreation1hTokens() + _u.mutation.SetCacheCreation1hTokens(v) + return _u +} + +// SetNillableCacheCreation1hTokens sets the "cache_creation_1h_tokens" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableCacheCreation1hTokens(v *int) *UsageLogUpdate { + if v != nil { + _u.SetCacheCreation1hTokens(*v) + } + return _u +} + +// AddCacheCreation1hTokens adds value to the "cache_creation_1h_tokens" field. +func (_u *UsageLogUpdate) AddCacheCreation1hTokens(v int) *UsageLogUpdate { + _u.mutation.AddCacheCreation1hTokens(v) + return _u +} + +// SetInputCost sets the "input_cost" field. +func (_u *UsageLogUpdate) SetInputCost(v float64) *UsageLogUpdate { + _u.mutation.ResetInputCost() + _u.mutation.SetInputCost(v) + return _u +} + +// SetNillableInputCost sets the "input_cost" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableInputCost(v *float64) *UsageLogUpdate { + if v != nil { + _u.SetInputCost(*v) + } + return _u +} + +// AddInputCost adds value to the "input_cost" field. +func (_u *UsageLogUpdate) AddInputCost(v float64) *UsageLogUpdate { + _u.mutation.AddInputCost(v) + return _u +} + +// SetOutputCost sets the "output_cost" field. +func (_u *UsageLogUpdate) SetOutputCost(v float64) *UsageLogUpdate { + _u.mutation.ResetOutputCost() + _u.mutation.SetOutputCost(v) + return _u +} + +// SetNillableOutputCost sets the "output_cost" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableOutputCost(v *float64) *UsageLogUpdate { + if v != nil { + _u.SetOutputCost(*v) + } + return _u +} + +// AddOutputCost adds value to the "output_cost" field. +func (_u *UsageLogUpdate) AddOutputCost(v float64) *UsageLogUpdate { + _u.mutation.AddOutputCost(v) + return _u +} + +// SetCacheCreationCost sets the "cache_creation_cost" field. +func (_u *UsageLogUpdate) SetCacheCreationCost(v float64) *UsageLogUpdate { + _u.mutation.ResetCacheCreationCost() + _u.mutation.SetCacheCreationCost(v) + return _u +} + +// SetNillableCacheCreationCost sets the "cache_creation_cost" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableCacheCreationCost(v *float64) *UsageLogUpdate { + if v != nil { + _u.SetCacheCreationCost(*v) + } + return _u +} + +// AddCacheCreationCost adds value to the "cache_creation_cost" field. +func (_u *UsageLogUpdate) AddCacheCreationCost(v float64) *UsageLogUpdate { + _u.mutation.AddCacheCreationCost(v) + return _u +} + +// SetCacheReadCost sets the "cache_read_cost" field. +func (_u *UsageLogUpdate) SetCacheReadCost(v float64) *UsageLogUpdate { + _u.mutation.ResetCacheReadCost() + _u.mutation.SetCacheReadCost(v) + return _u +} + +// SetNillableCacheReadCost sets the "cache_read_cost" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableCacheReadCost(v *float64) *UsageLogUpdate { + if v != nil { + _u.SetCacheReadCost(*v) + } + return _u +} + +// AddCacheReadCost adds value to the "cache_read_cost" field. +func (_u *UsageLogUpdate) AddCacheReadCost(v float64) *UsageLogUpdate { + _u.mutation.AddCacheReadCost(v) + return _u +} + +// SetTotalCost sets the "total_cost" field. +func (_u *UsageLogUpdate) SetTotalCost(v float64) *UsageLogUpdate { + _u.mutation.ResetTotalCost() + _u.mutation.SetTotalCost(v) + return _u +} + +// SetNillableTotalCost sets the "total_cost" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableTotalCost(v *float64) *UsageLogUpdate { + if v != nil { + _u.SetTotalCost(*v) + } + return _u +} + +// AddTotalCost adds value to the "total_cost" field. +func (_u *UsageLogUpdate) AddTotalCost(v float64) *UsageLogUpdate { + _u.mutation.AddTotalCost(v) + return _u +} + +// SetActualCost sets the "actual_cost" field. +func (_u *UsageLogUpdate) SetActualCost(v float64) *UsageLogUpdate { + _u.mutation.ResetActualCost() + _u.mutation.SetActualCost(v) + return _u +} + +// SetNillableActualCost sets the "actual_cost" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableActualCost(v *float64) *UsageLogUpdate { + if v != nil { + _u.SetActualCost(*v) + } + return _u +} + +// AddActualCost adds value to the "actual_cost" field. +func (_u *UsageLogUpdate) AddActualCost(v float64) *UsageLogUpdate { + _u.mutation.AddActualCost(v) + return _u +} + +// SetRateMultiplier sets the "rate_multiplier" field. +func (_u *UsageLogUpdate) SetRateMultiplier(v float64) *UsageLogUpdate { + _u.mutation.ResetRateMultiplier() + _u.mutation.SetRateMultiplier(v) + return _u +} + +// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableRateMultiplier(v *float64) *UsageLogUpdate { + if v != nil { + _u.SetRateMultiplier(*v) + } + return _u +} + +// AddRateMultiplier adds value to the "rate_multiplier" field. +func (_u *UsageLogUpdate) AddRateMultiplier(v float64) *UsageLogUpdate { + _u.mutation.AddRateMultiplier(v) + return _u +} + +// SetBillingType sets the "billing_type" field. +func (_u *UsageLogUpdate) SetBillingType(v int8) *UsageLogUpdate { + _u.mutation.ResetBillingType() + _u.mutation.SetBillingType(v) + return _u +} + +// SetNillableBillingType sets the "billing_type" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableBillingType(v *int8) *UsageLogUpdate { + if v != nil { + _u.SetBillingType(*v) + } + return _u +} + +// AddBillingType adds value to the "billing_type" field. +func (_u *UsageLogUpdate) AddBillingType(v int8) *UsageLogUpdate { + _u.mutation.AddBillingType(v) + return _u +} + +// SetStream sets the "stream" field. +func (_u *UsageLogUpdate) SetStream(v bool) *UsageLogUpdate { + _u.mutation.SetStream(v) + return _u +} + +// SetNillableStream sets the "stream" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableStream(v *bool) *UsageLogUpdate { + if v != nil { + _u.SetStream(*v) + } + return _u +} + +// SetDurationMs sets the "duration_ms" field. +func (_u *UsageLogUpdate) SetDurationMs(v int) *UsageLogUpdate { + _u.mutation.ResetDurationMs() + _u.mutation.SetDurationMs(v) + return _u +} + +// SetNillableDurationMs sets the "duration_ms" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableDurationMs(v *int) *UsageLogUpdate { + if v != nil { + _u.SetDurationMs(*v) + } + return _u +} + +// AddDurationMs adds value to the "duration_ms" field. +func (_u *UsageLogUpdate) AddDurationMs(v int) *UsageLogUpdate { + _u.mutation.AddDurationMs(v) + return _u +} + +// ClearDurationMs clears the value of the "duration_ms" field. +func (_u *UsageLogUpdate) ClearDurationMs() *UsageLogUpdate { + _u.mutation.ClearDurationMs() + return _u +} + +// SetFirstTokenMs sets the "first_token_ms" field. +func (_u *UsageLogUpdate) SetFirstTokenMs(v int) *UsageLogUpdate { + _u.mutation.ResetFirstTokenMs() + _u.mutation.SetFirstTokenMs(v) + return _u +} + +// SetNillableFirstTokenMs sets the "first_token_ms" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableFirstTokenMs(v *int) *UsageLogUpdate { + if v != nil { + _u.SetFirstTokenMs(*v) + } + return _u +} + +// AddFirstTokenMs adds value to the "first_token_ms" field. +func (_u *UsageLogUpdate) AddFirstTokenMs(v int) *UsageLogUpdate { + _u.mutation.AddFirstTokenMs(v) + return _u +} + +// ClearFirstTokenMs clears the value of the "first_token_ms" field. +func (_u *UsageLogUpdate) ClearFirstTokenMs() *UsageLogUpdate { + _u.mutation.ClearFirstTokenMs() + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate { + return _u.SetUserID(v.ID) +} + +// SetAPIKey sets the "api_key" edge to the ApiKey entity. +func (_u *UsageLogUpdate) SetAPIKey(v *ApiKey) *UsageLogUpdate { + return _u.SetAPIKeyID(v.ID) +} + +// SetAccount sets the "account" edge to the Account entity. +func (_u *UsageLogUpdate) SetAccount(v *Account) *UsageLogUpdate { + return _u.SetAccountID(v.ID) +} + +// SetGroup sets the "group" edge to the Group entity. +func (_u *UsageLogUpdate) SetGroup(v *Group) *UsageLogUpdate { + return _u.SetGroupID(v.ID) +} + +// SetSubscription sets the "subscription" edge to the UserSubscription entity. +func (_u *UsageLogUpdate) SetSubscription(v *UserSubscription) *UsageLogUpdate { + return _u.SetSubscriptionID(v.ID) +} + +// Mutation returns the UsageLogMutation object of the builder. +func (_u *UsageLogUpdate) Mutation() *UsageLogMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *UsageLogUpdate) ClearUser() *UsageLogUpdate { + _u.mutation.ClearUser() + return _u +} + +// ClearAPIKey clears the "api_key" edge to the ApiKey entity. +func (_u *UsageLogUpdate) ClearAPIKey() *UsageLogUpdate { + _u.mutation.ClearAPIKey() + return _u +} + +// ClearAccount clears the "account" edge to the Account entity. +func (_u *UsageLogUpdate) ClearAccount() *UsageLogUpdate { + _u.mutation.ClearAccount() + return _u +} + +// ClearGroup clears the "group" edge to the Group entity. +func (_u *UsageLogUpdate) ClearGroup() *UsageLogUpdate { + _u.mutation.ClearGroup() + return _u +} + +// ClearSubscription clears the "subscription" edge to the UserSubscription entity. +func (_u *UsageLogUpdate) ClearSubscription() *UsageLogUpdate { + _u.mutation.ClearSubscription() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *UsageLogUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *UsageLogUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *UsageLogUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *UsageLogUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *UsageLogUpdate) check() error { + if v, ok := _u.mutation.RequestID(); ok { + if err := usagelog.RequestIDValidator(v); err != nil { + return &ValidationError{Name: "request_id", err: fmt.Errorf(`ent: validator failed for field "UsageLog.request_id": %w`, err)} + } + } + if v, ok := _u.mutation.Model(); ok { + if err := usagelog.ModelValidator(v); err != nil { + return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) + } + if _u.mutation.APIKeyCleared() && len(_u.mutation.APIKeyIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "UsageLog.api_key"`) + } + if _u.mutation.AccountCleared() && len(_u.mutation.AccountIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "UsageLog.account"`) + } + return nil +} + +func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(usagelog.Table, usagelog.Columns, sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.RequestID(); ok { + _spec.SetField(usagelog.FieldRequestID, field.TypeString, value) + } + if value, ok := _u.mutation.Model(); ok { + _spec.SetField(usagelog.FieldModel, field.TypeString, value) + } + if value, ok := _u.mutation.InputTokens(); ok { + _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedInputTokens(); ok { + _spec.AddField(usagelog.FieldInputTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.OutputTokens(); ok { + _spec.SetField(usagelog.FieldOutputTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedOutputTokens(); ok { + _spec.AddField(usagelog.FieldOutputTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.CacheCreationTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreationTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCacheCreationTokens(); ok { + _spec.AddField(usagelog.FieldCacheCreationTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.CacheReadTokens(); ok { + _spec.SetField(usagelog.FieldCacheReadTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCacheReadTokens(); ok { + _spec.AddField(usagelog.FieldCacheReadTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.CacheCreation5mTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreation5mTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCacheCreation5mTokens(); ok { + _spec.AddField(usagelog.FieldCacheCreation5mTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.CacheCreation1hTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreation1hTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCacheCreation1hTokens(); ok { + _spec.AddField(usagelog.FieldCacheCreation1hTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.InputCost(); ok { + _spec.SetField(usagelog.FieldInputCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedInputCost(); ok { + _spec.AddField(usagelog.FieldInputCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.OutputCost(); ok { + _spec.SetField(usagelog.FieldOutputCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedOutputCost(); ok { + _spec.AddField(usagelog.FieldOutputCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.CacheCreationCost(); ok { + _spec.SetField(usagelog.FieldCacheCreationCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedCacheCreationCost(); ok { + _spec.AddField(usagelog.FieldCacheCreationCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.CacheReadCost(); ok { + _spec.SetField(usagelog.FieldCacheReadCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedCacheReadCost(); ok { + _spec.AddField(usagelog.FieldCacheReadCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.TotalCost(); ok { + _spec.SetField(usagelog.FieldTotalCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedTotalCost(); ok { + _spec.AddField(usagelog.FieldTotalCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.ActualCost(); ok { + _spec.SetField(usagelog.FieldActualCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedActualCost(); ok { + _spec.AddField(usagelog.FieldActualCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.RateMultiplier(); ok { + _spec.SetField(usagelog.FieldRateMultiplier, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateMultiplier(); ok { + _spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value) + } + if value, ok := _u.mutation.BillingType(); ok { + _spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value) + } + if value, ok := _u.mutation.AddedBillingType(); ok { + _spec.AddField(usagelog.FieldBillingType, field.TypeInt8, value) + } + if value, ok := _u.mutation.Stream(); ok { + _spec.SetField(usagelog.FieldStream, field.TypeBool, value) + } + if value, ok := _u.mutation.DurationMs(); ok { + _spec.SetField(usagelog.FieldDurationMs, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedDurationMs(); ok { + _spec.AddField(usagelog.FieldDurationMs, field.TypeInt, value) + } + if _u.mutation.DurationMsCleared() { + _spec.ClearField(usagelog.FieldDurationMs, field.TypeInt) + } + if value, ok := _u.mutation.FirstTokenMs(); ok { + _spec.SetField(usagelog.FieldFirstTokenMs, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedFirstTokenMs(); ok { + _spec.AddField(usagelog.FieldFirstTokenMs, field.TypeInt, value) + } + if _u.mutation.FirstTokenMsCleared() { + _spec.ClearField(usagelog.FieldFirstTokenMs, field.TypeInt) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.UserTable, + Columns: []string{usagelog.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.UserTable, + Columns: []string{usagelog.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.APIKeyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.APIKeyTable, + Columns: []string{usagelog.APIKeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(apikey.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.APIKeyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.APIKeyTable, + Columns: []string{usagelog.APIKeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(apikey.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AccountCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.AccountTable, + Columns: []string{usagelog.AccountColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AccountIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.AccountTable, + Columns: []string{usagelog.AccountColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.GroupCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.GroupTable, + Columns: []string{usagelog.GroupColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.GroupIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.GroupTable, + Columns: []string{usagelog.GroupColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.SubscriptionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.SubscriptionTable, + Columns: []string{usagelog.SubscriptionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.SubscriptionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.SubscriptionTable, + Columns: []string{usagelog.SubscriptionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{usagelog.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// UsageLogUpdateOne is the builder for updating a single UsageLog entity. +type UsageLogUpdateOne struct { + config + fields []string + hooks []Hook + mutation *UsageLogMutation +} + +// SetUserID sets the "user_id" field. +func (_u *UsageLogUpdateOne) SetUserID(v int64) *UsageLogUpdateOne { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableUserID(v *int64) *UsageLogUpdateOne { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetAPIKeyID sets the "api_key_id" field. +func (_u *UsageLogUpdateOne) SetAPIKeyID(v int64) *UsageLogUpdateOne { + _u.mutation.SetAPIKeyID(v) + return _u +} + +// SetNillableAPIKeyID sets the "api_key_id" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableAPIKeyID(v *int64) *UsageLogUpdateOne { + if v != nil { + _u.SetAPIKeyID(*v) + } + return _u +} + +// SetAccountID sets the "account_id" field. +func (_u *UsageLogUpdateOne) SetAccountID(v int64) *UsageLogUpdateOne { + _u.mutation.SetAccountID(v) + return _u +} + +// SetNillableAccountID sets the "account_id" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableAccountID(v *int64) *UsageLogUpdateOne { + if v != nil { + _u.SetAccountID(*v) + } + return _u +} + +// SetRequestID sets the "request_id" field. +func (_u *UsageLogUpdateOne) SetRequestID(v string) *UsageLogUpdateOne { + _u.mutation.SetRequestID(v) + return _u +} + +// SetNillableRequestID sets the "request_id" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableRequestID(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetRequestID(*v) + } + return _u +} + +// SetModel sets the "model" field. +func (_u *UsageLogUpdateOne) SetModel(v string) *UsageLogUpdateOne { + _u.mutation.SetModel(v) + return _u +} + +// SetNillableModel sets the "model" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetModel(*v) + } + return _u +} + +// SetGroupID sets the "group_id" field. +func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne { + _u.mutation.SetGroupID(v) + return _u +} + +// SetNillableGroupID sets the "group_id" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableGroupID(v *int64) *UsageLogUpdateOne { + if v != nil { + _u.SetGroupID(*v) + } + return _u +} + +// ClearGroupID clears the value of the "group_id" field. +func (_u *UsageLogUpdateOne) ClearGroupID() *UsageLogUpdateOne { + _u.mutation.ClearGroupID() + return _u +} + +// SetSubscriptionID sets the "subscription_id" field. +func (_u *UsageLogUpdateOne) SetSubscriptionID(v int64) *UsageLogUpdateOne { + _u.mutation.SetSubscriptionID(v) + return _u +} + +// SetNillableSubscriptionID sets the "subscription_id" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableSubscriptionID(v *int64) *UsageLogUpdateOne { + if v != nil { + _u.SetSubscriptionID(*v) + } + return _u +} + +// ClearSubscriptionID clears the value of the "subscription_id" field. +func (_u *UsageLogUpdateOne) ClearSubscriptionID() *UsageLogUpdateOne { + _u.mutation.ClearSubscriptionID() + return _u +} + +// SetInputTokens sets the "input_tokens" field. +func (_u *UsageLogUpdateOne) SetInputTokens(v int) *UsageLogUpdateOne { + _u.mutation.ResetInputTokens() + _u.mutation.SetInputTokens(v) + return _u +} + +// SetNillableInputTokens sets the "input_tokens" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableInputTokens(v *int) *UsageLogUpdateOne { + if v != nil { + _u.SetInputTokens(*v) + } + return _u +} + +// AddInputTokens adds value to the "input_tokens" field. +func (_u *UsageLogUpdateOne) AddInputTokens(v int) *UsageLogUpdateOne { + _u.mutation.AddInputTokens(v) + return _u +} + +// SetOutputTokens sets the "output_tokens" field. +func (_u *UsageLogUpdateOne) SetOutputTokens(v int) *UsageLogUpdateOne { + _u.mutation.ResetOutputTokens() + _u.mutation.SetOutputTokens(v) + return _u +} + +// SetNillableOutputTokens sets the "output_tokens" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableOutputTokens(v *int) *UsageLogUpdateOne { + if v != nil { + _u.SetOutputTokens(*v) + } + return _u +} + +// AddOutputTokens adds value to the "output_tokens" field. +func (_u *UsageLogUpdateOne) AddOutputTokens(v int) *UsageLogUpdateOne { + _u.mutation.AddOutputTokens(v) + return _u +} + +// SetCacheCreationTokens sets the "cache_creation_tokens" field. +func (_u *UsageLogUpdateOne) SetCacheCreationTokens(v int) *UsageLogUpdateOne { + _u.mutation.ResetCacheCreationTokens() + _u.mutation.SetCacheCreationTokens(v) + return _u +} + +// SetNillableCacheCreationTokens sets the "cache_creation_tokens" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableCacheCreationTokens(v *int) *UsageLogUpdateOne { + if v != nil { + _u.SetCacheCreationTokens(*v) + } + return _u +} + +// AddCacheCreationTokens adds value to the "cache_creation_tokens" field. +func (_u *UsageLogUpdateOne) AddCacheCreationTokens(v int) *UsageLogUpdateOne { + _u.mutation.AddCacheCreationTokens(v) + return _u +} + +// SetCacheReadTokens sets the "cache_read_tokens" field. +func (_u *UsageLogUpdateOne) SetCacheReadTokens(v int) *UsageLogUpdateOne { + _u.mutation.ResetCacheReadTokens() + _u.mutation.SetCacheReadTokens(v) + return _u +} + +// SetNillableCacheReadTokens sets the "cache_read_tokens" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableCacheReadTokens(v *int) *UsageLogUpdateOne { + if v != nil { + _u.SetCacheReadTokens(*v) + } + return _u +} + +// AddCacheReadTokens adds value to the "cache_read_tokens" field. +func (_u *UsageLogUpdateOne) AddCacheReadTokens(v int) *UsageLogUpdateOne { + _u.mutation.AddCacheReadTokens(v) + return _u +} + +// SetCacheCreation5mTokens sets the "cache_creation_5m_tokens" field. +func (_u *UsageLogUpdateOne) SetCacheCreation5mTokens(v int) *UsageLogUpdateOne { + _u.mutation.ResetCacheCreation5mTokens() + _u.mutation.SetCacheCreation5mTokens(v) + return _u +} + +// SetNillableCacheCreation5mTokens sets the "cache_creation_5m_tokens" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableCacheCreation5mTokens(v *int) *UsageLogUpdateOne { + if v != nil { + _u.SetCacheCreation5mTokens(*v) + } + return _u +} + +// AddCacheCreation5mTokens adds value to the "cache_creation_5m_tokens" field. +func (_u *UsageLogUpdateOne) AddCacheCreation5mTokens(v int) *UsageLogUpdateOne { + _u.mutation.AddCacheCreation5mTokens(v) + return _u +} + +// SetCacheCreation1hTokens sets the "cache_creation_1h_tokens" field. +func (_u *UsageLogUpdateOne) SetCacheCreation1hTokens(v int) *UsageLogUpdateOne { + _u.mutation.ResetCacheCreation1hTokens() + _u.mutation.SetCacheCreation1hTokens(v) + return _u +} + +// SetNillableCacheCreation1hTokens sets the "cache_creation_1h_tokens" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableCacheCreation1hTokens(v *int) *UsageLogUpdateOne { + if v != nil { + _u.SetCacheCreation1hTokens(*v) + } + return _u +} + +// AddCacheCreation1hTokens adds value to the "cache_creation_1h_tokens" field. +func (_u *UsageLogUpdateOne) AddCacheCreation1hTokens(v int) *UsageLogUpdateOne { + _u.mutation.AddCacheCreation1hTokens(v) + return _u +} + +// SetInputCost sets the "input_cost" field. +func (_u *UsageLogUpdateOne) SetInputCost(v float64) *UsageLogUpdateOne { + _u.mutation.ResetInputCost() + _u.mutation.SetInputCost(v) + return _u +} + +// SetNillableInputCost sets the "input_cost" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableInputCost(v *float64) *UsageLogUpdateOne { + if v != nil { + _u.SetInputCost(*v) + } + return _u +} + +// AddInputCost adds value to the "input_cost" field. +func (_u *UsageLogUpdateOne) AddInputCost(v float64) *UsageLogUpdateOne { + _u.mutation.AddInputCost(v) + return _u +} + +// SetOutputCost sets the "output_cost" field. +func (_u *UsageLogUpdateOne) SetOutputCost(v float64) *UsageLogUpdateOne { + _u.mutation.ResetOutputCost() + _u.mutation.SetOutputCost(v) + return _u +} + +// SetNillableOutputCost sets the "output_cost" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableOutputCost(v *float64) *UsageLogUpdateOne { + if v != nil { + _u.SetOutputCost(*v) + } + return _u +} + +// AddOutputCost adds value to the "output_cost" field. +func (_u *UsageLogUpdateOne) AddOutputCost(v float64) *UsageLogUpdateOne { + _u.mutation.AddOutputCost(v) + return _u +} + +// SetCacheCreationCost sets the "cache_creation_cost" field. +func (_u *UsageLogUpdateOne) SetCacheCreationCost(v float64) *UsageLogUpdateOne { + _u.mutation.ResetCacheCreationCost() + _u.mutation.SetCacheCreationCost(v) + return _u +} + +// SetNillableCacheCreationCost sets the "cache_creation_cost" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableCacheCreationCost(v *float64) *UsageLogUpdateOne { + if v != nil { + _u.SetCacheCreationCost(*v) + } + return _u +} + +// AddCacheCreationCost adds value to the "cache_creation_cost" field. +func (_u *UsageLogUpdateOne) AddCacheCreationCost(v float64) *UsageLogUpdateOne { + _u.mutation.AddCacheCreationCost(v) + return _u +} + +// SetCacheReadCost sets the "cache_read_cost" field. +func (_u *UsageLogUpdateOne) SetCacheReadCost(v float64) *UsageLogUpdateOne { + _u.mutation.ResetCacheReadCost() + _u.mutation.SetCacheReadCost(v) + return _u +} + +// SetNillableCacheReadCost sets the "cache_read_cost" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableCacheReadCost(v *float64) *UsageLogUpdateOne { + if v != nil { + _u.SetCacheReadCost(*v) + } + return _u +} + +// AddCacheReadCost adds value to the "cache_read_cost" field. +func (_u *UsageLogUpdateOne) AddCacheReadCost(v float64) *UsageLogUpdateOne { + _u.mutation.AddCacheReadCost(v) + return _u +} + +// SetTotalCost sets the "total_cost" field. +func (_u *UsageLogUpdateOne) SetTotalCost(v float64) *UsageLogUpdateOne { + _u.mutation.ResetTotalCost() + _u.mutation.SetTotalCost(v) + return _u +} + +// SetNillableTotalCost sets the "total_cost" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableTotalCost(v *float64) *UsageLogUpdateOne { + if v != nil { + _u.SetTotalCost(*v) + } + return _u +} + +// AddTotalCost adds value to the "total_cost" field. +func (_u *UsageLogUpdateOne) AddTotalCost(v float64) *UsageLogUpdateOne { + _u.mutation.AddTotalCost(v) + return _u +} + +// SetActualCost sets the "actual_cost" field. +func (_u *UsageLogUpdateOne) SetActualCost(v float64) *UsageLogUpdateOne { + _u.mutation.ResetActualCost() + _u.mutation.SetActualCost(v) + return _u +} + +// SetNillableActualCost sets the "actual_cost" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableActualCost(v *float64) *UsageLogUpdateOne { + if v != nil { + _u.SetActualCost(*v) + } + return _u +} + +// AddActualCost adds value to the "actual_cost" field. +func (_u *UsageLogUpdateOne) AddActualCost(v float64) *UsageLogUpdateOne { + _u.mutation.AddActualCost(v) + return _u +} + +// SetRateMultiplier sets the "rate_multiplier" field. +func (_u *UsageLogUpdateOne) SetRateMultiplier(v float64) *UsageLogUpdateOne { + _u.mutation.ResetRateMultiplier() + _u.mutation.SetRateMultiplier(v) + return _u +} + +// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableRateMultiplier(v *float64) *UsageLogUpdateOne { + if v != nil { + _u.SetRateMultiplier(*v) + } + return _u +} + +// AddRateMultiplier adds value to the "rate_multiplier" field. +func (_u *UsageLogUpdateOne) AddRateMultiplier(v float64) *UsageLogUpdateOne { + _u.mutation.AddRateMultiplier(v) + return _u +} + +// SetBillingType sets the "billing_type" field. +func (_u *UsageLogUpdateOne) SetBillingType(v int8) *UsageLogUpdateOne { + _u.mutation.ResetBillingType() + _u.mutation.SetBillingType(v) + return _u +} + +// SetNillableBillingType sets the "billing_type" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableBillingType(v *int8) *UsageLogUpdateOne { + if v != nil { + _u.SetBillingType(*v) + } + return _u +} + +// AddBillingType adds value to the "billing_type" field. +func (_u *UsageLogUpdateOne) AddBillingType(v int8) *UsageLogUpdateOne { + _u.mutation.AddBillingType(v) + return _u +} + +// SetStream sets the "stream" field. +func (_u *UsageLogUpdateOne) SetStream(v bool) *UsageLogUpdateOne { + _u.mutation.SetStream(v) + return _u +} + +// SetNillableStream sets the "stream" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableStream(v *bool) *UsageLogUpdateOne { + if v != nil { + _u.SetStream(*v) + } + return _u +} + +// SetDurationMs sets the "duration_ms" field. +func (_u *UsageLogUpdateOne) SetDurationMs(v int) *UsageLogUpdateOne { + _u.mutation.ResetDurationMs() + _u.mutation.SetDurationMs(v) + return _u +} + +// SetNillableDurationMs sets the "duration_ms" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableDurationMs(v *int) *UsageLogUpdateOne { + if v != nil { + _u.SetDurationMs(*v) + } + return _u +} + +// AddDurationMs adds value to the "duration_ms" field. +func (_u *UsageLogUpdateOne) AddDurationMs(v int) *UsageLogUpdateOne { + _u.mutation.AddDurationMs(v) + return _u +} + +// ClearDurationMs clears the value of the "duration_ms" field. +func (_u *UsageLogUpdateOne) ClearDurationMs() *UsageLogUpdateOne { + _u.mutation.ClearDurationMs() + return _u +} + +// SetFirstTokenMs sets the "first_token_ms" field. +func (_u *UsageLogUpdateOne) SetFirstTokenMs(v int) *UsageLogUpdateOne { + _u.mutation.ResetFirstTokenMs() + _u.mutation.SetFirstTokenMs(v) + return _u +} + +// SetNillableFirstTokenMs sets the "first_token_ms" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableFirstTokenMs(v *int) *UsageLogUpdateOne { + if v != nil { + _u.SetFirstTokenMs(*v) + } + return _u +} + +// AddFirstTokenMs adds value to the "first_token_ms" field. +func (_u *UsageLogUpdateOne) AddFirstTokenMs(v int) *UsageLogUpdateOne { + _u.mutation.AddFirstTokenMs(v) + return _u +} + +// ClearFirstTokenMs clears the value of the "first_token_ms" field. +func (_u *UsageLogUpdateOne) ClearFirstTokenMs() *UsageLogUpdateOne { + _u.mutation.ClearFirstTokenMs() + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne { + return _u.SetUserID(v.ID) +} + +// SetAPIKey sets the "api_key" edge to the ApiKey entity. +func (_u *UsageLogUpdateOne) SetAPIKey(v *ApiKey) *UsageLogUpdateOne { + return _u.SetAPIKeyID(v.ID) +} + +// SetAccount sets the "account" edge to the Account entity. +func (_u *UsageLogUpdateOne) SetAccount(v *Account) *UsageLogUpdateOne { + return _u.SetAccountID(v.ID) +} + +// SetGroup sets the "group" edge to the Group entity. +func (_u *UsageLogUpdateOne) SetGroup(v *Group) *UsageLogUpdateOne { + return _u.SetGroupID(v.ID) +} + +// SetSubscription sets the "subscription" edge to the UserSubscription entity. +func (_u *UsageLogUpdateOne) SetSubscription(v *UserSubscription) *UsageLogUpdateOne { + return _u.SetSubscriptionID(v.ID) +} + +// Mutation returns the UsageLogMutation object of the builder. +func (_u *UsageLogUpdateOne) Mutation() *UsageLogMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *UsageLogUpdateOne) ClearUser() *UsageLogUpdateOne { + _u.mutation.ClearUser() + return _u +} + +// ClearAPIKey clears the "api_key" edge to the ApiKey entity. +func (_u *UsageLogUpdateOne) ClearAPIKey() *UsageLogUpdateOne { + _u.mutation.ClearAPIKey() + return _u +} + +// ClearAccount clears the "account" edge to the Account entity. +func (_u *UsageLogUpdateOne) ClearAccount() *UsageLogUpdateOne { + _u.mutation.ClearAccount() + return _u +} + +// ClearGroup clears the "group" edge to the Group entity. +func (_u *UsageLogUpdateOne) ClearGroup() *UsageLogUpdateOne { + _u.mutation.ClearGroup() + return _u +} + +// ClearSubscription clears the "subscription" edge to the UserSubscription entity. +func (_u *UsageLogUpdateOne) ClearSubscription() *UsageLogUpdateOne { + _u.mutation.ClearSubscription() + return _u +} + +// Where appends a list predicates to the UsageLogUpdate builder. +func (_u *UsageLogUpdateOne) Where(ps ...predicate.UsageLog) *UsageLogUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *UsageLogUpdateOne) Select(field string, fields ...string) *UsageLogUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated UsageLog entity. +func (_u *UsageLogUpdateOne) Save(ctx context.Context) (*UsageLog, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *UsageLogUpdateOne) SaveX(ctx context.Context) *UsageLog { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *UsageLogUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *UsageLogUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *UsageLogUpdateOne) check() error { + if v, ok := _u.mutation.RequestID(); ok { + if err := usagelog.RequestIDValidator(v); err != nil { + return &ValidationError{Name: "request_id", err: fmt.Errorf(`ent: validator failed for field "UsageLog.request_id": %w`, err)} + } + } + if v, ok := _u.mutation.Model(); ok { + if err := usagelog.ModelValidator(v); err != nil { + return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) + } + if _u.mutation.APIKeyCleared() && len(_u.mutation.APIKeyIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "UsageLog.api_key"`) + } + if _u.mutation.AccountCleared() && len(_u.mutation.AccountIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "UsageLog.account"`) + } + return nil +} + +func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(usagelog.Table, usagelog.Columns, sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UsageLog.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, usagelog.FieldID) + for _, f := range fields { + if !usagelog.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != usagelog.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.RequestID(); ok { + _spec.SetField(usagelog.FieldRequestID, field.TypeString, value) + } + if value, ok := _u.mutation.Model(); ok { + _spec.SetField(usagelog.FieldModel, field.TypeString, value) + } + if value, ok := _u.mutation.InputTokens(); ok { + _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedInputTokens(); ok { + _spec.AddField(usagelog.FieldInputTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.OutputTokens(); ok { + _spec.SetField(usagelog.FieldOutputTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedOutputTokens(); ok { + _spec.AddField(usagelog.FieldOutputTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.CacheCreationTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreationTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCacheCreationTokens(); ok { + _spec.AddField(usagelog.FieldCacheCreationTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.CacheReadTokens(); ok { + _spec.SetField(usagelog.FieldCacheReadTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCacheReadTokens(); ok { + _spec.AddField(usagelog.FieldCacheReadTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.CacheCreation5mTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreation5mTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCacheCreation5mTokens(); ok { + _spec.AddField(usagelog.FieldCacheCreation5mTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.CacheCreation1hTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreation1hTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCacheCreation1hTokens(); ok { + _spec.AddField(usagelog.FieldCacheCreation1hTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.InputCost(); ok { + _spec.SetField(usagelog.FieldInputCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedInputCost(); ok { + _spec.AddField(usagelog.FieldInputCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.OutputCost(); ok { + _spec.SetField(usagelog.FieldOutputCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedOutputCost(); ok { + _spec.AddField(usagelog.FieldOutputCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.CacheCreationCost(); ok { + _spec.SetField(usagelog.FieldCacheCreationCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedCacheCreationCost(); ok { + _spec.AddField(usagelog.FieldCacheCreationCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.CacheReadCost(); ok { + _spec.SetField(usagelog.FieldCacheReadCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedCacheReadCost(); ok { + _spec.AddField(usagelog.FieldCacheReadCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.TotalCost(); ok { + _spec.SetField(usagelog.FieldTotalCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedTotalCost(); ok { + _spec.AddField(usagelog.FieldTotalCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.ActualCost(); ok { + _spec.SetField(usagelog.FieldActualCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedActualCost(); ok { + _spec.AddField(usagelog.FieldActualCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.RateMultiplier(); ok { + _spec.SetField(usagelog.FieldRateMultiplier, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateMultiplier(); ok { + _spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value) + } + if value, ok := _u.mutation.BillingType(); ok { + _spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value) + } + if value, ok := _u.mutation.AddedBillingType(); ok { + _spec.AddField(usagelog.FieldBillingType, field.TypeInt8, value) + } + if value, ok := _u.mutation.Stream(); ok { + _spec.SetField(usagelog.FieldStream, field.TypeBool, value) + } + if value, ok := _u.mutation.DurationMs(); ok { + _spec.SetField(usagelog.FieldDurationMs, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedDurationMs(); ok { + _spec.AddField(usagelog.FieldDurationMs, field.TypeInt, value) + } + if _u.mutation.DurationMsCleared() { + _spec.ClearField(usagelog.FieldDurationMs, field.TypeInt) + } + if value, ok := _u.mutation.FirstTokenMs(); ok { + _spec.SetField(usagelog.FieldFirstTokenMs, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedFirstTokenMs(); ok { + _spec.AddField(usagelog.FieldFirstTokenMs, field.TypeInt, value) + } + if _u.mutation.FirstTokenMsCleared() { + _spec.ClearField(usagelog.FieldFirstTokenMs, field.TypeInt) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.UserTable, + Columns: []string{usagelog.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.UserTable, + Columns: []string{usagelog.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.APIKeyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.APIKeyTable, + Columns: []string{usagelog.APIKeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(apikey.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.APIKeyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.APIKeyTable, + Columns: []string{usagelog.APIKeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(apikey.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AccountCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.AccountTable, + Columns: []string{usagelog.AccountColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AccountIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.AccountTable, + Columns: []string{usagelog.AccountColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.GroupCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.GroupTable, + Columns: []string{usagelog.GroupColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.GroupIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.GroupTable, + Columns: []string{usagelog.GroupColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.SubscriptionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.SubscriptionTable, + Columns: []string{usagelog.SubscriptionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.SubscriptionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.SubscriptionTable, + Columns: []string{usagelog.SubscriptionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &UsageLog{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{usagelog.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/user.go b/backend/ent/user.go index 1f06eb4e..eda67c84 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -59,11 +59,13 @@ type UserEdges struct { AssignedSubscriptions []*UserSubscription `json:"assigned_subscriptions,omitempty"` // AllowedGroups holds the value of the allowed_groups edge. AllowedGroups []*Group `json:"allowed_groups,omitempty"` + // UsageLogs holds the value of the usage_logs edge. + UsageLogs []*UsageLog `json:"usage_logs,omitempty"` // UserAllowedGroups holds the value of the user_allowed_groups edge. UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [6]bool + loadedTypes [7]bool } // APIKeysOrErr returns the APIKeys value or an error if the edge @@ -111,10 +113,19 @@ func (e UserEdges) AllowedGroupsOrErr() ([]*Group, error) { return nil, &NotLoadedError{edge: "allowed_groups"} } +// UsageLogsOrErr returns the UsageLogs value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) UsageLogsOrErr() ([]*UsageLog, error) { + if e.loadedTypes[5] { + return e.UsageLogs, nil + } + return nil, &NotLoadedError{edge: "usage_logs"} +} + // UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge // was not loaded in eager-loading. func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) { - if e.loadedTypes[5] { + if e.loadedTypes[6] { return e.UserAllowedGroups, nil } return nil, &NotLoadedError{edge: "user_allowed_groups"} @@ -265,6 +276,11 @@ func (_m *User) QueryAllowedGroups() *GroupQuery { return NewUserClient(_m.config).QueryAllowedGroups(_m) } +// QueryUsageLogs queries the "usage_logs" edge of the User entity. +func (_m *User) QueryUsageLogs() *UsageLogQuery { + return NewUserClient(_m.config).QueryUsageLogs(_m) +} + // QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity. func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery { return NewUserClient(_m.config).QueryUserAllowedGroups(_m) diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index e1e6988b..9ad87890 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -49,6 +49,8 @@ const ( EdgeAssignedSubscriptions = "assigned_subscriptions" // EdgeAllowedGroups holds the string denoting the allowed_groups edge name in mutations. EdgeAllowedGroups = "allowed_groups" + // EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations. + EdgeUsageLogs = "usage_logs" // EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations. EdgeUserAllowedGroups = "user_allowed_groups" // Table holds the table name of the user in the database. @@ -86,6 +88,13 @@ const ( // AllowedGroupsInverseTable is the table name for the Group entity. // It exists in this package in order to avoid circular dependency with the "group" package. AllowedGroupsInverseTable = "groups" + // UsageLogsTable is the table that holds the usage_logs relation/edge. + UsageLogsTable = "usage_logs" + // UsageLogsInverseTable is the table name for the UsageLog entity. + // It exists in this package in order to avoid circular dependency with the "usagelog" package. + UsageLogsInverseTable = "usage_logs" + // UsageLogsColumn is the table column denoting the usage_logs relation/edge. + UsageLogsColumn = "user_id" // UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge. UserAllowedGroupsTable = "user_allowed_groups" // UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity. @@ -308,6 +317,20 @@ func ByAllowedGroups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { } } +// ByUsageLogsCount orders the results by usage_logs count. +func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...) + } +} + +// ByUsageLogs orders the results by usage_logs terms. +func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + // ByUserAllowedGroupsCount orders the results by user_allowed_groups count. func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -356,6 +379,13 @@ func newAllowedGroupsStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.M2M, false, AllowedGroupsTable, AllowedGroupsPrimaryKey...), ) } +func newUsageLogsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UsageLogsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) +} func newUserAllowedGroupsStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index ad434c59..81959cf4 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -895,6 +895,29 @@ func HasAllowedGroupsWith(preds ...predicate.Group) predicate.User { }) } +// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge. +func HasUsageLogs() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates). +func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newUsageLogsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge. func HasUserAllowedGroups() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index 8c9caaa2..51bdc493 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -14,6 +14,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -253,6 +254,21 @@ func (_c *UserCreate) AddAllowedGroups(v ...*Group) *UserCreate { return _c.AddAllowedGroupIDs(ids...) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_c *UserCreate) AddUsageLogIDs(ids ...int64) *UserCreate { + _c.mutation.AddUsageLogIDs(ids...) + return _c +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_c *UserCreate) AddUsageLogs(v ...*UsageLog) *UserCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddUsageLogIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_c *UserCreate) Mutation() *UserMutation { return _c.mutation @@ -559,6 +575,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { edge.Target.Fields = specE.Fields _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.UsageLogsTable, + Columns: []string{user.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } diff --git a/backend/ent/user_query.go b/backend/ent/user_query.go index 21159a62..c172dda3 100644 --- a/backend/ent/user_query.go +++ b/backend/ent/user_query.go @@ -16,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -33,6 +34,7 @@ type UserQuery struct { withSubscriptions *UserSubscriptionQuery withAssignedSubscriptions *UserSubscriptionQuery withAllowedGroups *GroupQuery + withUsageLogs *UsageLogQuery withUserAllowedGroups *UserAllowedGroupQuery // intermediate query (i.e. traversal path). sql *sql.Selector @@ -180,6 +182,28 @@ func (_q *UserQuery) QueryAllowedGroups() *GroupQuery { return query } +// QueryUsageLogs chains the current query on the "usage_logs" edge. +func (_q *UserQuery) QueryUsageLogs() *UsageLogQuery { + query := (&UsageLogClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.UsageLogsTable, user.UsageLogsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge. func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: _q.config}).Query() @@ -399,6 +423,7 @@ func (_q *UserQuery) Clone() *UserQuery { withSubscriptions: _q.withSubscriptions.Clone(), withAssignedSubscriptions: _q.withAssignedSubscriptions.Clone(), withAllowedGroups: _q.withAllowedGroups.Clone(), + withUsageLogs: _q.withUsageLogs.Clone(), withUserAllowedGroups: _q.withUserAllowedGroups.Clone(), // clone intermediate query. sql: _q.sql.Clone(), @@ -461,6 +486,17 @@ func (_q *UserQuery) WithAllowedGroups(opts ...func(*GroupQuery)) *UserQuery { return _q } +// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to +// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *UserQuery { + query := (&UsageLogClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUsageLogs = query + return _q +} + // WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to // the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge. func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery { @@ -550,12 +586,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e var ( nodes = []*User{} _spec = _q.querySpec() - loadedTypes = [6]bool{ + loadedTypes = [7]bool{ _q.withAPIKeys != nil, _q.withRedeemCodes != nil, _q.withSubscriptions != nil, _q.withAssignedSubscriptions != nil, _q.withAllowedGroups != nil, + _q.withUsageLogs != nil, _q.withUserAllowedGroups != nil, } ) @@ -614,6 +651,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e return nil, err } } + if query := _q.withUsageLogs; query != nil { + if err := _q.loadUsageLogs(ctx, query, nodes, + func(n *User) { n.Edges.UsageLogs = []*UsageLog{} }, + func(n *User, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil { + return nil, err + } + } if query := _q.withUserAllowedGroups; query != nil { if err := _q.loadUserAllowedGroups(ctx, query, nodes, func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} }, @@ -811,6 +855,36 @@ func (_q *UserQuery) loadAllowedGroups(ctx context.Context, query *GroupQuery, n } return nil } +func (_q *UserQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*User, init func(*User), assign func(*User, *UsageLog)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(usagelog.FieldUserID) + } + query.Where(predicate.UsageLog(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.UsageLogsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error { fks := make([]driver.Value, 0, len(nodes)) nodeids := make(map[int64]*User) diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index a00f9b8a..31e57a43 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -15,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -273,6 +274,21 @@ func (_u *UserUpdate) AddAllowedGroups(v ...*Group) *UserUpdate { return _u.AddAllowedGroupIDs(ids...) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *UserUpdate) AddUsageLogIDs(ids ...int64) *UserUpdate { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *UserUpdate) AddUsageLogs(v ...*UsageLog) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdate) Mutation() *UserMutation { return _u.mutation @@ -383,6 +399,27 @@ func (_u *UserUpdate) RemoveAllowedGroups(v ...*Group) *UserUpdate { return _u.RemoveAllowedGroupIDs(ids...) } +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *UserUpdate) ClearUsageLogs() *UserUpdate { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *UserUpdate) RemoveUsageLogIDs(ids ...int64) *UserUpdate { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *UserUpdate) RemoveUsageLogs(v ...*UsageLog) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *UserUpdate) Save(ctx context.Context) (int, error) { if err := _u.defaults(); err != nil { @@ -751,6 +788,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { edge.Target.Fields = specE.Fields _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.UsageLogsTable, + Columns: []string{user.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.UsageLogsTable, + Columns: []string{user.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.UsageLogsTable, + Columns: []string{user.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{user.Label} @@ -1012,6 +1094,21 @@ func (_u *UserUpdateOne) AddAllowedGroups(v ...*Group) *UserUpdateOne { return _u.AddAllowedGroupIDs(ids...) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *UserUpdateOne) AddUsageLogIDs(ids ...int64) *UserUpdateOne { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *UserUpdateOne) AddUsageLogs(v ...*UsageLog) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdateOne) Mutation() *UserMutation { return _u.mutation @@ -1122,6 +1219,27 @@ func (_u *UserUpdateOne) RemoveAllowedGroups(v ...*Group) *UserUpdateOne { return _u.RemoveAllowedGroupIDs(ids...) } +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *UserUpdateOne) ClearUsageLogs() *UserUpdateOne { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *UserUpdateOne) RemoveUsageLogIDs(ids ...int64) *UserUpdateOne { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *UserUpdateOne) RemoveUsageLogs(v ...*UsageLog) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // Where appends a list predicates to the UserUpdate builder. func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne { _u.mutation.Where(ps...) @@ -1520,6 +1638,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { edge.Target.Fields = specE.Fields _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.UsageLogsTable, + Columns: []string{user.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.UsageLogsTable, + Columns: []string{user.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.UsageLogsTable, + Columns: []string{user.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &User{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/ent/usersubscription.go b/backend/ent/usersubscription.go index 3cfe9475..01beb2fc 100644 --- a/backend/ent/usersubscription.go +++ b/backend/ent/usersubscription.go @@ -23,6 +23,8 @@ type UserSubscription struct { CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` // UserID holds the value of the "user_id" field. UserID int64 `json:"user_id,omitempty"` // GroupID holds the value of the "group_id" field. @@ -65,9 +67,11 @@ type UserSubscriptionEdges struct { Group *Group `json:"group,omitempty"` // AssignedByUser holds the value of the assigned_by_user edge. AssignedByUser *User `json:"assigned_by_user,omitempty"` + // UsageLogs holds the value of the usage_logs edge. + UsageLogs []*UsageLog `json:"usage_logs,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [3]bool + loadedTypes [4]bool } // UserOrErr returns the User value or an error if the edge @@ -103,6 +107,15 @@ func (e UserSubscriptionEdges) AssignedByUserOrErr() (*User, error) { return nil, &NotLoadedError{edge: "assigned_by_user"} } +// UsageLogsOrErr returns the UsageLogs value or an error if the edge +// was not loaded in eager-loading. +func (e UserSubscriptionEdges) UsageLogsOrErr() ([]*UsageLog, error) { + if e.loadedTypes[3] { + return e.UsageLogs, nil + } + return nil, &NotLoadedError{edge: "usage_logs"} +} + // scanValues returns the types for scanning values from sql.Rows. func (*UserSubscription) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) @@ -114,7 +127,7 @@ func (*UserSubscription) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullInt64) case usersubscription.FieldStatus, usersubscription.FieldNotes: values[i] = new(sql.NullString) - case usersubscription.FieldCreatedAt, usersubscription.FieldUpdatedAt, usersubscription.FieldStartsAt, usersubscription.FieldExpiresAt, usersubscription.FieldDailyWindowStart, usersubscription.FieldWeeklyWindowStart, usersubscription.FieldMonthlyWindowStart, usersubscription.FieldAssignedAt: + case usersubscription.FieldCreatedAt, usersubscription.FieldUpdatedAt, usersubscription.FieldDeletedAt, usersubscription.FieldStartsAt, usersubscription.FieldExpiresAt, usersubscription.FieldDailyWindowStart, usersubscription.FieldWeeklyWindowStart, usersubscription.FieldMonthlyWindowStart, usersubscription.FieldAssignedAt: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -149,6 +162,13 @@ func (_m *UserSubscription) assignValues(columns []string, values []any) error { } else if value.Valid { _m.UpdatedAt = value.Time } + case usersubscription.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + _m.DeletedAt = new(time.Time) + *_m.DeletedAt = value.Time + } case usersubscription.FieldUserID: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field user_id", values[i]) @@ -266,6 +286,11 @@ func (_m *UserSubscription) QueryAssignedByUser() *UserQuery { return NewUserSubscriptionClient(_m.config).QueryAssignedByUser(_m) } +// QueryUsageLogs queries the "usage_logs" edge of the UserSubscription entity. +func (_m *UserSubscription) QueryUsageLogs() *UsageLogQuery { + return NewUserSubscriptionClient(_m.config).QueryUsageLogs(_m) +} + // Update returns a builder for updating this UserSubscription. // Note that you need to call UserSubscription.Unwrap() before calling this method if this UserSubscription // was returned from a transaction, and the transaction was committed or rolled back. @@ -295,6 +320,11 @@ func (_m *UserSubscription) String() string { builder.WriteString("updated_at=") builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) builder.WriteString(", ") + if v := _m.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") builder.WriteString("user_id=") builder.WriteString(fmt.Sprintf("%v", _m.UserID)) builder.WriteString(", ") diff --git a/backend/ent/usersubscription/usersubscription.go b/backend/ent/usersubscription/usersubscription.go index f4f7fa82..06441646 100644 --- a/backend/ent/usersubscription/usersubscription.go +++ b/backend/ent/usersubscription/usersubscription.go @@ -5,6 +5,7 @@ package usersubscription import ( "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" ) @@ -18,6 +19,8 @@ const ( FieldCreatedAt = "created_at" // FieldUpdatedAt holds the string denoting the updated_at field in the database. FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" // FieldUserID holds the string denoting the user_id field in the database. FieldUserID = "user_id" // FieldGroupID holds the string denoting the group_id field in the database. @@ -52,6 +55,8 @@ const ( EdgeGroup = "group" // EdgeAssignedByUser holds the string denoting the assigned_by_user edge name in mutations. EdgeAssignedByUser = "assigned_by_user" + // EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations. + EdgeUsageLogs = "usage_logs" // Table holds the table name of the usersubscription in the database. Table = "user_subscriptions" // UserTable is the table that holds the user relation/edge. @@ -75,6 +80,13 @@ const ( AssignedByUserInverseTable = "users" // AssignedByUserColumn is the table column denoting the assigned_by_user relation/edge. AssignedByUserColumn = "assigned_by" + // UsageLogsTable is the table that holds the usage_logs relation/edge. + UsageLogsTable = "usage_logs" + // UsageLogsInverseTable is the table name for the UsageLog entity. + // It exists in this package in order to avoid circular dependency with the "usagelog" package. + UsageLogsInverseTable = "usage_logs" + // UsageLogsColumn is the table column denoting the usage_logs relation/edge. + UsageLogsColumn = "subscription_id" ) // Columns holds all SQL columns for usersubscription fields. @@ -82,6 +94,7 @@ var Columns = []string{ FieldID, FieldCreatedAt, FieldUpdatedAt, + FieldDeletedAt, FieldUserID, FieldGroupID, FieldStartsAt, @@ -108,7 +121,14 @@ func ValidColumn(column string) bool { return false } +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/Wei-Shaw/sub2api/ent/runtime" var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. @@ -147,6 +167,11 @@ func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() } +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + // ByUserID orders the results by the user_id field. func ByUserID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldUserID, opts...).ToFunc() @@ -237,6 +262,20 @@ func ByAssignedByUserField(field string, opts ...sql.OrderTermOption) OrderOptio sqlgraph.OrderByNeighborTerms(s, newAssignedByUserStep(), sql.OrderByField(field, opts...)) } } + +// ByUsageLogsCount orders the results by usage_logs count. +func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...) + } +} + +// ByUsageLogs orders the results by usage_logs terms. +func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} func newUserStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), @@ -258,3 +297,10 @@ func newAssignedByUserStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.M2O, true, AssignedByUserTable, AssignedByUserColumn), ) } +func newUsageLogsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UsageLogsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) +} diff --git a/backend/ent/usersubscription/where.go b/backend/ent/usersubscription/where.go index f6060d95..250e5ed5 100644 --- a/backend/ent/usersubscription/where.go +++ b/backend/ent/usersubscription/where.go @@ -65,6 +65,11 @@ func UpdatedAt(v time.Time) predicate.UserSubscription { return predicate.UserSubscription(sql.FieldEQ(FieldUpdatedAt, v)) } +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldEQ(FieldDeletedAt, v)) +} + // UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ. func UserID(v int64) predicate.UserSubscription { return predicate.UserSubscription(sql.FieldEQ(FieldUserID, v)) @@ -215,6 +220,56 @@ func UpdatedAtLTE(v time.Time) predicate.UserSubscription { return predicate.UserSubscription(sql.FieldLTE(FieldUpdatedAt, v)) } +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldNotNull(FieldDeletedAt)) +} + // UserIDEQ applies the EQ predicate on the "user_id" field. func UserIDEQ(v int64) predicate.UserSubscription { return predicate.UserSubscription(sql.FieldEQ(FieldUserID, v)) @@ -884,6 +939,29 @@ func HasAssignedByUserWith(preds ...predicate.User) predicate.UserSubscription { }) } +// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge. +func HasUsageLogs() predicate.UserSubscription { + return predicate.UserSubscription(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates). +func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.UserSubscription { + return predicate.UserSubscription(func(s *sql.Selector) { + step := newUsageLogsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.UserSubscription) predicate.UserSubscription { return predicate.UserSubscription(sql.AndPredicates(predicates...)) diff --git a/backend/ent/usersubscription_create.go b/backend/ent/usersubscription_create.go index 43997f64..dd03115b 100644 --- a/backend/ent/usersubscription_create.go +++ b/backend/ent/usersubscription_create.go @@ -12,6 +12,7 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -52,6 +53,20 @@ func (_c *UserSubscriptionCreate) SetNillableUpdatedAt(v *time.Time) *UserSubscr return _c } +// SetDeletedAt sets the "deleted_at" field. +func (_c *UserSubscriptionCreate) SetDeletedAt(v time.Time) *UserSubscriptionCreate { + _c.mutation.SetDeletedAt(v) + return _c +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (_c *UserSubscriptionCreate) SetNillableDeletedAt(v *time.Time) *UserSubscriptionCreate { + if v != nil { + _c.SetDeletedAt(*v) + } + return _c +} + // SetUserID sets the "user_id" field. func (_c *UserSubscriptionCreate) SetUserID(v int64) *UserSubscriptionCreate { _c.mutation.SetUserID(v) @@ -245,6 +260,21 @@ func (_c *UserSubscriptionCreate) SetAssignedByUser(v *User) *UserSubscriptionCr return _c.SetAssignedByUserID(v.ID) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_c *UserSubscriptionCreate) AddUsageLogIDs(ids ...int64) *UserSubscriptionCreate { + _c.mutation.AddUsageLogIDs(ids...) + return _c +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_c *UserSubscriptionCreate) AddUsageLogs(v ...*UsageLog) *UserSubscriptionCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddUsageLogIDs(ids...) +} + // Mutation returns the UserSubscriptionMutation object of the builder. func (_c *UserSubscriptionCreate) Mutation() *UserSubscriptionMutation { return _c.mutation @@ -252,7 +282,9 @@ func (_c *UserSubscriptionCreate) Mutation() *UserSubscriptionMutation { // Save creates the UserSubscription in the database. func (_c *UserSubscriptionCreate) Save(ctx context.Context) (*UserSubscription, error) { - _c.defaults() + if err := _c.defaults(); err != nil { + return nil, err + } return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) } @@ -279,12 +311,18 @@ func (_c *UserSubscriptionCreate) ExecX(ctx context.Context) { } // defaults sets the default values of the builder before save. -func (_c *UserSubscriptionCreate) defaults() { +func (_c *UserSubscriptionCreate) defaults() error { if _, ok := _c.mutation.CreatedAt(); !ok { + if usersubscription.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized usersubscription.DefaultCreatedAt (forgotten import ent/runtime?)") + } v := usersubscription.DefaultCreatedAt() _c.mutation.SetCreatedAt(v) } if _, ok := _c.mutation.UpdatedAt(); !ok { + if usersubscription.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized usersubscription.DefaultUpdatedAt (forgotten import ent/runtime?)") + } v := usersubscription.DefaultUpdatedAt() _c.mutation.SetUpdatedAt(v) } @@ -305,9 +343,13 @@ func (_c *UserSubscriptionCreate) defaults() { _c.mutation.SetMonthlyUsageUsd(v) } if _, ok := _c.mutation.AssignedAt(); !ok { + if usersubscription.DefaultAssignedAt == nil { + return fmt.Errorf("ent: uninitialized usersubscription.DefaultAssignedAt (forgotten import ent/runtime?)") + } v := usersubscription.DefaultAssignedAt() _c.mutation.SetAssignedAt(v) } + return nil } // check runs all checks and user-defined validators on the builder. @@ -391,6 +433,10 @@ func (_c *UserSubscriptionCreate) createSpec() (*UserSubscription, *sqlgraph.Cre _spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value) _node.UpdatedAt = value } + if value, ok := _c.mutation.DeletedAt(); ok { + _spec.SetField(usersubscription.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } if value, ok := _c.mutation.StartsAt(); ok { _spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value) _node.StartsAt = value @@ -486,6 +532,22 @@ func (_c *UserSubscriptionCreate) createSpec() (*UserSubscription, *sqlgraph.Cre _node.AssignedBy = &nodes[0] _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: usersubscription.UsageLogsTable, + Columns: []string{usersubscription.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } @@ -550,6 +612,24 @@ func (u *UserSubscriptionUpsert) UpdateUpdatedAt() *UserSubscriptionUpsert { return u } +// SetDeletedAt sets the "deleted_at" field. +func (u *UserSubscriptionUpsert) SetDeletedAt(v time.Time) *UserSubscriptionUpsert { + u.Set(usersubscription.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserSubscriptionUpsert) UpdateDeletedAt() *UserSubscriptionUpsert { + u.SetExcluded(usersubscription.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserSubscriptionUpsert) ClearDeletedAt() *UserSubscriptionUpsert { + u.SetNull(usersubscription.FieldDeletedAt) + return u +} + // SetUserID sets the "user_id" field. func (u *UserSubscriptionUpsert) SetUserID(v int64) *UserSubscriptionUpsert { u.Set(usersubscription.FieldUserID, v) @@ -825,6 +905,27 @@ func (u *UserSubscriptionUpsertOne) UpdateUpdatedAt() *UserSubscriptionUpsertOne }) } +// SetDeletedAt sets the "deleted_at" field. +func (u *UserSubscriptionUpsertOne) SetDeletedAt(v time.Time) *UserSubscriptionUpsertOne { + return u.Update(func(s *UserSubscriptionUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserSubscriptionUpsertOne) UpdateDeletedAt() *UserSubscriptionUpsertOne { + return u.Update(func(s *UserSubscriptionUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserSubscriptionUpsertOne) ClearDeletedAt() *UserSubscriptionUpsertOne { + return u.Update(func(s *UserSubscriptionUpsert) { + s.ClearDeletedAt() + }) +} + // SetUserID sets the "user_id" field. func (u *UserSubscriptionUpsertOne) SetUserID(v int64) *UserSubscriptionUpsertOne { return u.Update(func(s *UserSubscriptionUpsert) { @@ -1302,6 +1403,27 @@ func (u *UserSubscriptionUpsertBulk) UpdateUpdatedAt() *UserSubscriptionUpsertBu }) } +// SetDeletedAt sets the "deleted_at" field. +func (u *UserSubscriptionUpsertBulk) SetDeletedAt(v time.Time) *UserSubscriptionUpsertBulk { + return u.Update(func(s *UserSubscriptionUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserSubscriptionUpsertBulk) UpdateDeletedAt() *UserSubscriptionUpsertBulk { + return u.Update(func(s *UserSubscriptionUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserSubscriptionUpsertBulk) ClearDeletedAt() *UserSubscriptionUpsertBulk { + return u.Update(func(s *UserSubscriptionUpsert) { + s.ClearDeletedAt() + }) +} + // SetUserID sets the "user_id" field. func (u *UserSubscriptionUpsertBulk) SetUserID(v int64) *UserSubscriptionUpsertBulk { return u.Update(func(s *UserSubscriptionUpsert) { diff --git a/backend/ent/usersubscription_query.go b/backend/ent/usersubscription_query.go index 034f29b4..967fbddb 100644 --- a/backend/ent/usersubscription_query.go +++ b/backend/ent/usersubscription_query.go @@ -4,6 +4,7 @@ package ent import ( "context" + "database/sql/driver" "fmt" "math" @@ -13,6 +14,7 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -27,6 +29,7 @@ type UserSubscriptionQuery struct { withUser *UserQuery withGroup *GroupQuery withAssignedByUser *UserQuery + withUsageLogs *UsageLogQuery // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -129,6 +132,28 @@ func (_q *UserSubscriptionQuery) QueryAssignedByUser() *UserQuery { return query } +// QueryUsageLogs chains the current query on the "usage_logs" edge. +func (_q *UserSubscriptionQuery) QueryUsageLogs() *UsageLogQuery { + query := (&UsageLogClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(usersubscription.Table, usersubscription.FieldID, selector), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, usersubscription.UsageLogsTable, usersubscription.UsageLogsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // First returns the first UserSubscription entity from the query. // Returns a *NotFoundError when no UserSubscription was found. func (_q *UserSubscriptionQuery) First(ctx context.Context) (*UserSubscription, error) { @@ -324,6 +349,7 @@ func (_q *UserSubscriptionQuery) Clone() *UserSubscriptionQuery { withUser: _q.withUser.Clone(), withGroup: _q.withGroup.Clone(), withAssignedByUser: _q.withAssignedByUser.Clone(), + withUsageLogs: _q.withUsageLogs.Clone(), // clone intermediate query. sql: _q.sql.Clone(), path: _q.path, @@ -363,6 +389,17 @@ func (_q *UserSubscriptionQuery) WithAssignedByUser(opts ...func(*UserQuery)) *U return _q } +// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to +// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserSubscriptionQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *UserSubscriptionQuery { + query := (&UsageLogClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUsageLogs = query + return _q +} + // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. // @@ -441,10 +478,11 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook) var ( nodes = []*UserSubscription{} _spec = _q.querySpec() - loadedTypes = [3]bool{ + loadedTypes = [4]bool{ _q.withUser != nil, _q.withGroup != nil, _q.withAssignedByUser != nil, + _q.withUsageLogs != nil, } ) _spec.ScanValues = func(columns []string) ([]any, error) { @@ -483,6 +521,13 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook) return nil, err } } + if query := _q.withUsageLogs; query != nil { + if err := _q.loadUsageLogs(ctx, query, nodes, + func(n *UserSubscription) { n.Edges.UsageLogs = []*UsageLog{} }, + func(n *UserSubscription, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil { + return nil, err + } + } return nodes, nil } @@ -576,6 +621,39 @@ func (_q *UserSubscriptionQuery) loadAssignedByUser(ctx context.Context, query * } return nil } +func (_q *UserSubscriptionQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*UserSubscription, init func(*UserSubscription), assign func(*UserSubscription, *UsageLog)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*UserSubscription) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(usagelog.FieldSubscriptionID) + } + query.Where(predicate.UsageLog(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(usersubscription.UsageLogsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.SubscriptionID + if fk == nil { + return fmt.Errorf(`foreign-key "subscription_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "subscription_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *UserSubscriptionQuery) sqlCount(ctx context.Context) (int, error) { _spec := _q.querySpec() diff --git a/backend/ent/usersubscription_update.go b/backend/ent/usersubscription_update.go index c0df17ff..811dae7e 100644 --- a/backend/ent/usersubscription_update.go +++ b/backend/ent/usersubscription_update.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -36,6 +37,26 @@ func (_u *UserSubscriptionUpdate) SetUpdatedAt(v time.Time) *UserSubscriptionUpd return _u } +// SetDeletedAt sets the "deleted_at" field. +func (_u *UserSubscriptionUpdate) SetDeletedAt(v time.Time) *UserSubscriptionUpdate { + _u.mutation.SetDeletedAt(v) + return _u +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (_u *UserSubscriptionUpdate) SetNillableDeletedAt(v *time.Time) *UserSubscriptionUpdate { + if v != nil { + _u.SetDeletedAt(*v) + } + return _u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (_u *UserSubscriptionUpdate) ClearDeletedAt() *UserSubscriptionUpdate { + _u.mutation.ClearDeletedAt() + return _u +} + // SetUserID sets the "user_id" field. func (_u *UserSubscriptionUpdate) SetUserID(v int64) *UserSubscriptionUpdate { _u.mutation.SetUserID(v) @@ -312,6 +333,21 @@ func (_u *UserSubscriptionUpdate) SetAssignedByUser(v *User) *UserSubscriptionUp return _u.SetAssignedByUserID(v.ID) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *UserSubscriptionUpdate) AddUsageLogIDs(ids ...int64) *UserSubscriptionUpdate { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *UserSubscriptionUpdate) AddUsageLogs(v ...*UsageLog) *UserSubscriptionUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // Mutation returns the UserSubscriptionMutation object of the builder. func (_u *UserSubscriptionUpdate) Mutation() *UserSubscriptionMutation { return _u.mutation @@ -335,9 +371,32 @@ func (_u *UserSubscriptionUpdate) ClearAssignedByUser() *UserSubscriptionUpdate return _u } +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *UserSubscriptionUpdate) ClearUsageLogs() *UserSubscriptionUpdate { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *UserSubscriptionUpdate) RemoveUsageLogIDs(ids ...int64) *UserSubscriptionUpdate { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *UserSubscriptionUpdate) RemoveUsageLogs(v ...*UsageLog) *UserSubscriptionUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *UserSubscriptionUpdate) Save(ctx context.Context) (int, error) { - _u.defaults() + if err := _u.defaults(); err != nil { + return 0, err + } return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) } @@ -364,11 +423,15 @@ func (_u *UserSubscriptionUpdate) ExecX(ctx context.Context) { } // defaults sets the default values of the builder before save. -func (_u *UserSubscriptionUpdate) defaults() { +func (_u *UserSubscriptionUpdate) defaults() error { if _, ok := _u.mutation.UpdatedAt(); !ok { + if usersubscription.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized usersubscription.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } v := usersubscription.UpdateDefaultUpdatedAt() _u.mutation.SetUpdatedAt(v) } + return nil } // check runs all checks and user-defined validators on the builder. @@ -402,6 +465,12 @@ func (_u *UserSubscriptionUpdate) sqlSave(ctx context.Context) (_node int, err e if value, ok := _u.mutation.UpdatedAt(); ok { _spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value) } + if value, ok := _u.mutation.DeletedAt(); ok { + _spec.SetField(usersubscription.FieldDeletedAt, field.TypeTime, value) + } + if _u.mutation.DeletedAtCleared() { + _spec.ClearField(usersubscription.FieldDeletedAt, field.TypeTime) + } if value, ok := _u.mutation.StartsAt(); ok { _spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value) } @@ -543,6 +612,51 @@ func (_u *UserSubscriptionUpdate) sqlSave(ctx context.Context) (_node int, err e } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: usersubscription.UsageLogsTable, + Columns: []string{usersubscription.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: usersubscription.UsageLogsTable, + Columns: []string{usersubscription.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: usersubscription.UsageLogsTable, + Columns: []string{usersubscription.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{usersubscription.Label} @@ -569,6 +683,26 @@ func (_u *UserSubscriptionUpdateOne) SetUpdatedAt(v time.Time) *UserSubscription return _u } +// SetDeletedAt sets the "deleted_at" field. +func (_u *UserSubscriptionUpdateOne) SetDeletedAt(v time.Time) *UserSubscriptionUpdateOne { + _u.mutation.SetDeletedAt(v) + return _u +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (_u *UserSubscriptionUpdateOne) SetNillableDeletedAt(v *time.Time) *UserSubscriptionUpdateOne { + if v != nil { + _u.SetDeletedAt(*v) + } + return _u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (_u *UserSubscriptionUpdateOne) ClearDeletedAt() *UserSubscriptionUpdateOne { + _u.mutation.ClearDeletedAt() + return _u +} + // SetUserID sets the "user_id" field. func (_u *UserSubscriptionUpdateOne) SetUserID(v int64) *UserSubscriptionUpdateOne { _u.mutation.SetUserID(v) @@ -845,6 +979,21 @@ func (_u *UserSubscriptionUpdateOne) SetAssignedByUser(v *User) *UserSubscriptio return _u.SetAssignedByUserID(v.ID) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *UserSubscriptionUpdateOne) AddUsageLogIDs(ids ...int64) *UserSubscriptionUpdateOne { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *UserSubscriptionUpdateOne) AddUsageLogs(v ...*UsageLog) *UserSubscriptionUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // Mutation returns the UserSubscriptionMutation object of the builder. func (_u *UserSubscriptionUpdateOne) Mutation() *UserSubscriptionMutation { return _u.mutation @@ -868,6 +1017,27 @@ func (_u *UserSubscriptionUpdateOne) ClearAssignedByUser() *UserSubscriptionUpda return _u } +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *UserSubscriptionUpdateOne) ClearUsageLogs() *UserSubscriptionUpdateOne { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *UserSubscriptionUpdateOne) RemoveUsageLogIDs(ids ...int64) *UserSubscriptionUpdateOne { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *UserSubscriptionUpdateOne) RemoveUsageLogs(v ...*UsageLog) *UserSubscriptionUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // Where appends a list predicates to the UserSubscriptionUpdate builder. func (_u *UserSubscriptionUpdateOne) Where(ps ...predicate.UserSubscription) *UserSubscriptionUpdateOne { _u.mutation.Where(ps...) @@ -883,7 +1053,9 @@ func (_u *UserSubscriptionUpdateOne) Select(field string, fields ...string) *Use // Save executes the query and returns the updated UserSubscription entity. func (_u *UserSubscriptionUpdateOne) Save(ctx context.Context) (*UserSubscription, error) { - _u.defaults() + if err := _u.defaults(); err != nil { + return nil, err + } return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) } @@ -910,11 +1082,15 @@ func (_u *UserSubscriptionUpdateOne) ExecX(ctx context.Context) { } // defaults sets the default values of the builder before save. -func (_u *UserSubscriptionUpdateOne) defaults() { +func (_u *UserSubscriptionUpdateOne) defaults() error { if _, ok := _u.mutation.UpdatedAt(); !ok { + if usersubscription.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized usersubscription.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } v := usersubscription.UpdateDefaultUpdatedAt() _u.mutation.SetUpdatedAt(v) } + return nil } // check runs all checks and user-defined validators on the builder. @@ -965,6 +1141,12 @@ func (_u *UserSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *UserSu if value, ok := _u.mutation.UpdatedAt(); ok { _spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value) } + if value, ok := _u.mutation.DeletedAt(); ok { + _spec.SetField(usersubscription.FieldDeletedAt, field.TypeTime, value) + } + if _u.mutation.DeletedAtCleared() { + _spec.ClearField(usersubscription.FieldDeletedAt, field.TypeTime) + } if value, ok := _u.mutation.StartsAt(); ok { _spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value) } @@ -1106,6 +1288,51 @@ func (_u *UserSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *UserSu } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: usersubscription.UsageLogsTable, + Columns: []string{usersubscription.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: usersubscription.UsageLogsTable, + Columns: []string{usersubscription.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: usersubscription.UsageLogsTable, + Columns: []string{usersubscription.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &UserSubscription{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 19dff447..23e85e9a 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -14,6 +14,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "strconv" "time" @@ -56,7 +57,7 @@ func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accoun func (r *accountRepository) Create(ctx context.Context, account *service.Account) error { if account == nil { - return nil + return service.ErrAccountNilInput } builder := r.client.Account.Create(). @@ -98,7 +99,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account created, err := builder.Save(ctx) if err != nil { - return err + return translatePersistenceError(err, service.ErrAccountNotFound, nil) } account.ID = created.ID @@ -231,11 +232,32 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account } func (r *accountRepository) Delete(ctx context.Context, id int64) error { - if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil { + // 使用事务保证账号与关联分组的删除原子性 + tx, err := r.client.Tx(ctx) + if err != nil && !errors.Is(err, dbent.ErrTxStarted) { return err } - _, err := r.client.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx) - return err + + var txClient *dbent.Client + if err == nil { + defer func() { _ = tx.Rollback() }() + txClient = tx.Client() + } else { + // 已处于外部事务中(ErrTxStarted),复用当前 client + txClient = r.client + } + + if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil { + return err + } + if _, err := txClient.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx); err != nil { + return err + } + + if tx != nil { + return tx.Commit() + } + return nil } func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { @@ -393,25 +415,49 @@ func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]s } func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { - if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil { + // 使用事务保证删除旧绑定与创建新绑定的原子性 + tx, err := r.client.Tx(ctx) + if err != nil && !errors.Is(err, dbent.ErrTxStarted) { + return err + } + + var txClient *dbent.Client + if err == nil { + defer func() { _ = tx.Rollback() }() + txClient = tx.Client() + } else { + // 已处于外部事务中(ErrTxStarted),复用当前 client + txClient = r.client + } + + if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil { return err } if len(groupIDs) == 0 { + if tx != nil { + return tx.Commit() + } return nil } builders := make([]*dbent.AccountGroupCreate, 0, len(groupIDs)) for i, groupID := range groupIDs { - builders = append(builders, r.client.AccountGroup.Create(). + builders = append(builders, txClient.AccountGroup.Create(). SetAccountID(accountID). SetGroupID(groupID). SetPriority(i+1), ) } - _, err := r.client.AccountGroup.CreateBulk(builders...).Save(ctx) - return err + if _, err := txClient.AccountGroup.CreateBulk(builders...).Save(ctx); err != nil { + return err + } + + if tx != nil { + return tx.Commit() + } + return nil } func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) { @@ -555,24 +601,30 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m return nil } - accountExtra, err := r.client.Account.Query(). - Where(dbaccount.IDEQ(id)). - Select(dbaccount.FieldExtra). - Only(ctx) + // 使用 JSONB 合并操作实现原子更新,避免读-改-写的并发丢失更新问题 + payload, err := json.Marshal(updates) if err != nil { - return translatePersistenceError(err, service.ErrAccountNotFound, nil) + return err } - extra := normalizeJSONMap(accountExtra.Extra) - for k, v := range updates { - extra[k] = v + client := clientFromContext(ctx, r.client) + result, err := client.ExecContext( + ctx, + "UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) || $1::jsonb, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL", + payload, id, + ) + if err != nil { + return err } - _, err = r.client.Account.Update(). - Where(dbaccount.IDEQ(id)). - SetExtra(extra). - Save(ctx) - return err + affected, err := result.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return service.ErrAccountNotFound + } + return nil } func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 269f0661..a3a52333 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -318,12 +318,13 @@ func groupEntityToService(g *dbent.Group) *service.Group { RateMultiplier: g.RateMultiplier, IsExclusive: g.IsExclusive, Status: g.Status, - SubscriptionType: g.SubscriptionType, - DailyLimitUSD: g.DailyLimitUsd, - WeeklyLimitUSD: g.WeeklyLimitUsd, - MonthlyLimitUSD: g.MonthlyLimitUsd, - CreatedAt: g.CreatedAt, - UpdatedAt: g.UpdatedAt, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUsd, + WeeklyLimitUSD: g.WeeklyLimitUsd, + MonthlyLimitUSD: g.MonthlyLimitUsd, + DefaultValidityDays: g.DefaultValidityDays, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, } } diff --git a/backend/internal/repository/error_translate.go b/backend/internal/repository/error_translate.go index 68348830..192f9261 100644 --- a/backend/internal/repository/error_translate.go +++ b/backend/internal/repository/error_translate.go @@ -1,6 +1,7 @@ package repository import ( + "context" "database/sql" "errors" "strings" @@ -10,6 +11,25 @@ import ( "github.com/lib/pq" ) +// clientFromContext 从 context 中获取事务 client,如果不存在则返回默认 client。 +// +// 这个辅助函数支持 repository 方法在事务上下文中工作: +// - 如果 context 中存在事务(通过 ent.NewTxContext 设置),返回事务的 client +// - 否则返回传入的默认 client +// +// 使用示例: +// +// func (r *someRepo) SomeMethod(ctx context.Context) error { +// client := clientFromContext(ctx, r.client) +// return client.SomeEntity.Create().Save(ctx) +// } +func clientFromContext(ctx context.Context, defaultClient *dbent.Client) *dbent.Client { + if tx := dbent.TxFromContext(ctx); tx != nil { + return tx.Client() + } + return defaultClient +} + // translatePersistenceError 将数据库层错误翻译为业务层错误。 // // 这是 Repository 层的核心错误处理函数,确保数据库细节不会泄露到业务层。 diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 5670a69b..53085247 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -42,7 +42,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetSubscriptionType(groupIn.SubscriptionType). SetNillableDailyLimitUsd(groupIn.DailyLimitUSD). SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD). - SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD) + SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD). + SetDefaultValidityDays(groupIn.DefaultValidityDays) created, err := builder.Save(ctx) if err == nil { @@ -79,6 +80,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableDailyLimitUsd(groupIn.DailyLimitUSD). SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD). SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD). + SetDefaultValidityDays(groupIn.DefaultValidityDays). Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists) @@ -89,7 +91,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er func (r *groupRepository) Delete(ctx context.Context, id int64) error { _, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx) - return err + return translatePersistenceError(err, service.ErrGroupNotFound, nil) } func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { @@ -239,8 +241,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, // err 为 dbent.ErrTxStarted 时,复用当前 client 参与同一事务。 // Lock the group row to avoid concurrent writes while we cascade. - // 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分“未找到”与其他错误。 - rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id) + // 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分"未找到"与其他错误。 + rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 AND deleted_at IS NULL FOR UPDATE", id) if err != nil { return nil, err } @@ -263,7 +265,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, var affectedUserIDs []int64 if groupSvc.IsSubscriptionType() { - rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1", id) + // 只查询未软删除的订阅,避免通知已取消订阅的用户 + rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1 AND deleted_at IS NULL", id) if err != nil { return nil, err } @@ -282,7 +285,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, return nil, err } - if _, err := exec.ExecContext(ctx, "DELETE FROM user_subscriptions WHERE group_id = $1", id); err != nil { + // 软删除订阅:设置 deleted_at 而非硬删除 + if _, err := exec.ExecContext(ctx, "UPDATE user_subscriptions SET deleted_at = NOW() WHERE group_id = $1 AND deleted_at IS NULL", id); err != nil { return nil, err } } @@ -297,18 +301,11 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, return nil, err } - // 3. Remove the group id from users.allowed_groups array (legacy representation). - // Phase 1 compatibility: also delete from user_allowed_groups join table when present. + // 3. Remove the group id from user_allowed_groups join table. + // Legacy users.allowed_groups 列已弃用,不再同步。 if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil { return nil, err } - if _, err := exec.ExecContext( - ctx, - "UPDATE users SET allowed_groups = array_remove(allowed_groups, $1) WHERE $1 = ANY(allowed_groups)", - id, - ); err != nil { - return nil, err - } // 4. Delete account_groups join rows. if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil { diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index a02c5f8f..b9079d7a 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -478,3 +478,58 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() { count, _ := s.repo.GetAccountCount(s.ctx, g.ID) s.Require().Zero(count) } + +// --- 软删除过滤测试 --- + +func (s *GroupRepoSuite) TestDelete_SoftDelete_NotVisibleInList() { + group := &service.Group{ + Name: "to-soft-delete", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, group)) + + // 获取删除前的列表数量 + listBefore, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100}) + s.Require().NoError(err) + beforeCount := len(listBefore) + + // 软删除 + err = s.repo.Delete(s.ctx, group.ID) + s.Require().NoError(err, "Delete (soft delete)") + + // 验证列表中不再包含软删除的 group + listAfter, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100}) + s.Require().NoError(err) + s.Require().Len(listAfter, beforeCount-1, "soft deleted group should not appear in list") + + // 验证 GetByID 也无法找到 + _, err = s.repo.GetByID(s.ctx, group.ID) + s.Require().Error(err) + s.Require().ErrorIs(err, service.ErrGroupNotFound) +} + +func (s *GroupRepoSuite) TestDelete_SoftDeletedGroup_lockForUpdate() { + group := &service.Group{ + Name: "lock-soft-delete", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, group)) + + // 软删除 + err := s.repo.Delete(s.ctx, group.ID) + s.Require().NoError(err) + + // 验证软删除的 group 在 GetByID 时返回 ErrGroupNotFound + // 这证明 lockForUpdate 的 deleted_at IS NULL 过滤正在工作 + _, err = s.repo.GetByID(s.ctx, group.ID) + s.Require().Error(err, "should fail to get soft-deleted group") + s.Require().ErrorIs(err, service.ErrGroupNotFound) +} diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index 80b0fad7..49d96445 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -53,6 +53,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { var uagRegclass sql.NullString require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass)) require.True(t, uagRegclass.Valid, "expected user_allowed_groups table to exist") + + // user_subscriptions: deleted_at for soft delete support (migration 012) + requireColumn(t, tx, "user_subscriptions", "deleted_at", "timestamp with time zone", 0, true) + + // orphan_allowed_groups_audit table should exist (migration 013) + var orphanAuditRegclass sql.NullString + require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.orphan_allowed_groups_audit')").Scan(&orphanAuditRegclass)) + require.True(t, orphanAuditRegclass.Valid, "expected orphan_allowed_groups_audit table to exist") + + // account_groups: created_at should be timestamptz + requireColumn(t, tx, "account_groups", "created_at", "timestamp with time zone", 0, false) + + // user_allowed_groups: created_at should be timestamptz + requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false) } func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) { diff --git a/backend/internal/repository/proxy_repo.go b/backend/internal/repository/proxy_repo.go index f9315525..c24b2e2c 100644 --- a/backend/internal/repository/proxy_repo.go +++ b/backend/internal/repository/proxy_repo.go @@ -178,7 +178,7 @@ func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in // GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) { - rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL GROUP BY proxy_id") + rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id") if err != nil { return nil, err } diff --git a/backend/internal/repository/redeem_code_repo.go b/backend/internal/repository/redeem_code_repo.go index 1429c678..ee8a01b5 100644 --- a/backend/internal/repository/redeem_code_repo.go +++ b/backend/internal/repository/redeem_code_repo.go @@ -168,7 +168,8 @@ func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemC func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error { now := time.Now() - affected, err := r.client.RedeemCode.Update(). + client := clientFromContext(ctx, r.client) + affected, err := client.RedeemCode.Update(). Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)). SetStatus(service.StatusUsed). SetUsedBy(userID). diff --git a/backend/internal/repository/soft_delete_ent_integration_test.go b/backend/internal/repository/soft_delete_ent_integration_test.go index 02176f90..e3560ab5 100644 --- a/backend/internal/repository/soft_delete_ent_integration_test.go +++ b/backend/internal/repository/soft_delete_ent_integration_test.go @@ -7,10 +7,12 @@ import ( "fmt" "strings" "testing" + "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + "github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/require" ) @@ -111,3 +113,104 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) { Only(mixins.SkipSoftDelete(ctx)) require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted") } + +// --- UserSubscription 软删除测试 --- + +func createEntGroup(t *testing.T, ctx context.Context, client *dbent.Client, name string) *dbent.Group { + t.Helper() + + g, err := client.Group.Create(). + SetName(name). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err, "create ent group") + return g +} + +func TestEntSoftDelete_UserSubscription_DefaultFilterAndSkip(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + + u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user")+"@example.com") + g := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group")) + + repo := NewUserSubscriptionRepository(client) + sub := &service.UserSubscription{ + UserID: u.ID, + GroupID: g.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + require.NoError(t, repo.Create(ctx, sub), "create user subscription") + + require.NoError(t, repo.Delete(ctx, sub.ID), "soft delete user subscription") + + _, err := repo.GetByID(ctx, sub.ID) + require.Error(t, err, "deleted rows should be hidden by default") + + _, err = client.UserSubscription.Query().Where(usersubscription.IDEQ(sub.ID)).Only(ctx) + require.Error(t, err, "default ent query should not see soft-deleted rows") + require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter") + + got, err := client.UserSubscription.Query(). + Where(usersubscription.IDEQ(sub.ID)). + Only(mixins.SkipSoftDelete(ctx)) + require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows") + require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete") +} + +func TestEntSoftDelete_UserSubscription_DeleteIdempotent(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + + u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user2")+"@example.com") + g := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group2")) + + repo := NewUserSubscriptionRepository(client) + sub := &service.UserSubscription{ + UserID: u.ID, + GroupID: g.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + require.NoError(t, repo.Create(ctx, sub), "create user subscription") + + require.NoError(t, repo.Delete(ctx, sub.ID), "first delete") + require.NoError(t, repo.Delete(ctx, sub.ID), "second delete should be idempotent") +} + +func TestEntSoftDelete_UserSubscription_ListExcludesDeleted(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + + u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user3")+"@example.com") + g1 := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group3a")) + g2 := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group3b")) + + repo := NewUserSubscriptionRepository(client) + + sub1 := &service.UserSubscription{ + UserID: u.ID, + GroupID: g1.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + require.NoError(t, repo.Create(ctx, sub1), "create subscription 1") + + sub2 := &service.UserSubscription{ + UserID: u.ID, + GroupID: g2.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + require.NoError(t, repo.Create(ctx, sub2), "create subscription 2") + + // 软删除 sub1 + require.NoError(t, repo.Delete(ctx, sub1.ID), "soft delete subscription 1") + + // ListByUserID 应只返回未删除的订阅 + subs, err := repo.ListByUserID(ctx, u.ID) + require.NoError(t, err, "ListByUserID") + require.Len(t, subs, 1, "should only return non-deleted subscriptions") + require.Equal(t, sub2.ID, subs[0].ID, "expected sub2 to be returned") +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 9a210bde..367ad430 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -1109,6 +1109,9 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs if err := rows.Close(); err != nil { return nil, err } + if err := rows.Err(); err != nil { + return nil, err + } today := timezone.Today() todayQuery := ` @@ -1135,6 +1138,9 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs if err := rows.Close(); err != nil { return nil, err } + if err := rows.Err(); err != nil { + return nil, err + } return result, nil } @@ -1177,6 +1183,9 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe if err := rows.Close(); err != nil { return nil, err } + if err := rows.Err(); err != nil { + return nil, err + } today := timezone.Today() todayQuery := ` @@ -1203,6 +1212,9 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe if err := rows.Close(); err != nil { return nil, err } + if err := rows.Err(); err != nil { + return nil, err + } return result, nil } diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 7766fe98..7294fadc 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -12,7 +12,6 @@ import ( "github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/lib/pq" ) type userRepository struct { @@ -86,10 +85,11 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, out := userEntityToService(m) groups, err := r.loadAllowedGroups(ctx, []int64{id}) - if err == nil { - if v, ok := groups[id]; ok { - out.AllowedGroups = v - } + if err != nil { + return nil, err + } + if v, ok := groups[id]; ok { + out.AllowedGroups = v } return out, nil } @@ -102,10 +102,11 @@ func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service out := userEntityToService(m) groups, err := r.loadAllowedGroups(ctx, []int64{m.ID}) - if err == nil { - if v, ok := groups[m.ID]; ok { - out.AllowedGroups = v - } + if err != nil { + return nil, err + } + if v, ok := groups[m.ID]; ok { + out.AllowedGroups = v } return out, nil } @@ -240,11 +241,12 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. } allowedGroupsByUser, err := r.loadAllowedGroups(ctx, userIDs) - if err == nil { - for id, u := range userMap { - if groups, ok := allowedGroupsByUser[id]; ok { - u.AllowedGroups = groups - } + if err != nil { + return nil, nil, err + } + for id, u := range userMap { + if groups, ok := allowedGroupsByUser[id]; ok { + u.AllowedGroups = groups } } @@ -252,12 +254,20 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. } func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { - _, err := r.client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx) - return err + client := clientFromContext(ctx, r.client) + n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx) + if err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + if n == 0 { + return service.ErrUserNotFound + } + return nil } func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error { - n, err := r.client.User.Update(). + client := clientFromContext(ctx, r.client) + n, err := client.User.Update(). Where(dbuser.IDEQ(id), dbuser.BalanceGTE(amount)). AddBalance(-amount). Save(ctx) @@ -271,8 +281,15 @@ func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount flo } func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error { - _, err := r.client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx) - return err + client := clientFromContext(ctx, r.client) + n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx) + if err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + if n == 0 { + return service.ErrUserNotFound + } + return nil } func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { @@ -280,33 +297,14 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, } func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { - exec := r.sql - if exec == nil { - // 未注入 sqlExecutor 时,退回到 ent client 的 ExecContext(支持事务)。 - exec = r.client - } - - joinAffected, err := r.client.UserAllowedGroup.Delete(). + // 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。 + affected, err := r.client.UserAllowedGroup.Delete(). Where(userallowedgroup.GroupIDEQ(groupID)). Exec(ctx) if err != nil { return 0, err } - - arrayRes, err := exec.ExecContext( - ctx, - "UPDATE users SET allowed_groups = array_remove(allowed_groups, $1), updated_at = NOW() WHERE $1 = ANY(allowed_groups)", - groupID, - ) - if err != nil { - return 0, err - } - arrayAffected, _ := arrayRes.RowsAffected() - - if int64(joinAffected) > arrayAffected { - return int64(joinAffected), nil - } - return arrayAffected, nil + return int64(affected), nil } func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) { @@ -323,10 +321,11 @@ func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, erro out := userEntityToService(m) groups, err := r.loadAllowedGroups(ctx, []int64{m.ID}) - if err == nil { - if v, ok := groups[m.ID]; ok { - out.AllowedGroups = v - } + if err != nil { + return nil, err + } + if v, ok := groups[m.ID]; ok { + out.AllowedGroups = v } return out, nil } @@ -356,8 +355,7 @@ func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64) } // syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组: -// 1) 以 user_allowed_groups 为读写源,确保新旧逻辑一致; -// 2) 额外更新 users.allowed_groups(历史字段)以保持兼容。 +// 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。 func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error { if client == nil { return nil @@ -376,12 +374,10 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl unique[id] = struct{}{} } - legacyGroups := make([]int64, 0, len(unique)) if len(unique) > 0 { creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique)) for groupID := range unique { creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID)) - legacyGroups = append(legacyGroups, groupID) } if err := client.UserAllowedGroup. CreateBulk(creates...). @@ -392,16 +388,6 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl } } - // Phase 1 兼容:保持 users.allowed_groups(数组字段)同步,避免旧查询路径读取到过期数据。 - var legacy any - if len(legacyGroups) > 0 { - sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] }) - legacy = pq.Array(legacyGroups) - } - if _, err := client.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil { - return err - } - return nil } diff --git a/backend/internal/repository/user_repo_integration_test.go b/backend/internal/repository/user_repo_integration_test.go index a59d2312..c5c9e78c 100644 --- a/backend/internal/repository/user_repo_integration_test.go +++ b/backend/internal/repository/user_repo_integration_test.go @@ -508,3 +508,24 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch") } +// --- UpdateBalance/UpdateConcurrency 影响行数校验测试 --- + +func (s *UserRepoSuite) TestUpdateBalance_NotFound() { + err := s.repo.UpdateBalance(s.ctx, 999999, 10.0) + s.Require().Error(err, "expected error for non-existent user") + s.Require().ErrorIs(err, service.ErrUserNotFound) +} + +func (s *UserRepoSuite) TestUpdateConcurrency_NotFound() { + err := s.repo.UpdateConcurrency(s.ctx, 999999, 5) + s.Require().Error(err, "expected error for non-existent user") + s.Require().ErrorIs(err, service.ErrUserNotFound) +} + +func (s *UserRepoSuite) TestDeductBalance_NotFound() { + err := s.repo.DeductBalance(s.ctx, 999999, 5) + s.Require().Error(err, "expected error for non-existent user") + // DeductBalance 在用户不存在时返回 ErrInsufficientBalance 因为 WHERE 条件不匹配 + s.Require().ErrorIs(err, service.ErrInsufficientBalance) +} + diff --git a/backend/internal/repository/user_subscription_repo.go b/backend/internal/repository/user_subscription_repo.go index 918ccab4..2b308674 100644 --- a/backend/internal/repository/user_subscription_repo.go +++ b/backend/internal/repository/user_subscription_repo.go @@ -20,10 +20,11 @@ func NewUserSubscriptionRepository(client *dbent.Client) service.UserSubscriptio func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error { if sub == nil { - return nil + return service.ErrSubscriptionNilInput } - builder := r.client.UserSubscription.Create(). + client := clientFromContext(ctx, r.client) + builder := client.UserSubscription.Create(). SetUserID(sub.UserID). SetGroupID(sub.GroupID). SetExpiresAt(sub.ExpiresAt). @@ -57,7 +58,8 @@ func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.Us } func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) { - m, err := r.client.UserSubscription.Query(). + client := clientFromContext(ctx, r.client) + m, err := client.UserSubscription.Query(). Where(usersubscription.IDEQ(id)). WithUser(). WithGroup(). @@ -70,7 +72,8 @@ func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*se } func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { - m, err := r.client.UserSubscription.Query(). + client := clientFromContext(ctx, r.client) + m, err := client.UserSubscription.Query(). Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)). WithGroup(). Only(ctx) @@ -81,7 +84,8 @@ func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, } func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { - m, err := r.client.UserSubscription.Query(). + client := clientFromContext(ctx, r.client) + m, err := client.UserSubscription.Query(). Where( usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID), @@ -98,10 +102,11 @@ func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Con func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error { if sub == nil { - return nil + return service.ErrSubscriptionNilInput } - builder := r.client.UserSubscription.UpdateOneID(sub.ID). + client := clientFromContext(ctx, r.client) + builder := client.UserSubscription.UpdateOneID(sub.ID). SetUserID(sub.UserID). SetGroupID(sub.GroupID). SetStartsAt(sub.StartsAt). @@ -127,12 +132,14 @@ func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.Us func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error { // Match GORM semantics: deleting a missing row is not an error. - _, err := r.client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx) + client := clientFromContext(ctx, r.client) + _, err := client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx) return err } func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { - subs, err := r.client.UserSubscription.Query(). + client := clientFromContext(ctx, r.client) + subs, err := client.UserSubscription.Query(). Where(usersubscription.UserIDEQ(userID)). WithGroup(). Order(dbent.Desc(usersubscription.FieldCreatedAt)). @@ -144,7 +151,8 @@ func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID in } func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { - subs, err := r.client.UserSubscription.Query(). + client := clientFromContext(ctx, r.client) + subs, err := client.UserSubscription.Query(). Where( usersubscription.UserIDEQ(userID), usersubscription.StatusEQ(service.SubscriptionStatusActive), @@ -160,7 +168,8 @@ func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, use } func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { - q := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)) + client := clientFromContext(ctx, r.client) + q := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)) total, err := q.Clone().Count(ctx) if err != nil { @@ -182,7 +191,8 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID } func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { - q := r.client.UserSubscription.Query() + client := clientFromContext(ctx, r.client) + q := client.UserSubscription.Query() if userID != nil { q = q.Where(usersubscription.UserIDEQ(*userID)) } @@ -214,34 +224,39 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination } func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { - return r.client.UserSubscription.Query(). + client := clientFromContext(ctx, r.client) + return client.UserSubscription.Query(). Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)). Exist(ctx) } func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error { - _, err := r.client.UserSubscription.UpdateOneID(subscriptionID). + client := clientFromContext(ctx, r.client) + _, err := client.UserSubscription.UpdateOneID(subscriptionID). SetExpiresAt(newExpiresAt). Save(ctx) return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error { - _, err := r.client.UserSubscription.UpdateOneID(subscriptionID). + client := clientFromContext(ctx, r.client) + _, err := client.UserSubscription.UpdateOneID(subscriptionID). SetStatus(status). Save(ctx) return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error { - _, err := r.client.UserSubscription.UpdateOneID(subscriptionID). + client := clientFromContext(ctx, r.client) + _, err := client.UserSubscription.UpdateOneID(subscriptionID). SetNotes(notes). Save(ctx) return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error { - _, err := r.client.UserSubscription.UpdateOneID(id). + client := clientFromContext(ctx, r.client) + _, err := client.UserSubscription.UpdateOneID(id). SetDailyWindowStart(start). SetWeeklyWindowStart(start). SetMonthlyWindowStart(start). @@ -250,7 +265,8 @@ func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int } func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { - _, err := r.client.UserSubscription.UpdateOneID(id). + client := clientFromContext(ctx, r.client) + _, err := client.UserSubscription.UpdateOneID(id). SetDailyUsageUsd(0). SetDailyWindowStart(newWindowStart). Save(ctx) @@ -258,7 +274,8 @@ func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int } func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { - _, err := r.client.UserSubscription.UpdateOneID(id). + client := clientFromContext(ctx, r.client) + _, err := client.UserSubscription.UpdateOneID(id). SetWeeklyUsageUsd(0). SetWeeklyWindowStart(newWindowStart). Save(ctx) @@ -266,24 +283,112 @@ func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in } func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { - _, err := r.client.UserSubscription.UpdateOneID(id). + client := clientFromContext(ctx, r.client) + _, err := client.UserSubscription.UpdateOneID(id). SetMonthlyUsageUsd(0). SetMonthlyWindowStart(newWindowStart). Save(ctx) return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } +// IncrementUsage 原子性地累加用量并校验限额。 +// 使用单条 SQL 语句同时检查 Group 的限额,如果任一限额即将超出则拒绝更新。 +// 当更新失败时,会执行额外查询确定具体超出的限额类型。 func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { - _, err := r.client.UserSubscription.UpdateOneID(id). - AddDailyUsageUsd(costUSD). - AddWeeklyUsageUsd(costUSD). - AddMonthlyUsageUsd(costUSD). - Save(ctx) - return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) + // 使用 JOIN 的原子更新:只有当所有限额条件满足时才执行累加 + // NULL 限额表示无限制 + const atomicUpdateSQL = ` + UPDATE user_subscriptions us + SET + daily_usage_usd = us.daily_usage_usd + $1, + weekly_usage_usd = us.weekly_usage_usd + $1, + monthly_usage_usd = us.monthly_usage_usd + $1, + updated_at = NOW() + FROM groups g + WHERE us.id = $2 + AND us.deleted_at IS NULL + AND us.group_id = g.id + AND g.deleted_at IS NULL + AND (g.daily_limit_usd IS NULL OR us.daily_usage_usd + $1 <= g.daily_limit_usd) + AND (g.weekly_limit_usd IS NULL OR us.weekly_usage_usd + $1 <= g.weekly_limit_usd) + AND (g.monthly_limit_usd IS NULL OR us.monthly_usage_usd + $1 <= g.monthly_limit_usd) + ` + + client := clientFromContext(ctx, r.client) + result, err := client.ExecContext(ctx, atomicUpdateSQL, costUSD, id) + if err != nil { + return err + } + + affected, err := result.RowsAffected() + if err != nil { + return err + } + + if affected > 0 { + return nil // 更新成功 + } + + // affected == 0:可能是订阅不存在、分组已删除、或限额超出 + // 执行额外查询确定具体原因 + return r.checkIncrementFailureReason(ctx, id, costUSD) +} + +// checkIncrementFailureReason 查询更新失败的具体原因 +func (r *userSubscriptionRepository) checkIncrementFailureReason(ctx context.Context, id int64, costUSD float64) error { + const checkSQL = ` + SELECT + CASE WHEN us.deleted_at IS NOT NULL THEN 'subscription_deleted' + WHEN g.id IS NULL THEN 'subscription_not_found' + WHEN g.deleted_at IS NOT NULL THEN 'group_deleted' + WHEN g.daily_limit_usd IS NOT NULL AND us.daily_usage_usd + $1 > g.daily_limit_usd THEN 'daily_exceeded' + WHEN g.weekly_limit_usd IS NOT NULL AND us.weekly_usage_usd + $1 > g.weekly_limit_usd THEN 'weekly_exceeded' + WHEN g.monthly_limit_usd IS NOT NULL AND us.monthly_usage_usd + $1 > g.monthly_limit_usd THEN 'monthly_exceeded' + ELSE 'unknown' + END AS reason + FROM user_subscriptions us + LEFT JOIN groups g ON us.group_id = g.id + WHERE us.id = $2 + ` + + client := clientFromContext(ctx, r.client) + rows, err := client.QueryContext(ctx, checkSQL, costUSD, id) + if err != nil { + return err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + return service.ErrSubscriptionNotFound + } + + var reason string + if err := rows.Scan(&reason); err != nil { + return err + } + + if err := rows.Err(); err != nil { + return err + } + + switch reason { + case "subscription_not_found", "subscription_deleted", "group_deleted": + return service.ErrSubscriptionNotFound + case "daily_exceeded": + return service.ErrDailyLimitExceeded + case "weekly_exceeded": + return service.ErrWeeklyLimitExceeded + case "monthly_exceeded": + return service.ErrMonthlyLimitExceeded + default: + // unknown 情况理论上不应发生,但作为兜底返回 + return service.ErrSubscriptionNotFound + } } func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { - n, err := r.client.UserSubscription.Update(). + client := clientFromContext(ctx, r.client) + n, err := client.UserSubscription.Update(). Where( usersubscription.StatusEQ(service.SubscriptionStatusActive), usersubscription.ExpiresAtLTE(time.Now()), @@ -296,7 +401,8 @@ func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Contex // Extra repository helpers (currently used only by integration tests). func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) { - subs, err := r.client.UserSubscription.Query(). + client := clientFromContext(ctx, r.client) + subs, err := client.UserSubscription.Query(). Where( usersubscription.StatusEQ(service.SubscriptionStatusActive), usersubscription.ExpiresAtLTE(time.Now()), @@ -309,12 +415,14 @@ func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service } func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { - count, err := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx) + client := clientFromContext(ctx, r.client) + count, err := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx) return int64(count), err } func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) { - count, err := r.client.UserSubscription.Query(). + client := clientFromContext(ctx, r.client) + count, err := client.UserSubscription.Query(). Where( usersubscription.GroupIDEQ(groupID), usersubscription.StatusEQ(service.SubscriptionStatusActive), @@ -325,7 +433,8 @@ func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, g } func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) { - n, err := r.client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx) + client := clientFromContext(ctx, r.client) + n, err := client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx) return int64(n), err } diff --git a/backend/internal/repository/user_subscription_repo_integration_test.go b/backend/internal/repository/user_subscription_repo_integration_test.go index 282b9673..3a6c6434 100644 --- a/backend/internal/repository/user_subscription_repo_integration_test.go +++ b/backend/internal/repository/user_subscription_repo_integration_test.go @@ -4,6 +4,7 @@ package repository import ( "context" + "fmt" "testing" "time" @@ -631,3 +632,249 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba s.Require().NoError(err, "GetByID expired") s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired") } + +// --- 限额检查与软删除过滤测试 --- + +func (s *UserSubscriptionRepoSuite) mustCreateGroupWithLimits(name string, daily, weekly, monthly *float64) *service.Group { + s.T().Helper() + + create := s.client.Group.Create(). + SetName(name). + SetStatus(service.StatusActive). + SetSubscriptionType(service.SubscriptionTypeSubscription) + + if daily != nil { + create.SetDailyLimitUsd(*daily) + } + if weekly != nil { + create.SetWeeklyLimitUsd(*weekly) + } + if monthly != nil { + create.SetMonthlyLimitUsd(*monthly) + } + + g, err := create.Save(s.ctx) + s.Require().NoError(err, "create group with limits") + return groupEntityToService(g) +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_DailyLimitExceeded() { + user := s.mustCreateUser("dailylimit@test.com", service.RoleUser) + dailyLimit := 10.0 + group := s.mustCreateGroupWithLimits("g-dailylimit", &dailyLimit, nil, nil) + sub := s.mustCreateSubscription(user.ID, group.ID, nil) + + // 先增加 9.0,应该成功 + err := s.repo.IncrementUsage(s.ctx, sub.ID, 9.0) + s.Require().NoError(err, "first increment should succeed") + + // 再增加 2.0,会超过 10.0 限额,应该失败 + err = s.repo.IncrementUsage(s.ctx, sub.ID, 2.0) + s.Require().Error(err, "should fail when daily limit exceeded") + s.Require().ErrorIs(err, service.ErrDailyLimitExceeded) + + // 验证用量没有变化 + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().InDelta(9.0, got.DailyUsageUSD, 1e-6, "usage should not change after failed increment") +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_WeeklyLimitExceeded() { + user := s.mustCreateUser("weeklylimit@test.com", service.RoleUser) + weeklyLimit := 50.0 + group := s.mustCreateGroupWithLimits("g-weeklylimit", nil, &weeklyLimit, nil) + sub := s.mustCreateSubscription(user.ID, group.ID, nil) + + // 增加 45.0,应该成功 + err := s.repo.IncrementUsage(s.ctx, sub.ID, 45.0) + s.Require().NoError(err, "first increment should succeed") + + // 再增加 10.0,会超过 50.0 限额,应该失败 + err = s.repo.IncrementUsage(s.ctx, sub.ID, 10.0) + s.Require().Error(err, "should fail when weekly limit exceeded") + s.Require().ErrorIs(err, service.ErrWeeklyLimitExceeded) +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_MonthlyLimitExceeded() { + user := s.mustCreateUser("monthlylimit@test.com", service.RoleUser) + monthlyLimit := 100.0 + group := s.mustCreateGroupWithLimits("g-monthlylimit", nil, nil, &monthlyLimit) + sub := s.mustCreateSubscription(user.ID, group.ID, nil) + + // 增加 90.0,应该成功 + err := s.repo.IncrementUsage(s.ctx, sub.ID, 90.0) + s.Require().NoError(err, "first increment should succeed") + + // 再增加 20.0,会超过 100.0 限额,应该失败 + err = s.repo.IncrementUsage(s.ctx, sub.ID, 20.0) + s.Require().Error(err, "should fail when monthly limit exceeded") + s.Require().ErrorIs(err, service.ErrMonthlyLimitExceeded) +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NoLimits() { + user := s.mustCreateUser("nolimits@test.com", service.RoleUser) + group := s.mustCreateGroupWithLimits("g-nolimits", nil, nil, nil) // 无限额 + sub := s.mustCreateSubscription(user.ID, group.ID, nil) + + // 应该可以增加任意金额 + err := s.repo.IncrementUsage(s.ctx, sub.ID, 1000000.0) + s.Require().NoError(err, "should succeed without limits") + + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().InDelta(1000000.0, got.DailyUsageUSD, 1e-6) +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_AtExactLimit() { + user := s.mustCreateUser("exactlimit@test.com", service.RoleUser) + dailyLimit := 10.0 + group := s.mustCreateGroupWithLimits("g-exactlimit", &dailyLimit, nil, nil) + sub := s.mustCreateSubscription(user.ID, group.ID, nil) + + // 正好达到限额应该成功 + err := s.repo.IncrementUsage(s.ctx, sub.ID, 10.0) + s.Require().NoError(err, "should succeed at exact limit") + + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().InDelta(10.0, got.DailyUsageUSD, 1e-6) +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() { + user := s.mustCreateUser("softdeleted@test.com", service.RoleUser) + group := s.mustCreateGroup("g-softdeleted") + sub := s.mustCreateSubscription(user.ID, group.ID, nil) + + // 软删除分组 + _, err := s.client.Group.UpdateOneID(group.ID).SetDeletedAt(time.Now()).Save(s.ctx) + s.Require().NoError(err, "soft delete group") + + // IncrementUsage 应该失败,因为分组已软删除 + err = s.repo.IncrementUsage(s.ctx, sub.ID, 1.0) + s.Require().Error(err, "should fail for soft-deleted group") + s.Require().ErrorIs(err, service.ErrSubscriptionNotFound) +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NotFound() { + err := s.repo.IncrementUsage(s.ctx, 999999, 1.0) + s.Require().Error(err, "should fail for non-existent subscription") + s.Require().ErrorIs(err, service.ErrSubscriptionNotFound) +} + +// --- nil 入参测试 --- + +func (s *UserSubscriptionRepoSuite) TestCreate_NilInput() { + err := s.repo.Create(s.ctx, nil) + s.Require().Error(err, "Create should fail with nil input") + s.Require().ErrorIs(err, service.ErrSubscriptionNilInput) +} + +func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() { + err := s.repo.Update(s.ctx, nil) + s.Require().Error(err, "Update should fail with nil input") + s.Require().ErrorIs(err, service.ErrSubscriptionNilInput) +} + +// --- 并发用量更新测试 --- + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() { + user := s.mustCreateUser("concurrent@test.com", service.RoleUser) + group := s.mustCreateGroupWithLimits("g-concurrent", nil, nil, nil) // 无限额 + sub := s.mustCreateSubscription(user.ID, group.ID, nil) + + const numGoroutines = 10 + const incrementPerGoroutine = 1.5 + + // 启动多个 goroutine 并发调用 IncrementUsage + errCh := make(chan error, numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + errCh <- s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerGoroutine) + }() + } + + // 等待所有 goroutine 完成 + for i := 0; i < numGoroutines; i++ { + err := <-errCh + s.Require().NoError(err, "IncrementUsage should succeed") + } + + // 验证累加结果正确 + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + expectedUsage := float64(numGoroutines) * incrementPerGoroutine + s.Require().InDelta(expectedUsage, got.DailyUsageUSD, 1e-6, "daily usage should be correctly accumulated") + s.Require().InDelta(expectedUsage, got.WeeklyUsageUSD, 1e-6, "weekly usage should be correctly accumulated") + s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated") +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_ConcurrentWithLimit() { + user := s.mustCreateUser("concurrentlimit@test.com", service.RoleUser) + dailyLimit := 5.0 + group := s.mustCreateGroupWithLimits("g-concurrentlimit", &dailyLimit, nil, nil) + sub := s.mustCreateSubscription(user.ID, group.ID, nil) + + // 注意:事务内的操作是串行的,所以这里改为顺序执行以验证限额逻辑 + // 尝试增加 10 次,每次 1.0,但限额只有 5.0 + const numAttempts = 10 + const incrementPerAttempt = 1.0 + + successCount := 0 + for i := 0; i < numAttempts; i++ { + err := s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerAttempt) + if err == nil { + successCount++ + } + } + + // 验证:应该有 5 次成功(不超过限额),5 次失败(超出限额) + s.Require().Equal(5, successCount, "exactly 5 increments should succeed (limit=5, increment=1)") + + // 验证最终用量等于限额 + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().InDelta(dailyLimit, got.DailyUsageUSD, 1e-6, "daily usage should equal limit") +} + +func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() { + baseClient := testEntClient(s.T()) + tx, err := baseClient.Tx(context.Background()) + s.Require().NoError(err, "begin tx") + defer func() { + if tx != nil { + _ = tx.Rollback() + } + }() + + txCtx := dbent.NewTxContext(context.Background(), tx) + suffix := fmt.Sprintf("%d", time.Now().UnixNano()) + + userEnt, err := tx.Client().User.Create(). + SetEmail("tx-user-" + suffix + "@example.com"). + SetPasswordHash("test"). + Save(txCtx) + s.Require().NoError(err, "create user in tx") + + groupEnt, err := tx.Client().Group.Create(). + SetName("tx-group-" + suffix). + Save(txCtx) + s.Require().NoError(err, "create group in tx") + + repo := NewUserSubscriptionRepository(baseClient) + sub := &service.UserSubscription{ + UserID: userEnt.ID, + GroupID: groupEnt.ID, + ExpiresAt: time.Now().AddDate(0, 0, 30), + Status: service.SubscriptionStatusActive, + AssignedAt: time.Now(), + Notes: "tx", + } + s.Require().NoError(repo.Create(txCtx, sub), "create subscription in tx") + s.Require().NoError(repo.UpdateNotes(txCtx, sub.ID, "tx-note"), "update subscription in tx") + + s.Require().NoError(tx.Rollback(), "rollback tx") + tx = nil + + _, err = repo.GetByID(context.Background(), sub.ID) + s.Require().ErrorIs(err, service.ErrSubscriptionNotFound) +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index ca3c4250..05895c8b 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -11,6 +11,7 @@ import ( var ( ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found") + ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil") ) type AccountRepository interface { diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index f1e36b89..7d6f407d 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -11,10 +11,11 @@ type Group struct { IsExclusive bool Status string - SubscriptionType string - DailyLimitUSD *float64 - WeeklyLimitUSD *float64 - MonthlyLimitUSD *float64 + SubscriptionType string + DailyLimitUSD *float64 + WeeklyLimitUSD *float64 + MonthlyLimitUSD *float64 + DefaultValidityDays int CreatedAt time.Time UpdatedAt time.Time diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index c587d212..7b0b80f5 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -9,6 +9,7 @@ import ( "strings" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) @@ -72,6 +73,7 @@ type RedeemService struct { subscriptionService *SubscriptionService cache RedeemCache billingCacheService *BillingCacheService + entClient *dbent.Client } // NewRedeemService 创建兑换码服务实例 @@ -81,6 +83,7 @@ func NewRedeemService( subscriptionService *SubscriptionService, cache RedeemCache, billingCacheService *BillingCacheService, + entClient *dbent.Client, ) *RedeemService { return &RedeemService{ redeemRepo: redeemRepo, @@ -88,6 +91,7 @@ func NewRedeemService( subscriptionService: subscriptionService, cache: cache, billingCacheService: billingCacheService, + entClient: entClient, } } @@ -248,9 +252,19 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( } _ = user // 使用变量避免未使用错误 + // 使用数据库事务保证兑换码标记与权益发放的原子性 + tx, err := s.entClient.Tx(ctx) + if err != nil { + return nil, fmt.Errorf("begin transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + // 将事务放入 context,使 repository 方法能够使用同一事务 + txCtx := dbent.NewTxContext(ctx, tx) + // 【关键】先标记兑换码为已使用,确保并发安全 // 利用数据库乐观锁(WHERE status = 'unused')保证原子性 - if err := s.redeemRepo.Use(ctx, redeemCode.ID, userID); err != nil { + if err := s.redeemRepo.Use(txCtx, redeemCode.ID, userID); err != nil { if errors.Is(err, ErrRedeemCodeNotFound) || errors.Is(err, ErrRedeemCodeUsed) { return nil, ErrRedeemCodeUsed } @@ -261,21 +275,13 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( switch redeemCode.Type { case RedeemTypeBalance: // 增加用户余额 - if err := s.userRepo.UpdateBalance(ctx, userID, redeemCode.Value); err != nil { + if err := s.userRepo.UpdateBalance(txCtx, userID, redeemCode.Value); err != nil { return nil, fmt.Errorf("update user balance: %w", err) } - // 失效余额缓存 - if s.billingCacheService != nil { - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) - }() - } case RedeemTypeConcurrency: // 增加用户并发数 - if err := s.userRepo.UpdateConcurrency(ctx, userID, int(redeemCode.Value)); err != nil { + if err := s.userRepo.UpdateConcurrency(txCtx, userID, int(redeemCode.Value)); err != nil { return nil, fmt.Errorf("update user concurrency: %w", err) } @@ -284,7 +290,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( if validityDays <= 0 { validityDays = 30 } - _, _, err := s.subscriptionService.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ + _, _, err := s.subscriptionService.AssignOrExtendSubscription(txCtx, &AssignSubscriptionInput{ UserID: userID, GroupID: *redeemCode.GroupID, ValidityDays: validityDays, @@ -294,20 +300,19 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( if err != nil { return nil, fmt.Errorf("assign or extend subscription: %w", err) } - // 失效订阅缓存 - if s.billingCacheService != nil { - groupID := *redeemCode.GroupID - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) - }() - } default: return nil, fmt.Errorf("unsupported redeem type: %s", redeemCode.Type) } + // 提交事务 + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit transaction: %w", err) + } + + // 事务提交成功后失效缓存 + s.invalidateRedeemCaches(ctx, userID, redeemCode) + // 重新获取更新后的兑换码 redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID) if err != nil { @@ -317,6 +322,31 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( return redeemCode, nil } +// invalidateRedeemCaches 失效兑换相关的缓存 +func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) { + if s.billingCacheService == nil { + return + } + + switch redeemCode.Type { + case RedeemTypeBalance: + go func() { + cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) + }() + case RedeemTypeSubscription: + if redeemCode.GroupID != nil { + groupID := *redeemCode.GroupID + go func() { + cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) + }() + } + } +} + // GetByID 根据ID获取兑换码 func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) { code, err := s.redeemRepo.GetByID(ctx, id) diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index fec6c147..09554c0f 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -26,6 +26,7 @@ var ( ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded") ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded") ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded") + ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil") ) // SubscriptionService 订阅服务 diff --git a/backend/migrations/011_remove_duplicate_unique_indexes.sql b/backend/migrations/011_remove_duplicate_unique_indexes.sql new file mode 100644 index 00000000..8fd62710 --- /dev/null +++ b/backend/migrations/011_remove_duplicate_unique_indexes.sql @@ -0,0 +1,39 @@ +-- 011_remove_duplicate_unique_indexes.sql +-- 移除重复的唯一索引 +-- 这些字段在 ent schema 的 Fields() 中已声明 .Unique(), +-- 因此在 Indexes() 中再次声明 index.Fields("x").Unique() 会创建重复索引。 +-- 本迁移脚本清理这些冗余索引。 + +-- 重复索引命名约定(由 Ent 自动生成/历史迁移遗留): +-- - 字段级 Unique() 创建的索引名: __key +-- - Indexes() 中的 Unique() 创建的索引名:
_ +-- - 初始化迁移中的非唯一索引: idx_
_ + +-- 仅当索引存在时才删除(幂等操作) + +-- api_keys 表: key 字段 +DROP INDEX IF EXISTS apikey_key; +DROP INDEX IF EXISTS api_keys_key; +DROP INDEX IF EXISTS idx_api_keys_key; + +-- users 表: email 字段 +DROP INDEX IF EXISTS user_email; +DROP INDEX IF EXISTS users_email; +DROP INDEX IF EXISTS idx_users_email; + +-- settings 表: key 字段 +DROP INDEX IF EXISTS settings_key; +DROP INDEX IF EXISTS idx_settings_key; + +-- redeem_codes 表: code 字段 +DROP INDEX IF EXISTS redeemcode_code; +DROP INDEX IF EXISTS redeem_codes_code; +DROP INDEX IF EXISTS idx_redeem_codes_code; + +-- groups 表: name 字段 +DROP INDEX IF EXISTS group_name; +DROP INDEX IF EXISTS groups_name; +DROP INDEX IF EXISTS idx_groups_name; + +-- 注意: 每个字段的唯一约束仍由字段级 Unique() 创建的约束保留, +-- 如 api_keys_key_key、users_email_key 等。 diff --git a/backend/migrations/012_add_user_subscription_soft_delete.sql b/backend/migrations/012_add_user_subscription_soft_delete.sql new file mode 100644 index 00000000..b6cb7366 --- /dev/null +++ b/backend/migrations/012_add_user_subscription_soft_delete.sql @@ -0,0 +1,13 @@ +-- 012: 为 user_subscriptions 表添加软删除支持 +-- 任务:fix-medium-data-hygiene 1.1 + +-- 添加 deleted_at 字段 +ALTER TABLE user_subscriptions +ADD COLUMN IF NOT EXISTS deleted_at TIMESTAMPTZ DEFAULT NULL; + +-- 添加 deleted_at 索引以优化软删除查询 +CREATE INDEX IF NOT EXISTS usersubscription_deleted_at +ON user_subscriptions (deleted_at); + +-- 注释:与其他使用软删除的实体保持一致 +COMMENT ON COLUMN user_subscriptions.deleted_at IS '软删除时间戳,NULL 表示未删除'; diff --git a/backend/migrations/013_log_orphan_allowed_groups.sql b/backend/migrations/013_log_orphan_allowed_groups.sql new file mode 100644 index 00000000..976c0aca --- /dev/null +++ b/backend/migrations/013_log_orphan_allowed_groups.sql @@ -0,0 +1,32 @@ +-- 013: 记录 users.allowed_groups 中的孤立 group_id +-- 任务:fix-medium-data-hygiene 3.1 +-- +-- 目的:在删除 legacy allowed_groups 列前,记录所有引用了不存在 group 的孤立记录 +-- 这些记录可用于审计或后续数据修复 + +-- 创建审计表存储孤立的 allowed_groups 记录 +CREATE TABLE IF NOT EXISTS orphan_allowed_groups_audit ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL, + group_id BIGINT NOT NULL, + recorded_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE (user_id, group_id) +); + +-- 记录孤立的 group_id(存在于 users.allowed_groups 但不存在于 groups 表) +INSERT INTO orphan_allowed_groups_audit (user_id, group_id) +SELECT u.id, x.group_id +FROM users u +CROSS JOIN LATERAL unnest(u.allowed_groups) AS x(group_id) +LEFT JOIN groups g ON g.id = x.group_id +WHERE u.allowed_groups IS NOT NULL + AND g.id IS NULL +ON CONFLICT (user_id, group_id) DO NOTHING; + +-- 添加索引便于查询 +CREATE INDEX IF NOT EXISTS idx_orphan_allowed_groups_audit_user_id +ON orphan_allowed_groups_audit(user_id); + +-- 记录迁移完成信息 +COMMENT ON TABLE orphan_allowed_groups_audit IS +'审计表:记录 users.allowed_groups 中引用的不存在的 group_id,用于数据清理前的审计'; diff --git a/backend/migrations/014_drop_legacy_allowed_groups.sql b/backend/migrations/014_drop_legacy_allowed_groups.sql new file mode 100644 index 00000000..2c2a3d45 --- /dev/null +++ b/backend/migrations/014_drop_legacy_allowed_groups.sql @@ -0,0 +1,15 @@ +-- 014: 删除 legacy users.allowed_groups 列 +-- 任务:fix-medium-data-hygiene 3.3 +-- +-- 前置条件: +-- - 迁移 007 已将数据回填到 user_allowed_groups 联接表 +-- - 迁移 013 已记录所有孤立的 group_id 到审计表 +-- - 应用代码已停止写入该列(3.2 完成) +-- +-- 该列现已废弃,所有读写操作均使用 user_allowed_groups 联接表。 + +-- 删除 allowed_groups 列 +ALTER TABLE users DROP COLUMN IF EXISTS allowed_groups; + +-- 添加注释记录删除原因 +COMMENT ON TABLE users IS '用户表。注:原 allowed_groups BIGINT[] 列已迁移至 user_allowed_groups 联接表';