Files
sub2api/backend/internal/repository/user_repo.go
2026-04-22 16:38:36 +08:00

980 lines
27 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"database/sql"
"errors"
"fmt"
"sort"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/predicate"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
entsql "entgo.io/ent/dialect/sql"
)
type userRepository struct {
client *dbent.Client
sql sqlExecutor
}
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
return newUserRepositoryWithSQL(client, sqlDB)
}
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
return &userRepository{client: client, sql: sqlq}
}
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
if userIn == nil {
return nil
}
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
var txClient *dbent.Client
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
txCtx = dbent.NewTxContext(ctx, tx)
} else {
// 已处于外部事务中ErrTxStarted复用当前事务 client 并由调用方负责提交/回滚。
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
txClient = existingTx.Client()
} else {
txClient = r.client
}
}
releaseEmailLock, err := lockRepositoryScopedKeys(
txCtx,
txClient,
txAwareSQLExecutor(txCtx, r.sql, r.client),
normalizedEmailUniquenessLockKey(userIn.Email),
)
if err != nil {
return err
}
defer releaseEmailLock()
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, 0, userIn.Email); err != nil {
return err
}
created, err := txClient.User.Create().
SetEmail(userIn.Email).
SetUsername(userIn.Username).
SetNotes(userIn.Notes).
SetPasswordHash(userIn.PasswordHash).
SetRole(userIn.Role).
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
SetNillableLastLoginAt(userIn.LastLoginAt).
SetNillableLastActiveAt(userIn.LastActiveAt).
Save(txCtx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil {
return err
}
if err := ensureEmailAuthIdentityWithClient(txCtx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
return err
}
if tx != nil {
if err := tx.Commit(); err != nil {
return err
}
}
applyUserEntityToService(userIn, created)
return nil
}
func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
m, err := r.client.User.Query().Where(dbuser.IDEQ(id)).Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{id})
if err != nil {
return nil, err
}
if v, ok := groups[id]; ok {
out.AllowedGroups = v
}
return out, nil
}
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
matches, err := r.client.User.Query().
Where(userEmailLookupPredicate(email)).
Order(dbent.Asc(dbuser.FieldID)).
All(ctx)
if err != nil {
return nil, err
}
if len(matches) == 0 {
return nil, service.ErrUserNotFound
}
if len(matches) > 1 {
return nil, fmt.Errorf("normalized email lookup matched multiple users for %q", strings.TrimSpace(email))
}
m := matches[0]
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err != nil {
return nil, err
}
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
return out, nil
}
func (r *userRepository) Update(ctx context.Context, userIn *service.User) error {
if userIn == nil {
return nil
}
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
var txClient *dbent.Client
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
txCtx = dbent.NewTxContext(ctx, tx)
} else {
// 已处于外部事务中ErrTxStarted复用当前事务 client 并由调用方负责提交/回滚。
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
txClient = existingTx.Client()
} else {
txClient = r.client
}
}
releaseEmailLock, err := lockRepositoryScopedKeys(
txCtx,
txClient,
txAwareSQLExecutor(txCtx, r.sql, r.client),
normalizedEmailUniquenessLockKey(userIn.Email),
)
if err != nil {
return err
}
defer releaseEmailLock()
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, userIn.ID, userIn.Email); err != nil {
return err
}
existing, err := clientFromContext(txCtx, txClient).User.Get(txCtx, userIn.ID)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
oldEmail := existing.Email
updateOp := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email).
SetUsername(userIn.Username).
SetNotes(userIn.Notes).
SetPasswordHash(userIn.PasswordHash).
SetRole(userIn.Role).
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
SetBalanceNotifyEnabled(userIn.BalanceNotifyEnabled).
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
SetTotalRecharged(userIn.TotalRecharged)
if userIn.SignupSource != "" {
updateOp = updateOp.SetSignupSource(userIn.SignupSource)
}
if userIn.LastLoginAt != nil {
updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt)
}
if userIn.LastActiveAt != nil {
updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt)
}
if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold()
}
updated, err := updateOp.Save(txCtx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
}
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
return err
}
if err := replaceEmailAuthIdentityWithClient(txCtx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
return err
}
if tx != nil {
if err := tx.Commit(); err != nil {
return err
}
}
userIn.UpdatedAt = updated.UpdatedAt
return nil
}
func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error {
client = clientFromContext(ctx, client)
if client == nil || userID <= 0 {
return nil
}
subject := normalizeEmailAuthIdentitySubject(email)
if subject == "" {
return nil
}
if err := client.AuthIdentity.Create().
SetUserID(userID).
SetProviderType("email").
SetProviderKey("email").
SetProviderSubject(subject).
SetVerifiedAt(time.Now().UTC()).
SetMetadata(map[string]any{"source": source}).
OnConflictColumns(
authidentity.FieldProviderType,
authidentity.FieldProviderKey,
authidentity.FieldProviderSubject,
).
DoNothing().
Exec(ctx); err != nil {
if !isSQLNoRowsError(err) {
return err
}
}
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ(subject),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil
}
return err
}
if identity.UserID != userID {
return ErrAuthIdentityOwnershipConflict
}
return nil
}
func replaceEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, oldEmail, newEmail string, source string) error {
newSubject := normalizeEmailAuthIdentitySubject(newEmail)
if err := ensureEmailAuthIdentityWithClient(ctx, client, userID, newEmail, source); err != nil {
return err
}
oldSubject := normalizeEmailAuthIdentitySubject(oldEmail)
if oldSubject == "" || oldSubject == newSubject {
return nil
}
_, err := clientFromContext(ctx, client).AuthIdentity.Delete().
Where(
authidentity.UserIDEQ(userID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ(oldSubject),
).
Exec(ctx)
return err
}
func normalizeEmailAuthIdentitySubject(email string) string {
normalized := strings.ToLower(strings.TrimSpace(email))
if normalized == "" {
return ""
}
if strings.HasSuffix(normalized, service.LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, service.OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) {
return ""
}
return normalized
}
func (r *userRepository) Delete(ctx context.Context, id int64) error {
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
txClient = existingTx.Client()
} else {
txClient = r.client
}
}
identityIDs, err := txClient.AuthIdentity.Query().
Where(authidentity.UserIDEQ(id)).
IDs(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if len(identityIDs) > 0 {
if _, err := txClient.IdentityAdoptionDecision.Update().
Where(identityadoptiondecision.IdentityIDIn(identityIDs...)).
ClearIdentityID().
Save(ctx); err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if _, err := txClient.AuthIdentityChannel.Delete().
Where(authidentitychannel.IdentityIDIn(identityIDs...)).
Exec(ctx); err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if _, err := txClient.AuthIdentity.Delete().
Where(authidentity.UserIDEQ(id)).
Exec(ctx); err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
}
affected, err := txClient.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if affected == 0 {
return service.ErrUserNotFound
}
if tx != nil {
if err := tx.Commit(); err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
}
return nil
}
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, service.UserListFilters{})
}
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
q := r.client.User.Query()
if filters.Status != "" {
q = q.Where(dbuser.StatusEQ(filters.Status))
}
if filters.Role != "" {
q = q.Where(dbuser.RoleEQ(filters.Role))
}
if filters.Search != "" {
q = q.Where(
dbuser.Or(
dbuser.EmailContainsFold(filters.Search),
dbuser.UsernameContainsFold(filters.Search),
dbuser.NotesContainsFold(filters.Search),
dbuser.HasAPIKeysWith(apikey.KeyContainsFold(filters.Search)),
),
)
}
if filters.GroupName != "" {
q = q.Where(dbuser.HasAllowedGroupsWith(
dbgroup.NameContainsFold(filters.GroupName),
))
}
// If attribute filters are specified, we need to filter by user IDs first
var allowedUserIDs []int64
if len(filters.Attributes) > 0 {
var attrErr error
allowedUserIDs, attrErr = r.filterUsersByAttributes(ctx, filters.Attributes)
if attrErr != nil {
return nil, nil, attrErr
}
if len(allowedUserIDs) == 0 {
// No users match the attribute filters
return []service.User{}, paginationResultFromTotal(0, params), nil
}
q = q.Where(dbuser.IDIn(allowedUserIDs...))
}
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
usersQuery := q.
Offset(params.Offset()).
Limit(params.Limit())
for _, order := range userListOrder(params) {
usersQuery = usersQuery.Order(order)
}
users, err := usersQuery.All(ctx)
if err != nil {
return nil, nil, err
}
outUsers := make([]service.User, 0, len(users))
if len(users) == 0 {
return outUsers, paginationResultFromTotal(int64(total), params), nil
}
userIDs := make([]int64, 0, len(users))
userMap := make(map[int64]*service.User, len(users))
for i := range users {
userIDs = append(userIDs, users[i].ID)
u := userEntityToService(users[i])
outUsers = append(outUsers, *u)
userMap[u.ID] = &outUsers[len(outUsers)-1]
}
shouldLoadSubscriptions := filters.IncludeSubscriptions == nil || *filters.IncludeSubscriptions
if shouldLoadSubscriptions {
// Batch load active subscriptions with groups to avoid N+1.
subs, err := r.client.UserSubscription.Query().
Where(
usersubscription.UserIDIn(userIDs...),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
).
WithGroup().
All(ctx)
if err != nil {
return nil, nil, err
}
for i := range subs {
if u, ok := userMap[subs[i].UserID]; ok {
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
}
}
}
allowedGroupsByUser, err := r.loadAllowedGroups(ctx, userIDs)
if err != nil {
return nil, nil, err
}
for id, u := range userMap {
if groups, ok := allowedGroupsByUser[id]; ok {
u.AllowedGroups = groups
}
}
return outUsers, paginationResultFromTotal(int64(total), params), nil
}
func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
if sortBy == "last_used_at" {
return userLastUsedAtOrder(sortOrder)
}
var field string
defaultField := true
nullsLastField := false
switch sortBy {
case "email":
field = dbuser.FieldEmail
defaultField = false
case "username":
field = dbuser.FieldUsername
defaultField = false
case "role":
field = dbuser.FieldRole
defaultField = false
case "balance":
field = dbuser.FieldBalance
defaultField = false
case "concurrency":
field = dbuser.FieldConcurrency
defaultField = false
case "status":
field = dbuser.FieldStatus
defaultField = false
case "created_at":
field = dbuser.FieldCreatedAt
defaultField = false
case "last_active_at":
field = dbuser.FieldLastActiveAt
defaultField = false
nullsLastField = true
default:
field = dbuser.FieldID
}
if sortOrder == pagination.SortOrderAsc {
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
}
if nullsLastField {
return []func(*entsql.Selector){
entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(),
dbent.Asc(dbuser.FieldID),
}
}
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
}
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
}
if nullsLastField {
return []func(*entsql.Selector){
entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(),
dbent.Desc(dbuser.FieldID),
}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
}
func (r *userRepository) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
result := make(map[int64]*time.Time, len(userIDs))
if len(userIDs) == 0 {
return result, nil
}
if r.sql == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
const query = `
SELECT user_id, MAX(created_at) AS last_used_at
FROM usage_logs
WHERE user_id = ANY($1)
GROUP BY user_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var (
userID int64
lastUsedAt time.Time
)
if scanErr := rows.Scan(&userID, &lastUsedAt); scanErr != nil {
return nil, scanErr
}
ts := lastUsedAt.UTC()
result[userID] = &ts
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *userRepository) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
latestByUserID, err := r.GetLatestUsedAtByUserIDs(ctx, []int64{userID})
if err != nil {
return nil, err
}
return latestByUserID[userID], nil
}
func userLastUsedAtOrder(sortOrder string) []func(*entsql.Selector) {
orderExpr := func(direction, nulls string, tieOrder func(string) string) func(*entsql.Selector) {
return func(s *entsql.Selector) {
subquery := fmt.Sprintf("(SELECT MAX(created_at) FROM usage_logs WHERE user_id = %s)", s.C(dbuser.FieldID))
s.OrderExpr(entsql.Expr(subquery + " " + direction + " NULLS " + nulls))
s.OrderBy(tieOrder(s.C(dbuser.FieldID)))
}
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){
orderExpr("ASC", "FIRST", entsql.Asc),
}
}
return []func(*entsql.Selector){
orderExpr("DESC", "LAST", entsql.Desc),
}
}
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 {
return nil, nil
}
if r.sql == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
clauses := make([]string, 0, len(attrs))
args := make([]any, 0, len(attrs)*2+1)
argIndex := 1
for attrID, value := range attrs {
clauses = append(clauses, fmt.Sprintf("(attribute_id = $%d AND value ILIKE $%d)", argIndex, argIndex+1))
args = append(args, attrID, "%"+value+"%")
argIndex += 2
}
query := fmt.Sprintf(
`SELECT user_id
FROM user_attribute_values
WHERE %s
GROUP BY user_id
HAVING COUNT(DISTINCT attribute_id) = $%d`,
strings.Join(clauses, " OR "),
argIndex,
)
args = append(args, len(attrs))
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
result := make([]int64, 0)
for rows.Next() {
var userID int64
if scanErr := rows.Scan(&userID); scanErr != nil {
return nil, scanErr
}
result = append(result, userID)
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client)
update := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount)
// Track cumulative recharge amount for percentage-based notifications
if amount > 0 {
update = update.AddTotalRecharged(amount)
}
n, err := update.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
}
// DeductBalance 扣除用户余额
// 透支策略:允许余额变为负数,确保当前请求能够完成
// 中间件会阻止余额 <= 0 的用户发起后续请求
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().
Where(dbuser.IDEQ(id)).
AddBalance(-amount).
Save(ctx)
if err != nil {
return err
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
}
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
}
func ensureNormalizedEmailAvailableWithClient(ctx context.Context, client *dbent.Client, userID int64, email string) error {
client = clientFromContext(ctx, client)
if client == nil {
return nil
}
matches, err := client.User.Query().
Where(userEmailLookupPredicate(email)).
All(ctx)
if err != nil {
return err
}
for _, match := range matches {
if match.ID != userID {
return service.ErrEmailExists
}
}
return nil
}
func userEmailLookupPredicate(email string) predicate.User {
normalized := normalizeEmailLookupValue(email)
if normalized == "" {
return dbuser.EmailEQ(email)
}
return predicate.User(func(s *entsql.Selector) {
s.Where(entsql.P(func(b *entsql.Builder) {
b.WriteString("LOWER(TRIM(").
Ident(s.C(dbuser.FieldEmail)).
WriteString(")) = ").
Arg(normalized)
}))
})
}
func normalizeEmailLookupValue(email string) string {
return strings.ToLower(strings.TrimSpace(email))
}
func normalizedEmailUniquenessLockKey(email string) string {
normalized := normalizeEmailLookupValue(email)
if normalized == "" {
return ""
}
return "users:normalized-email:" + normalized
}
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
client := clientFromContext(ctx, r.client)
err := client.UserAllowedGroup.Create().
SetUserID(userID).
SetGroupID(groupID).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx)
if isSQLNoRowsError(err) {
return nil
}
return err
}
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
// 仅操作 user_allowed_groups 联接表legacy users.allowed_groups 列已弃用。
affected, err := r.client.UserAllowedGroup.Delete().
Where(userallowedgroup.GroupIDEQ(groupID)).
Exec(ctx)
if err != nil {
return 0, err
}
return int64(affected), nil
}
// RemoveGroupFromUserAllowedGroups 移除单个用户的指定分组权限
func (r *userRepository) RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAllowedGroup.Delete().
Where(userallowedgroup.UserIDEQ(userID), userallowedgroup.GroupIDEQ(groupID)).
Exec(ctx)
return err
}
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
m, err := r.client.User.Query().
Where(
dbuser.RoleEQ(service.RoleAdmin),
dbuser.StatusEQ(service.StatusActive),
).
Order(dbent.Asc(dbuser.FieldID)).
First(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err != nil {
return nil, err
}
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
return out, nil
}
func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64) (map[int64][]int64, error) {
out := make(map[int64][]int64, len(userIDs))
if len(userIDs) == 0 {
return out, nil
}
rows, err := r.client.UserAllowedGroup.Query().
Where(userallowedgroup.UserIDIn(userIDs...)).
All(ctx)
if err != nil {
return nil, err
}
for i := range rows {
out[rows[i].UserID] = append(out[rows[i].UserID], rows[i].GroupID)
}
for userID := range out {
sort.Slice(out[userID], func(i, j int) bool { return out[userID][i] < out[userID][j] })
}
return out, nil
}
// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组:
// 仅操作 user_allowed_groups 联接表legacy users.allowed_groups 列已弃用。
func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error {
if client == nil {
return nil
}
// Keep join table as the source of truth for reads.
if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil {
return err
}
unique := make(map[int64]struct{}, len(groupIDs))
for _, id := range groupIDs {
if id <= 0 {
continue
}
unique[id] = struct{}{}
}
if len(unique) > 0 {
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
for groupID := range unique {
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
}
if err := client.UserAllowedGroup.
CreateBulk(creates...).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx); err != nil {
if isSQLNoRowsError(err) {
return nil
}
return err
}
}
return nil
}
func applyUserEntityToService(dst *service.User, src *dbent.User) {
if dst == nil || src == nil {
return
}
dst.ID = src.ID
dst.SignupSource = src.SignupSource
dst.LastLoginAt = src.LastLoginAt
dst.LastActiveAt = src.LastActiveAt
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
func userSignupSourceOrDefault(signupSource string) string {
switch strings.TrimSpace(strings.ToLower(signupSource)) {
case "", "email":
return "email"
case "linuxdo", "wechat", "oidc":
return strings.TrimSpace(strings.ToLower(signupSource))
default:
return "email"
}
}
// marshalExtraEmails serializes notify email entries to JSON for storage.
func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
return service.MarshalNotifyEmails(entries)
}
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
client := clientFromContext(ctx, r.client)
update := client.User.UpdateOneID(userID)
if encryptedSecret == nil {
update = update.ClearTotpSecretEncrypted()
} else {
update = update.SetTotpSecretEncrypted(*encryptedSecret)
}
_, err := update.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
// EnableTotp 启用用户的 TOTP 双因素认证
func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.User.UpdateOneID(userID).
SetTotpEnabled(true).
SetTotpEnabledAt(time.Now()).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
// DisableTotp 禁用用户的 TOTP 双因素认证
func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.User.UpdateOneID(userID).
SetTotpEnabled(false).
ClearTotpEnabledAt().
ClearTotpSecretEncrypted().
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}