Files
sub2api/backend/internal/service/subscription_service.go
shaw c5b792add5 fix(billing): 修复限额为0时消费记录失败的问题
- 添加 normalizeLimit 函数,将 0 或负数限额规范化为 nil(无限制)
- 简化 IncrementUsage,移除冗余的配额检查逻辑
  - 配额检查已在请求前由中间件和网关完成
  - 消费记录应无条件执行,确保数据完整性
- 删除测试限额超出行为的无效集成测试
2025-12-31 22:48:35 +08:00

670 lines
21 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"fmt"
"log"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
// MaxExpiresAt is the maximum allowed expiration date (year 2099)
// This prevents time.Time JSON serialization errors (RFC 3339 requires year <= 9999)
var MaxExpiresAt = time.Date(2099, 12, 31, 23, 59, 59, 0, time.UTC)
// MaxValidityDays is the maximum allowed validity days for subscriptions (100 years)
const MaxValidityDays = 36500
var (
ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found")
ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired")
ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended")
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil")
)
// SubscriptionService 订阅服务
type SubscriptionService struct {
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService
}
// NewSubscriptionService 创建订阅服务
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService) *SubscriptionService {
return &SubscriptionService{
groupRepo: groupRepo,
userSubRepo: userSubRepo,
billingCacheService: billingCacheService,
}
}
// AssignSubscriptionInput 分配订阅输入
type AssignSubscriptionInput struct {
UserID int64
GroupID int64
ValidityDays int
AssignedBy int64
Notes string
}
// AssignSubscription 分配订阅给用户(不允许重复分配)
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) {
// 检查分组是否存在且为订阅类型
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil {
return nil, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, ErrGroupNotSubscriptionType
}
// 检查是否已存在订阅
exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil {
return nil, err
}
if exists {
return nil, ErrSubscriptionAlreadyExists
}
sub, err := s.createSubscription(ctx, input)
if err != nil {
return nil, err
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return sub, nil
}
// AssignOrExtendSubscription 分配或续期订阅(用于兑换码等场景)
// 如果用户已有同分组的订阅:
// - 未过期:从当前过期时间累加天数
// - 已过期:从当前时间开始计算新的过期时间,并激活订阅
//
// 如果没有订阅:创建新订阅
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
// 检查分组是否存在且为订阅类型
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil {
return nil, false, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, false, ErrGroupNotSubscriptionType
}
// 查询是否已有订阅
existingSub, err := s.userSubRepo.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil {
// 不存在记录是正常情况,其他错误需要返回
existingSub = nil
}
validityDays := input.ValidityDays
if validityDays <= 0 {
validityDays = 30
}
if validityDays > MaxValidityDays {
validityDays = MaxValidityDays
}
// 已有订阅,执行续期
if existingSub != nil {
now := time.Now()
var newExpiresAt time.Time
if existingSub.ExpiresAt.After(now) {
// 未过期:从当前过期时间累加
newExpiresAt = existingSub.ExpiresAt.AddDate(0, 0, validityDays)
} else {
// 已过期:从当前时间开始计算
newExpiresAt = now.AddDate(0, 0, validityDays)
}
// 确保不超过最大过期时间
if newExpiresAt.After(MaxExpiresAt) {
newExpiresAt = MaxExpiresAt
}
// 更新过期时间
if err := s.userSubRepo.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
return nil, false, fmt.Errorf("extend subscription: %w", err)
}
// 如果订阅已过期或被暂停恢复为active状态
if existingSub.Status != SubscriptionStatusActive {
if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, SubscriptionStatusActive); err != nil {
return nil, false, fmt.Errorf("update subscription status: %w", err)
}
}
// 追加备注
if input.Notes != "" {
newNotes := existingSub.Notes
if newNotes != "" {
newNotes += "\n"
}
newNotes += input.Notes
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
log.Printf("update subscription notes failed: sub_id=%d err=%v", existingSub.ID, err)
}
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
// 返回更新后的订阅
sub, err := s.userSubRepo.GetByID(ctx, existingSub.ID)
return sub, true, err // true 表示是续期
}
// 没有订阅,创建新订阅
sub, err := s.createSubscription(ctx, input)
if err != nil {
return nil, false, err
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return sub, false, nil // false 表示是新建
}
// createSubscription 创建新订阅(内部方法)
func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) {
validityDays := input.ValidityDays
if validityDays <= 0 {
validityDays = 30
}
if validityDays > MaxValidityDays {
validityDays = MaxValidityDays
}
now := time.Now()
expiresAt := now.AddDate(0, 0, validityDays)
if expiresAt.After(MaxExpiresAt) {
expiresAt = MaxExpiresAt
}
sub := &UserSubscription{
UserID: input.UserID,
GroupID: input.GroupID,
StartsAt: now,
ExpiresAt: expiresAt,
Status: SubscriptionStatusActive,
AssignedAt: now,
Notes: input.Notes,
CreatedAt: now,
UpdatedAt: now,
}
// 只有当 AssignedBy > 0 时才设置0 表示系统分配,如兑换码)
if input.AssignedBy > 0 {
sub.AssignedBy = &input.AssignedBy
}
if err := s.userSubRepo.Create(ctx, sub); err != nil {
return nil, err
}
// 重新获取完整订阅信息(包含关联)
return s.userSubRepo.GetByID(ctx, sub.ID)
}
// BulkAssignSubscriptionInput 批量分配订阅输入
type BulkAssignSubscriptionInput struct {
UserIDs []int64
GroupID int64
ValidityDays int
AssignedBy int64
Notes string
}
// BulkAssignResult 批量分配结果
type BulkAssignResult struct {
SuccessCount int
FailedCount int
Subscriptions []UserSubscription
Errors []string
}
// BulkAssignSubscription 批量分配订阅
func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input *BulkAssignSubscriptionInput) (*BulkAssignResult, error) {
result := &BulkAssignResult{
Subscriptions: make([]UserSubscription, 0),
Errors: make([]string, 0),
}
for _, userID := range input.UserIDs {
sub, err := s.AssignSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: input.GroupID,
ValidityDays: input.ValidityDays,
AssignedBy: input.AssignedBy,
Notes: input.Notes,
})
if err != nil {
result.FailedCount++
result.Errors = append(result.Errors, fmt.Sprintf("user %d: %v", userID, err))
} else {
result.SuccessCount++
result.Subscriptions = append(result.Subscriptions, *sub)
}
}
return result, nil
}
// RevokeSubscription 撤销订阅
func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error {
// 先获取订阅信息用于失效缓存
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil {
return err
}
if err := s.userSubRepo.Delete(ctx, subscriptionID); err != nil {
return err
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return nil
}
// ExtendSubscription 延长订阅
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*UserSubscription, error) {
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
// 限制延长天数
if days > MaxValidityDays {
days = MaxValidityDays
}
// 计算新的过期时间
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
if newExpiresAt.After(MaxExpiresAt) {
newExpiresAt = MaxExpiresAt
}
if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
return nil, err
}
// 如果订阅已过期恢复为active状态
if sub.Status == SubscriptionStatusExpired {
if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, SubscriptionStatusActive); err != nil {
return nil, err
}
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return s.userSubRepo.GetByID(ctx, subscriptionID)
}
// GetByID 根据ID获取订阅
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*UserSubscription, error) {
return s.userSubRepo.GetByID(ctx, id)
}
// GetActiveSubscription 获取用户对特定分组的有效订阅
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*UserSubscription, error) {
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
return sub, nil
}
// ListUserSubscriptions 获取用户的所有订阅
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]UserSubscription, error) {
subs, err := s.userSubRepo.ListByUserID(ctx, userID)
if err != nil {
return nil, err
}
normalizeExpiredWindows(subs)
return subs, nil
}
// ListActiveUserSubscriptions 获取用户的所有有效订阅
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]UserSubscription, error) {
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, err
}
normalizeExpiredWindows(subs)
return subs, nil
}
// ListGroupSubscriptions 获取分组的所有订阅
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
subs, pag, err := s.userSubRepo.ListByGroupID(ctx, groupID, params)
if err != nil {
return nil, nil, err
}
normalizeExpiredWindows(subs)
return subs, pag, nil
}
// List 获取所有订阅(分页,支持筛选)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status)
if err != nil {
return nil, nil, err
}
normalizeExpiredWindows(subs)
return subs, pag, nil
}
// normalizeExpiredWindows 将已过期窗口的数据清零(仅影响返回数据,不影响数据库)
// 这确保前端显示正确的当前窗口状态,而不是过期窗口的历史数据
func normalizeExpiredWindows(subs []UserSubscription) {
for i := range subs {
sub := &subs[i]
// 日窗口过期:清零展示数据
if sub.NeedsDailyReset() {
sub.DailyWindowStart = nil
sub.DailyUsageUSD = 0
}
// 周窗口过期:清零展示数据
if sub.NeedsWeeklyReset() {
sub.WeeklyWindowStart = nil
sub.WeeklyUsageUSD = 0
}
// 月窗口过期:清零展示数据
if sub.NeedsMonthlyReset() {
sub.MonthlyWindowStart = nil
sub.MonthlyUsageUSD = 0
}
}
}
// startOfDay 返回给定时间所在日期的零点(保持原时区)
func startOfDay(t time.Time) time.Time {
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
}
// CheckAndActivateWindow 检查并激活窗口(首次使用时)
func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *UserSubscription) error {
if sub.IsWindowActivated() {
return nil
}
// 使用当天零点作为窗口起始时间
windowStart := startOfDay(time.Now())
return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart)
}
// CheckAndResetWindows 检查并重置过期的窗口
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *UserSubscription) error {
// 使用当天零点作为新窗口起始时间
windowStart := startOfDay(time.Now())
needsInvalidateCache := false
// 日窗口重置24小时
if sub.NeedsDailyReset() {
if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, windowStart); err != nil {
return err
}
sub.DailyWindowStart = &windowStart
sub.DailyUsageUSD = 0
needsInvalidateCache = true
}
// 周窗口重置7天
if sub.NeedsWeeklyReset() {
if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, windowStart); err != nil {
return err
}
sub.WeeklyWindowStart = &windowStart
sub.WeeklyUsageUSD = 0
needsInvalidateCache = true
}
// 月窗口重置30天
if sub.NeedsMonthlyReset() {
if err := s.userSubRepo.ResetMonthlyUsage(ctx, sub.ID, windowStart); err != nil {
return err
}
sub.MonthlyWindowStart = &windowStart
sub.MonthlyUsageUSD = 0
needsInvalidateCache = true
}
// 如果有窗口被重置,失效 Redis 缓存以保持一致性
if needsInvalidateCache && s.billingCacheService != nil {
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
}
return nil
}
// CheckUsageLimits 检查使用限额(返回错误如果超限)
// 用于中间件的快速预检查additionalCost 通常为 0
func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSubscription, group *Group, additionalCost float64) error {
if !sub.CheckDailyLimit(group, additionalCost) {
return ErrDailyLimitExceeded
}
if !sub.CheckWeeklyLimit(group, additionalCost) {
return ErrWeeklyLimitExceeded
}
if !sub.CheckMonthlyLimit(group, additionalCost) {
return ErrMonthlyLimitExceeded
}
return nil
}
// RecordUsage 记录使用量到订阅
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD)
}
// SubscriptionProgress 订阅进度
type SubscriptionProgress struct {
ID int64 `json:"id"`
GroupName string `json:"group_name"`
ExpiresAt time.Time `json:"expires_at"`
ExpiresInDays int `json:"expires_in_days"`
Daily *UsageWindowProgress `json:"daily,omitempty"`
Weekly *UsageWindowProgress `json:"weekly,omitempty"`
Monthly *UsageWindowProgress `json:"monthly,omitempty"`
}
// UsageWindowProgress 使用窗口进度
type UsageWindowProgress struct {
LimitUSD float64 `json:"limit_usd"`
UsedUSD float64 `json:"used_usd"`
RemainingUSD float64 `json:"remaining_usd"`
Percentage float64 `json:"percentage"`
WindowStart time.Time `json:"window_start"`
ResetsAt time.Time `json:"resets_at"`
ResetsInSeconds int64 `json:"resets_in_seconds"`
}
// GetSubscriptionProgress 获取订阅使用进度
func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subscriptionID int64) (*SubscriptionProgress, error) {
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
group := sub.Group
if group == nil {
group, err = s.groupRepo.GetByID(ctx, sub.GroupID)
if err != nil {
return nil, err
}
}
progress := &SubscriptionProgress{
ID: sub.ID,
GroupName: group.Name,
ExpiresAt: sub.ExpiresAt,
ExpiresInDays: sub.DaysRemaining(),
}
// 日进度
if group.HasDailyLimit() && sub.DailyWindowStart != nil {
limit := *group.DailyLimitUSD
resetsAt := sub.DailyWindowStart.Add(24 * time.Hour)
progress.Daily = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.DailyUsageUSD,
RemainingUSD: limit - sub.DailyUsageUSD,
Percentage: (sub.DailyUsageUSD / limit) * 100,
WindowStart: *sub.DailyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Daily.RemainingUSD < 0 {
progress.Daily.RemainingUSD = 0
}
if progress.Daily.Percentage > 100 {
progress.Daily.Percentage = 100
}
if progress.Daily.ResetsInSeconds < 0 {
progress.Daily.ResetsInSeconds = 0
}
}
// 周进度
if group.HasWeeklyLimit() && sub.WeeklyWindowStart != nil {
limit := *group.WeeklyLimitUSD
resetsAt := sub.WeeklyWindowStart.Add(7 * 24 * time.Hour)
progress.Weekly = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.WeeklyUsageUSD,
RemainingUSD: limit - sub.WeeklyUsageUSD,
Percentage: (sub.WeeklyUsageUSD / limit) * 100,
WindowStart: *sub.WeeklyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Weekly.RemainingUSD < 0 {
progress.Weekly.RemainingUSD = 0
}
if progress.Weekly.Percentage > 100 {
progress.Weekly.Percentage = 100
}
if progress.Weekly.ResetsInSeconds < 0 {
progress.Weekly.ResetsInSeconds = 0
}
}
// 月进度
if group.HasMonthlyLimit() && sub.MonthlyWindowStart != nil {
limit := *group.MonthlyLimitUSD
resetsAt := sub.MonthlyWindowStart.Add(30 * 24 * time.Hour)
progress.Monthly = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.MonthlyUsageUSD,
RemainingUSD: limit - sub.MonthlyUsageUSD,
Percentage: (sub.MonthlyUsageUSD / limit) * 100,
WindowStart: *sub.MonthlyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Monthly.RemainingUSD < 0 {
progress.Monthly.RemainingUSD = 0
}
if progress.Monthly.Percentage > 100 {
progress.Monthly.Percentage = 100
}
if progress.Monthly.ResetsInSeconds < 0 {
progress.Monthly.ResetsInSeconds = 0
}
}
return progress, nil
}
// GetUserSubscriptionsWithProgress 获取用户所有订阅及进度
func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) {
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, err
}
progresses := make([]SubscriptionProgress, 0, len(subs))
for _, sub := range subs {
progress, err := s.GetSubscriptionProgress(ctx, sub.ID)
if err != nil {
continue
}
progresses = append(progresses, *progress)
}
return progresses, nil
}
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) {
return s.userSubRepo.BatchUpdateExpiredStatus(ctx)
}
// ValidateSubscription 验证订阅是否有效
func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *UserSubscription) error {
if sub.Status == SubscriptionStatusExpired {
return ErrSubscriptionExpired
}
if sub.Status == SubscriptionStatusSuspended {
return ErrSubscriptionSuspended
}
if sub.IsExpired() {
// 更新状态
_ = s.userSubRepo.UpdateStatus(ctx, sub.ID, SubscriptionStatusExpired)
return ErrSubscriptionExpired
}
return nil
}