Files
xinghuoapi/backend/internal/service/user_service.go
QTom 9a91815b94 feat(admin): 完整实现管理员修改用户 API Key 分组的功能
## 核心功能
- 添加 AdminUpdateAPIKeyGroupID 服务方法,支持绑定/解绑/保持不变三态语义
- 实现 UserRepository.AddGroupToAllowedGroups 接口,自动同步专属分组权限
- 添加 HTTP PUT /api-keys/:id handler 端点,支持管理员直接修改 API Key 分组

## 事务一致性
- 使用 ent Tx 保证专属分组绑定时「添加权限」和「更新 Key」的原子性
- Repository 方法支持 clientFromContext,兼容事务内调用
- 事务失败时自动回滚,避免权限孤立

## 业务逻辑
- 订阅类型分组阻断,需通过订阅管理流程
- 非活跃分组拒绝绑定
- 负 ID 和非法 ID 验证
- 自动授权响应,告知管理员成功授权的分组

## 代码质量
- 16 个单元测试覆盖所有业务路径和边界用例
- 7 个 handler 集成测试覆盖 HTTP 层
- GroupRepo stub 返回克隆副本,防止测试间数据泄漏
- API 类型安全修复(PaginatedResponse<ApiKey>)
- 前端 ref 回调类型对齐 Vue 规范

## 国际化支持
- 中英文提示信息完整
- 自动授权成功/失败提示
2026-02-28 20:18:14 +08:00

244 lines
7.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"fmt"
"log"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
var (
ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
)
// UserListFilters contains all filter options for listing users
type UserListFilters struct {
Status string // User status filter
Role string // User role filter
Search string // Search in email, username
Attributes map[int64]string // Custom attribute filters: attributeID -> value
}
type UserRepository interface {
Create(ctx context.Context, user *User) error
GetByID(ctx context.Context, id int64) (*User, error)
GetByEmail(ctx context.Context, email string) (*User, error)
GetFirstAdmin(ctx context.Context) (*User, error)
Update(ctx context.Context, user *User) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error
UpdateConcurrency(ctx context.Context, id int64, amount int) error
ExistsByEmail(ctx context.Context, email string) (bool, error)
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
// AddGroupToAllowedGroups 将指定分组增量添加到用户的 allowed_groups幂等冲突忽略
AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error
// TOTP 双因素认证
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
EnableTotp(ctx context.Context, userID int64) error
DisableTotp(ctx context.Context, userID int64) error
}
// UpdateProfileRequest 更新用户资料请求
type UpdateProfileRequest struct {
Email *string `json:"email"`
Username *string `json:"username"`
Concurrency *int `json:"concurrency"`
}
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
}
// UserService 用户服务
type UserService struct {
userRepo UserRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
billingCache BillingCache
}
// NewUserService 创建用户服务实例
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService {
return &UserService{
userRepo: userRepo,
authCacheInvalidator: authCacheInvalidator,
billingCache: billingCache,
}
}
// GetFirstAdmin 获取首个管理员用户(用于 Admin API Key 认证)
func (s *UserService) GetFirstAdmin(ctx context.Context) (*User, error) {
admin, err := s.userRepo.GetFirstAdmin(ctx)
if err != nil {
return nil, fmt.Errorf("get first admin: %w", err)
}
return admin, nil
}
// GetProfile 获取用户资料
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
return user, nil
}
// UpdateProfile 更新用户资料
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
oldConcurrency := user.Concurrency
// 更新字段
if req.Email != nil {
// 检查新邮箱是否已被使用
exists, err := s.userRepo.ExistsByEmail(ctx, *req.Email)
if err != nil {
return nil, fmt.Errorf("check email exists: %w", err)
}
if exists && *req.Email != user.Email {
return nil, ErrEmailExists
}
user.Email = *req.Email
}
if req.Username != nil {
user.Username = *req.Username
}
if req.Concurrency != nil {
user.Concurrency = *req.Concurrency
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err)
}
if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return user, nil
}
// ChangePassword 修改密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
// 验证当前密码
if !user.CheckPassword(req.CurrentPassword) {
return ErrPasswordIncorrect
}
if err := user.SetPassword(req.NewPassword); err != nil {
return fmt.Errorf("set password: %w", err)
}
// Increment TokenVersion to invalidate all existing tokens
// This ensures that any tokens issued before the password change become invalid
user.TokenVersion++
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
return nil
}
// GetByID 根据ID获取用户管理员功能
func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
return user, nil
}
// List 获取用户列表(管理员功能)
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
users, pagination, err := s.userRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list users: %w", err)
}
return users, pagination, nil
}
// UpdateBalance 更新用户余额(管理员功能)
func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount float64) error {
if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil {
return fmt.Errorf("update balance: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCache != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
}
}()
}
return nil
}
// UpdateConcurrency 更新用户并发数(管理员功能)
func (s *UserService) UpdateConcurrency(ctx context.Context, userID int64, concurrency int) error {
if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil {
return fmt.Errorf("update concurrency: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return nil
}
// UpdateStatus 更新用户状态(管理员功能)
func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
user.Status = status
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return nil
}
// Delete 删除用户(管理员功能)
func (s *UserService) Delete(ctx context.Context, userID int64) error {
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if err := s.userRepo.Delete(ctx, userID); err != nil {
return fmt.Errorf("delete user: %w", err)
}
return nil
}