Merge branch 'main' of https://github.com/mt21625457/aicodex2api
This commit is contained in:
@@ -43,6 +43,11 @@ type accountRepository struct {
|
||||
sql sqlExecutor // 原生 SQL 执行接口
|
||||
}
|
||||
|
||||
type tempUnschedSnapshot struct {
|
||||
until *time.Time
|
||||
reason string
|
||||
}
|
||||
|
||||
// NewAccountRepository 创建账户仓储实例。
|
||||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
|
||||
@@ -165,6 +170,11 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
|
||||
accountIDs = append(accountIDs, acc.ID)
|
||||
}
|
||||
|
||||
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -191,6 +201,10 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
|
||||
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
|
||||
out.AccountGroups = ags
|
||||
}
|
||||
if snap, ok := tempUnschedMap[entAcc.ID]; ok {
|
||||
out.TempUnschedulableUntil = snap.until
|
||||
out.TempUnschedulableReason = snap.reason
|
||||
}
|
||||
outByID[entAcc.ID] = out
|
||||
}
|
||||
|
||||
@@ -550,6 +564,7 @@ func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Acco
|
||||
Where(
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
tempUnschedulablePredicate(),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
).
|
||||
@@ -575,6 +590,7 @@ func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platf
|
||||
dbaccount.PlatformEQ(platform),
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
tempUnschedulablePredicate(),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
).
|
||||
@@ -607,6 +623,7 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
|
||||
dbaccount.PlatformIn(platforms...),
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
tempUnschedulablePredicate(),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
).
|
||||
@@ -648,6 +665,31 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET temp_unschedulable_until = $1,
|
||||
temp_unschedulable_reason = $2,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3
|
||||
AND deleted_at IS NULL
|
||||
AND (temp_unschedulable_until IS NULL OR temp_unschedulable_until < $1)
|
||||
`, until, reason, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET temp_unschedulable_until = NULL,
|
||||
temp_unschedulable_reason = NULL,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND deleted_at IS NULL
|
||||
`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
_, err := r.client.Account.Update().
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
@@ -808,6 +850,7 @@ func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID in
|
||||
now := time.Now()
|
||||
preds = append(preds,
|
||||
dbaccount.SchedulableEQ(true),
|
||||
tempUnschedulablePredicate(),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
)
|
||||
@@ -869,6 +912,10 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -894,12 +941,68 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
|
||||
if ags, ok := accountGroupsByAccount[acc.ID]; ok {
|
||||
out.AccountGroups = ags
|
||||
}
|
||||
if snap, ok := tempUnschedMap[acc.ID]; ok {
|
||||
out.TempUnschedulableUntil = snap.until
|
||||
out.TempUnschedulableReason = snap.reason
|
||||
}
|
||||
outAccounts = append(outAccounts, *out)
|
||||
}
|
||||
|
||||
return outAccounts, nil
|
||||
}
|
||||
|
||||
func tempUnschedulablePredicate() dbpredicate.Account {
|
||||
return dbpredicate.Account(func(s *entsql.Selector) {
|
||||
col := s.C("temp_unschedulable_until")
|
||||
s.Where(entsql.Or(
|
||||
entsql.IsNull(col),
|
||||
entsql.LTE(col, entsql.Expr("NOW()")),
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
|
||||
out := make(map[int64]tempUnschedSnapshot)
|
||||
if len(accountIDs) == 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT id, temp_unschedulable_until, temp_unschedulable_reason
|
||||
FROM accounts
|
||||
WHERE id = ANY($1)
|
||||
`, pq.Array(accountIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var until sql.NullTime
|
||||
var reason sql.NullString
|
||||
if err := rows.Scan(&id, &until, &reason); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var untilPtr *time.Time
|
||||
if until.Valid {
|
||||
tmp := until.Time
|
||||
untilPtr = &tmp
|
||||
}
|
||||
if reason.Valid {
|
||||
out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String}
|
||||
} else {
|
||||
out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
|
||||
proxyMap := make(map[int64]*service.Proxy)
|
||||
if len(proxyIDs) == 0 {
|
||||
|
||||
@@ -135,12 +135,12 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
name: "filter_by_type",
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeApiKey})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeAPIKey})
|
||||
},
|
||||
accType: service.AccountTypeApiKey,
|
||||
accType: service.AccountTypeAPIKey,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type)
|
||||
s.Require().Equal(service.AccountTypeAPIKey, accounts[0].Type)
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
|
||||
|
||||
userRepo := newUserRepositoryWithSQL(entClient, tx)
|
||||
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
|
||||
apiKeyRepo := NewApiKeyRepository(entClient)
|
||||
apiKeyRepo := NewAPIKeyRepository(entClient)
|
||||
|
||||
u := &service.User{
|
||||
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
|
||||
@@ -110,7 +110,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
|
||||
}
|
||||
require.NoError(t, userRepo.Create(ctx, u))
|
||||
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
|
||||
Name: "test key",
|
||||
|
||||
@@ -24,7 +24,7 @@ type apiKeyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
|
||||
func NewAPIKeyCache(rdb *redis.Client) service.APIKeyCache {
|
||||
return &apiKeyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,17 +16,17 @@ type apiKeyRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
|
||||
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
|
||||
return &apiKeyRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
|
||||
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
|
||||
// 默认过滤已软删除记录,避免删除后仍被查询到。
|
||||
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
|
||||
return r.client.APIKey.Query().Where(apikey.DeletedAtIsNil())
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
created, err := r.client.ApiKey.Create().
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
|
||||
created, err := r.client.APIKey.Create().
|
||||
SetUserID(key.UserID).
|
||||
SetKey(key.Key).
|
||||
SetName(key.Name).
|
||||
@@ -38,10 +38,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) erro
|
||||
key.CreatedAt = created.CreatedAt
|
||||
key.UpdatedAt = created.UpdatedAt
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
|
||||
return translatePersistenceError(err, nil, service.ErrAPIKeyExists)
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
WithUser().
|
||||
@@ -49,7 +49,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
|
||||
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
|
||||
// 相比 GetByID,此方法性能更优,因为:
|
||||
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
|
||||
// - 不加载完整的 ApiKey 实体及其关联数据(User、Group 等)
|
||||
// - 不加载完整的 API Key 实体及其关联数据(User、Group 等)
|
||||
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
|
||||
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
m, err := r.activeQuery().
|
||||
@@ -68,14 +68,14 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return 0, service.ErrApiKeyNotFound
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return m.UserID, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.KeyEQ(key)).
|
||||
WithUser().
|
||||
@@ -83,21 +83,21 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
|
||||
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
|
||||
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
|
||||
// 则会更新已删除的记录。
|
||||
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
|
||||
// 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。
|
||||
now := time.Now()
|
||||
builder := r.client.ApiKey.Update().
|
||||
builder := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
@@ -114,7 +114,7 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
|
||||
}
|
||||
if affected == 0 {
|
||||
// 更新影响行数为 0,说明记录不存在或已被软删除。
|
||||
return service.ErrApiKeyNotFound
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
// 使用同一时间戳回填,避免并发删除导致二次查询失败。
|
||||
@@ -124,18 +124,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
|
||||
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
|
||||
affected, err := r.client.ApiKey.Update().
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetDeletedAt(time.Now()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrApiKeyNotFound
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
exists, err := r.client.ApiKey.Query().
|
||||
exists, err := r.client.APIKey.Query().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Exist(mixins.SkipSoftDelete(ctx))
|
||||
if err != nil {
|
||||
@@ -144,12 +144,12 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
return service.ErrApiKeyNotFound
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
@@ -167,7 +167,7 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
@@ -180,7 +180,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
ids, err := r.client.ApiKey.Query().
|
||||
ids, err := r.client.APIKey.Query().
|
||||
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
|
||||
IDs(ctx)
|
||||
if err != nil {
|
||||
@@ -199,7 +199,7 @@ func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
@@ -217,7 +217,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
@@ -225,8 +225,8 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// SearchApiKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
q := r.activeQuery()
|
||||
if userID > 0 {
|
||||
q = q.Where(apikey.UserIDEQ(userID))
|
||||
@@ -241,7 +241,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
@@ -250,7 +250,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
|
||||
|
||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
n, err := r.client.ApiKey.Update().
|
||||
n, err := r.client.APIKey.Update().
|
||||
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx)
|
||||
@@ -263,11 +263,11 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey {
|
||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.ApiKey{
|
||||
out := &service.APIKey{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
|
||||
@@ -12,30 +12,30 @@ import (
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ApiKeyRepoSuite struct {
|
||||
type APIKeyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *apiKeyRepository
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) SetupTest() {
|
||||
func (s *APIKeyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository)
|
||||
s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository)
|
||||
}
|
||||
|
||||
func TestApiKeyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(ApiKeyRepoSuite))
|
||||
func TestAPIKeyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(APIKeyRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / GetByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCreate() {
|
||||
func (s *APIKeyRepoSuite) TestCreate() {
|
||||
user := s.mustCreateUser("create@test.com")
|
||||
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-create-test",
|
||||
Name: "Test Key",
|
||||
@@ -51,16 +51,16 @@ func (s *ApiKeyRepoSuite) TestCreate() {
|
||||
s.Require().Equal("sk-create-test", got.Key)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
|
||||
func (s *APIKeyRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey() {
|
||||
func (s *APIKeyRepoSuite) TestGetByKey() {
|
||||
user := s.mustCreateUser("getbykey@test.com")
|
||||
group := s.mustCreateGroup("g-key")
|
||||
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-getbykey",
|
||||
Name: "My Key",
|
||||
@@ -78,16 +78,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey() {
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
|
||||
func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
|
||||
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
|
||||
s.Require().Error(err, "expected error for non-existent key")
|
||||
}
|
||||
|
||||
// --- Update ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate() {
|
||||
func (s *APIKeyRepoSuite) TestUpdate() {
|
||||
user := s.mustCreateUser("update@test.com")
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update",
|
||||
Name: "Original",
|
||||
@@ -108,10 +108,10 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
|
||||
s.Require().Equal(service.StatusDisabled, got.Status)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
func (s *APIKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
user := s.mustCreateUser("cleargroup@test.com")
|
||||
group := s.mustCreateGroup("g-clear")
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-clear-group",
|
||||
Name: "Group Key",
|
||||
@@ -131,9 +131,9 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestDelete() {
|
||||
func (s *APIKeyRepoSuite) TestDelete() {
|
||||
user := s.mustCreateUser("delete@test.com")
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-delete",
|
||||
Name: "Delete Me",
|
||||
@@ -150,7 +150,7 @@ func (s *ApiKeyRepoSuite) TestDelete() {
|
||||
|
||||
// --- ListByUserID / CountByUserID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID() {
|
||||
func (s *APIKeyRepoSuite) TestListByUserID() {
|
||||
user := s.mustCreateUser("listbyuser@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
|
||||
@@ -161,7 +161,7 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
user := s.mustCreateUser("paging@test.com")
|
||||
for i := 0; i < 5; i++ {
|
||||
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
|
||||
@@ -174,7 +174,7 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
s.Require().Equal(3, page.Pages)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByUserID() {
|
||||
func (s *APIKeyRepoSuite) TestCountByUserID() {
|
||||
user := s.mustCreateUser("count@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
|
||||
@@ -186,7 +186,7 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
|
||||
|
||||
// --- ListByGroupID / CountByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByGroupID() {
|
||||
func (s *APIKeyRepoSuite) TestListByGroupID() {
|
||||
user := s.mustCreateUser("listbygroup@test.com")
|
||||
group := s.mustCreateGroup("g-list")
|
||||
|
||||
@@ -202,7 +202,7 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
|
||||
s.Require().NotNil(keys[0].User)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
|
||||
func (s *APIKeyRepoSuite) TestCountByGroupID() {
|
||||
user := s.mustCreateUser("countgroup@test.com")
|
||||
group := s.mustCreateGroup("g-count")
|
||||
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
|
||||
@@ -214,7 +214,7 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
|
||||
|
||||
// --- ExistsByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestExistsByKey() {
|
||||
func (s *APIKeyRepoSuite) TestExistsByKey() {
|
||||
user := s.mustCreateUser("exists@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
|
||||
|
||||
@@ -227,41 +227,41 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- SearchApiKeys ---
|
||||
// --- SearchAPIKeys ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys() {
|
||||
user := s.mustCreateUser("search@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "prod", 10)
|
||||
s.Require().NoError(err, "SearchAPIKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Contains(found[0].Name, "Production")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoKeyword() {
|
||||
user := s.mustCreateUser("searchnokw@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 2)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoUserID() {
|
||||
user := s.mustCreateUser("searchnouid@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, 0, "testkey", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 1)
|
||||
}
|
||||
|
||||
// --- ClearGroupIDByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
func (s *APIKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
user := s.mustCreateUser("cleargrp@test.com")
|
||||
group := s.mustCreateGroup("g-clear-bulk")
|
||||
|
||||
@@ -284,7 +284,7 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
|
||||
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
user := s.mustCreateUser("k@example.com")
|
||||
group := s.mustCreateGroup("g-k")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
|
||||
@@ -320,8 +320,8 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
s.Require().True(exists, "expected key to exist")
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "renam", 10)
|
||||
s.Require().NoError(err, "SearchAPIKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Equal(key.ID, found[0].ID)
|
||||
|
||||
@@ -346,7 +346,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
func (s *APIKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
s.T().Helper()
|
||||
|
||||
u, err := s.client.User.Create().
|
||||
@@ -359,7 +359,7 @@ func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
return userEntityToService(u)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
func (s *APIKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
s.T().Helper()
|
||||
|
||||
g, err := s.client.Group.Create().
|
||||
@@ -370,10 +370,10 @@ func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
return groupEntityToService(g)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.ApiKey {
|
||||
func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.APIKey {
|
||||
s.T().Helper()
|
||||
|
||||
k := &service.ApiKey{
|
||||
k := &service.APIKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: name,
|
||||
|
||||
@@ -5,28 +5,20 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ClaudeOAuthServiceSuite struct {
|
||||
suite.Suite
|
||||
srv *httptest.Server
|
||||
client *claudeOAuthService
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
// requestCapture holds captured request data for assertions in the main goroutine.
|
||||
type requestCapture struct {
|
||||
path string
|
||||
@@ -37,6 +29,12 @@ type requestCapture struct {
|
||||
contentType string
|
||||
}
|
||||
|
||||
func newTestReqClient(rt http.RoundTripper) *req.Client {
|
||||
c := req.C()
|
||||
c.GetClient().Transport = rt
|
||||
return c
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -83,17 +81,17 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.path = r.URL.Path
|
||||
captured.cookies = r.Cookies()
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
}), nil)
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.baseURL = s.srv.URL
|
||||
s.client.baseURL = "http://in-process"
|
||||
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
||||
|
||||
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
|
||||
|
||||
@@ -158,20 +156,20 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.path = r.URL.Path
|
||||
captured.method = r.Method
|
||||
captured.cookies = r.Cookies()
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
}), nil)
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.baseURL = s.srv.URL
|
||||
s.client.baseURL = "http://in-process"
|
||||
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
||||
|
||||
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
|
||||
|
||||
@@ -266,19 +264,19 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
}), nil)
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.tokenURL = s.srv.URL
|
||||
s.client.tokenURL = "http://in-process/token"
|
||||
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
||||
|
||||
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
|
||||
|
||||
@@ -362,19 +360,19 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
}), nil)
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.tokenURL = s.srv.URL
|
||||
s.client.tokenURL = "http://in-process/token"
|
||||
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
||||
|
||||
resp, err := s.client.RefreshToken(context.Background(), "rt", "")
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ type usageRequestCapture struct {
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
|
||||
var captured usageRequestCapture
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.authorization = r.Header.Get("Authorization")
|
||||
captured.anthropicBeta = r.Header.Get("anthropic-beta")
|
||||
|
||||
@@ -62,7 +62,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = io.WriteString(w, "nope")
|
||||
}))
|
||||
@@ -79,7 +79,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
}))
|
||||
@@ -95,7 +95,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Never respond - simulate slow server
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
@@ -309,7 +309,7 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
|
||||
|
||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
key := waitQueueKey(userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int()
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Package infrastructure 提供应用程序的基础设施层组件。
|
||||
// Package repository 提供应用程序的基础设施层组件。
|
||||
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
|
||||
package repository
|
||||
|
||||
|
||||
@@ -243,7 +243,7 @@ func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) *
|
||||
return a
|
||||
}
|
||||
|
||||
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *service.ApiKey {
|
||||
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *service.APIKey {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -257,7 +257,7 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *se
|
||||
k.Name = "default"
|
||||
}
|
||||
|
||||
create := client.ApiKey.Create().
|
||||
create := client.APIKey.Create().
|
||||
SetUserID(k.UserID).
|
||||
SetKey(k.Key).
|
||||
SetName(k.Name).
|
||||
|
||||
@@ -30,6 +30,7 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
|
||||
|
||||
// Use different OAuth clients based on oauthType:
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public)
|
||||
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client
|
||||
// - ai_studio: requires a user-provided OAuth client
|
||||
oauthCfgInput := geminicli.OAuthConfig{
|
||||
ClientID: c.cfg.Gemini.OAuth.ClientID,
|
||||
|
||||
@@ -56,7 +56,7 @@ func (s *GitHubReleaseServiceSuite) TearDownTest() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Length", "100")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
|
||||
@@ -73,7 +73,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Force chunked encoding (unknown Content-Length) by flushing headers before writing.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
@@ -98,7 +98,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
@@ -124,7 +124,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
@@ -139,7 +139,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("sum"))
|
||||
}))
|
||||
@@ -152,7 +152,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
@@ -163,7 +163,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
@@ -186,7 +186,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("content"))
|
||||
}))
|
||||
@@ -220,7 +220,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
|
||||
]
|
||||
}`
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path)
|
||||
require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept"))
|
||||
require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent"))
|
||||
@@ -246,7 +246,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
@@ -263,7 +263,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("not valid json"))
|
||||
}))
|
||||
@@ -280,7 +280,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
@@ -299,7 +299,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
|
||||
@@ -293,8 +293,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
|
||||
// 2. Clear group_id for api keys bound to this group.
|
||||
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
|
||||
// 与 ApiKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
|
||||
if _, err := txClient.ApiKey.Update().
|
||||
// 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
|
||||
if _, err := txClient.APIKey.Update().
|
||||
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx); err != nil {
|
||||
|
||||
@@ -3,7 +3,6 @@ package repository
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -99,7 +98,7 @@ func (s *HTTPUpstreamSuite) TestAcquireClient_OverLimitReturnsError() {
|
||||
// 验证空代理 URL 时请求直接发送到目标服务器
|
||||
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
|
||||
// 创建模拟上游服务器
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "direct")
|
||||
}))
|
||||
s.T().Cleanup(upstream.Close)
|
||||
@@ -121,7 +120,7 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
|
||||
// 用于接收代理请求的通道
|
||||
seen := make(chan string, 1)
|
||||
// 创建模拟代理服务器
|
||||
proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxySrv := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
seen <- r.RequestURI // 记录请求 URI
|
||||
_, _ = io.WriteString(w, "proxied")
|
||||
}))
|
||||
@@ -151,7 +150,7 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
|
||||
// TestDo_EmptyProxy_UsesDirect 测试空代理字符串
|
||||
// 验证空字符串代理等同于直连
|
||||
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "direct-empty")
|
||||
}))
|
||||
s.T().Cleanup(upstream.Close)
|
||||
|
||||
63
backend/internal/repository/inprocess_transport_test.go
Normal file
63
backend/internal/repository/inprocess_transport_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
|
||||
|
||||
// newInProcessTransport adapts an http.HandlerFunc into an http.RoundTripper without opening sockets.
|
||||
// It captures the request body (if any) and then rewinds it before invoking the handler.
|
||||
func newInProcessTransport(handler http.HandlerFunc, capture func(r *http.Request, body []byte)) http.RoundTripper {
|
||||
return roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
var body []byte
|
||||
if r.Body != nil {
|
||||
body, _ = io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
}
|
||||
if capture != nil {
|
||||
capture(r, body)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
handler(rec, r)
|
||||
return rec.Result(), nil
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
canListenOnce sync.Once
|
||||
canListen bool
|
||||
canListenErr error
|
||||
)
|
||||
|
||||
func localListenerAvailable() bool {
|
||||
canListenOnce.Do(func() {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
canListenErr = err
|
||||
canListen = false
|
||||
return
|
||||
}
|
||||
_ = ln.Close()
|
||||
canListen = true
|
||||
})
|
||||
return canListen
|
||||
}
|
||||
|
||||
func newLocalTestServer(tb testing.TB, handler http.Handler) *httptest.Server {
|
||||
tb.Helper()
|
||||
if !localListenerAvailable() {
|
||||
tb.Skipf("local listeners are not permitted in this environment: %v", canListenErr)
|
||||
}
|
||||
return httptest.NewServer(handler)
|
||||
}
|
||||
@@ -34,7 +34,7 @@ func (s *OpenAIOAuthServiceSuite) TearDownTest() {
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||
s.srv = httptest.NewServer(handler)
|
||||
s.srv = newLocalTestServer(s.T(), handler)
|
||||
s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
|
||||
}
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ func (s *PricingServiceSuite) TearDownTest() {
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||
s.srv = httptest.NewServer(handler)
|
||||
s.srv = newLocalTestServer(s.T(), handler)
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {
|
||||
|
||||
@@ -34,7 +34,7 @@ func (s *ProxyProbeServiceSuite) TearDownTest() {
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
|
||||
s.proxySrv = httptest.NewServer(handler)
|
||||
s.proxySrv = newLocalTestServer(s.T(), handler)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() {
|
||||
|
||||
@@ -41,8 +41,8 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
|
||||
|
||||
repo := NewApiKeyRepository(client)
|
||||
key := &service.ApiKey{
|
||||
repo := NewAPIKeyRepository(client)
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
|
||||
Name: "soft-delete",
|
||||
@@ -53,13 +53,13 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
|
||||
|
||||
_, err := repo.GetByID(ctx, key.ID)
|
||||
require.ErrorIs(t, err, service.ErrApiKeyNotFound, "deleted rows should be hidden by default")
|
||||
require.ErrorIs(t, err, service.ErrAPIKeyNotFound, "deleted rows should be hidden by default")
|
||||
|
||||
_, err = client.ApiKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
|
||||
_, err = client.APIKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
|
||||
require.Error(t, err, "default ent query should not see soft-deleted rows")
|
||||
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
|
||||
|
||||
got, err := client.ApiKey.Query().
|
||||
got, err := client.APIKey.Query().
|
||||
Where(apikey.IDEQ(key.ID)).
|
||||
Only(mixins.SkipSoftDelete(ctx))
|
||||
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
|
||||
@@ -73,8 +73,8 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
|
||||
|
||||
repo := NewApiKeyRepository(client)
|
||||
key := &service.ApiKey{
|
||||
repo := NewAPIKeyRepository(client)
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
|
||||
Name: "soft-delete2",
|
||||
@@ -93,8 +93,8 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
|
||||
|
||||
repo := NewApiKeyRepository(client)
|
||||
key := &service.ApiKey{
|
||||
repo := NewAPIKeyRepository(client)
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),
|
||||
Name: "soft-delete3",
|
||||
@@ -105,10 +105,10 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
|
||||
|
||||
// Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at.
|
||||
_, err := client.ApiKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
|
||||
_, err := client.APIKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
|
||||
require.NoError(t, err, "hard delete")
|
||||
|
||||
_, err = client.ApiKey.Query().
|
||||
_, err = client.APIKey.Query().
|
||||
Where(apikey.IDEQ(key.ID)).
|
||||
Only(mixins.SkipSoftDelete(ctx))
|
||||
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")
|
||||
|
||||
91
backend/internal/repository/temp_unsched_cache.go
Normal file
91
backend/internal/repository/temp_unsched_cache.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const tempUnschedPrefix = "temp_unsched:account:"
|
||||
|
||||
var tempUnschedSetScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local new_until = tonumber(ARGV[1])
|
||||
local new_value = ARGV[2]
|
||||
local new_ttl = tonumber(ARGV[3])
|
||||
|
||||
local existing = redis.call('GET', key)
|
||||
if existing then
|
||||
local ok, existing_data = pcall(cjson.decode, existing)
|
||||
if ok and existing_data and existing_data.until_unix then
|
||||
local existing_until = tonumber(existing_data.until_unix)
|
||||
if existing_until and new_until <= existing_until then
|
||||
return 0
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
redis.call('SET', key, new_value, 'EX', new_ttl)
|
||||
return 1
|
||||
`)
|
||||
|
||||
type tempUnschedCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewTempUnschedCache(rdb *redis.Client) service.TempUnschedCache {
|
||||
return &tempUnschedCache{rdb: rdb}
|
||||
}
|
||||
|
||||
// SetTempUnsched 设置临时不可调度状态(只延长不缩短)
|
||||
func (c *tempUnschedCache) SetTempUnsched(ctx context.Context, accountID int64, state *service.TempUnschedState) error {
|
||||
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
|
||||
|
||||
stateJSON, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal state: %w", err)
|
||||
}
|
||||
|
||||
ttl := time.Until(time.Unix(state.UntilUnix, 0))
|
||||
if ttl <= 0 {
|
||||
return nil // 已过期,不设置
|
||||
}
|
||||
|
||||
ttlSeconds := int(ttl.Seconds())
|
||||
if ttlSeconds < 1 {
|
||||
ttlSeconds = 1
|
||||
}
|
||||
|
||||
_, err = tempUnschedSetScript.Run(ctx, c.rdb, []string{key}, state.UntilUnix, string(stateJSON), ttlSeconds).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
// GetTempUnsched 获取临时不可调度状态
|
||||
func (c *tempUnschedCache) GetTempUnsched(ctx context.Context, accountID int64) (*service.TempUnschedState, error) {
|
||||
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
|
||||
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var state service.TempUnschedState
|
||||
if err := json.Unmarshal([]byte(val), &state); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal state: %w", err)
|
||||
}
|
||||
|
||||
return &state, nil
|
||||
}
|
||||
|
||||
// DeleteTempUnsched 删除临时不可调度状态
|
||||
func (c *tempUnschedCache) DeleteTempUnsched(ctx context.Context, accountID int64) error {
|
||||
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
@@ -3,9 +3,9 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
type TurnstileServiceSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
srv *httptest.Server
|
||||
verifier *turnstileVerifier
|
||||
received chan url.Values
|
||||
}
|
||||
@@ -31,21 +30,15 @@ func (s *TurnstileServiceSuite) SetupTest() {
|
||||
s.verifier = verifier
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
func (s *TurnstileServiceSuite) setupTransport(handler http.HandlerFunc) {
|
||||
s.verifier.verifyURL = "http://in-process/turnstile"
|
||||
s.verifier.httpClient = &http.Client{
|
||||
Transport: newInProcessTransport(handler, nil),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||
s.srv = httptest.NewServer(handler)
|
||||
s.verifier.verifyURL = s.srv.URL
|
||||
s.verifier.httpClient = s.srv.Client()
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Capture form data in main goroutine context later
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
values, _ := url.ParseQuery(string(body))
|
||||
@@ -73,7 +66,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
|
||||
var contentType string
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
contentType = r.Header.Get("Content-Type")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
|
||||
@@ -85,7 +78,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
values, _ := url.ParseQuery(string(body))
|
||||
s.received <- values
|
||||
@@ -106,15 +99,19 @@ func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
s.srv.Close()
|
||||
s.verifier.verifyURL = "http://in-process/turnstile"
|
||||
s.verifier.httpClient = &http.Client{
|
||||
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("dial failed")
|
||||
}),
|
||||
}
|
||||
|
||||
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||
require.Error(s.T(), err, "expected error when server is closed")
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-valid-json")
|
||||
}))
|
||||
@@ -124,7 +121,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{
|
||||
Success: false,
|
||||
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -60,9 +61,16 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
|
||||
return requestCount / 5, tokenCount / 5, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) error {
|
||||
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||||
if log == nil {
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL,保证与其他更新同事务。
|
||||
// 无事务时回退到默认的 *sql.DB 执行器。
|
||||
sqlq := r.sql
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
sqlq = tx.Client()
|
||||
}
|
||||
|
||||
createdAt := log.CreatedAt
|
||||
@@ -70,6 +78,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
createdAt = time.Now()
|
||||
}
|
||||
|
||||
requestID := strings.TrimSpace(log.RequestID)
|
||||
log.RequestID = requestID
|
||||
|
||||
rateMultiplier := log.RateMultiplier
|
||||
|
||||
query := `
|
||||
@@ -107,6 +118,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
`
|
||||
|
||||
@@ -115,11 +127,16 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
duration := nullInt(log.DurationMs)
|
||||
firstToken := nullInt(log.FirstTokenMs)
|
||||
|
||||
var requestIDArg any
|
||||
if requestID != "" {
|
||||
requestIDArg = requestID
|
||||
}
|
||||
|
||||
args := []any{
|
||||
log.UserID,
|
||||
log.ApiKeyID,
|
||||
log.APIKeyID,
|
||||
log.AccountID,
|
||||
log.RequestID,
|
||||
requestIDArg,
|
||||
log.Model,
|
||||
groupID,
|
||||
subscriptionID,
|
||||
@@ -142,11 +159,20 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
firstToken,
|
||||
createdAt,
|
||||
}
|
||||
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||
return err
|
||||
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) && requestID != "" {
|
||||
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
|
||||
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil {
|
||||
return false, err
|
||||
}
|
||||
log.RateMultiplier = rateMultiplier
|
||||
return false, nil
|
||||
} else {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
log.RateMultiplier = rateMultiplier
|
||||
return nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
|
||||
@@ -183,7 +209,7 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
|
||||
return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params)
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params)
|
||||
}
|
||||
|
||||
@@ -270,8 +296,8 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
r.sql,
|
||||
apiKeyStatsQuery,
|
||||
[]any{service.StatusActive},
|
||||
&stats.TotalApiKeys,
|
||||
&stats.ActiveApiKeys,
|
||||
&stats.TotalAPIKeys,
|
||||
&stats.ActiveAPIKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -418,8 +444,8 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetApiKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
|
||||
func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
// GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
|
||||
func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
query := `
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
@@ -623,7 +649,7 @@ func resolveUsageStatsTimezone() string {
|
||||
return "UTC"
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
|
||||
return logs, nil, err
|
||||
@@ -709,11 +735,11 @@ type ModelStat = usagestats.ModelStat
|
||||
// UserUsageTrendPoint represents user usage trend data point
|
||||
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
|
||||
|
||||
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
||||
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
|
||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||
type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
|
||||
|
||||
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
|
||||
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) {
|
||||
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
|
||||
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
@@ -755,10 +781,10 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
|
||||
}
|
||||
}()
|
||||
|
||||
results = make([]ApiKeyUsageTrendPoint, 0)
|
||||
results = make([]APIKeyUsageTrendPoint, 0)
|
||||
for rows.Next() {
|
||||
var row ApiKeyUsageTrendPoint
|
||||
if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
|
||||
var row APIKeyUsageTrendPoint
|
||||
if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, row)
|
||||
@@ -844,7 +870,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
r.sql,
|
||||
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
|
||||
[]any{userID},
|
||||
&stats.TotalApiKeys,
|
||||
&stats.TotalAPIKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -853,7 +879,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
r.sql,
|
||||
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
|
||||
[]any{userID, service.StatusActive},
|
||||
&stats.ActiveApiKeys,
|
||||
&stats.ActiveAPIKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1023,9 +1049,9 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
|
||||
args = append(args, filters.UserID)
|
||||
}
|
||||
if filters.ApiKeyID > 0 {
|
||||
if filters.APIKeyID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
|
||||
args = append(args, filters.ApiKeyID)
|
||||
args = append(args, filters.APIKeyID)
|
||||
}
|
||||
if filters.AccountID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
|
||||
@@ -1145,18 +1171,18 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// BatchApiKeyUsageStats represents usage stats for a single API key
|
||||
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
|
||||
// BatchAPIKeyUsageStats represents usage stats for a single API key
|
||||
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
|
||||
|
||||
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
|
||||
func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
|
||||
result := make(map[int64]*BatchApiKeyUsageStats)
|
||||
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
|
||||
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
|
||||
result := make(map[int64]*BatchAPIKeyUsageStats)
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for _, id := range apiKeyIDs {
|
||||
result[id] = &BatchApiKeyUsageStats{ApiKeyID: id}
|
||||
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
|
||||
}
|
||||
|
||||
query := `
|
||||
@@ -1582,7 +1608,7 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
apiKeys, err := r.loadApiKeys(ctx, ids.apiKeyIDs)
|
||||
apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1603,8 +1629,8 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
|
||||
if user, ok := users[logs[i].UserID]; ok {
|
||||
logs[i].User = user
|
||||
}
|
||||
if key, ok := apiKeys[logs[i].ApiKeyID]; ok {
|
||||
logs[i].ApiKey = key
|
||||
if key, ok := apiKeys[logs[i].APIKeyID]; ok {
|
||||
logs[i].APIKey = key
|
||||
}
|
||||
if acc, ok := accounts[logs[i].AccountID]; ok {
|
||||
logs[i].Account = acc
|
||||
@@ -1642,7 +1668,7 @@ func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs {
|
||||
|
||||
for i := range logs {
|
||||
userIDs[logs[i].UserID] = struct{}{}
|
||||
apiKeyIDs[logs[i].ApiKeyID] = struct{}{}
|
||||
apiKeyIDs[logs[i].APIKeyID] = struct{}{}
|
||||
accountIDs[logs[i].AccountID] = struct{}{}
|
||||
if logs[i].GroupID != nil {
|
||||
groupIDs[*logs[i].GroupID] = struct{}{}
|
||||
@@ -1676,12 +1702,12 @@ func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[in
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) loadApiKeys(ctx context.Context, ids []int64) (map[int64]*service.ApiKey, error) {
|
||||
out := make(map[int64]*service.ApiKey)
|
||||
func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) {
|
||||
out := make(map[int64]*service.APIKey)
|
||||
if len(ids) == 0 {
|
||||
return out, nil
|
||||
}
|
||||
models, err := r.client.ApiKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
|
||||
models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1800,7 +1826,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
log := &service.UsageLog{
|
||||
ID: id,
|
||||
UserID: userID,
|
||||
ApiKeyID: apiKeyID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
Model: model,
|
||||
InputTokens: inputTokens,
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
@@ -35,11 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(UsageLogRepoSuite))
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.ApiKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
|
||||
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.New().String(), // Generate unique RequestID for each log
|
||||
Model: "claude-3",
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
@@ -47,7 +50,8 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A
|
||||
ActualCost: cost,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log))
|
||||
_, err := s.repo.Create(s.ctx, log)
|
||||
s.Require().NoError(err)
|
||||
return log
|
||||
}
|
||||
|
||||
@@ -55,12 +59,12 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A
|
||||
|
||||
func (s *UsageLogRepoSuite) TestCreate() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "create@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-create", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-create"})
|
||||
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
@@ -69,14 +73,14 @@ func (s *UsageLogRepoSuite) TestCreate() {
|
||||
ActualCost: 0.4,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, log)
|
||||
_, err := s.repo.Create(s.ctx, log)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(log.ID)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid"})
|
||||
|
||||
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -96,7 +100,7 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestDelete() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "delete@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-delete"})
|
||||
|
||||
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -112,7 +116,7 @@ func (s *UsageLogRepoSuite) TestDelete() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByUser() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyuser@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyuser"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -124,18 +128,18 @@ func (s *UsageLogRepoSuite) TestListByUser() {
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
// --- ListByApiKey ---
|
||||
// --- ListByAPIKey ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByApiKey() {
|
||||
func (s *UsageLogRepoSuite) TestListByAPIKey() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyapikey@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyapikey"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
|
||||
|
||||
logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByApiKey")
|
||||
logs, page, err := s.repo.ListByAPIKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByAPIKey")
|
||||
s.Require().Len(logs, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
@@ -144,7 +148,7 @@ func (s *UsageLogRepoSuite) TestListByApiKey() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByAccount() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyaccount@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyaccount"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -159,7 +163,7 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -179,7 +183,7 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filters"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -211,8 +215,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
})
|
||||
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-ul"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
|
||||
mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
|
||||
mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
|
||||
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
accNormal := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-normal", Schedulable: true})
|
||||
@@ -223,7 +227,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
d1, d2, d3 := 100, 200, 300
|
||||
logToday := &service.UsageLog{
|
||||
UserID: userToday.ID,
|
||||
ApiKeyID: apiKey1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: accNormal.ID,
|
||||
Model: "claude-3",
|
||||
GroupID: &group.ID,
|
||||
@@ -236,11 +240,12 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
DurationMs: &d1,
|
||||
CreatedAt: maxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday")
|
||||
_, err = s.repo.Create(s.ctx, logToday)
|
||||
s.Require().NoError(err, "Create logToday")
|
||||
|
||||
logOld := &service.UsageLog{
|
||||
UserID: userOld.ID,
|
||||
ApiKeyID: apiKey1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: accNormal.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 5,
|
||||
@@ -250,11 +255,12 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
DurationMs: &d2,
|
||||
CreatedAt: todayStart.Add(-1 * time.Hour),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld")
|
||||
_, err = s.repo.Create(s.ctx, logOld)
|
||||
s.Require().NoError(err, "Create logOld")
|
||||
|
||||
logPerf := &service.UsageLog{
|
||||
UserID: userToday.ID,
|
||||
ApiKeyID: apiKey1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: accNormal.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 1,
|
||||
@@ -264,7 +270,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
DurationMs: &d3,
|
||||
CreatedAt: now.Add(-30 * time.Second),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, logPerf), "Create logPerf")
|
||||
_, err = s.repo.Create(s.ctx, logPerf)
|
||||
s.Require().NoError(err, "Create logPerf")
|
||||
|
||||
stats, err := s.repo.GetDashboardStats(s.ctx)
|
||||
s.Require().NoError(err, "GetDashboardStats")
|
||||
@@ -272,8 +279,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch")
|
||||
s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch")
|
||||
s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch")
|
||||
s.Require().Equal(baseStats.TotalApiKeys+2, stats.TotalApiKeys, "TotalApiKeys mismatch")
|
||||
s.Require().Equal(baseStats.ActiveApiKeys+1, stats.ActiveApiKeys, "ActiveApiKeys mismatch")
|
||||
s.Require().Equal(baseStats.TotalAPIKeys+2, stats.TotalAPIKeys, "TotalAPIKeys mismatch")
|
||||
s.Require().Equal(baseStats.ActiveAPIKeys+1, stats.ActiveAPIKeys, "ActiveAPIKeys mismatch")
|
||||
s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch")
|
||||
s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch")
|
||||
s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch")
|
||||
@@ -300,14 +307,14 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userdash@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userdash"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "GetUserDashboardStats")
|
||||
s.Require().Equal(int64(1), stats.TotalApiKeys)
|
||||
s.Require().Equal(int64(1), stats.TotalAPIKeys)
|
||||
s.Require().Equal(int64(1), stats.TotalRequests)
|
||||
}
|
||||
|
||||
@@ -315,7 +322,7 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctoday@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -331,8 +338,8 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
||||
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
||||
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batch"})
|
||||
|
||||
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
@@ -351,24 +358,24 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
|
||||
s.Require().Empty(stats)
|
||||
}
|
||||
|
||||
// --- GetBatchApiKeyUsageStats ---
|
||||
// --- GetBatchAPIKeyUsageStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "batchkey@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batchkey"})
|
||||
|
||||
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||
|
||||
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
|
||||
s.Require().NoError(err, "GetBatchApiKeyUsageStats")
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
|
||||
s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
|
||||
s.Require().Len(stats, 2)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
|
||||
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{})
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(stats)
|
||||
}
|
||||
@@ -377,7 +384,7 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "global@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-global", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-global"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -402,7 +409,7 @@ func maxTime(a, b time.Time) time.Time {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "timerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-timerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -417,11 +424,11 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
|
||||
s.Require().Len(logs, 2)
|
||||
}
|
||||
|
||||
// --- ListByApiKeyAndTimeRange ---
|
||||
// --- ListByAPIKeyAndTimeRange ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
|
||||
func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -431,8 +438,8 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
|
||||
s.Require().NoError(err, "ListByApiKeyAndTimeRange")
|
||||
logs, _, err := s.repo.ListByAPIKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
|
||||
s.Require().NoError(err, "ListByAPIKeyAndTimeRange")
|
||||
s.Require().Len(logs, 2)
|
||||
}
|
||||
|
||||
@@ -440,7 +447,7 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-acctimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -459,7 +466,7 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modeltimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modeltimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -467,7 +474,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
// Create logs with different models
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 10,
|
||||
@@ -476,11 +483,12 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: base,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||
_, err := s.repo.Create(s.ctx, log1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 15,
|
||||
@@ -489,11 +497,12 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
ActualCost: 0.6,
|
||||
CreatedAt: base.Add(30 * time.Minute),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||
_, err = s.repo.Create(s.ctx, log2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
log3 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-sonnet",
|
||||
InputTokens: 20,
|
||||
@@ -502,7 +511,8 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
ActualCost: 0.7,
|
||||
CreatedAt: base.Add(1 * time.Hour),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log3))
|
||||
_, err = s.repo.Create(s.ctx, log3)
|
||||
s.Require().NoError(err)
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
@@ -515,7 +525,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "windowstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-windowstats"})
|
||||
|
||||
now := time.Now()
|
||||
@@ -535,7 +545,7 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrend"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -552,7 +562,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrendhourly@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrendhourly"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -571,7 +581,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -579,7 +589,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
// Create logs with different models
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 100,
|
||||
@@ -588,11 +598,12 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: base,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||
_, err := s.repo.Create(s.ctx, log1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-sonnet",
|
||||
InputTokens: 50,
|
||||
@@ -601,7 +612,8 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
ActualCost: 0.2,
|
||||
CreatedAt: base.Add(1 * time.Hour),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||
_, err = s.repo.Create(s.ctx, log2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
@@ -618,7 +630,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -646,7 +658,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters-h@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters-h"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -665,14 +677,14 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelfilters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelfilters"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 100,
|
||||
@@ -681,11 +693,12 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: base,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||
_, err := s.repo.Create(s.ctx, log1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-sonnet",
|
||||
InputTokens: 50,
|
||||
@@ -694,7 +707,8 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
ActualCost: 0.2,
|
||||
CreatedAt: base.Add(1 * time.Hour),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||
_, err = s.repo.Create(s.ctx, log2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
@@ -719,7 +733,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "accstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-accstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
|
||||
@@ -727,7 +741,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
// Create logs on different days
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 100,
|
||||
@@ -736,11 +750,12 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
ActualCost: 0.4,
|
||||
CreatedAt: base.Add(12 * time.Hour),
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||
_, err := s.repo.Create(s.ctx, log1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-sonnet",
|
||||
InputTokens: 50,
|
||||
@@ -749,7 +764,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
ActualCost: 0.15,
|
||||
CreatedAt: base.Add(36 * time.Hour), // next day
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||
_, err = s.repo.Create(s.ctx, log2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
startTime := base
|
||||
endTime := base.Add(72 * time.Hour)
|
||||
@@ -782,8 +798,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
|
||||
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrends"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -799,12 +815,12 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
|
||||
s.Require().GreaterOrEqual(len(trend), 2)
|
||||
}
|
||||
|
||||
// --- GetApiKeyUsageTrend ---
|
||||
// --- GetAPIKeyUsageTrend ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
|
||||
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -815,14 +831,14 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(48 * time.Hour)
|
||||
|
||||
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
|
||||
s.Require().NoError(err, "GetApiKeyUsageTrend")
|
||||
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
|
||||
s.Require().NoError(err, "GetAPIKeyUsageTrend")
|
||||
s.Require().GreaterOrEqual(len(trend), 2)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
|
||||
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -832,8 +848,8 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(3 * time.Hour)
|
||||
|
||||
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
|
||||
s.Require().NoError(err, "GetApiKeyUsageTrend hourly")
|
||||
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
|
||||
s.Require().NoError(err, "GetAPIKeyUsageTrend hourly")
|
||||
s.Require().Len(trend, 2)
|
||||
}
|
||||
|
||||
@@ -841,12 +857,12 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterskey@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterskey"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID}
|
||||
filters := usagestats.UsageLogFilters{APIKeyID: apiKey.ID}
|
||||
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
|
||||
s.Require().NoError(err, "ListWithFilters apiKey")
|
||||
s.Require().Len(logs, 1)
|
||||
@@ -855,7 +871,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterstime@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterstime"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -874,7 +890,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterscombined@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterscombined"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -885,7 +901,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
StartTime: &startTime,
|
||||
EndTime: &endTime,
|
||||
}
|
||||
|
||||
@@ -4,12 +4,13 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -17,14 +18,15 @@ import (
|
||||
|
||||
type userRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
|
||||
return newUserRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func newUserRepositoryWithSQL(client *dbent.Client, _ sqlExecutor) *userRepository {
|
||||
return &userRepository{client: client}
|
||||
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
|
||||
return &userRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
|
||||
@@ -194,7 +196,11 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
// If attribute filters are specified, we need to filter by user IDs first
|
||||
var allowedUserIDs []int64
|
||||
if len(filters.Attributes) > 0 {
|
||||
allowedUserIDs = r.filterUsersByAttributes(ctx, filters.Attributes)
|
||||
var attrErr error
|
||||
allowedUserIDs, attrErr = r.filterUsersByAttributes(ctx, filters.Attributes)
|
||||
if attrErr != nil {
|
||||
return nil, nil, attrErr
|
||||
}
|
||||
if len(allowedUserIDs) == 0 {
|
||||
// No users match the attribute filters
|
||||
return []service.User{}, paginationResultFromTotal(0, params), nil
|
||||
@@ -262,56 +268,53 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
}
|
||||
|
||||
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
|
||||
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) []int64 {
|
||||
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
|
||||
if len(attrs) == 0 {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// For each attribute filter, get the set of matching user IDs
|
||||
// Then intersect all sets to get users matching ALL filters
|
||||
var resultSet map[int64]struct{}
|
||||
first := true
|
||||
if r.sql == nil {
|
||||
return nil, fmt.Errorf("sql executor is not configured")
|
||||
}
|
||||
|
||||
clauses := make([]string, 0, len(attrs))
|
||||
args := make([]any, 0, len(attrs)*2+1)
|
||||
argIndex := 1
|
||||
for attrID, value := range attrs {
|
||||
// Query user_attribute_values for this attribute
|
||||
values, err := r.client.UserAttributeValue.Query().
|
||||
Where(
|
||||
userattributevalue.AttributeIDEQ(attrID),
|
||||
userattributevalue.ValueContainsFold(value),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
currentSet := make(map[int64]struct{}, len(values))
|
||||
for _, v := range values {
|
||||
currentSet[v.UserID] = struct{}{}
|
||||
}
|
||||
|
||||
if first {
|
||||
resultSet = currentSet
|
||||
first = false
|
||||
} else {
|
||||
// Intersect with previous results
|
||||
for userID := range resultSet {
|
||||
if _, ok := currentSet[userID]; !ok {
|
||||
delete(resultSet, userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Early exit if no users match
|
||||
if len(resultSet) == 0 {
|
||||
return nil
|
||||
}
|
||||
clauses = append(clauses, fmt.Sprintf("(attribute_id = $%d AND value ILIKE $%d)", argIndex, argIndex+1))
|
||||
args = append(args, attrID, "%"+value+"%")
|
||||
argIndex += 2
|
||||
}
|
||||
|
||||
result := make([]int64, 0, len(resultSet))
|
||||
for userID := range resultSet {
|
||||
query := fmt.Sprintf(
|
||||
`SELECT user_id
|
||||
FROM user_attribute_values
|
||||
WHERE %s
|
||||
GROUP BY user_id
|
||||
HAVING COUNT(DISTINCT attribute_id) = $%d`,
|
||||
strings.Join(clauses, " OR "),
|
||||
argIndex,
|
||||
)
|
||||
args = append(args, len(attrs))
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
result := make([]int64, 0)
|
||||
for rows.Next() {
|
||||
var userID int64
|
||||
if scanErr := rows.Scan(&userID); scanErr != nil {
|
||||
return nil, scanErr
|
||||
}
|
||||
result = append(result, userID)
|
||||
}
|
||||
return result
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
|
||||
@@ -28,7 +28,7 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
|
||||
// ProviderSet is the Wire provider set for all repositories
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewUserRepository,
|
||||
NewApiKeyRepository,
|
||||
NewAPIKeyRepository,
|
||||
NewGroupRepository,
|
||||
NewAccountRepository,
|
||||
NewProxyRepository,
|
||||
@@ -42,7 +42,8 @@ var ProviderSet = wire.NewSet(
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
NewBillingCache,
|
||||
NewApiKeyCache,
|
||||
NewAPIKeyCache,
|
||||
NewTempUnschedCache,
|
||||
ProvideConcurrencyCache,
|
||||
NewEmailCache,
|
||||
NewIdentityCache,
|
||||
|
||||
Reference in New Issue
Block a user