feat: 品牌重命名 Sub2API -> TianShuAPI
- 前端: 所有界面显示、i18n 文本、组件中的品牌名称 - 后端: 服务层、设置默认值、邮件模板、安装向导 - 数据库: 迁移脚本注释 - 保持功能完全一致,仅更改品牌名称 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,462 +1,462 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"sort"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type userRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
|
||||
return newUserRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func newUserRepositoryWithSQL(client *dbent.Client, _ sqlExecutor) *userRepository {
|
||||
return &userRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
|
||||
if userIn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
|
||||
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return err
|
||||
}
|
||||
|
||||
var txClient *dbent.Client
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
||||
txClient = r.client
|
||||
}
|
||||
|
||||
created, err := txClient.User.Create().
|
||||
SetEmail(userIn.Email).
|
||||
SetUsername(userIn.Username).
|
||||
SetNotes(userIn.Notes).
|
||||
SetPasswordHash(userIn.PasswordHash).
|
||||
SetRole(userIn.Role).
|
||||
SetBalance(userIn.Balance).
|
||||
SetConcurrency(userIn.Concurrency).
|
||||
SetStatus(userIn.Status).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
applyUserEntityToService(userIn, created)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
|
||||
m, err := r.client.User.Query().Where(dbuser.IDEQ(id)).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
|
||||
out := userEntityToService(m)
|
||||
groups, err := r.loadAllowedGroups(ctx, []int64{id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v, ok := groups[id]; ok {
|
||||
out.AllowedGroups = v
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
|
||||
m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
|
||||
out := userEntityToService(m)
|
||||
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v, ok := groups[m.ID]; ok {
|
||||
out.AllowedGroups = v
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) Update(ctx context.Context, userIn *service.User) error {
|
||||
if userIn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return err
|
||||
}
|
||||
|
||||
var txClient *dbent.Client
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
||||
txClient = r.client
|
||||
}
|
||||
|
||||
updated, err := txClient.User.UpdateOneID(userIn.ID).
|
||||
SetEmail(userIn.Email).
|
||||
SetUsername(userIn.Username).
|
||||
SetNotes(userIn.Notes).
|
||||
SetPasswordHash(userIn.PasswordHash).
|
||||
SetRole(userIn.Role).
|
||||
SetBalance(userIn.Balance).
|
||||
SetConcurrency(userIn.Concurrency).
|
||||
SetStatus(userIn.Status).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
userIn.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) Delete(ctx context.Context, id int64) error {
|
||||
affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, service.UserListFilters{})
|
||||
}
|
||||
|
||||
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
|
||||
q := r.client.User.Query()
|
||||
|
||||
if filters.Status != "" {
|
||||
q = q.Where(dbuser.StatusEQ(filters.Status))
|
||||
}
|
||||
if filters.Role != "" {
|
||||
q = q.Where(dbuser.RoleEQ(filters.Role))
|
||||
}
|
||||
if filters.Search != "" {
|
||||
q = q.Where(
|
||||
dbuser.Or(
|
||||
dbuser.EmailContainsFold(filters.Search),
|
||||
dbuser.UsernameContainsFold(filters.Search),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// If attribute filters are specified, we need to filter by user IDs first
|
||||
var allowedUserIDs []int64
|
||||
if len(filters.Attributes) > 0 {
|
||||
allowedUserIDs = r.filterUsersByAttributes(ctx, filters.Attributes)
|
||||
if len(allowedUserIDs) == 0 {
|
||||
// No users match the attribute filters
|
||||
return []service.User{}, paginationResultFromTotal(0, params), nil
|
||||
}
|
||||
q = q.Where(dbuser.IDIn(allowedUserIDs...))
|
||||
}
|
||||
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
users, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(dbuser.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outUsers := make([]service.User, 0, len(users))
|
||||
if len(users) == 0 {
|
||||
return outUsers, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
userIDs := make([]int64, 0, len(users))
|
||||
userMap := make(map[int64]*service.User, len(users))
|
||||
for i := range users {
|
||||
userIDs = append(userIDs, users[i].ID)
|
||||
u := userEntityToService(users[i])
|
||||
outUsers = append(outUsers, *u)
|
||||
userMap[u.ID] = &outUsers[len(outUsers)-1]
|
||||
}
|
||||
|
||||
// Batch load active subscriptions with groups to avoid N+1.
|
||||
subs, err := r.client.UserSubscription.Query().
|
||||
Where(
|
||||
usersubscription.UserIDIn(userIDs...),
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
).
|
||||
WithGroup().
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for i := range subs {
|
||||
if u, ok := userMap[subs[i].UserID]; ok {
|
||||
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
|
||||
}
|
||||
}
|
||||
|
||||
allowedGroupsByUser, err := r.loadAllowedGroups(ctx, userIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for id, u := range userMap {
|
||||
if groups, ok := allowedGroupsByUser[id]; ok {
|
||||
u.AllowedGroups = groups
|
||||
}
|
||||
}
|
||||
|
||||
return outUsers, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
|
||||
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) []int64 {
|
||||
if len(attrs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For each attribute filter, get the set of matching user IDs
|
||||
// Then intersect all sets to get users matching ALL filters
|
||||
var resultSet map[int64]struct{}
|
||||
first := true
|
||||
|
||||
for attrID, value := range attrs {
|
||||
// Query user_attribute_values for this attribute
|
||||
values, err := r.client.UserAttributeValue.Query().
|
||||
Where(
|
||||
userattributevalue.AttributeIDEQ(attrID),
|
||||
userattributevalue.ValueContainsFold(value),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
currentSet := make(map[int64]struct{}, len(values))
|
||||
for _, v := range values {
|
||||
currentSet[v.UserID] = struct{}{}
|
||||
}
|
||||
|
||||
if first {
|
||||
resultSet = currentSet
|
||||
first = false
|
||||
} else {
|
||||
// Intersect with previous results
|
||||
for userID := range resultSet {
|
||||
if _, ok := currentSet[userID]; !ok {
|
||||
delete(resultSet, userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Early exit if no users match
|
||||
if len(resultSet) == 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]int64, 0, len(resultSet))
|
||||
for userID := range resultSet {
|
||||
result = append(result, userID)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
if n == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
n, err := client.User.Update().
|
||||
Where(dbuser.IDEQ(id), dbuser.BalanceGTE(amount)).
|
||||
AddBalance(-amount).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return service.ErrInsufficientBalance
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
if n == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
|
||||
}
|
||||
|
||||
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||
// 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。
|
||||
affected, err := r.client.UserAllowedGroup.Delete().
|
||||
Where(userallowedgroup.GroupIDEQ(groupID)).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(affected), nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
|
||||
m, err := r.client.User.Query().
|
||||
Where(
|
||||
dbuser.RoleEQ(service.RoleAdmin),
|
||||
dbuser.StatusEQ(service.StatusActive),
|
||||
).
|
||||
Order(dbent.Asc(dbuser.FieldID)).
|
||||
First(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
|
||||
out := userEntityToService(m)
|
||||
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v, ok := groups[m.ID]; ok {
|
||||
out.AllowedGroups = v
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64) (map[int64][]int64, error) {
|
||||
out := make(map[int64][]int64, len(userIDs))
|
||||
if len(userIDs) == 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
rows, err := r.client.UserAllowedGroup.Query().
|
||||
Where(userallowedgroup.UserIDIn(userIDs...)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range rows {
|
||||
out[rows[i].UserID] = append(out[rows[i].UserID], rows[i].GroupID)
|
||||
}
|
||||
|
||||
for userID := range out {
|
||||
sort.Slice(out[userID], func(i, j int) bool { return out[userID][i] < out[userID][j] })
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组:
|
||||
// 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。
|
||||
func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error {
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keep join table as the source of truth for reads.
|
||||
if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unique := make(map[int64]struct{}, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
unique[id] = struct{}{}
|
||||
}
|
||||
|
||||
if len(unique) > 0 {
|
||||
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
|
||||
for groupID := range unique {
|
||||
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
|
||||
}
|
||||
if err := client.UserAllowedGroup.
|
||||
CreateBulk(creates...).
|
||||
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyUserEntityToService(dst *service.User, src *dbent.User) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
dst.ID = src.ID
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"sort"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type userRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
|
||||
return newUserRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func newUserRepositoryWithSQL(client *dbent.Client, _ sqlExecutor) *userRepository {
|
||||
return &userRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
|
||||
if userIn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
|
||||
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return err
|
||||
}
|
||||
|
||||
var txClient *dbent.Client
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
||||
txClient = r.client
|
||||
}
|
||||
|
||||
created, err := txClient.User.Create().
|
||||
SetEmail(userIn.Email).
|
||||
SetUsername(userIn.Username).
|
||||
SetNotes(userIn.Notes).
|
||||
SetPasswordHash(userIn.PasswordHash).
|
||||
SetRole(userIn.Role).
|
||||
SetBalance(userIn.Balance).
|
||||
SetConcurrency(userIn.Concurrency).
|
||||
SetStatus(userIn.Status).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
applyUserEntityToService(userIn, created)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
|
||||
m, err := r.client.User.Query().Where(dbuser.IDEQ(id)).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
|
||||
out := userEntityToService(m)
|
||||
groups, err := r.loadAllowedGroups(ctx, []int64{id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v, ok := groups[id]; ok {
|
||||
out.AllowedGroups = v
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
|
||||
m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
|
||||
out := userEntityToService(m)
|
||||
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v, ok := groups[m.ID]; ok {
|
||||
out.AllowedGroups = v
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) Update(ctx context.Context, userIn *service.User) error {
|
||||
if userIn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return err
|
||||
}
|
||||
|
||||
var txClient *dbent.Client
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
||||
txClient = r.client
|
||||
}
|
||||
|
||||
updated, err := txClient.User.UpdateOneID(userIn.ID).
|
||||
SetEmail(userIn.Email).
|
||||
SetUsername(userIn.Username).
|
||||
SetNotes(userIn.Notes).
|
||||
SetPasswordHash(userIn.PasswordHash).
|
||||
SetRole(userIn.Role).
|
||||
SetBalance(userIn.Balance).
|
||||
SetConcurrency(userIn.Concurrency).
|
||||
SetStatus(userIn.Status).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
userIn.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) Delete(ctx context.Context, id int64) error {
|
||||
affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, service.UserListFilters{})
|
||||
}
|
||||
|
||||
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
|
||||
q := r.client.User.Query()
|
||||
|
||||
if filters.Status != "" {
|
||||
q = q.Where(dbuser.StatusEQ(filters.Status))
|
||||
}
|
||||
if filters.Role != "" {
|
||||
q = q.Where(dbuser.RoleEQ(filters.Role))
|
||||
}
|
||||
if filters.Search != "" {
|
||||
q = q.Where(
|
||||
dbuser.Or(
|
||||
dbuser.EmailContainsFold(filters.Search),
|
||||
dbuser.UsernameContainsFold(filters.Search),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// If attribute filters are specified, we need to filter by user IDs first
|
||||
var allowedUserIDs []int64
|
||||
if len(filters.Attributes) > 0 {
|
||||
allowedUserIDs = r.filterUsersByAttributes(ctx, filters.Attributes)
|
||||
if len(allowedUserIDs) == 0 {
|
||||
// No users match the attribute filters
|
||||
return []service.User{}, paginationResultFromTotal(0, params), nil
|
||||
}
|
||||
q = q.Where(dbuser.IDIn(allowedUserIDs...))
|
||||
}
|
||||
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
users, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(dbuser.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outUsers := make([]service.User, 0, len(users))
|
||||
if len(users) == 0 {
|
||||
return outUsers, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
userIDs := make([]int64, 0, len(users))
|
||||
userMap := make(map[int64]*service.User, len(users))
|
||||
for i := range users {
|
||||
userIDs = append(userIDs, users[i].ID)
|
||||
u := userEntityToService(users[i])
|
||||
outUsers = append(outUsers, *u)
|
||||
userMap[u.ID] = &outUsers[len(outUsers)-1]
|
||||
}
|
||||
|
||||
// Batch load active subscriptions with groups to avoid N+1.
|
||||
subs, err := r.client.UserSubscription.Query().
|
||||
Where(
|
||||
usersubscription.UserIDIn(userIDs...),
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
).
|
||||
WithGroup().
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for i := range subs {
|
||||
if u, ok := userMap[subs[i].UserID]; ok {
|
||||
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
|
||||
}
|
||||
}
|
||||
|
||||
allowedGroupsByUser, err := r.loadAllowedGroups(ctx, userIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for id, u := range userMap {
|
||||
if groups, ok := allowedGroupsByUser[id]; ok {
|
||||
u.AllowedGroups = groups
|
||||
}
|
||||
}
|
||||
|
||||
return outUsers, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
|
||||
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) []int64 {
|
||||
if len(attrs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For each attribute filter, get the set of matching user IDs
|
||||
// Then intersect all sets to get users matching ALL filters
|
||||
var resultSet map[int64]struct{}
|
||||
first := true
|
||||
|
||||
for attrID, value := range attrs {
|
||||
// Query user_attribute_values for this attribute
|
||||
values, err := r.client.UserAttributeValue.Query().
|
||||
Where(
|
||||
userattributevalue.AttributeIDEQ(attrID),
|
||||
userattributevalue.ValueContainsFold(value),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
currentSet := make(map[int64]struct{}, len(values))
|
||||
for _, v := range values {
|
||||
currentSet[v.UserID] = struct{}{}
|
||||
}
|
||||
|
||||
if first {
|
||||
resultSet = currentSet
|
||||
first = false
|
||||
} else {
|
||||
// Intersect with previous results
|
||||
for userID := range resultSet {
|
||||
if _, ok := currentSet[userID]; !ok {
|
||||
delete(resultSet, userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Early exit if no users match
|
||||
if len(resultSet) == 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]int64, 0, len(resultSet))
|
||||
for userID := range resultSet {
|
||||
result = append(result, userID)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
if n == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
n, err := client.User.Update().
|
||||
Where(dbuser.IDEQ(id), dbuser.BalanceGTE(amount)).
|
||||
AddBalance(-amount).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return service.ErrInsufficientBalance
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
if n == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
|
||||
}
|
||||
|
||||
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||
// 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。
|
||||
affected, err := r.client.UserAllowedGroup.Delete().
|
||||
Where(userallowedgroup.GroupIDEQ(groupID)).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(affected), nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
|
||||
m, err := r.client.User.Query().
|
||||
Where(
|
||||
dbuser.RoleEQ(service.RoleAdmin),
|
||||
dbuser.StatusEQ(service.StatusActive),
|
||||
).
|
||||
Order(dbent.Asc(dbuser.FieldID)).
|
||||
First(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
|
||||
out := userEntityToService(m)
|
||||
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v, ok := groups[m.ID]; ok {
|
||||
out.AllowedGroups = v
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64) (map[int64][]int64, error) {
|
||||
out := make(map[int64][]int64, len(userIDs))
|
||||
if len(userIDs) == 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
rows, err := r.client.UserAllowedGroup.Query().
|
||||
Where(userallowedgroup.UserIDIn(userIDs...)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range rows {
|
||||
out[rows[i].UserID] = append(out[rows[i].UserID], rows[i].GroupID)
|
||||
}
|
||||
|
||||
for userID := range out {
|
||||
sort.Slice(out[userID], func(i, j int) bool { return out[userID][i] < out[userID][j] })
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组:
|
||||
// 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。
|
||||
func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error {
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keep join table as the source of truth for reads.
|
||||
if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unique := make(map[int64]struct{}, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
unique[id] = struct{}{}
|
||||
}
|
||||
|
||||
if len(unique) > 0 {
|
||||
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
|
||||
for groupID := range unique {
|
||||
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
|
||||
}
|
||||
if err := client.UserAllowedGroup.
|
||||
CreateBulk(creates...).
|
||||
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyUserEntityToService(dst *service.User, src *dbent.User) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
dst.ID = src.ID
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user