853 lines
24 KiB
Go
853 lines
24 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"errors"
|
||
"fmt"
|
||
"sort"
|
||
"strings"
|
||
"time"
|
||
|
||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
|
||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||
"github.com/lib/pq"
|
||
|
||
entsql "entgo.io/ent/dialect/sql"
|
||
)
|
||
|
||
type userRepository struct {
|
||
client *dbent.Client
|
||
sql sqlExecutor
|
||
}
|
||
|
||
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
|
||
return newUserRepositoryWithSQL(client, sqlDB)
|
||
}
|
||
|
||
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
|
||
return &userRepository{client: client, sql: sqlq}
|
||
}
|
||
|
||
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
|
||
if userIn == nil {
|
||
return nil
|
||
}
|
||
|
||
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
|
||
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
|
||
tx, err := r.client.Tx(ctx)
|
||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||
return err
|
||
}
|
||
|
||
var txClient *dbent.Client
|
||
if err == nil {
|
||
defer func() { _ = tx.Rollback() }()
|
||
txClient = tx.Client()
|
||
} else {
|
||
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
||
txClient = r.client
|
||
}
|
||
|
||
created, err := txClient.User.Create().
|
||
SetEmail(userIn.Email).
|
||
SetUsername(userIn.Username).
|
||
SetNotes(userIn.Notes).
|
||
SetPasswordHash(userIn.PasswordHash).
|
||
SetRole(userIn.Role).
|
||
SetBalance(userIn.Balance).
|
||
SetConcurrency(userIn.Concurrency).
|
||
SetStatus(userIn.Status).
|
||
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
|
||
SetNillableLastLoginAt(userIn.LastLoginAt).
|
||
SetNillableLastActiveAt(userIn.LastActiveAt).
|
||
Save(ctx)
|
||
if err != nil {
|
||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||
}
|
||
|
||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
|
||
return err
|
||
}
|
||
if err := ensureEmailAuthIdentityWithClient(ctx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
|
||
return err
|
||
}
|
||
|
||
if tx != nil {
|
||
if err := tx.Commit(); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
applyUserEntityToService(userIn, created)
|
||
return nil
|
||
}
|
||
|
||
func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
|
||
m, err := r.client.User.Query().Where(dbuser.IDEQ(id)).Only(ctx)
|
||
if err != nil {
|
||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||
}
|
||
|
||
out := userEntityToService(m)
|
||
groups, err := r.loadAllowedGroups(ctx, []int64{id})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if v, ok := groups[id]; ok {
|
||
out.AllowedGroups = v
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
|
||
matches, err := r.client.User.Query().
|
||
Where(userEmailLookupPredicate(email)).
|
||
Order(dbent.Asc(dbuser.FieldID)).
|
||
All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if len(matches) == 0 {
|
||
return nil, service.ErrUserNotFound
|
||
}
|
||
if len(matches) > 1 {
|
||
return nil, fmt.Errorf("normalized email lookup matched multiple users for %q", strings.TrimSpace(email))
|
||
}
|
||
m := matches[0]
|
||
|
||
out := userEntityToService(m)
|
||
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if v, ok := groups[m.ID]; ok {
|
||
out.AllowedGroups = v
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func (r *userRepository) Update(ctx context.Context, userIn *service.User) error {
|
||
if userIn == nil {
|
||
return nil
|
||
}
|
||
|
||
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
|
||
tx, err := r.client.Tx(ctx)
|
||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||
return err
|
||
}
|
||
|
||
var txClient *dbent.Client
|
||
if err == nil {
|
||
defer func() { _ = tx.Rollback() }()
|
||
txClient = tx.Client()
|
||
} else {
|
||
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
||
txClient = r.client
|
||
}
|
||
existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID)
|
||
if err != nil {
|
||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||
}
|
||
oldEmail := existing.Email
|
||
|
||
updateOp := txClient.User.UpdateOneID(userIn.ID).
|
||
SetEmail(userIn.Email).
|
||
SetUsername(userIn.Username).
|
||
SetNotes(userIn.Notes).
|
||
SetPasswordHash(userIn.PasswordHash).
|
||
SetRole(userIn.Role).
|
||
SetBalance(userIn.Balance).
|
||
SetConcurrency(userIn.Concurrency).
|
||
SetStatus(userIn.Status).
|
||
SetBalanceNotifyEnabled(userIn.BalanceNotifyEnabled).
|
||
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
|
||
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
|
||
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
|
||
SetTotalRecharged(userIn.TotalRecharged)
|
||
if userIn.SignupSource != "" {
|
||
updateOp = updateOp.SetSignupSource(userIn.SignupSource)
|
||
}
|
||
if userIn.LastLoginAt != nil {
|
||
updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt)
|
||
}
|
||
if userIn.LastActiveAt != nil {
|
||
updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt)
|
||
}
|
||
if userIn.BalanceNotifyThreshold == nil {
|
||
updateOp = updateOp.ClearBalanceNotifyThreshold()
|
||
}
|
||
updated, err := updateOp.Save(ctx)
|
||
if err != nil {
|
||
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
||
}
|
||
|
||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
|
||
return err
|
||
}
|
||
if err := replaceEmailAuthIdentityWithClient(ctx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
|
||
return err
|
||
}
|
||
|
||
if tx != nil {
|
||
if err := tx.Commit(); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
userIn.UpdatedAt = updated.UpdatedAt
|
||
return nil
|
||
}
|
||
|
||
func (r *userRepository) EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error {
|
||
return ensureEmailAuthIdentityWithClient(ctx, r.client, userID, email, "service_dual_write")
|
||
}
|
||
|
||
func (r *userRepository) ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error {
|
||
return replaceEmailAuthIdentityWithClient(ctx, r.client, userID, oldEmail, newEmail, "service_dual_write")
|
||
}
|
||
|
||
func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error {
|
||
client = clientFromContext(ctx, client)
|
||
if client == nil || userID <= 0 {
|
||
return nil
|
||
}
|
||
|
||
subject := normalizeEmailAuthIdentitySubject(email)
|
||
if subject == "" {
|
||
return nil
|
||
}
|
||
|
||
if err := client.AuthIdentity.Create().
|
||
SetUserID(userID).
|
||
SetProviderType("email").
|
||
SetProviderKey("email").
|
||
SetProviderSubject(subject).
|
||
SetVerifiedAt(time.Now().UTC()).
|
||
SetMetadata(map[string]any{"source": source}).
|
||
OnConflictColumns(
|
||
authidentity.FieldProviderType,
|
||
authidentity.FieldProviderKey,
|
||
authidentity.FieldProviderSubject,
|
||
).
|
||
DoNothing().
|
||
Exec(ctx); err != nil {
|
||
return err
|
||
}
|
||
|
||
identity, err := client.AuthIdentity.Query().
|
||
Where(
|
||
authidentity.ProviderTypeEQ("email"),
|
||
authidentity.ProviderKeyEQ("email"),
|
||
authidentity.ProviderSubjectEQ(subject),
|
||
).
|
||
Only(ctx)
|
||
if err != nil {
|
||
if dbent.IsNotFound(err) {
|
||
return nil
|
||
}
|
||
return err
|
||
}
|
||
if identity.UserID != userID {
|
||
return ErrAuthIdentityOwnershipConflict
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func replaceEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, oldEmail, newEmail string, source string) error {
|
||
newSubject := normalizeEmailAuthIdentitySubject(newEmail)
|
||
if err := ensureEmailAuthIdentityWithClient(ctx, client, userID, newEmail, source); err != nil {
|
||
return err
|
||
}
|
||
|
||
oldSubject := normalizeEmailAuthIdentitySubject(oldEmail)
|
||
if oldSubject == "" || oldSubject == newSubject {
|
||
return nil
|
||
}
|
||
|
||
_, err := clientFromContext(ctx, client).AuthIdentity.Delete().
|
||
Where(
|
||
authidentity.UserIDEQ(userID),
|
||
authidentity.ProviderTypeEQ("email"),
|
||
authidentity.ProviderKeyEQ("email"),
|
||
authidentity.ProviderSubjectEQ(oldSubject),
|
||
).
|
||
Exec(ctx)
|
||
return err
|
||
}
|
||
|
||
func normalizeEmailAuthIdentitySubject(email string) string {
|
||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||
if normalized == "" {
|
||
return ""
|
||
}
|
||
if strings.HasSuffix(normalized, service.LinuxDoConnectSyntheticEmailDomain) ||
|
||
strings.HasSuffix(normalized, service.OIDCConnectSyntheticEmailDomain) ||
|
||
strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) {
|
||
return ""
|
||
}
|
||
return normalized
|
||
}
|
||
|
||
func (r *userRepository) Delete(ctx context.Context, id int64) error {
|
||
affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
|
||
if err != nil {
|
||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||
}
|
||
if affected == 0 {
|
||
return service.ErrUserNotFound
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
|
||
return r.ListWithFilters(ctx, params, service.UserListFilters{})
|
||
}
|
||
|
||
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
|
||
q := r.client.User.Query()
|
||
|
||
if filters.Status != "" {
|
||
q = q.Where(dbuser.StatusEQ(filters.Status))
|
||
}
|
||
if filters.Role != "" {
|
||
q = q.Where(dbuser.RoleEQ(filters.Role))
|
||
}
|
||
if filters.Search != "" {
|
||
q = q.Where(
|
||
dbuser.Or(
|
||
dbuser.EmailContainsFold(filters.Search),
|
||
dbuser.UsernameContainsFold(filters.Search),
|
||
dbuser.NotesContainsFold(filters.Search),
|
||
dbuser.HasAPIKeysWith(apikey.KeyContainsFold(filters.Search)),
|
||
),
|
||
)
|
||
}
|
||
|
||
if filters.GroupName != "" {
|
||
q = q.Where(dbuser.HasAllowedGroupsWith(
|
||
dbgroup.NameContainsFold(filters.GroupName),
|
||
))
|
||
}
|
||
|
||
// If attribute filters are specified, we need to filter by user IDs first
|
||
var allowedUserIDs []int64
|
||
if len(filters.Attributes) > 0 {
|
||
var attrErr error
|
||
allowedUserIDs, attrErr = r.filterUsersByAttributes(ctx, filters.Attributes)
|
||
if attrErr != nil {
|
||
return nil, nil, attrErr
|
||
}
|
||
if len(allowedUserIDs) == 0 {
|
||
// No users match the attribute filters
|
||
return []service.User{}, paginationResultFromTotal(0, params), nil
|
||
}
|
||
q = q.Where(dbuser.IDIn(allowedUserIDs...))
|
||
}
|
||
|
||
total, err := q.Clone().Count(ctx)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
usersQuery := q.
|
||
Offset(params.Offset()).
|
||
Limit(params.Limit())
|
||
for _, order := range userListOrder(params) {
|
||
usersQuery = usersQuery.Order(order)
|
||
}
|
||
|
||
users, err := usersQuery.All(ctx)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
outUsers := make([]service.User, 0, len(users))
|
||
if len(users) == 0 {
|
||
return outUsers, paginationResultFromTotal(int64(total), params), nil
|
||
}
|
||
|
||
userIDs := make([]int64, 0, len(users))
|
||
userMap := make(map[int64]*service.User, len(users))
|
||
for i := range users {
|
||
userIDs = append(userIDs, users[i].ID)
|
||
u := userEntityToService(users[i])
|
||
outUsers = append(outUsers, *u)
|
||
userMap[u.ID] = &outUsers[len(outUsers)-1]
|
||
}
|
||
|
||
shouldLoadSubscriptions := filters.IncludeSubscriptions == nil || *filters.IncludeSubscriptions
|
||
if shouldLoadSubscriptions {
|
||
// Batch load active subscriptions with groups to avoid N+1.
|
||
subs, err := r.client.UserSubscription.Query().
|
||
Where(
|
||
usersubscription.UserIDIn(userIDs...),
|
||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||
).
|
||
WithGroup().
|
||
All(ctx)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
for i := range subs {
|
||
if u, ok := userMap[subs[i].UserID]; ok {
|
||
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
|
||
}
|
||
}
|
||
}
|
||
|
||
allowedGroupsByUser, err := r.loadAllowedGroups(ctx, userIDs)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
for id, u := range userMap {
|
||
if groups, ok := allowedGroupsByUser[id]; ok {
|
||
u.AllowedGroups = groups
|
||
}
|
||
}
|
||
|
||
return outUsers, paginationResultFromTotal(int64(total), params), nil
|
||
}
|
||
|
||
func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
|
||
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
|
||
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
|
||
|
||
if sortBy == "last_used_at" {
|
||
return userLastUsedAtOrder(sortOrder)
|
||
}
|
||
|
||
var field string
|
||
defaultField := true
|
||
nullsLastField := false
|
||
switch sortBy {
|
||
case "email":
|
||
field = dbuser.FieldEmail
|
||
defaultField = false
|
||
case "username":
|
||
field = dbuser.FieldUsername
|
||
defaultField = false
|
||
case "role":
|
||
field = dbuser.FieldRole
|
||
defaultField = false
|
||
case "balance":
|
||
field = dbuser.FieldBalance
|
||
defaultField = false
|
||
case "concurrency":
|
||
field = dbuser.FieldConcurrency
|
||
defaultField = false
|
||
case "status":
|
||
field = dbuser.FieldStatus
|
||
defaultField = false
|
||
case "created_at":
|
||
field = dbuser.FieldCreatedAt
|
||
defaultField = false
|
||
case "last_login_at":
|
||
field = dbuser.FieldLastLoginAt
|
||
defaultField = false
|
||
nullsLastField = true
|
||
case "last_active_at":
|
||
field = dbuser.FieldLastActiveAt
|
||
defaultField = false
|
||
nullsLastField = true
|
||
default:
|
||
field = dbuser.FieldID
|
||
}
|
||
|
||
if sortOrder == pagination.SortOrderAsc {
|
||
if defaultField && field == dbuser.FieldID {
|
||
return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
|
||
}
|
||
if nullsLastField {
|
||
return []func(*entsql.Selector){
|
||
entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(),
|
||
dbent.Asc(dbuser.FieldID),
|
||
}
|
||
}
|
||
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
|
||
}
|
||
if defaultField && field == dbuser.FieldID {
|
||
return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
|
||
}
|
||
if nullsLastField {
|
||
return []func(*entsql.Selector){
|
||
entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(),
|
||
dbent.Desc(dbuser.FieldID),
|
||
}
|
||
}
|
||
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
|
||
}
|
||
|
||
func (r *userRepository) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
|
||
result := make(map[int64]*time.Time, len(userIDs))
|
||
if len(userIDs) == 0 {
|
||
return result, nil
|
||
}
|
||
if r.sql == nil {
|
||
return nil, fmt.Errorf("sql executor is not configured")
|
||
}
|
||
|
||
const query = `
|
||
SELECT user_id, MAX(created_at) AS last_used_at
|
||
FROM usage_logs
|
||
WHERE user_id = ANY($1)
|
||
GROUP BY user_id
|
||
`
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() { _ = rows.Close() }()
|
||
|
||
for rows.Next() {
|
||
var (
|
||
userID int64
|
||
lastUsedAt time.Time
|
||
)
|
||
if scanErr := rows.Scan(&userID, &lastUsedAt); scanErr != nil {
|
||
return nil, scanErr
|
||
}
|
||
ts := lastUsedAt.UTC()
|
||
result[userID] = &ts
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
func (r *userRepository) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
|
||
latestByUserID, err := r.GetLatestUsedAtByUserIDs(ctx, []int64{userID})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return latestByUserID[userID], nil
|
||
}
|
||
|
||
func userLastUsedAtOrder(sortOrder string) []func(*entsql.Selector) {
|
||
orderExpr := func(direction, nulls string, tieOrder func(string) string) func(*entsql.Selector) {
|
||
return func(s *entsql.Selector) {
|
||
subquery := fmt.Sprintf("(SELECT MAX(created_at) FROM usage_logs WHERE user_id = %s)", s.C(dbuser.FieldID))
|
||
s.OrderExpr(entsql.Expr(subquery + " " + direction + " NULLS " + nulls))
|
||
s.OrderBy(tieOrder(s.C(dbuser.FieldID)))
|
||
}
|
||
}
|
||
|
||
if sortOrder == pagination.SortOrderAsc {
|
||
return []func(*entsql.Selector){
|
||
orderExpr("ASC", "FIRST", entsql.Asc),
|
||
}
|
||
}
|
||
return []func(*entsql.Selector){
|
||
orderExpr("DESC", "LAST", entsql.Desc),
|
||
}
|
||
}
|
||
|
||
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
|
||
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
|
||
if len(attrs) == 0 {
|
||
return nil, nil
|
||
}
|
||
|
||
if r.sql == nil {
|
||
return nil, fmt.Errorf("sql executor is not configured")
|
||
}
|
||
|
||
clauses := make([]string, 0, len(attrs))
|
||
args := make([]any, 0, len(attrs)*2+1)
|
||
argIndex := 1
|
||
for attrID, value := range attrs {
|
||
clauses = append(clauses, fmt.Sprintf("(attribute_id = $%d AND value ILIKE $%d)", argIndex, argIndex+1))
|
||
args = append(args, attrID, "%"+value+"%")
|
||
argIndex += 2
|
||
}
|
||
|
||
query := fmt.Sprintf(
|
||
`SELECT user_id
|
||
FROM user_attribute_values
|
||
WHERE %s
|
||
GROUP BY user_id
|
||
HAVING COUNT(DISTINCT attribute_id) = $%d`,
|
||
strings.Join(clauses, " OR "),
|
||
argIndex,
|
||
)
|
||
args = append(args, len(attrs))
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() { _ = rows.Close() }()
|
||
|
||
result := make([]int64, 0)
|
||
for rows.Next() {
|
||
var userID int64
|
||
if scanErr := rows.Scan(&userID); scanErr != nil {
|
||
return nil, scanErr
|
||
}
|
||
result = append(result, userID)
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||
client := clientFromContext(ctx, r.client)
|
||
update := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount)
|
||
// Track cumulative recharge amount for percentage-based notifications
|
||
if amount > 0 {
|
||
update = update.AddTotalRecharged(amount)
|
||
}
|
||
n, err := update.Save(ctx)
|
||
if err != nil {
|
||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||
}
|
||
if n == 0 {
|
||
return service.ErrUserNotFound
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// DeductBalance 扣除用户余额
|
||
// 透支策略:允许余额变为负数,确保当前请求能够完成
|
||
// 中间件会阻止余额 <= 0 的用户发起后续请求
|
||
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||
client := clientFromContext(ctx, r.client)
|
||
n, err := client.User.Update().
|
||
Where(dbuser.IDEQ(id)).
|
||
AddBalance(-amount).
|
||
Save(ctx)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if n == 0 {
|
||
return service.ErrUserNotFound
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
|
||
client := clientFromContext(ctx, r.client)
|
||
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx)
|
||
if err != nil {
|
||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||
}
|
||
if n == 0 {
|
||
return service.ErrUserNotFound
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||
return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
|
||
}
|
||
|
||
func userEmailLookupPredicate(email string) predicate.User {
|
||
normalized := strings.TrimSpace(email)
|
||
if normalized == "" {
|
||
return dbuser.EmailEQ(email)
|
||
}
|
||
return predicate.User(func(s *entsql.Selector) {
|
||
s.Where(entsql.ExprP(
|
||
fmt.Sprintf("LOWER(TRIM(%s)) = LOWER(TRIM(?))", s.C(dbuser.FieldEmail)),
|
||
normalized,
|
||
))
|
||
})
|
||
}
|
||
|
||
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
||
client := clientFromContext(ctx, r.client)
|
||
return client.UserAllowedGroup.Create().
|
||
SetUserID(userID).
|
||
SetGroupID(groupID).
|
||
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
|
||
DoNothing().
|
||
Exec(ctx)
|
||
}
|
||
|
||
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||
// 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。
|
||
affected, err := r.client.UserAllowedGroup.Delete().
|
||
Where(userallowedgroup.GroupIDEQ(groupID)).
|
||
Exec(ctx)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return int64(affected), nil
|
||
}
|
||
|
||
// RemoveGroupFromUserAllowedGroups 移除单个用户的指定分组权限
|
||
func (r *userRepository) RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
||
client := clientFromContext(ctx, r.client)
|
||
_, err := client.UserAllowedGroup.Delete().
|
||
Where(userallowedgroup.UserIDEQ(userID), userallowedgroup.GroupIDEQ(groupID)).
|
||
Exec(ctx)
|
||
return err
|
||
}
|
||
|
||
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
|
||
m, err := r.client.User.Query().
|
||
Where(
|
||
dbuser.RoleEQ(service.RoleAdmin),
|
||
dbuser.StatusEQ(service.StatusActive),
|
||
).
|
||
Order(dbent.Asc(dbuser.FieldID)).
|
||
First(ctx)
|
||
if err != nil {
|
||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||
}
|
||
|
||
out := userEntityToService(m)
|
||
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if v, ok := groups[m.ID]; ok {
|
||
out.AllowedGroups = v
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64) (map[int64][]int64, error) {
|
||
out := make(map[int64][]int64, len(userIDs))
|
||
if len(userIDs) == 0 {
|
||
return out, nil
|
||
}
|
||
|
||
rows, err := r.client.UserAllowedGroup.Query().
|
||
Where(userallowedgroup.UserIDIn(userIDs...)).
|
||
All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
for i := range rows {
|
||
out[rows[i].UserID] = append(out[rows[i].UserID], rows[i].GroupID)
|
||
}
|
||
|
||
for userID := range out {
|
||
sort.Slice(out[userID], func(i, j int) bool { return out[userID][i] < out[userID][j] })
|
||
}
|
||
|
||
return out, nil
|
||
}
|
||
|
||
// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组:
|
||
// 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。
|
||
func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error {
|
||
if client == nil {
|
||
return nil
|
||
}
|
||
|
||
// Keep join table as the source of truth for reads.
|
||
if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil {
|
||
return err
|
||
}
|
||
|
||
unique := make(map[int64]struct{}, len(groupIDs))
|
||
for _, id := range groupIDs {
|
||
if id <= 0 {
|
||
continue
|
||
}
|
||
unique[id] = struct{}{}
|
||
}
|
||
|
||
if len(unique) > 0 {
|
||
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
|
||
for groupID := range unique {
|
||
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
|
||
}
|
||
if err := client.UserAllowedGroup.
|
||
CreateBulk(creates...).
|
||
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
|
||
DoNothing().
|
||
Exec(ctx); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func applyUserEntityToService(dst *service.User, src *dbent.User) {
|
||
if dst == nil || src == nil {
|
||
return
|
||
}
|
||
dst.ID = src.ID
|
||
dst.SignupSource = src.SignupSource
|
||
dst.LastLoginAt = src.LastLoginAt
|
||
dst.LastActiveAt = src.LastActiveAt
|
||
dst.CreatedAt = src.CreatedAt
|
||
dst.UpdatedAt = src.UpdatedAt
|
||
}
|
||
|
||
func userSignupSourceOrDefault(signupSource string) string {
|
||
signupSource = strings.TrimSpace(signupSource)
|
||
if signupSource == "" {
|
||
return "email"
|
||
}
|
||
return signupSource
|
||
}
|
||
|
||
// marshalExtraEmails serializes notify email entries to JSON for storage.
|
||
func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
|
||
return service.MarshalNotifyEmails(entries)
|
||
}
|
||
|
||
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
|
||
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||
client := clientFromContext(ctx, r.client)
|
||
update := client.User.UpdateOneID(userID)
|
||
if encryptedSecret == nil {
|
||
update = update.ClearTotpSecretEncrypted()
|
||
} else {
|
||
update = update.SetTotpSecretEncrypted(*encryptedSecret)
|
||
}
|
||
_, err := update.Save(ctx)
|
||
if err != nil {
|
||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// EnableTotp 启用用户的 TOTP 双因素认证
|
||
func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error {
|
||
client := clientFromContext(ctx, r.client)
|
||
_, err := client.User.UpdateOneID(userID).
|
||
SetTotpEnabled(true).
|
||
SetTotpEnabledAt(time.Now()).
|
||
Save(ctx)
|
||
if err != nil {
|
||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// DisableTotp 禁用用户的 TOTP 双因素认证
|
||
func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error {
|
||
client := clientFromContext(ctx, r.client)
|
||
_, err := client.User.UpdateOneID(userID).
|
||
SetTotpEnabled(false).
|
||
ClearTotpEnabledAt().
|
||
ClearTotpSecretEncrypted().
|
||
Save(ctx)
|
||
if err != nil {
|
||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||
}
|
||
return nil
|
||
}
|