refactor(数据库): 迁移持久层到 Ent 并清理 GORM

将仓储层/基础设施改为 Ent + 原生 SQL 执行路径,并移除 AutoMigrate 与 GORM 依赖。
重构内容包括:
- 仓储层改用 Ent/SQL(含 usage_log/account 等复杂查询),统一错误映射
- 基础设施与 setup 初始化切换为 Ent + SQL migrations
- 集成测试与 fixtures 迁移到 Ent 事务模型
- 清理遗留 GORM 模型/依赖,补充迁移与文档说明
- 增加根目录 Makefile 便于前后端编译

测试:
- go test -tags unit ./...
- go test -tags integration ./...
This commit is contained in:
yangjianbo
2025-12-29 10:03:27 +08:00
parent fd51ff6970
commit 3d617de577
149 changed files with 62892 additions and 3212 deletions

View File

@@ -4,333 +4,336 @@ import (
"context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm"
)
type userSubscriptionRepository struct {
db *gorm.DB
client *dbent.Client
}
func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository {
return &userSubscriptionRepository{db: db}
func NewUserSubscriptionRepository(client *dbent.Client) service.UserSubscriptionRepository {
return &userSubscriptionRepository{client: client}
}
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error {
m := userSubscriptionModelFromService(sub)
err := r.db.WithContext(ctx).Create(m).Error
if sub == nil {
return nil
}
builder := r.client.UserSubscription.Create().
SetUserID(sub.UserID).
SetGroupID(sub.GroupID).
SetExpiresAt(sub.ExpiresAt).
SetNillableDailyWindowStart(sub.DailyWindowStart).
SetNillableWeeklyWindowStart(sub.WeeklyWindowStart).
SetNillableMonthlyWindowStart(sub.MonthlyWindowStart).
SetDailyUsageUsd(sub.DailyUsageUSD).
SetWeeklyUsageUsd(sub.WeeklyUsageUSD).
SetMonthlyUsageUsd(sub.MonthlyUsageUSD).
SetNillableAssignedBy(sub.AssignedBy)
if sub.StartsAt.IsZero() {
builder.SetStartsAt(time.Now())
} else {
builder.SetStartsAt(sub.StartsAt)
}
if sub.Status != "" {
builder.SetStatus(sub.Status)
}
if !sub.AssignedAt.IsZero() {
builder.SetAssignedAt(sub.AssignedAt)
}
// Keep compatibility with historical behavior: always store notes as a string value.
builder.SetNotes(sub.Notes)
created, err := builder.Save(ctx)
if err == nil {
applyUserSubscriptionModelToService(sub, m)
applyUserSubscriptionEntityToService(sub, created)
}
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
}
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
var m userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("User").
Preload("Group").
Preload("AssignedByUser").
First(&m, id).Error
m, err := r.client.UserSubscription.Query().
Where(usersubscription.IDEQ(id)).
WithUser().
WithGroup().
WithAssignedByUser().
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return userSubscriptionModelToService(&m), nil
return userSubscriptionEntityToService(m), nil
}
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
var m userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND group_id = ?", userID, groupID).
First(&m).Error
m, err := r.client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
WithGroup().
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return userSubscriptionModelToService(&m), nil
return userSubscriptionEntityToService(m), nil
}
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
var m userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND group_id = ? AND status = ? AND expires_at > ?",
userID, groupID, service.SubscriptionStatusActive, time.Now()).
First(&m).Error
m, err := r.client.UserSubscription.Query().
Where(
usersubscription.UserIDEQ(userID),
usersubscription.GroupIDEQ(groupID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtGT(time.Now()),
).
WithGroup().
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return userSubscriptionModelToService(&m), nil
return userSubscriptionEntityToService(m), nil
}
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error {
sub.UpdatedAt = time.Now()
m := userSubscriptionModelFromService(sub)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyUserSubscriptionModelToService(sub, m)
if sub == nil {
return nil
}
return err
builder := r.client.UserSubscription.UpdateOneID(sub.ID).
SetUserID(sub.UserID).
SetGroupID(sub.GroupID).
SetStartsAt(sub.StartsAt).
SetExpiresAt(sub.ExpiresAt).
SetStatus(sub.Status).
SetNillableDailyWindowStart(sub.DailyWindowStart).
SetNillableWeeklyWindowStart(sub.WeeklyWindowStart).
SetNillableMonthlyWindowStart(sub.MonthlyWindowStart).
SetDailyUsageUsd(sub.DailyUsageUSD).
SetWeeklyUsageUsd(sub.WeeklyUsageUSD).
SetMonthlyUsageUsd(sub.MonthlyUsageUSD).
SetNillableAssignedBy(sub.AssignedBy).
SetAssignedAt(sub.AssignedAt).
SetNotes(sub.Notes)
updated, err := builder.Save(ctx)
if err == nil {
applyUserSubscriptionEntityToService(sub, updated)
return nil
}
return translatePersistenceError(err, service.ErrSubscriptionNotFound, service.ErrSubscriptionAlreadyExists)
}
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&userSubscriptionModel{}, id).Error
// Match GORM semantics: deleting a missing row is not an error.
_, err := r.client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx)
return err
}
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
var subs []userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ?", userID).
Order("created_at DESC").
Find(&subs).Error
subs, err := r.client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID)).
WithGroup().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
All(ctx)
if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
return userSubscriptionEntitiesToService(subs), nil
}
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
var subs []userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND status = ? AND expires_at > ?",
userID, service.SubscriptionStatusActive, time.Now()).
Order("created_at DESC").
Find(&subs).Error
subs, err := r.client.UserSubscription.Query().
Where(
usersubscription.UserIDEQ(userID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtGT(time.Now()),
).
WithGroup().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
All(ctx)
if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
return userSubscriptionEntitiesToService(subs), nil
}
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
var subs []userSubscriptionModel
var total int64
q := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID))
query := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).Where("group_id = ?", groupID)
if err := query.Count(&total).Error; err != nil {
return nil, nil, err
}
err := query.
Preload("User").
Preload("Group").
Order("created_at DESC").
Offset(params.Offset()).
Limit(params.Limit()).
Find(&subs).Error
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
subs, err := q.
WithUser().
WithGroup().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
Offset(params.Offset()).
Limit(params.Limit()).
All(ctx)
if err != nil {
return nil, nil, err
}
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
}
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
var subs []userSubscriptionModel
var total int64
query := r.db.WithContext(ctx).Model(&userSubscriptionModel{})
q := r.client.UserSubscription.Query()
if userID != nil {
query = query.Where("user_id = ?", *userID)
q = q.Where(usersubscription.UserIDEQ(*userID))
}
if groupID != nil {
query = query.Where("group_id = ?", *groupID)
q = q.Where(usersubscription.GroupIDEQ(*groupID))
}
if status != "" {
query = query.Where("status = ?", status)
q = q.Where(usersubscription.StatusEQ(status))
}
if err := query.Count(&total).Error; err != nil {
return nil, nil, err
}
err := query.
Preload("User").
Preload("Group").
Preload("AssignedByUser").
Order("created_at DESC").
Offset(params.Offset()).
Limit(params.Limit()).
Find(&subs).Error
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
subs, err := q.
WithUser().
WithGroup().
WithAssignedByUser().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
Offset(params.Offset()).
Limit(params.Limit()).
All(ctx)
if err != nil {
return nil, nil, err
}
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
}
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("user_id = ? AND group_id = ?", userID, groupID).
Count(&count).Error
return count > 0, err
return r.client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
Exist(ctx)
}
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", subscriptionID).
Updates(map[string]any{
"expires_at": newExpiresAt,
"updated_at": time.Now(),
}).Error
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
SetExpiresAt(newExpiresAt).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", subscriptionID).
Updates(map[string]any{
"status": status,
"updated_at": time.Now(),
}).Error
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
SetStatus(status).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", subscriptionID).
Updates(map[string]any{
"notes": notes,
"updated_at": time.Now(),
}).Error
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
SetNotes(notes).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_window_start": start,
"weekly_window_start": start,
"monthly_window_start": start,
"updated_at": time.Now(),
}).Error
_, err := r.client.UserSubscription.UpdateOneID(id).
SetDailyWindowStart(start).
SetWeeklyWindowStart(start).
SetMonthlyWindowStart(start).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_usage_usd": 0,
"daily_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
_, err := r.client.UserSubscription.UpdateOneID(id).
SetDailyUsageUsd(0).
SetDailyWindowStart(newWindowStart).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"weekly_usage_usd": 0,
"weekly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
_, err := r.client.UserSubscription.UpdateOneID(id).
SetWeeklyUsageUsd(0).
SetWeeklyWindowStart(newWindowStart).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"monthly_usage_usd": 0,
"monthly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
_, err := r.client.UserSubscription.UpdateOneID(id).
SetMonthlyUsageUsd(0).
SetMonthlyWindowStart(newWindowStart).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
"weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD),
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
"updated_at": time.Now(),
}).Error
_, err := r.client.UserSubscription.UpdateOneID(id).
AddDailyUsageUsd(costUSD).
AddWeeklyUsageUsd(costUSD).
AddMonthlyUsageUsd(costUSD).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
result := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
Updates(map[string]any{
"status": service.SubscriptionStatusExpired,
"updated_at": time.Now(),
})
return result.RowsAffected, result.Error
n, err := r.client.UserSubscription.Update().
Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(time.Now()),
).
SetStatus(service.SubscriptionStatusExpired).
Save(ctx)
return int64(n), err
}
// Extra repository helpers (currently used only by integration tests).
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) {
var subs []userSubscriptionModel
err := r.db.WithContext(ctx).
Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
Find(&subs).Error
subs, err := r.client.UserSubscription.Query().
Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(time.Now()),
).
All(ctx)
if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
return userSubscriptionEntitiesToService(subs), nil
}
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("group_id = ?", groupID).
Count(&count).Error
return count, err
count, err := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx)
return int64(count), err
}
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("group_id = ? AND status = ? AND expires_at > ?",
groupID, service.SubscriptionStatusActive, time.Now()).
Count(&count).Error
return count, err
count, err := r.client.UserSubscription.Query().
Where(
usersubscription.GroupIDEQ(groupID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtGT(time.Now()),
).
Count(ctx)
return int64(count), err
}
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&userSubscriptionModel{})
return result.RowsAffected, result.Error
n, err := r.client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx)
return int64(n), err
}
type userSubscriptionModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
GroupID int64 `gorm:"index;not null"`
StartsAt time.Time `gorm:"not null"`
ExpiresAt time.Time `gorm:"not null"`
Status string `gorm:"size:20;default:active;not null"`
DailyWindowStart *time.Time
WeeklyWindowStart *time.Time
MonthlyWindowStart *time.Time
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
AssignedBy *int64 `gorm:"index"`
AssignedAt time.Time `gorm:"not null"`
Notes string `gorm:"type:text"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
AssignedByUser *userModel `gorm:"foreignKey:AssignedBy"`
}
func (userSubscriptionModel) TableName() string { return "user_subscriptions" }
func userSubscriptionModelToService(m *userSubscriptionModel) *service.UserSubscription {
func userSubscriptionEntityToService(m *dbent.UserSubscription) *service.UserSubscription {
if m == nil {
return nil
}
return &service.UserSubscription{
out := &service.UserSubscription{
ID: m.ID,
UserID: m.UserID,
GroupID: m.GroupID,
@@ -340,60 +343,42 @@ func userSubscriptionModelToService(m *userSubscriptionModel) *service.UserSubsc
DailyWindowStart: m.DailyWindowStart,
WeeklyWindowStart: m.WeeklyWindowStart,
MonthlyWindowStart: m.MonthlyWindowStart,
DailyUsageUSD: m.DailyUsageUSD,
WeeklyUsageUSD: m.WeeklyUsageUSD,
MonthlyUsageUSD: m.MonthlyUsageUSD,
DailyUsageUSD: m.DailyUsageUsd,
WeeklyUsageUSD: m.WeeklyUsageUsd,
MonthlyUsageUSD: m.MonthlyUsageUsd,
AssignedBy: m.AssignedBy,
AssignedAt: m.AssignedAt,
Notes: m.Notes,
Notes: derefString(m.Notes),
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
AssignedByUser: userModelToService(m.AssignedByUser),
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
}
if m.Edges.Group != nil {
out.Group = groupEntityToService(m.Edges.Group)
}
if m.Edges.AssignedByUser != nil {
out.AssignedByUser = userEntityToService(m.Edges.AssignedByUser)
}
return out
}
func userSubscriptionModelsToService(models []userSubscriptionModel) []service.UserSubscription {
func userSubscriptionEntitiesToService(models []*dbent.UserSubscription) []service.UserSubscription {
out := make([]service.UserSubscription, 0, len(models))
for i := range models {
if s := userSubscriptionModelToService(&models[i]); s != nil {
if s := userSubscriptionEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func userSubscriptionModelFromService(s *service.UserSubscription) *userSubscriptionModel {
if s == nil {
return nil
}
return &userSubscriptionModel{
ID: s.ID,
UserID: s.UserID,
GroupID: s.GroupID,
StartsAt: s.StartsAt,
ExpiresAt: s.ExpiresAt,
Status: s.Status,
DailyWindowStart: s.DailyWindowStart,
WeeklyWindowStart: s.WeeklyWindowStart,
MonthlyWindowStart: s.MonthlyWindowStart,
DailyUsageUSD: s.DailyUsageUSD,
WeeklyUsageUSD: s.WeeklyUsageUSD,
MonthlyUsageUSD: s.MonthlyUsageUSD,
AssignedBy: s.AssignedBy,
AssignedAt: s.AssignedAt,
Notes: s.Notes,
CreatedAt: s.CreatedAt,
UpdatedAt: s.UpdatedAt,
}
}
func applyUserSubscriptionModelToService(sub *service.UserSubscription, m *userSubscriptionModel) {
if sub == nil || m == nil {
func applyUserSubscriptionEntityToService(dst *service.UserSubscription, src *dbent.UserSubscription) {
if dst == nil || src == nil {
return
}
sub.ID = m.ID
sub.CreatedAt = m.CreatedAt
sub.UpdatedAt = m.UpdatedAt
dst.ID = src.ID
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}