chore: 更新依赖、配置和代码生成

主要更新:
- 更新 go.mod/go.sum 依赖
- 重新生成 Ent ORM 代码
- 更新 Wire 依赖注入配置
- 添加 docker-compose.override.yml 到 .gitignore
- 更新 README 文档(Simple Mode 说明和已知问题)
- 清理调试日志
- 其他代码优化和格式修复
This commit is contained in:
ianshaw
2026-01-03 06:37:08 -08:00
parent b1702de522
commit 112a2d0866
121 changed files with 3058 additions and 2948 deletions

View File

@@ -135,12 +135,12 @@ func (s *AccountRepoSuite) TestListWithFilters() {
name: "filter_by_type",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeAPIKey})
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeApiKey})
},
accType: service.AccountTypeAPIKey,
accType: service.AccountTypeApiKey,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal(service.AccountTypeAPIKey, accounts[0].Type)
s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type)
},
},
{

View File

@@ -80,7 +80,7 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
}
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsAPIKeys(t *testing.T) {
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
entClient := tx.Client()
@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsAPIKeys(t *t
userRepo := newUserRepositoryWithSQL(entClient, tx)
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
apiKeyRepo := NewAPIKeyRepository(entClient)
apiKeyRepo := NewApiKeyRepository(entClient)
u := &service.User{
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
@@ -110,7 +110,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsAPIKeys(t *t
}
require.NoError(t, userRepo.Create(ctx, u))
key := &service.APIKey{
key := &service.ApiKey{
UserID: u.ID,
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
Name: "test key",

View File

@@ -24,7 +24,7 @@ type apiKeyCache struct {
rdb *redis.Client
}
func NewAPIKeyCache(rdb *redis.Client) service.APIKeyCache {
func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
return &apiKeyCache{rdb: rdb}
}

View File

@@ -13,11 +13,11 @@ import (
"github.com/stretchr/testify/suite"
)
type APIKeyCacheSuite struct {
type ApiKeyCacheSuite struct {
IntegrationRedisSuite
}
func (s *APIKeyCacheSuite) TestCreateAttemptCount() {
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
@@ -78,7 +78,7 @@ func (s *APIKeyCacheSuite) TestCreateAttemptCount() {
}
}
func (s *APIKeyCacheSuite) TestDailyUsage() {
func (s *ApiKeyCacheSuite) TestDailyUsage() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
@@ -122,6 +122,6 @@ func (s *APIKeyCacheSuite) TestDailyUsage() {
}
}
func TestAPIKeyCacheSuite(t *testing.T) {
suite.Run(t, new(APIKeyCacheSuite))
func TestApiKeyCacheSuite(t *testing.T) {
suite.Run(t, new(ApiKeyCacheSuite))
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestAPIKeyRateLimitKey(t *testing.T) {
func TestApiKeyRateLimitKey(t *testing.T) {
tests := []struct {
name string
userID int64

View File

@@ -16,17 +16,17 @@ type apiKeyRepository struct {
client *dbent.Client
}
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
return &apiKeyRepository{client: client}
}
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
// 默认过滤已软删除记录,避免删除后仍被查询到。
return r.client.APIKey.Query().Where(apikey.DeletedAtIsNil())
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
}
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
created, err := r.client.APIKey.Create().
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
created, err := r.client.ApiKey.Create().
SetUserID(key.UserID).
SetKey(key.Key).
SetName(key.Name).
@@ -38,10 +38,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
key.CreatedAt = created.CreatedAt
key.UpdatedAt = created.UpdatedAt
}
return translatePersistenceError(err, nil, service.ErrAPIKeyExists)
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
}
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
WithUser().
@@ -49,7 +49,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrAPIKeyNotFound
return nil, service.ErrApiKeyNotFound
}
return nil, err
}
@@ -59,7 +59,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
// 相比 GetByID此方法性能更优因为
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
// - 不加载完整的 APIKey 实体及其关联数据User、Group 等)
// - 不加载完整的 ApiKey 实体及其关联数据User、Group 等)
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
m, err := r.activeQuery().
@@ -68,14 +68,14 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, service.ErrAPIKeyNotFound
return 0, service.ErrApiKeyNotFound
}
return 0, err
}
return m.UserID, nil
}
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
m, err := r.activeQuery().
Where(apikey.KeyEQ(key)).
WithUser().
@@ -83,21 +83,21 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrAPIKeyNotFound
return nil, service.ErrApiKeyNotFound
}
return nil, err
}
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
// 之前的实现先检查 Exist 再 UpdateOneID若在两步之间发生软删除
// 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at避免二次查询带来的并发可见性问题。
now := time.Now()
builder := r.client.APIKey.Update().
builder := r.client.ApiKey.Update().
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name).
SetStatus(key.Status).
@@ -114,7 +114,7 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
}
if affected == 0 {
// 更新影响行数为 0说明记录不存在或已被软删除。
return service.ErrAPIKeyNotFound
return service.ErrApiKeyNotFound
}
// 使用同一时间戳回填,避免并发删除导致二次查询失败。
@@ -124,18 +124,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
affected, err := r.client.APIKey.Update().
affected, err := r.client.ApiKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetDeletedAt(time.Now()).
Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrAPIKeyNotFound
return service.ErrApiKeyNotFound
}
return err
}
if affected == 0 {
exists, err := r.client.APIKey.Query().
exists, err := r.client.ApiKey.Query().
Where(apikey.IDEQ(id)).
Exist(mixins.SkipSoftDelete(ctx))
if err != nil {
@@ -144,12 +144,12 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
if exists {
return nil
}
return service.ErrAPIKeyNotFound
return service.ErrApiKeyNotFound
}
return nil
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
total, err := q.Count(ctx)
@@ -167,7 +167,7 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return nil, nil, err
}
outKeys := make([]service.APIKey, 0, len(keys))
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
@@ -180,7 +180,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
return []int64{}, nil
}
ids, err := r.client.APIKey.Query().
ids, err := r.client.ApiKey.Query().
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
IDs(ctx)
if err != nil {
@@ -199,7 +199,7 @@ func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
return count > 0, err
}
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
total, err := q.Count(ctx)
@@ -217,7 +217,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return nil, nil, err
}
outKeys := make([]service.APIKey, 0, len(keys))
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
@@ -225,8 +225,8 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return outKeys, paginationResultFromTotal(int64(total), params), nil
}
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
q := r.activeQuery()
if userID > 0 {
q = q.Where(apikey.UserIDEQ(userID))
@@ -241,7 +241,7 @@ func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyw
return nil, err
}
outKeys := make([]service.APIKey, 0, len(keys))
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
@@ -250,7 +250,7 @@ func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyw
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
n, err := r.client.APIKey.Update().
n, err := r.client.ApiKey.Update().
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
ClearGroupID().
Save(ctx)
@@ -263,11 +263,11 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
return int64(count), err
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey {
if m == nil {
return nil
}
out := &service.APIKey{
out := &service.ApiKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,

View File

@@ -12,30 +12,30 @@ import (
"github.com/stretchr/testify/suite"
)
type APIKeyRepoSuite struct {
type ApiKeyRepoSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
repo *apiKeyRepository
}
func (s *APIKeyRepoSuite) SetupTest() {
func (s *ApiKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository)
s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository)
}
func TestAPIKeyRepoSuite(t *testing.T) {
suite.Run(t, new(APIKeyRepoSuite))
func TestApiKeyRepoSuite(t *testing.T) {
suite.Run(t, new(ApiKeyRepoSuite))
}
// --- Create / GetByID / GetByKey ---
func (s *APIKeyRepoSuite) TestCreate() {
func (s *ApiKeyRepoSuite) TestCreate() {
user := s.mustCreateUser("create@test.com")
key := &service.APIKey{
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-create-test",
Name: "Test Key",
@@ -51,16 +51,16 @@ func (s *APIKeyRepoSuite) TestCreate() {
s.Require().Equal("sk-create-test", got.Key)
}
func (s *APIKeyRepoSuite) TestGetByID_NotFound() {
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *APIKeyRepoSuite) TestGetByKey() {
func (s *ApiKeyRepoSuite) TestGetByKey() {
user := s.mustCreateUser("getbykey@test.com")
group := s.mustCreateGroup("g-key")
key := &service.APIKey{
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-getbykey",
Name: "My Key",
@@ -78,16 +78,16 @@ func (s *APIKeyRepoSuite) TestGetByKey() {
s.Require().Equal(group.ID, got.Group.ID)
}
func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
s.Require().Error(err, "expected error for non-existent key")
}
// --- Update ---
func (s *APIKeyRepoSuite) TestUpdate() {
func (s *ApiKeyRepoSuite) TestUpdate() {
user := s.mustCreateUser("update@test.com")
key := &service.APIKey{
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-update",
Name: "Original",
@@ -108,10 +108,10 @@ func (s *APIKeyRepoSuite) TestUpdate() {
s.Require().Equal(service.StatusDisabled, got.Status)
}
func (s *APIKeyRepoSuite) TestUpdate_ClearGroupID() {
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
user := s.mustCreateUser("cleargroup@test.com")
group := s.mustCreateGroup("g-clear")
key := &service.APIKey{
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-clear-group",
Name: "Group Key",
@@ -131,9 +131,9 @@ func (s *APIKeyRepoSuite) TestUpdate_ClearGroupID() {
// --- Delete ---
func (s *APIKeyRepoSuite) TestDelete() {
func (s *ApiKeyRepoSuite) TestDelete() {
user := s.mustCreateUser("delete@test.com")
key := &service.APIKey{
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-delete",
Name: "Delete Me",
@@ -150,10 +150,10 @@ func (s *APIKeyRepoSuite) TestDelete() {
// --- ListByUserID / CountByUserID ---
func (s *APIKeyRepoSuite) TestListByUserID() {
func (s *ApiKeyRepoSuite) TestListByUserID() {
user := s.mustCreateUser("listbyuser@test.com")
s.mustCreateAPIKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateAPIKey(user.ID, "sk-list-2", "Key 2", nil)
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
@@ -161,10 +161,10 @@ func (s *APIKeyRepoSuite) TestListByUserID() {
s.Require().Equal(int64(2), page.Total)
}
func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
user := s.mustCreateUser("paging@test.com")
for i := 0; i < 5; i++ {
s.mustCreateAPIKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
}
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
@@ -174,10 +174,10 @@ func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
s.Require().Equal(3, page.Pages)
}
func (s *APIKeyRepoSuite) TestCountByUserID() {
func (s *ApiKeyRepoSuite) TestCountByUserID() {
user := s.mustCreateUser("count@test.com")
s.mustCreateAPIKey(user.ID, "sk-count-1", "K1", nil)
s.mustCreateAPIKey(user.ID, "sk-count-2", "K2", nil)
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID")
@@ -186,13 +186,13 @@ func (s *APIKeyRepoSuite) TestCountByUserID() {
// --- ListByGroupID / CountByGroupID ---
func (s *APIKeyRepoSuite) TestListByGroupID() {
func (s *ApiKeyRepoSuite) TestListByGroupID() {
user := s.mustCreateUser("listbygroup@test.com")
group := s.mustCreateGroup("g-list")
s.mustCreateAPIKey(user.ID, "sk-grp-1", "K1", &group.ID)
s.mustCreateAPIKey(user.ID, "sk-grp-2", "K2", &group.ID)
s.mustCreateAPIKey(user.ID, "sk-grp-3", "K3", nil) // no group
s.mustCreateApiKey(user.ID, "sk-grp-1", "K1", &group.ID)
s.mustCreateApiKey(user.ID, "sk-grp-2", "K2", &group.ID)
s.mustCreateApiKey(user.ID, "sk-grp-3", "K3", nil) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
@@ -202,10 +202,10 @@ func (s *APIKeyRepoSuite) TestListByGroupID() {
s.Require().NotNil(keys[0].User)
}
func (s *APIKeyRepoSuite) TestCountByGroupID() {
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
user := s.mustCreateUser("countgroup@test.com")
group := s.mustCreateGroup("g-count")
s.mustCreateAPIKey(user.ID, "sk-gc-1", "K1", &group.ID)
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
@@ -214,9 +214,9 @@ func (s *APIKeyRepoSuite) TestCountByGroupID() {
// --- ExistsByKey ---
func (s *APIKeyRepoSuite) TestExistsByKey() {
func (s *ApiKeyRepoSuite) TestExistsByKey() {
user := s.mustCreateUser("exists@test.com")
s.mustCreateAPIKey(user.ID, "sk-exists", "K", nil)
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey")
@@ -227,47 +227,47 @@ func (s *APIKeyRepoSuite) TestExistsByKey() {
s.Require().False(notExists)
}
// --- SearchAPIKeys ---
// --- SearchApiKeys ---
func (s *APIKeyRepoSuite) TestSearchAPIKeys() {
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
user := s.mustCreateUser("search@test.com")
s.mustCreateAPIKey(user.ID, "sk-search-1", "Production Key", nil)
s.mustCreateAPIKey(user.ID, "sk-search-2", "Development Key", nil)
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchAPIKeys")
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys")
s.Require().Len(found, 1)
s.Require().Contains(found[0].Name, "Production")
}
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoKeyword() {
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
user := s.mustCreateUser("searchnokw@test.com")
s.mustCreateAPIKey(user.ID, "sk-nk-1", "K1", nil)
s.mustCreateAPIKey(user.ID, "sk-nk-2", "K2", nil)
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "", 10)
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err)
s.Require().Len(found, 2)
}
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoUserID() {
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
user := s.mustCreateUser("searchnouid@test.com")
s.mustCreateAPIKey(user.ID, "sk-nu-1", "TestKey", nil)
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
found, err := s.repo.SearchAPIKeys(s.ctx, 0, "testkey", 10)
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err)
s.Require().Len(found, 1)
}
// --- ClearGroupIDByGroupID ---
func (s *APIKeyRepoSuite) TestClearGroupIDByGroupID() {
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
user := s.mustCreateUser("cleargrp@test.com")
group := s.mustCreateGroup("g-clear-bulk")
k1 := s.mustCreateAPIKey(user.ID, "sk-clr-1", "K1", &group.ID)
k2 := s.mustCreateAPIKey(user.ID, "sk-clr-2", "K2", &group.ID)
s.mustCreateAPIKey(user.ID, "sk-clr-3", "K3", nil) // no group
k1 := s.mustCreateApiKey(user.ID, "sk-clr-1", "K1", &group.ID)
k2 := s.mustCreateApiKey(user.ID, "sk-clr-2", "K2", &group.ID)
s.mustCreateApiKey(user.ID, "sk-clr-3", "K3", nil) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
@@ -284,10 +284,10 @@ func (s *APIKeyRepoSuite) TestClearGroupIDByGroupID() {
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := s.mustCreateUser("k@example.com")
group := s.mustCreateGroup("g-k")
key := s.mustCreateAPIKey(user.ID, "sk-test-1", "My Key", &group.ID)
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
key.GroupID = &group.ID
got, err := s.repo.GetByKey(s.ctx, key.Key)
@@ -320,13 +320,13 @@ func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().NoError(err, "ExistsByKey")
s.Require().True(exists, "expected key to exist")
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "renam", 10)
s.Require().NoError(err, "SearchAPIKeys")
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
s.Require().NoError(err, "SearchApiKeys")
s.Require().Len(found, 1)
s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID
k2 := s.mustCreateAPIKey(user.ID, "sk-test-2", "Group Key", &group.ID)
k2 := s.mustCreateApiKey(user.ID, "sk-test-2", "Group Key", &group.ID)
k2.GroupID = &group.ID
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
@@ -346,7 +346,7 @@ func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
}
func (s *APIKeyRepoSuite) mustCreateUser(email string) *service.User {
func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
s.T().Helper()
u, err := s.client.User.Create().
@@ -359,7 +359,7 @@ func (s *APIKeyRepoSuite) mustCreateUser(email string) *service.User {
return userEntityToService(u)
}
func (s *APIKeyRepoSuite) mustCreateGroup(name string) *service.Group {
func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
s.T().Helper()
g, err := s.client.Group.Create().
@@ -370,10 +370,10 @@ func (s *APIKeyRepoSuite) mustCreateGroup(name string) *service.Group {
return groupEntityToService(g)
}
func (s *APIKeyRepoSuite) mustCreateAPIKey(userID int64, key, name string, groupID *int64) *service.APIKey {
func (s *ApiKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.ApiKey {
s.T().Helper()
k := &service.APIKey{
k := &service.ApiKey{
UserID: userID,
Key: key,
Name: name,

View File

@@ -1,4 +1,4 @@
// Package repository 提供应用程序的基础设施层组件。
// Package infrastructure 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package repository

View File

@@ -243,7 +243,7 @@ func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) *
return a
}
func mustCreateAPIKey(t *testing.T, client *dbent.Client, k *service.APIKey) *service.APIKey {
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *service.ApiKey {
t.Helper()
ctx := context.Background()
@@ -257,7 +257,7 @@ func mustCreateAPIKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se
k.Name = "default"
}
create := client.APIKey.Create().
create := client.ApiKey.Create().
SetUserID(k.UserID).
SetKey(k.Key).
SetName(k.Name).

View File

@@ -293,8 +293,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
// 2. Clear group_id for api keys bound to this group.
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
// 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
if _, err := txClient.APIKey.Update().
// 与 ApiKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
if _, err := txClient.ApiKey.Update().
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
ClearGroupID().
Save(ctx); err != nil {

View File

@@ -34,15 +34,15 @@ func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, emai
return u
}
func TestEntSoftDelete_APIKey_DefaultFilterAndSkip(t *testing.T) {
func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
ctx := context.Background()
// 使用全局 ent client确保软删除验证在实际持久化数据上进行。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
Name: "soft-delete",
@@ -53,28 +53,28 @@ func TestEntSoftDelete_APIKey_DefaultFilterAndSkip(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
_, err := repo.GetByID(ctx, key.ID)
require.ErrorIs(t, err, service.ErrAPIKeyNotFound, "deleted rows should be hidden by default")
require.ErrorIs(t, err, service.ErrApiKeyNotFound, "deleted rows should be hidden by default")
_, err = client.APIKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
_, err = client.ApiKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
require.Error(t, err, "default ent query should not see soft-deleted rows")
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
got, err := client.APIKey.Query().
got, err := client.ApiKey.Query().
Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
}
func TestEntSoftDelete_APIKey_DeleteIdempotent(t *testing.T) {
func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
ctx := context.Background()
// 使用全局 ent client避免事务回滚影响幂等性验证。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
Name: "soft-delete2",
@@ -86,15 +86,15 @@ func TestEntSoftDelete_APIKey_DeleteIdempotent(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "second delete should be idempotent")
}
func TestEntSoftDelete_APIKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
ctx := context.Background()
// 使用全局 ent client确保 SkipSoftDelete 的硬删除语义可验证。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),
Name: "soft-delete3",
@@ -105,10 +105,10 @@ func TestEntSoftDelete_APIKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
// Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at.
_, err := client.APIKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
_, err := client.ApiKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "hard delete")
_, err = client.APIKey.Query().
_, err = client.ApiKey.Query().
Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")

View File

@@ -4,12 +4,13 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"sort"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -17,14 +18,15 @@ import (
type userRepository struct {
client *dbent.Client
sql sqlExecutor
}
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
return newUserRepositoryWithSQL(client, sqlDB)
}
func newUserRepositoryWithSQL(client *dbent.Client, _ sqlExecutor) *userRepository {
return &userRepository{client: client}
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
return &userRepository{client: client, sql: sqlq}
}
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
@@ -194,7 +196,11 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
// If attribute filters are specified, we need to filter by user IDs first
var allowedUserIDs []int64
if len(filters.Attributes) > 0 {
allowedUserIDs = r.filterUsersByAttributes(ctx, filters.Attributes)
var attrErr error
allowedUserIDs, attrErr = r.filterUsersByAttributes(ctx, filters.Attributes)
if attrErr != nil {
return nil, nil, attrErr
}
if len(allowedUserIDs) == 0 {
// No users match the attribute filters
return []service.User{}, paginationResultFromTotal(0, params), nil
@@ -262,56 +268,53 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
}
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) []int64 {
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 {
return nil
return nil, nil
}
// For each attribute filter, get the set of matching user IDs
// Then intersect all sets to get users matching ALL filters
var resultSet map[int64]struct{}
first := true
if r.sql == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
clauses := make([]string, 0, len(attrs))
args := make([]any, 0, len(attrs)*2+1)
argIndex := 1
for attrID, value := range attrs {
// Query user_attribute_values for this attribute
values, err := r.client.UserAttributeValue.Query().
Where(
userattributevalue.AttributeIDEQ(attrID),
userattributevalue.ValueContainsFold(value),
).
All(ctx)
if err != nil {
continue
}
currentSet := make(map[int64]struct{}, len(values))
for _, v := range values {
currentSet[v.UserID] = struct{}{}
}
if first {
resultSet = currentSet
first = false
} else {
// Intersect with previous results
for userID := range resultSet {
if _, ok := currentSet[userID]; !ok {
delete(resultSet, userID)
}
}
}
// Early exit if no users match
if len(resultSet) == 0 {
return nil
}
clauses = append(clauses, fmt.Sprintf("(attribute_id = $%d AND value ILIKE $%d)", argIndex, argIndex+1))
args = append(args, attrID, "%"+value+"%")
argIndex += 2
}
result := make([]int64, 0, len(resultSet))
for userID := range resultSet {
query := fmt.Sprintf(
`SELECT user_id
FROM user_attribute_values
WHERE %s
GROUP BY user_id
HAVING COUNT(DISTINCT attribute_id) = $%d`,
strings.Join(clauses, " OR "),
argIndex,
)
args = append(args, len(attrs))
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
result := make([]int64, 0)
for rows.Next() {
var userID int64
if scanErr := rows.Scan(&userID); scanErr != nil {
return nil, scanErr
}
result = append(result, userID)
}
return result
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {

View File

@@ -28,13 +28,12 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
NewAPIKeyRepository,
NewApiKeyRepository,
NewGroupRepository,
NewAccountRepository,
NewProxyRepository,
NewRedeemCodeRepository,
NewUsageLogRepository,
NewOpsRepository,
NewSettingRepository,
NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository,
@@ -43,7 +42,8 @@ var ProviderSet = wire.NewSet(
// Cache implementations
NewGatewayCache,
NewBillingCache,
NewAPIKeyCache,
NewApiKeyCache,
NewTempUnschedCache,
ProvideConcurrencyCache,
NewEmailCache,
NewIdentityCache,