refactor(数据库): 迁移持久层到 Ent 并清理 GORM
将仓储层/基础设施改为 Ent + 原生 SQL 执行路径,并移除 AutoMigrate 与 GORM 依赖。 重构内容包括: - 仓储层改用 Ent/SQL(含 usage_log/account 等复杂查询),统一错误映射 - 基础设施与 setup 初始化切换为 Ent + SQL migrations - 集成测试与 fixtures 迁移到 Ent 事务模型 - 清理遗留 GORM 模型/依赖,补充迁移与文档说明 - 增加根目录 Makefile 便于前后端编译 测试: - go test -tags unit ./... - go test -tags integration ./...
This commit is contained in:
@@ -2,83 +2,118 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type apiKeyRepository struct {
|
||||
db *gorm.DB
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
|
||||
return &apiKeyRepository{db: db}
|
||||
func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
|
||||
return &apiKeyRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
m := apiKeyModelFromService(key)
|
||||
err := r.db.WithContext(ctx).Create(m).Error
|
||||
created, err := r.client.ApiKey.Create().
|
||||
SetUserID(key.UserID).
|
||||
SetKey(key.Key).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetNillableGroupID(key.GroupID).
|
||||
Save(ctx)
|
||||
if err == nil {
|
||||
applyApiKeyModelToService(key, m)
|
||||
key.ID = created.ID
|
||||
key.CreatedAt = created.CreatedAt
|
||||
key.UpdatedAt = created.UpdatedAt
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
var m apiKeyModel
|
||||
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&m, id).Error
|
||||
m, err := r.client.ApiKey.Query().
|
||||
Where(apikey.IDEQ(id)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyModelToService(&m), nil
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
var m apiKeyModel
|
||||
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&m).Error
|
||||
m, err := r.client.ApiKey.Query().
|
||||
Where(apikey.KeyEQ(key)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyModelToService(&m), nil
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
m := apiKeyModelFromService(key)
|
||||
err := r.db.WithContext(ctx).Model(m).Select("name", "group_id", "status", "updated_at").Updates(m).Error
|
||||
builder := r.client.ApiKey.UpdateOneID(key.ID).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status)
|
||||
if key.GroupID != nil {
|
||||
builder.SetGroupID(*key.GroupID)
|
||||
} else {
|
||||
builder.ClearGroupID()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
applyApiKeyModelToService(key, m)
|
||||
key.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&apiKeyModel{}, id).Error
|
||||
_, err := r.client.ApiKey.Delete().Where(apikey.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
var keys []apiKeyModel
|
||||
var total int64
|
||||
q := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID))
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID)
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Preload("Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&keys).Error; err != nil {
|
||||
keys, err := q.
|
||||
WithGroup().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(apikey.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
|
||||
return outKeys, paginationResultFromTotal(total, params), nil
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
@@ -86,11 +121,9 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
ids := make([]int64, 0, len(apiKeyIDs))
|
||||
err := r.db.WithContext(ctx).
|
||||
Model(&apiKeyModel{}).
|
||||
Where("user_id = ? AND id IN ?", userID, apiKeyIDs).
|
||||
Pluck("id", &ids).Error
|
||||
ids, err := r.client.ApiKey.Query().
|
||||
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...)).
|
||||
IDs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -98,136 +131,146 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count, err
|
||||
count, err := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("key = ?", key).Count(&count).Error
|
||||
count, err := r.client.ApiKey.Query().Where(apikey.KeyEQ(key)).Count(ctx)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
var keys []apiKeyModel
|
||||
var total int64
|
||||
q := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID))
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID)
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Preload("User").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&keys).Error; err != nil {
|
||||
keys, err := q.
|
||||
WithUser().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(apikey.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
|
||||
return outKeys, paginationResultFromTotal(total, params), nil
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// SearchApiKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
var keys []apiKeyModel
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&apiKeyModel{})
|
||||
|
||||
q := r.client.ApiKey.Query()
|
||||
if userID > 0 {
|
||||
db = db.Where("user_id = ?", userID)
|
||||
q = q.Where(apikey.UserIDEQ(userID))
|
||||
}
|
||||
|
||||
if keyword != "" {
|
||||
searchPattern := "%" + keyword + "%"
|
||||
db = db.Where("name ILIKE ?", searchPattern)
|
||||
q = q.Where(apikey.NameContainsFold(keyword))
|
||||
}
|
||||
|
||||
if err := db.Limit(limit).Order("id DESC").Find(&keys).Error; err != nil {
|
||||
keys, err := q.Limit(limit).Order(dbent.Desc(apikey.FieldID)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
return outKeys, nil
|
||||
}
|
||||
|
||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Model(&apiKeyModel{}).
|
||||
Where("group_id = ?", groupID).
|
||||
Update("group_id", nil)
|
||||
return result.RowsAffected, result.Error
|
||||
n, err := r.client.ApiKey.Update().
|
||||
Where(apikey.GroupIDEQ(groupID)).
|
||||
ClearGroupID().
|
||||
Save(ctx)
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// CountByGroupID 获取分组的 API Key 数量
|
||||
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID).Count(&count).Error
|
||||
return count, err
|
||||
count, err := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
type apiKeyModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
UserID int64 `gorm:"index;not null"`
|
||||
Key string `gorm:"uniqueIndex;size:128;not null"`
|
||||
Name string `gorm:"size:100;not null"`
|
||||
GroupID *int64 `gorm:"index"`
|
||||
Status string `gorm:"size:20;default:active;not null"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
UpdatedAt time.Time `gorm:"not null"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
|
||||
User *userModel `gorm:"foreignKey:UserID"`
|
||||
Group *groupModel `gorm:"foreignKey:GroupID"`
|
||||
}
|
||||
|
||||
func (apiKeyModel) TableName() string { return "api_keys" }
|
||||
|
||||
func apiKeyModelToService(m *apiKeyModel) *service.ApiKey {
|
||||
func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.ApiKey{
|
||||
out := &service.ApiKey{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
GroupID: m.GroupID,
|
||||
Status: m.Status,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
User: userModelToService(m.User),
|
||||
Group: groupModelToService(m.Group),
|
||||
GroupID: m.GroupID,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
if m.Edges.Group != nil {
|
||||
out.Group = groupEntityToService(m.Edges.Group)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func apiKeyModelFromService(k *service.ApiKey) *apiKeyModel {
|
||||
if k == nil {
|
||||
func userEntityToService(u *dbent.User) *service.User {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
return &apiKeyModel{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
Name: k.Name,
|
||||
GroupID: k.GroupID,
|
||||
Status: k.Status,
|
||||
CreatedAt: k.CreatedAt,
|
||||
UpdatedAt: k.UpdatedAt,
|
||||
return &service.User{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Wechat: u.Wechat,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func applyApiKeyModelToService(key *service.ApiKey, m *apiKeyModel) {
|
||||
if key == nil || m == nil {
|
||||
return
|
||||
func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: derefString(g.Description),
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
key.ID = m.ID
|
||||
key.CreatedAt = m.CreatedAt
|
||||
key.UpdatedAt = m.UpdatedAt
|
||||
}
|
||||
|
||||
func derefString(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user