fix(仓储): 修复软删除过滤与事务测试
修复软删除拦截器使用错误,确保默认查询过滤已删记录 仓储层改用 ent.Tx 与扫描辅助,避免 sql.Tx 断言问题 同步更新集成测试以覆盖事务与统计变动
This commit is contained in:
@@ -2,9 +2,11 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
@@ -18,6 +20,11 @@ func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
|
||||
return &apiKeyRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
|
||||
// 默认过滤已软删除记录,避免删除后仍被查询到。
|
||||
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
created, err := r.client.ApiKey.Create().
|
||||
SetUserID(key.UserID).
|
||||
@@ -35,7 +42,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) erro
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
m, err := r.client.ApiKey.Query().
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
@@ -55,7 +62,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
|
||||
// - 不加载完整的 ApiKey 实体及其关联数据(User、Group 等)
|
||||
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
|
||||
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
m, err := r.client.ApiKey.Query().
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Select(apikey.FieldUserID).
|
||||
Only(ctx)
|
||||
@@ -69,7 +76,7 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
m, err := r.client.ApiKey.Query().
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.KeyEQ(key)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
@@ -84,6 +91,14 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
exists, err := r.activeQuery().Where(apikey.IDEQ(key.ID)).Exist(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exists {
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
|
||||
builder := r.client.ApiKey.UpdateOneID(key.ID).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status)
|
||||
@@ -105,12 +120,34 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.client.ApiKey.Delete().Where(apikey.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
|
||||
affected, err := r.client.ApiKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetDeletedAt(time.Now()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
exists, err := r.client.ApiKey.Query().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Exist(mixins.SkipSoftDelete(ctx))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
q := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID))
|
||||
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
@@ -141,7 +178,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
|
||||
}
|
||||
|
||||
ids, err := r.client.ApiKey.Query().
|
||||
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...)).
|
||||
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
|
||||
IDs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -150,17 +187,17 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
count, err := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID)).Count(ctx)
|
||||
count, err := r.activeQuery().Where(apikey.UserIDEQ(userID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
count, err := r.client.ApiKey.Query().Where(apikey.KeyEQ(key)).Count(ctx)
|
||||
count, err := r.activeQuery().Where(apikey.KeyEQ(key)).Count(ctx)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
q := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID))
|
||||
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
@@ -187,7 +224,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
|
||||
// SearchApiKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
q := r.client.ApiKey.Query()
|
||||
q := r.activeQuery()
|
||||
if userID > 0 {
|
||||
q = q.Where(apikey.UserIDEQ(userID))
|
||||
}
|
||||
@@ -211,7 +248,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
|
||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
n, err := r.client.ApiKey.Update().
|
||||
Where(apikey.GroupIDEQ(groupID)).
|
||||
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx)
|
||||
return int64(n), err
|
||||
@@ -219,7 +256,7 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
|
||||
|
||||
// CountByGroupID 获取分组的 API Key 数量
|
||||
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
count, err := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
|
||||
count, err := r.activeQuery().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user