基于 backend-code-audit 审计报告,修复剩余 P0/P1/P2 共 34 项问题: P0 生产 Bug: - 修复 time.Since(time.Now()) 计时逻辑错误 (P0-03) - generateRandomID 改用 crypto/rand 替代固定索引 (P0-04) - IncrementQuotaUsed 重写为 Ent 原子操作消除 TOCTOU 竞态 (P0-05) 安全加固: - gateway/openai handler 错误响应替换为泛化消息,防止内部信息泄露 (P1-14) - usage_log_repo dateFormat 参数改用白名单映射,防止 SQL 注入 (P1-16) - 默认配置安全加固:sslmode=prefer、response_headers=true、mode=release (P1-18/19, P2-15) 性能优化: - gateway handler 循环内 defer 替换为显式 releaseWait 闭包 (P1-02) - group_repo/promo_code_repo Count 前 Clone 查询避免状态污染 (P1-03) - usage_log_repo 四个查询添加 LIMIT 10000 防止 OOM (P1-07) - GetBatchUsageStats 添加时间范围参数,默认最近 30 天 (P1-10) - ip.go CIDR 预编译为包级变量 (P1-11) - BatchUpdateCredentials 重构为先验证后更新 (P1-13) 缓存一致性: - billing_cache 添加 jitteredTTL 防止缓存雪崩 (P2-10) - DeductUserBalance/UpdateSubscriptionUsage 错误传播修复 (P2-12) - UserService.UpdateBalance 成功后异步失效 billingCache (P2-13) 代码质量: - search 截断改为按 rune 处理,支持多字节字符 (P2-01) - TLS Handshake 改为 HandshakeContext 支持 context 取消 (P2-07) - CORS 预检添加 Access-Control-Max-Age: 86400 (P2-16) 测试覆盖: - 新增 user_service_test.go(UpdateBalance 缓存失效 6 个用例) - 新增 batch_update_credentials_test.go(fail-fast + 类型验证 7 个用例) - 新增 response_transformer_test.go、ip_test.go、usage_log_repo_unit_test.go、search_truncate_test.go - 集成测试:IncrementQuotaUsed 并发测试、billing_cache 错误传播测试 - config_test.go 补充 server.mode/sslmode 默认值断言 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
242 lines
7.6 KiB
Go
242 lines
7.6 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"log"
|
||
"time"
|
||
|
||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||
)
|
||
|
||
var (
|
||
ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
|
||
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
|
||
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
|
||
)
|
||
|
||
// UserListFilters contains all filter options for listing users
|
||
type UserListFilters struct {
|
||
Status string // User status filter
|
||
Role string // User role filter
|
||
Search string // Search in email, username
|
||
Attributes map[int64]string // Custom attribute filters: attributeID -> value
|
||
}
|
||
|
||
type UserRepository interface {
|
||
Create(ctx context.Context, user *User) error
|
||
GetByID(ctx context.Context, id int64) (*User, error)
|
||
GetByEmail(ctx context.Context, email string) (*User, error)
|
||
GetFirstAdmin(ctx context.Context) (*User, error)
|
||
Update(ctx context.Context, user *User) error
|
||
Delete(ctx context.Context, id int64) error
|
||
|
||
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
|
||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
|
||
|
||
UpdateBalance(ctx context.Context, id int64, amount float64) error
|
||
DeductBalance(ctx context.Context, id int64, amount float64) error
|
||
UpdateConcurrency(ctx context.Context, id int64, amount int) error
|
||
ExistsByEmail(ctx context.Context, email string) (bool, error)
|
||
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
|
||
|
||
// TOTP 双因素认证
|
||
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
|
||
EnableTotp(ctx context.Context, userID int64) error
|
||
DisableTotp(ctx context.Context, userID int64) error
|
||
}
|
||
|
||
// UpdateProfileRequest 更新用户资料请求
|
||
type UpdateProfileRequest struct {
|
||
Email *string `json:"email"`
|
||
Username *string `json:"username"`
|
||
Concurrency *int `json:"concurrency"`
|
||
}
|
||
|
||
// ChangePasswordRequest 修改密码请求
|
||
type ChangePasswordRequest struct {
|
||
CurrentPassword string `json:"current_password"`
|
||
NewPassword string `json:"new_password"`
|
||
}
|
||
|
||
// UserService 用户服务
|
||
type UserService struct {
|
||
userRepo UserRepository
|
||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||
billingCache BillingCache
|
||
}
|
||
|
||
// NewUserService 创建用户服务实例
|
||
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService {
|
||
return &UserService{
|
||
userRepo: userRepo,
|
||
authCacheInvalidator: authCacheInvalidator,
|
||
billingCache: billingCache,
|
||
}
|
||
}
|
||
|
||
// GetFirstAdmin 获取首个管理员用户(用于 Admin API Key 认证)
|
||
func (s *UserService) GetFirstAdmin(ctx context.Context) (*User, error) {
|
||
admin, err := s.userRepo.GetFirstAdmin(ctx)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get first admin: %w", err)
|
||
}
|
||
return admin, nil
|
||
}
|
||
|
||
// GetProfile 获取用户资料
|
||
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, error) {
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get user: %w", err)
|
||
}
|
||
return user, nil
|
||
}
|
||
|
||
// UpdateProfile 更新用户资料
|
||
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) {
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get user: %w", err)
|
||
}
|
||
oldConcurrency := user.Concurrency
|
||
|
||
// 更新字段
|
||
if req.Email != nil {
|
||
// 检查新邮箱是否已被使用
|
||
exists, err := s.userRepo.ExistsByEmail(ctx, *req.Email)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("check email exists: %w", err)
|
||
}
|
||
if exists && *req.Email != user.Email {
|
||
return nil, ErrEmailExists
|
||
}
|
||
user.Email = *req.Email
|
||
}
|
||
|
||
if req.Username != nil {
|
||
user.Username = *req.Username
|
||
}
|
||
|
||
if req.Concurrency != nil {
|
||
user.Concurrency = *req.Concurrency
|
||
}
|
||
|
||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||
return nil, fmt.Errorf("update user: %w", err)
|
||
}
|
||
if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
|
||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||
}
|
||
|
||
return user, nil
|
||
}
|
||
|
||
// ChangePassword 修改密码
|
||
// Security: Increments TokenVersion to invalidate all existing JWT tokens
|
||
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return fmt.Errorf("get user: %w", err)
|
||
}
|
||
|
||
// 验证当前密码
|
||
if !user.CheckPassword(req.CurrentPassword) {
|
||
return ErrPasswordIncorrect
|
||
}
|
||
|
||
if err := user.SetPassword(req.NewPassword); err != nil {
|
||
return fmt.Errorf("set password: %w", err)
|
||
}
|
||
|
||
// Increment TokenVersion to invalidate all existing tokens
|
||
// This ensures that any tokens issued before the password change become invalid
|
||
user.TokenVersion++
|
||
|
||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||
return fmt.Errorf("update user: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetByID 根据ID获取用户(管理员功能)
|
||
func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
|
||
user, err := s.userRepo.GetByID(ctx, id)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get user: %w", err)
|
||
}
|
||
return user, nil
|
||
}
|
||
|
||
// List 获取用户列表(管理员功能)
|
||
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||
users, pagination, err := s.userRepo.List(ctx, params)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("list users: %w", err)
|
||
}
|
||
return users, pagination, nil
|
||
}
|
||
|
||
// UpdateBalance 更新用户余额(管理员功能)
|
||
func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount float64) error {
|
||
if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil {
|
||
return fmt.Errorf("update balance: %w", err)
|
||
}
|
||
if s.authCacheInvalidator != nil {
|
||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||
}
|
||
if s.billingCache != nil {
|
||
go func() {
|
||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil {
|
||
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
|
||
}
|
||
}()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// UpdateConcurrency 更新用户并发数(管理员功能)
|
||
func (s *UserService) UpdateConcurrency(ctx context.Context, userID int64, concurrency int) error {
|
||
if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil {
|
||
return fmt.Errorf("update concurrency: %w", err)
|
||
}
|
||
if s.authCacheInvalidator != nil {
|
||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// UpdateStatus 更新用户状态(管理员功能)
|
||
func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error {
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return fmt.Errorf("get user: %w", err)
|
||
}
|
||
|
||
user.Status = status
|
||
|
||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||
return fmt.Errorf("update user: %w", err)
|
||
}
|
||
if s.authCacheInvalidator != nil {
|
||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// Delete 删除用户(管理员功能)
|
||
func (s *UserService) Delete(ctx context.Context, userID int64) error {
|
||
if s.authCacheInvalidator != nil {
|
||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||
}
|
||
if err := s.userRepo.Delete(ctx, userID); err != nil {
|
||
return fmt.Errorf("delete user: %w", err)
|
||
}
|
||
return nil
|
||
}
|