Files
sub2api/backend/internal/service/sora_quota_service.go
2026-02-28 15:01:20 +08:00

258 lines
7.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"errors"
"fmt"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// SoraQuotaService 管理 Sora 用户存储配额。
// 配额优先级:用户级 → 分组级 → 系统默认值。
type SoraQuotaService struct {
userRepo UserRepository
groupRepo GroupRepository
settingService *SettingService
}
// NewSoraQuotaService 创建配额服务实例。
func NewSoraQuotaService(
userRepo UserRepository,
groupRepo GroupRepository,
settingService *SettingService,
) *SoraQuotaService {
return &SoraQuotaService{
userRepo: userRepo,
groupRepo: groupRepo,
settingService: settingService,
}
}
// QuotaInfo 返回给客户端的配额信息。
type QuotaInfo struct {
QuotaBytes int64 `json:"quota_bytes"` // 总配额0 表示无限制)
UsedBytes int64 `json:"used_bytes"` // 已使用
AvailableBytes int64 `json:"available_bytes"` // 剩余可用(无限制时为 0
QuotaSource string `json:"quota_source"` // 配额来源user / group / system / unlimited
Source string `json:"source,omitempty"` // 兼容旧字段
}
// ErrSoraStorageQuotaExceeded 表示配额不足。
var ErrSoraStorageQuotaExceeded = errors.New("sora storage quota exceeded")
// QuotaExceededError 包含配额不足的上下文信息。
type QuotaExceededError struct {
QuotaBytes int64
UsedBytes int64
}
func (e *QuotaExceededError) Error() string {
if e == nil {
return "存储配额不足"
}
return fmt.Sprintf("存储配额不足(已用 %d / 配额 %d 字节)", e.UsedBytes, e.QuotaBytes)
}
type soraQuotaAtomicUserRepository interface {
AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error)
ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error)
}
// GetQuota 获取用户的存储配额信息。
// 优先级:用户级 > 用户所属分组级 > 系统默认值。
func (s *SoraQuotaService) GetQuota(ctx context.Context, userID int64) (*QuotaInfo, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
info := &QuotaInfo{
UsedBytes: user.SoraStorageUsedBytes,
}
// 1. 用户级配额
if user.SoraStorageQuotaBytes > 0 {
info.QuotaBytes = user.SoraStorageQuotaBytes
info.QuotaSource = "user"
info.Source = info.QuotaSource
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
return info, nil
}
// 2. 分组级配额(取用户可用分组中最大的配额)
if len(user.AllowedGroups) > 0 {
var maxGroupQuota int64
for _, gid := range user.AllowedGroups {
group, err := s.groupRepo.GetByID(ctx, gid)
if err != nil {
continue
}
if group.SoraStorageQuotaBytes > maxGroupQuota {
maxGroupQuota = group.SoraStorageQuotaBytes
}
}
if maxGroupQuota > 0 {
info.QuotaBytes = maxGroupQuota
info.QuotaSource = "group"
info.Source = info.QuotaSource
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
return info, nil
}
}
// 3. 系统默认值
defaultQuota := s.getSystemDefaultQuota(ctx)
if defaultQuota > 0 {
info.QuotaBytes = defaultQuota
info.QuotaSource = "system"
info.Source = info.QuotaSource
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
return info, nil
}
// 无配额限制
info.QuotaSource = "unlimited"
info.Source = info.QuotaSource
info.AvailableBytes = 0
return info, nil
}
// CheckQuota 检查用户是否有足够的存储配额。
// 返回 nil 表示配额充足或无限制。
func (s *SoraQuotaService) CheckQuota(ctx context.Context, userID int64, additionalBytes int64) error {
quota, err := s.GetQuota(ctx, userID)
if err != nil {
return err
}
// 0 表示无限制
if quota.QuotaBytes == 0 {
return nil
}
if quota.UsedBytes+additionalBytes > quota.QuotaBytes {
return &QuotaExceededError{
QuotaBytes: quota.QuotaBytes,
UsedBytes: quota.UsedBytes,
}
}
return nil
}
// AddUsage 原子累加用量(上传成功后调用)。
func (s *SoraQuotaService) AddUsage(ctx context.Context, userID int64, bytes int64) error {
if bytes <= 0 {
return nil
}
quota, err := s.GetQuota(ctx, userID)
if err != nil {
return err
}
if quota.QuotaBytes > 0 && quota.UsedBytes+bytes > quota.QuotaBytes {
return &QuotaExceededError{
QuotaBytes: quota.QuotaBytes,
UsedBytes: quota.UsedBytes,
}
}
if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok {
newUsed, err := repo.AddSoraStorageUsageWithQuota(ctx, userID, bytes, quota.QuotaBytes)
if err != nil {
if errors.Is(err, ErrSoraStorageQuotaExceeded) {
return &QuotaExceededError{
QuotaBytes: quota.QuotaBytes,
UsedBytes: quota.UsedBytes,
}
}
return fmt.Errorf("update user quota usage (atomic): %w", err)
}
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, newUsed)
return nil
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user for quota update: %w", err)
}
user.SoraStorageUsedBytes += bytes
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user quota usage: %w", err)
}
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, user.SoraStorageUsedBytes)
return nil
}
// ReleaseUsage 释放用量(删除文件后调用)。
func (s *SoraQuotaService) ReleaseUsage(ctx context.Context, userID int64, bytes int64) error {
if bytes <= 0 {
return nil
}
if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok {
newUsed, err := repo.ReleaseSoraStorageUsageAtomic(ctx, userID, bytes)
if err != nil {
return fmt.Errorf("update user quota release (atomic): %w", err)
}
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, newUsed)
return nil
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user for quota release: %w", err)
}
user.SoraStorageUsedBytes -= bytes
if user.SoraStorageUsedBytes < 0 {
user.SoraStorageUsedBytes = 0
}
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user quota release: %w", err)
}
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, user.SoraStorageUsedBytes)
return nil
}
func calcAvailableBytes(quotaBytes, usedBytes int64) int64 {
if quotaBytes <= 0 {
return 0
}
if usedBytes >= quotaBytes {
return 0
}
return quotaBytes - usedBytes
}
func (s *SoraQuotaService) getSystemDefaultQuota(ctx context.Context) int64 {
if s.settingService == nil {
return 0
}
settings, err := s.settingService.GetSoraS3Settings(ctx)
if err != nil {
return 0
}
return settings.DefaultStorageQuotaBytes
}
// GetQuotaFromSettings 从系统设置获取默认配额(供外部使用)。
func (s *SoraQuotaService) GetQuotaFromSettings(ctx context.Context) int64 {
return s.getSystemDefaultQuota(ctx)
}
// SetUserQuota 设置用户级配额(管理员操作)。
func SetUserSoraQuota(ctx context.Context, userRepo UserRepository, userID int64, quotaBytes int64) error {
user, err := userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
user.SoraStorageQuotaBytes = quotaBytes
return userRepo.Update(ctx, user)
}
// ParseQuotaBytes 解析配额字符串为字节数。
func ParseQuotaBytes(s string) int64 {
v, _ := strconv.ParseInt(s, 10, 64)
return v
}