258 lines
7.5 KiB
Go
258 lines
7.5 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"strconv"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||
)
|
||
|
||
// SoraQuotaService 管理 Sora 用户存储配额。
|
||
// 配额优先级:用户级 → 分组级 → 系统默认值。
|
||
type SoraQuotaService struct {
|
||
userRepo UserRepository
|
||
groupRepo GroupRepository
|
||
settingService *SettingService
|
||
}
|
||
|
||
// NewSoraQuotaService 创建配额服务实例。
|
||
func NewSoraQuotaService(
|
||
userRepo UserRepository,
|
||
groupRepo GroupRepository,
|
||
settingService *SettingService,
|
||
) *SoraQuotaService {
|
||
return &SoraQuotaService{
|
||
userRepo: userRepo,
|
||
groupRepo: groupRepo,
|
||
settingService: settingService,
|
||
}
|
||
}
|
||
|
||
// QuotaInfo 返回给客户端的配额信息。
|
||
type QuotaInfo struct {
|
||
QuotaBytes int64 `json:"quota_bytes"` // 总配额(0 表示无限制)
|
||
UsedBytes int64 `json:"used_bytes"` // 已使用
|
||
AvailableBytes int64 `json:"available_bytes"` // 剩余可用(无限制时为 0)
|
||
QuotaSource string `json:"quota_source"` // 配额来源:user / group / system / unlimited
|
||
Source string `json:"source,omitempty"` // 兼容旧字段
|
||
}
|
||
|
||
// ErrSoraStorageQuotaExceeded 表示配额不足。
|
||
var ErrSoraStorageQuotaExceeded = errors.New("sora storage quota exceeded")
|
||
|
||
// QuotaExceededError 包含配额不足的上下文信息。
|
||
type QuotaExceededError struct {
|
||
QuotaBytes int64
|
||
UsedBytes int64
|
||
}
|
||
|
||
func (e *QuotaExceededError) Error() string {
|
||
if e == nil {
|
||
return "存储配额不足"
|
||
}
|
||
return fmt.Sprintf("存储配额不足(已用 %d / 配额 %d 字节)", e.UsedBytes, e.QuotaBytes)
|
||
}
|
||
|
||
type soraQuotaAtomicUserRepository interface {
|
||
AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error)
|
||
ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error)
|
||
}
|
||
|
||
// GetQuota 获取用户的存储配额信息。
|
||
// 优先级:用户级 > 用户所属分组级 > 系统默认值。
|
||
func (s *SoraQuotaService) GetQuota(ctx context.Context, userID int64) (*QuotaInfo, error) {
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get user: %w", err)
|
||
}
|
||
|
||
info := &QuotaInfo{
|
||
UsedBytes: user.SoraStorageUsedBytes,
|
||
}
|
||
|
||
// 1. 用户级配额
|
||
if user.SoraStorageQuotaBytes > 0 {
|
||
info.QuotaBytes = user.SoraStorageQuotaBytes
|
||
info.QuotaSource = "user"
|
||
info.Source = info.QuotaSource
|
||
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
|
||
return info, nil
|
||
}
|
||
|
||
// 2. 分组级配额(取用户可用分组中最大的配额)
|
||
if len(user.AllowedGroups) > 0 {
|
||
var maxGroupQuota int64
|
||
for _, gid := range user.AllowedGroups {
|
||
group, err := s.groupRepo.GetByID(ctx, gid)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
if group.SoraStorageQuotaBytes > maxGroupQuota {
|
||
maxGroupQuota = group.SoraStorageQuotaBytes
|
||
}
|
||
}
|
||
if maxGroupQuota > 0 {
|
||
info.QuotaBytes = maxGroupQuota
|
||
info.QuotaSource = "group"
|
||
info.Source = info.QuotaSource
|
||
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
|
||
return info, nil
|
||
}
|
||
}
|
||
|
||
// 3. 系统默认值
|
||
defaultQuota := s.getSystemDefaultQuota(ctx)
|
||
if defaultQuota > 0 {
|
||
info.QuotaBytes = defaultQuota
|
||
info.QuotaSource = "system"
|
||
info.Source = info.QuotaSource
|
||
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
|
||
return info, nil
|
||
}
|
||
|
||
// 无配额限制
|
||
info.QuotaSource = "unlimited"
|
||
info.Source = info.QuotaSource
|
||
info.AvailableBytes = 0
|
||
return info, nil
|
||
}
|
||
|
||
// CheckQuota 检查用户是否有足够的存储配额。
|
||
// 返回 nil 表示配额充足或无限制。
|
||
func (s *SoraQuotaService) CheckQuota(ctx context.Context, userID int64, additionalBytes int64) error {
|
||
quota, err := s.GetQuota(ctx, userID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
// 0 表示无限制
|
||
if quota.QuotaBytes == 0 {
|
||
return nil
|
||
}
|
||
if quota.UsedBytes+additionalBytes > quota.QuotaBytes {
|
||
return &QuotaExceededError{
|
||
QuotaBytes: quota.QuotaBytes,
|
||
UsedBytes: quota.UsedBytes,
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// AddUsage 原子累加用量(上传成功后调用)。
|
||
func (s *SoraQuotaService) AddUsage(ctx context.Context, userID int64, bytes int64) error {
|
||
if bytes <= 0 {
|
||
return nil
|
||
}
|
||
|
||
quota, err := s.GetQuota(ctx, userID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if quota.QuotaBytes > 0 && quota.UsedBytes+bytes > quota.QuotaBytes {
|
||
return &QuotaExceededError{
|
||
QuotaBytes: quota.QuotaBytes,
|
||
UsedBytes: quota.UsedBytes,
|
||
}
|
||
}
|
||
|
||
if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok {
|
||
newUsed, err := repo.AddSoraStorageUsageWithQuota(ctx, userID, bytes, quota.QuotaBytes)
|
||
if err != nil {
|
||
if errors.Is(err, ErrSoraStorageQuotaExceeded) {
|
||
return &QuotaExceededError{
|
||
QuotaBytes: quota.QuotaBytes,
|
||
UsedBytes: quota.UsedBytes,
|
||
}
|
||
}
|
||
return fmt.Errorf("update user quota usage (atomic): %w", err)
|
||
}
|
||
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, newUsed)
|
||
return nil
|
||
}
|
||
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return fmt.Errorf("get user for quota update: %w", err)
|
||
}
|
||
user.SoraStorageUsedBytes += bytes
|
||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||
return fmt.Errorf("update user quota usage: %w", err)
|
||
}
|
||
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, user.SoraStorageUsedBytes)
|
||
return nil
|
||
}
|
||
|
||
// ReleaseUsage 释放用量(删除文件后调用)。
|
||
func (s *SoraQuotaService) ReleaseUsage(ctx context.Context, userID int64, bytes int64) error {
|
||
if bytes <= 0 {
|
||
return nil
|
||
}
|
||
|
||
if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok {
|
||
newUsed, err := repo.ReleaseSoraStorageUsageAtomic(ctx, userID, bytes)
|
||
if err != nil {
|
||
return fmt.Errorf("update user quota release (atomic): %w", err)
|
||
}
|
||
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, newUsed)
|
||
return nil
|
||
}
|
||
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return fmt.Errorf("get user for quota release: %w", err)
|
||
}
|
||
user.SoraStorageUsedBytes -= bytes
|
||
if user.SoraStorageUsedBytes < 0 {
|
||
user.SoraStorageUsedBytes = 0
|
||
}
|
||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||
return fmt.Errorf("update user quota release: %w", err)
|
||
}
|
||
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, user.SoraStorageUsedBytes)
|
||
return nil
|
||
}
|
||
|
||
func calcAvailableBytes(quotaBytes, usedBytes int64) int64 {
|
||
if quotaBytes <= 0 {
|
||
return 0
|
||
}
|
||
if usedBytes >= quotaBytes {
|
||
return 0
|
||
}
|
||
return quotaBytes - usedBytes
|
||
}
|
||
|
||
func (s *SoraQuotaService) getSystemDefaultQuota(ctx context.Context) int64 {
|
||
if s.settingService == nil {
|
||
return 0
|
||
}
|
||
settings, err := s.settingService.GetSoraS3Settings(ctx)
|
||
if err != nil {
|
||
return 0
|
||
}
|
||
return settings.DefaultStorageQuotaBytes
|
||
}
|
||
|
||
// GetQuotaFromSettings 从系统设置获取默认配额(供外部使用)。
|
||
func (s *SoraQuotaService) GetQuotaFromSettings(ctx context.Context) int64 {
|
||
return s.getSystemDefaultQuota(ctx)
|
||
}
|
||
|
||
// SetUserQuota 设置用户级配额(管理员操作)。
|
||
func SetUserSoraQuota(ctx context.Context, userRepo UserRepository, userID int64, quotaBytes int64) error {
|
||
user, err := userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
user.SoraStorageQuotaBytes = quotaBytes
|
||
return userRepo.Update(ctx, user)
|
||
}
|
||
|
||
// ParseQuotaBytes 解析配额字符串为字节数。
|
||
func ParseQuotaBytes(s string) int64 {
|
||
v, _ := strconv.ParseInt(s, 10, 64)
|
||
return v
|
||
}
|