refactor(数据库): 迁移持久层到 Ent 并清理 GORM
将仓储层/基础设施改为 Ent + 原生 SQL 执行路径,并移除 AutoMigrate 与 GORM 依赖。 重构内容包括: - 仓储层改用 Ent/SQL(含 usage_log/account 等复杂查询),统一错误映射 - 基础设施与 setup 初始化切换为 Ent + SQL migrations - 集成测试与 fixtures 迁移到 Ent 事务模型 - 清理遗留 GORM 模型/依赖,补充迁移与文档说明 - 增加根目录 Makefile 便于前后端编译 测试: - go test -tags unit ./... - go test -tags integration ./...
This commit is contained in:
@@ -195,8 +195,10 @@ func setDefaults() {
|
||||
viper.SetDefault("jwt.expire_hour", 24)
|
||||
|
||||
// Default
|
||||
viper.SetDefault("default.admin_email", "admin@sub2api.com")
|
||||
viper.SetDefault("default.admin_password", "admin123")
|
||||
// Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
|
||||
// Do not ship fixed defaults here to avoid insecure "known credentials" in production.
|
||||
viper.SetDefault("default.admin_email", "")
|
||||
viper.SetDefault("default.admin_password", "")
|
||||
viper.SetDefault("default.user_concurrency", 5)
|
||||
viper.SetDefault("default.user_balance", 0)
|
||||
viper.SetDefault("default.api_key_prefix", "sk-")
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// InitDB 初始化数据库连接
|
||||
func InitDB(cfg *config.Config) (*gorm.DB, error) {
|
||||
// 初始化时区(在数据库连接之前,确保时区设置正确)
|
||||
if err := timezone.Init(cfg.Timezone); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gormConfig := &gorm.Config{}
|
||||
if cfg.Server.Mode == "debug" {
|
||||
gormConfig.Logger = logger.Default.LogMode(logger.Info)
|
||||
}
|
||||
|
||||
// 使用带时区的 DSN 连接数据库
|
||||
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 自动迁移(始终执行,确保数据库结构与代码同步)
|
||||
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
|
||||
if err := repository.AutoMigrate(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
65
backend/internal/infrastructure/ent.go
Normal file
65
backend/internal/infrastructure/ent.go
Normal file
@@ -0,0 +1,65 @@
|
||||
// Package infrastructure 提供应用程序的基础设施层组件。
|
||||
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/migrations"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "github.com/lib/pq" // PostgreSQL 驱动,通过副作用导入注册驱动
|
||||
)
|
||||
|
||||
// InitEnt 初始化 Ent ORM 客户端并返回客户端实例和底层的 *sql.DB。
|
||||
//
|
||||
// 该函数执行以下操作:
|
||||
// 1. 初始化全局时区设置,确保时间处理一致性
|
||||
// 2. 建立 PostgreSQL 数据库连接
|
||||
// 3. 自动执行数据库迁移,确保 schema 与代码同步
|
||||
// 4. 创建并返回 Ent 客户端实例
|
||||
//
|
||||
// 重要提示:调用者必须负责关闭返回的 ent.Client(关闭时会自动关闭底层的 driver/db)。
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 应用程序配置,包含数据库连接信息和时区设置
|
||||
//
|
||||
// 返回:
|
||||
// - *ent.Client: Ent ORM 客户端,用于执行数据库操作
|
||||
// - *sql.DB: 底层的 SQL 数据库连接,可用于直接执行原生 SQL
|
||||
// - error: 初始化过程中的错误
|
||||
func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
// 优先初始化时区设置,确保所有时间操作使用统一的时区。
|
||||
// 这对于跨时区部署和日志时间戳的一致性至关重要。
|
||||
if err := timezone.Init(cfg.Timezone); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 构建包含时区信息的数据库连接字符串 (DSN)。
|
||||
// 时区信息会传递给 PostgreSQL,确保数据库层面的时间处理正确。
|
||||
dsn := cfg.Database.DSNWithTimezone(cfg.Timezone)
|
||||
|
||||
// 使用 Ent 的 SQL 驱动打开 PostgreSQL 连接。
|
||||
// dialect.Postgres 指定使用 PostgreSQL 方言进行 SQL 生成。
|
||||
drv, err := entsql.Open(dialect.Postgres, dsn)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 确保数据库 schema 已准备就绪。
|
||||
// SQL 迁移文件是 schema 的权威来源(source of truth)。
|
||||
// 这种方式比 Ent 的自动迁移更可控,支持复杂的迁移场景。
|
||||
if err := applyMigrationsFS(context.Background(), drv.DB(), migrations.FS); err != nil {
|
||||
_ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
|
||||
client := ent.NewClient(ent.Driver(drv))
|
||||
return client, drv.DB(), nil
|
||||
}
|
||||
184
backend/internal/infrastructure/migrations_runner.go
Normal file
184
backend/internal/infrastructure/migrations_runner.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/migrations"
|
||||
)
|
||||
|
||||
// schemaMigrationsTableDDL 定义迁移记录表的 DDL。
|
||||
// 该表用于跟踪已应用的迁移文件及其校验和。
|
||||
// - filename: 迁移文件名,作为主键唯一标识每个迁移
|
||||
// - checksum: 文件内容的 SHA256 哈希值,用于检测迁移文件是否被篡改
|
||||
// - applied_at: 迁移应用时间戳
|
||||
const schemaMigrationsTableDDL = `
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
filename TEXT PRIMARY KEY,
|
||||
checksum TEXT NOT NULL,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
`
|
||||
|
||||
// migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。
|
||||
// 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。
|
||||
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
|
||||
const migrationsAdvisoryLockID int64 = 694208311321144027
|
||||
|
||||
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
|
||||
//
|
||||
// 该函数可以在每次应用启动时安全调用:
|
||||
// - 已应用的迁移会被自动跳过(通过校验 filename 判断)
|
||||
// - 如果迁移文件内容被修改(checksum 不匹配),会返回错误
|
||||
// - 使用 PostgreSQL Advisory Lock 确保多实例并发安全
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文,用于超时控制和取消
|
||||
// - db: 数据库连接
|
||||
//
|
||||
// 返回:
|
||||
// - error: 迁移过程中的任何错误
|
||||
func ApplyMigrations(ctx context.Context, db *sql.DB) error {
|
||||
if db == nil {
|
||||
return errors.New("nil sql db")
|
||||
}
|
||||
return applyMigrationsFS(ctx, db, migrations.FS)
|
||||
}
|
||||
|
||||
// applyMigrationsFS 是迁移执行的核心实现。
|
||||
// 它从指定的文件系统读取 SQL 迁移文件并按顺序应用。
|
||||
//
|
||||
// 迁移执行流程:
|
||||
// 1. 获取 PostgreSQL Advisory Lock,防止多实例并发迁移
|
||||
// 2. 确保 schema_migrations 表存在
|
||||
// 3. 按文件名排序读取所有 .sql 文件
|
||||
// 4. 对于每个迁移文件:
|
||||
// - 计算文件内容的 SHA256 校验和
|
||||
// - 检查该迁移是否已应用(通过 filename 查询)
|
||||
// - 如果已应用,验证校验和是否匹配
|
||||
// - 如果未应用,在事务中执行迁移并记录
|
||||
// 5. 释放 Advisory Lock
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - db: 数据库连接
|
||||
// - fsys: 包含迁移文件的文件系统(通常是 embed.FS)
|
||||
func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
if db == nil {
|
||||
return errors.New("nil sql db")
|
||||
}
|
||||
|
||||
// 获取分布式锁,确保多实例部署时只有一个实例执行迁移。
|
||||
// 这是 PostgreSQL 特有的 Advisory Lock 机制。
|
||||
if err := pgAdvisoryLock(ctx, db); err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
// 无论迁移是否成功,都要释放锁。
|
||||
// 使用 context.Background() 确保即使原 ctx 已取消也能释放锁。
|
||||
_ = pgAdvisoryUnlock(context.Background(), db)
|
||||
}()
|
||||
|
||||
// 创建迁移记录表(如果不存在)。
|
||||
// 该表记录所有已应用的迁移及其校验和。
|
||||
if _, err := db.ExecContext(ctx, schemaMigrationsTableDDL); err != nil {
|
||||
return fmt.Errorf("create schema_migrations: %w", err)
|
||||
}
|
||||
|
||||
// 获取所有 .sql 迁移文件并按文件名排序。
|
||||
// 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql)。
|
||||
files, err := fs.Glob(fsys, "*.sql")
|
||||
if err != nil {
|
||||
return fmt.Errorf("list migrations: %w", err)
|
||||
}
|
||||
sort.Strings(files) // 确保按文件名顺序执行迁移
|
||||
|
||||
for _, name := range files {
|
||||
// 读取迁移文件内容
|
||||
contentBytes, err := fs.ReadFile(fsys, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(string(contentBytes))
|
||||
if content == "" {
|
||||
continue // 跳过空文件
|
||||
}
|
||||
|
||||
// 计算文件内容的 SHA256 校验和,用于检测文件是否被修改。
|
||||
// 这是一种防篡改机制:如果有人修改了已应用的迁移文件,系统会拒绝启动。
|
||||
sum := sha256.Sum256([]byte(content))
|
||||
checksum := hex.EncodeToString(sum[:])
|
||||
|
||||
// 检查该迁移是否已经应用
|
||||
var existing string
|
||||
rowErr := db.QueryRowContext(ctx, "SELECT checksum FROM schema_migrations WHERE filename = $1", name).Scan(&existing)
|
||||
if rowErr == nil {
|
||||
// 迁移已应用,验证校验和是否匹配
|
||||
if existing != checksum {
|
||||
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
|
||||
// 正确的做法是创建新的迁移文件来进行变更。
|
||||
return fmt.Errorf("migration %s checksum mismatch (db=%s file=%s)", name, existing, checksum)
|
||||
}
|
||||
continue // 迁移已应用且校验和匹配,跳过
|
||||
}
|
||||
if !errors.Is(rowErr, sql.ErrNoRows) {
|
||||
return fmt.Errorf("check migration %s: %w", name, rowErr)
|
||||
}
|
||||
|
||||
// 迁移未应用,在事务中执行。
|
||||
// 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
// 执行迁移 SQL
|
||||
if _, err := tx.ExecContext(ctx, content); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("apply migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
// 记录迁移已完成,保存文件名和校验和
|
||||
if _, err := tx.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("record migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit(); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("commit migration %s: %w", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
|
||||
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
|
||||
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
|
||||
func pgAdvisoryLock(ctx context.Context, db *sql.DB) error {
|
||||
_, err := db.ExecContext(ctx, "SELECT pg_advisory_lock($1)", migrationsAdvisoryLockID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("acquire migrations lock: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// pgAdvisoryUnlock 释放 PostgreSQL Advisory Lock。
|
||||
// 必须在获取锁后确保释放,否则会阻塞其他实例的迁移操作。
|
||||
func pgAdvisoryUnlock(ctx context.Context, db *sql.DB) error {
|
||||
_, err := db.ExecContext(ctx, "SELECT pg_advisory_unlock($1)", migrationsAdvisoryLockID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("release migrations lock: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,25 +1,79 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
// ProviderSet 提供基础设施层的依赖
|
||||
// ProviderSet 是基础设施层的 Wire 依赖提供者集合。
|
||||
//
|
||||
// Wire 是 Google 开发的编译时依赖注入工具。ProviderSet 将相关的依赖提供函数
|
||||
// 组织在一起,便于在应用程序启动时自动组装依赖关系。
|
||||
//
|
||||
// 包含的提供者:
|
||||
// - ProvideEnt: 提供 Ent ORM 客户端
|
||||
// - ProvideSQLDB: 提供底层 SQL 数据库连接
|
||||
// - ProvideRedis: 提供 Redis 客户端
|
||||
var ProviderSet = wire.NewSet(
|
||||
ProvideDB,
|
||||
ProvideEnt,
|
||||
ProvideSQLDB,
|
||||
ProvideRedis,
|
||||
)
|
||||
|
||||
// ProvideDB 提供数据库连接
|
||||
func ProvideDB(cfg *config.Config) (*gorm.DB, error) {
|
||||
return InitDB(cfg)
|
||||
// ProvideEnt 为依赖注入提供 Ent 客户端。
|
||||
//
|
||||
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
|
||||
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
|
||||
//
|
||||
// 依赖:config.Config
|
||||
// 提供:*ent.Client
|
||||
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
|
||||
client, _, err := InitEnt(cfg)
|
||||
return client, err
|
||||
}
|
||||
|
||||
// ProvideRedis 提供 Redis 客户端
|
||||
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
|
||||
//
|
||||
// 某些 Repository 需要直接执行原生 SQL(如复杂的批量更新、聚合查询),
|
||||
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
|
||||
//
|
||||
// 设计说明:
|
||||
// - Ent 底层使用 sql.DB,通过 Driver 接口可以访问
|
||||
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
|
||||
//
|
||||
// 依赖:*ent.Client
|
||||
// 提供:*sql.DB
|
||||
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
|
||||
if client == nil {
|
||||
return nil, errors.New("nil ent client")
|
||||
}
|
||||
// 从 Ent 客户端获取底层驱动
|
||||
drv, ok := client.Driver().(*entsql.Driver)
|
||||
if !ok {
|
||||
return nil, errors.New("ent driver does not expose *sql.DB")
|
||||
}
|
||||
// 返回驱动持有的 sql.DB 实例
|
||||
return drv.DB(), nil
|
||||
}
|
||||
|
||||
// ProvideRedis 为依赖注入提供 Redis 客户端。
|
||||
//
|
||||
// Redis 用于:
|
||||
// - 分布式锁(如并发控制)
|
||||
// - 缓存(如用户会话、API 响应缓存)
|
||||
// - 速率限制
|
||||
// - 实时统计数据
|
||||
//
|
||||
// 依赖:config.Config
|
||||
// 提供:*redis.Client
|
||||
func ProvideRedis(cfg *config.Config) *redis.Client {
|
||||
return InitRedis(cfg)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,27 +4,31 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type AccountRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
tx *sql.Tx
|
||||
client *dbent.Client
|
||||
repo *accountRepository
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewAccountRepository(s.db).(*accountRepository)
|
||||
client, tx := testEntSQLTx(s.T())
|
||||
s.client = client
|
||||
s.tx = tx
|
||||
s.repo = newAccountRepositoryWithSQL(client, tx)
|
||||
}
|
||||
|
||||
func TestAccountRepoSuite(t *testing.T) {
|
||||
@@ -61,7 +65,7 @@ func (s *AccountRepoSuite) TestGetByID_NotFound() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate() {
|
||||
account := accountModelToService(mustCreateAccount(s.T(), s.db, &accountModel{Name: "original"}))
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "original"})
|
||||
|
||||
account.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
@@ -73,7 +77,7 @@ func (s *AccountRepoSuite) TestUpdate() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "to-delete"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
@@ -83,23 +87,23 @@ func (s *AccountRepoSuite) TestDelete() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
|
||||
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-del"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-del"})
|
||||
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
|
||||
|
||||
err := s.repo.Delete(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "Delete should cascade remove bindings")
|
||||
|
||||
var count int64
|
||||
s.db.Model(&accountGroupModel{}).Where("account_id = ?", account.ID).Count(&count)
|
||||
count, err := s.client.AccountGroup.Query().Where(accountgroup.AccountIDEQ(account.ID)).Count(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(count, "expected bindings to be removed")
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *AccountRepoSuite) TestList() {
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc2"})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc1"})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc2"})
|
||||
|
||||
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
@@ -110,7 +114,7 @@ func (s *AccountRepoSuite) TestList() {
|
||||
func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(db *gorm.DB)
|
||||
setup func(client *dbent.Client)
|
||||
platform string
|
||||
accType string
|
||||
status string
|
||||
@@ -120,9 +124,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
}{
|
||||
{
|
||||
name: "filter_by_platform",
|
||||
setup: func(db *gorm.DB) {
|
||||
mustCreateAccount(s.T(), db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic})
|
||||
mustCreateAccount(s.T(), db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI})
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI})
|
||||
},
|
||||
platform: service.PlatformOpenAI,
|
||||
wantCount: 1,
|
||||
@@ -132,9 +136,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
},
|
||||
{
|
||||
name: "filter_by_type",
|
||||
setup: func(db *gorm.DB) {
|
||||
mustCreateAccount(s.T(), db, &accountModel{Name: "t1", Type: service.AccountTypeOAuth})
|
||||
mustCreateAccount(s.T(), db, &accountModel{Name: "t2", Type: service.AccountTypeApiKey})
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeApiKey})
|
||||
},
|
||||
accType: service.AccountTypeApiKey,
|
||||
wantCount: 1,
|
||||
@@ -144,9 +148,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
},
|
||||
{
|
||||
name: "filter_by_status",
|
||||
setup: func(db *gorm.DB) {
|
||||
mustCreateAccount(s.T(), db, &accountModel{Name: "s1", Status: service.StatusActive})
|
||||
mustCreateAccount(s.T(), db, &accountModel{Name: "s2", Status: service.StatusDisabled})
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "s1", Status: service.StatusActive})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "s2", Status: service.StatusDisabled})
|
||||
},
|
||||
status: service.StatusDisabled,
|
||||
wantCount: 1,
|
||||
@@ -156,9 +160,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
},
|
||||
{
|
||||
name: "filter_by_search",
|
||||
setup: func(db *gorm.DB) {
|
||||
mustCreateAccount(s.T(), db, &accountModel{Name: "alpha-account"})
|
||||
mustCreateAccount(s.T(), db, &accountModel{Name: "beta-account"})
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "alpha-account"})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "beta-account"})
|
||||
},
|
||||
search: "alpha",
|
||||
wantCount: 1,
|
||||
@@ -171,11 +175,11 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
db := testTx(s.T())
|
||||
repo := NewAccountRepository(db).(*accountRepository)
|
||||
client, tx := testEntSQLTx(s.T())
|
||||
repo := newAccountRepositoryWithSQL(client, tx)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.setup(db)
|
||||
tt.setup(client)
|
||||
|
||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
|
||||
s.Require().NoError(err)
|
||||
@@ -190,11 +194,11 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
// --- ListByGroup / ListActive / ListByPlatform ---
|
||||
|
||||
func (s *AccountRepoSuite) TestListByGroup() {
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
|
||||
acc1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Status: service.StatusActive})
|
||||
acc2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Status: service.StatusActive})
|
||||
mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2)
|
||||
mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1)
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-list"})
|
||||
acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Status: service.StatusActive})
|
||||
acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Status: service.StatusActive})
|
||||
mustBindAccountToGroup(s.T(), s.client, acc1.ID, group.ID, 2)
|
||||
mustBindAccountToGroup(s.T(), s.client, acc2.ID, group.ID, 1)
|
||||
|
||||
accounts, err := s.repo.ListByGroup(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ListByGroup")
|
||||
@@ -204,8 +208,8 @@ func (s *AccountRepoSuite) TestListByGroup() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListActive() {
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "active1", Status: service.StatusActive})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "inactive1", Status: service.StatusDisabled})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "active1", Status: service.StatusActive})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "inactive1", Status: service.StatusDisabled})
|
||||
|
||||
accounts, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
@@ -214,8 +218,8 @@ func (s *AccountRepoSuite) TestListActive() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListByPlatform() {
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
|
||||
|
||||
accounts, err := s.repo.ListByPlatform(s.ctx, service.PlatformAnthropic)
|
||||
s.Require().NoError(err, "ListByPlatform")
|
||||
@@ -226,14 +230,14 @@ func (s *AccountRepoSuite) TestListByPlatform() {
|
||||
// --- Preload and VirtualFields ---
|
||||
|
||||
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
|
||||
proxy := mustCreateProxy(s.T(), s.client, &service.Proxy{Name: "p1"})
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g1"})
|
||||
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc1",
|
||||
ProxyID: &proxy.ID,
|
||||
})
|
||||
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
@@ -257,9 +261,9 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
|
||||
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
|
||||
|
||||
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
|
||||
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc"})
|
||||
g1 := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g1"})
|
||||
g2 := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc"})
|
||||
|
||||
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
|
||||
groups, err := s.repo.GetGroups(s.ctx, account.ID)
|
||||
@@ -279,9 +283,9 @@ func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-empty"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
|
||||
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-empty"})
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-empty"})
|
||||
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
|
||||
|
||||
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
|
||||
|
||||
@@ -294,14 +298,14 @@ func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulable() {
|
||||
now := time.Now()
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sched"})
|
||||
|
||||
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
|
||||
okAcc := mustCreateAccount(s.T(), s.client, &service.Account{Name: "ok", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.client, okAcc.ID, group.ID, 1)
|
||||
|
||||
future := now.Add(10 * time.Minute)
|
||||
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
|
||||
overloaded := mustCreateAccount(s.T(), s.client, &service.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||
mustBindAccountToGroup(s.T(), s.client, overloaded.ID, group.ID, 1)
|
||||
|
||||
sched, err := s.repo.ListSchedulable(s.ctx)
|
||||
s.Require().NoError(err, "ListSchedulable")
|
||||
@@ -312,17 +316,17 @@ func (s *AccountRepoSuite) TestListSchedulable() {
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
|
||||
now := time.Now()
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sched"})
|
||||
|
||||
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
|
||||
okAcc := mustCreateAccount(s.T(), s.client, &service.Account{Name: "ok", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.client, okAcc.ID, group.ID, 1)
|
||||
|
||||
future := now.Add(10 * time.Minute)
|
||||
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
|
||||
overloaded := mustCreateAccount(s.T(), s.client, &service.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||
mustBindAccountToGroup(s.T(), s.client, overloaded.ID, group.ID, 1)
|
||||
|
||||
rateLimited := mustCreateAccount(s.T(), s.db, &accountModel{Name: "rl", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1)
|
||||
rateLimited := mustCreateAccount(s.T(), s.client, &service.Account{Name: "rl", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.client, rateLimited.ID, group.ID, 1)
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
|
||||
|
||||
s.Require().NoError(s.repo.SetError(s.ctx, overloaded.ID, "boom"), "SetError")
|
||||
@@ -339,8 +343,8 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_Statu
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
|
||||
|
||||
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, service.PlatformAnthropic)
|
||||
s.Require().NoError(err)
|
||||
@@ -349,11 +353,11 @@ func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sp"})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sp"})
|
||||
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
|
||||
a2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.client, a1.ID, group.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.client, a2.ID, group.ID, 2)
|
||||
|
||||
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, service.PlatformAnthropic)
|
||||
s.Require().NoError(err)
|
||||
@@ -362,7 +366,7 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestSetSchedulable() {
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-sched", Schedulable: true})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
|
||||
|
||||
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
|
||||
|
||||
@@ -374,7 +378,7 @@ func (s *AccountRepoSuite) TestSetSchedulable() {
|
||||
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
|
||||
|
||||
func (s *AccountRepoSuite) TestSetOverloaded() {
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-over"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-over"})
|
||||
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
|
||||
@@ -386,7 +390,7 @@ func (s *AccountRepoSuite) TestSetOverloaded() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestSetRateLimited() {
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-rl"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-rl"})
|
||||
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
|
||||
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
|
||||
@@ -399,7 +403,7 @@ func (s *AccountRepoSuite) TestSetRateLimited() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestClearRateLimit() {
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-clear"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-clear"})
|
||||
until := time.Now().Add(1 * time.Hour)
|
||||
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
|
||||
@@ -416,7 +420,7 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
|
||||
// --- UpdateLastUsed ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateLastUsed() {
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-used"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-used"})
|
||||
s.Require().Nil(account.LastUsedAt)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
|
||||
@@ -429,7 +433,7 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() {
|
||||
// --- SetError ---
|
||||
|
||||
func (s *AccountRepoSuite) TestSetError() {
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-err", Status: service.StatusActive})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-err", Status: service.StatusActive})
|
||||
|
||||
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
|
||||
|
||||
@@ -442,7 +446,7 @@ func (s *AccountRepoSuite) TestSetError() {
|
||||
// --- UpdateSessionWindow ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-win"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-win"})
|
||||
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
|
||||
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
|
||||
|
||||
@@ -458,9 +462,9 @@ func (s *AccountRepoSuite) TestUpdateSessionWindow() {
|
||||
// --- UpdateExtra ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra",
|
||||
Extra: datatypes.JSONMap{"a": "1"},
|
||||
Extra: map[string]any{"a": "1"},
|
||||
})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
|
||||
|
||||
@@ -471,12 +475,12 @@ func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-extra-empty"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-extra-empty"})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-nil-extra", Extra: nil})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-nil-extra", Extra: nil})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
@@ -488,9 +492,9 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
||||
crsID := "crs-12345"
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-crs",
|
||||
Extra: datatypes.JSONMap{"crs_account_id": crsID},
|
||||
Extra: map[string]any{"crs_account_id": crsID},
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
|
||||
@@ -514,8 +518,8 @@ func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
|
||||
// --- BulkUpdate ---
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk1", Priority: 1})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk2", Priority: 1})
|
||||
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk1", Priority: 1})
|
||||
a2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk2", Priority: 1})
|
||||
|
||||
newPriority := 99
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
|
||||
@@ -531,13 +535,13 @@ func (s *AccountRepoSuite) TestBulkUpdate() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
a1 := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "bulk-cred",
|
||||
Credentials: datatypes.JSONMap{"existing": "value"},
|
||||
Credentials: map[string]any{"existing": "value"},
|
||||
})
|
||||
|
||||
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
|
||||
Credentials: datatypes.JSONMap{"new_key": "new_value"},
|
||||
Credentials: map[string]any{"new_key": "new_value"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
@@ -547,13 +551,13 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
a1 := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "bulk-extra",
|
||||
Extra: datatypes.JSONMap{"existing": "val"},
|
||||
Extra: map[string]any{"existing": "val"},
|
||||
})
|
||||
|
||||
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
|
||||
Extra: datatypes.JSONMap{"new_key": "new_val"},
|
||||
Extra: map[string]any{"new_key": "new_val"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
@@ -569,7 +573,7 @@ func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk-empty"})
|
||||
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-empty"})
|
||||
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
|
||||
s.Require().NoError(err)
|
||||
|
||||
@@ -0,0 +1,144 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func uniqueTestValue(t *testing.T, prefix string) string {
|
||||
t.Helper()
|
||||
safeName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
|
||||
return fmt.Sprintf("%s-%s", prefix, safeName)
|
||||
}
|
||||
|
||||
func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
entClient, sqlTx := testEntSQLTx(t)
|
||||
|
||||
targetGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "target-group")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
otherGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "other-group")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
repo := newUserRepositoryWithSQL(entClient, sqlTx)
|
||||
|
||||
u1 := &service.User{
|
||||
Email: uniqueTestValue(t, "u1") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u1))
|
||||
|
||||
u2 := &service.User{
|
||||
Email: uniqueTestValue(t, "u2") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u2))
|
||||
|
||||
u3 := &service.User{
|
||||
Email: uniqueTestValue(t, "u3") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u3))
|
||||
|
||||
affected, err := repo.RemoveGroupFromAllowedGroups(ctx, targetGroup.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), affected)
|
||||
|
||||
u1After, err := repo.GetByID(ctx, u1.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, u1After.AllowedGroups, targetGroup.ID)
|
||||
require.Contains(t, u1After.AllowedGroups, otherGroup.ID)
|
||||
|
||||
u2After, err := repo.GetByID(ctx, u2.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
|
||||
}
|
||||
|
||||
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
entClient, sqlTx := testEntSQLTx(t)
|
||||
|
||||
targetGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "delete-cascade-target")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
otherGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "delete-cascade-other")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
userRepo := newUserRepositoryWithSQL(entClient, sqlTx)
|
||||
groupRepo := newGroupRepositoryWithSQL(entClient, sqlTx)
|
||||
apiKeyRepo := NewApiKeyRepository(entClient)
|
||||
|
||||
u := &service.User{
|
||||
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, userRepo.Create(ctx, u))
|
||||
|
||||
key := &service.ApiKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
|
||||
Name: "test key",
|
||||
GroupID: &targetGroup.ID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, apiKeyRepo.Create(ctx, key))
|
||||
|
||||
_, err = groupRepo.DeleteCascade(ctx, targetGroup.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Deleted group should be hidden by default queries (soft-delete semantics).
|
||||
_, err = groupRepo.GetByID(ctx, targetGroup.ID)
|
||||
require.ErrorIs(t, err, service.ErrGroupNotFound)
|
||||
|
||||
activeGroups, err := groupRepo.ListActive(ctx)
|
||||
require.NoError(t, err)
|
||||
for _, g := range activeGroups {
|
||||
require.NotEqual(t, targetGroup.ID, g.ID)
|
||||
}
|
||||
|
||||
// User.allowed_groups should no longer include the deleted group.
|
||||
uAfter, err := userRepo.GetByID(ctx, u.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID)
|
||||
require.Contains(t, uAfter.AllowedGroups, otherGroup.ID)
|
||||
|
||||
// API keys bound to the deleted group should have group_id cleared.
|
||||
keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, keyAfter.GroupID)
|
||||
}
|
||||
|
||||
@@ -2,83 +2,118 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type apiKeyRepository struct {
|
||||
db *gorm.DB
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
|
||||
return &apiKeyRepository{db: db}
|
||||
func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
|
||||
return &apiKeyRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
m := apiKeyModelFromService(key)
|
||||
err := r.db.WithContext(ctx).Create(m).Error
|
||||
created, err := r.client.ApiKey.Create().
|
||||
SetUserID(key.UserID).
|
||||
SetKey(key.Key).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetNillableGroupID(key.GroupID).
|
||||
Save(ctx)
|
||||
if err == nil {
|
||||
applyApiKeyModelToService(key, m)
|
||||
key.ID = created.ID
|
||||
key.CreatedAt = created.CreatedAt
|
||||
key.UpdatedAt = created.UpdatedAt
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
var m apiKeyModel
|
||||
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&m, id).Error
|
||||
m, err := r.client.ApiKey.Query().
|
||||
Where(apikey.IDEQ(id)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyModelToService(&m), nil
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
var m apiKeyModel
|
||||
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&m).Error
|
||||
m, err := r.client.ApiKey.Query().
|
||||
Where(apikey.KeyEQ(key)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyModelToService(&m), nil
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
m := apiKeyModelFromService(key)
|
||||
err := r.db.WithContext(ctx).Model(m).Select("name", "group_id", "status", "updated_at").Updates(m).Error
|
||||
builder := r.client.ApiKey.UpdateOneID(key.ID).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status)
|
||||
if key.GroupID != nil {
|
||||
builder.SetGroupID(*key.GroupID)
|
||||
} else {
|
||||
builder.ClearGroupID()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
applyApiKeyModelToService(key, m)
|
||||
key.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&apiKeyModel{}, id).Error
|
||||
_, err := r.client.ApiKey.Delete().Where(apikey.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
var keys []apiKeyModel
|
||||
var total int64
|
||||
q := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID))
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID)
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Preload("Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&keys).Error; err != nil {
|
||||
keys, err := q.
|
||||
WithGroup().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(apikey.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
|
||||
return outKeys, paginationResultFromTotal(total, params), nil
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
@@ -86,11 +121,9 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
ids := make([]int64, 0, len(apiKeyIDs))
|
||||
err := r.db.WithContext(ctx).
|
||||
Model(&apiKeyModel{}).
|
||||
Where("user_id = ? AND id IN ?", userID, apiKeyIDs).
|
||||
Pluck("id", &ids).Error
|
||||
ids, err := r.client.ApiKey.Query().
|
||||
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...)).
|
||||
IDs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -98,136 +131,146 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count, err
|
||||
count, err := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("key = ?", key).Count(&count).Error
|
||||
count, err := r.client.ApiKey.Query().Where(apikey.KeyEQ(key)).Count(ctx)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
var keys []apiKeyModel
|
||||
var total int64
|
||||
q := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID))
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID)
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Preload("User").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&keys).Error; err != nil {
|
||||
keys, err := q.
|
||||
WithUser().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(apikey.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
|
||||
return outKeys, paginationResultFromTotal(total, params), nil
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// SearchApiKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
var keys []apiKeyModel
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&apiKeyModel{})
|
||||
|
||||
q := r.client.ApiKey.Query()
|
||||
if userID > 0 {
|
||||
db = db.Where("user_id = ?", userID)
|
||||
q = q.Where(apikey.UserIDEQ(userID))
|
||||
}
|
||||
|
||||
if keyword != "" {
|
||||
searchPattern := "%" + keyword + "%"
|
||||
db = db.Where("name ILIKE ?", searchPattern)
|
||||
q = q.Where(apikey.NameContainsFold(keyword))
|
||||
}
|
||||
|
||||
if err := db.Limit(limit).Order("id DESC").Find(&keys).Error; err != nil {
|
||||
keys, err := q.Limit(limit).Order(dbent.Desc(apikey.FieldID)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
return outKeys, nil
|
||||
}
|
||||
|
||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Model(&apiKeyModel{}).
|
||||
Where("group_id = ?", groupID).
|
||||
Update("group_id", nil)
|
||||
return result.RowsAffected, result.Error
|
||||
n, err := r.client.ApiKey.Update().
|
||||
Where(apikey.GroupIDEQ(groupID)).
|
||||
ClearGroupID().
|
||||
Save(ctx)
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// CountByGroupID 获取分组的 API Key 数量
|
||||
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID).Count(&count).Error
|
||||
return count, err
|
||||
count, err := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
type apiKeyModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
UserID int64 `gorm:"index;not null"`
|
||||
Key string `gorm:"uniqueIndex;size:128;not null"`
|
||||
Name string `gorm:"size:100;not null"`
|
||||
GroupID *int64 `gorm:"index"`
|
||||
Status string `gorm:"size:20;default:active;not null"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
UpdatedAt time.Time `gorm:"not null"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
|
||||
User *userModel `gorm:"foreignKey:UserID"`
|
||||
Group *groupModel `gorm:"foreignKey:GroupID"`
|
||||
}
|
||||
|
||||
func (apiKeyModel) TableName() string { return "api_keys" }
|
||||
|
||||
func apiKeyModelToService(m *apiKeyModel) *service.ApiKey {
|
||||
func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.ApiKey{
|
||||
out := &service.ApiKey{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
GroupID: m.GroupID,
|
||||
Status: m.Status,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
User: userModelToService(m.User),
|
||||
Group: groupModelToService(m.Group),
|
||||
GroupID: m.GroupID,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
if m.Edges.Group != nil {
|
||||
out.Group = groupEntityToService(m.Edges.Group)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func apiKeyModelFromService(k *service.ApiKey) *apiKeyModel {
|
||||
if k == nil {
|
||||
func userEntityToService(u *dbent.User) *service.User {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
return &apiKeyModel{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
Name: k.Name,
|
||||
GroupID: k.GroupID,
|
||||
Status: k.Status,
|
||||
CreatedAt: k.CreatedAt,
|
||||
UpdatedAt: k.UpdatedAt,
|
||||
return &service.User{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Wechat: u.Wechat,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func applyApiKeyModelToService(key *service.ApiKey, m *apiKeyModel) {
|
||||
if key == nil || m == nil {
|
||||
return
|
||||
func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: derefString(g.Description),
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
key.ID = m.ID
|
||||
key.CreatedAt = m.CreatedAt
|
||||
key.UpdatedAt = m.UpdatedAt
|
||||
}
|
||||
|
||||
func derefString(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
|
||||
@@ -6,23 +6,24 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ApiKeyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *apiKeyRepository
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *apiKeyRepository
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewApiKeyRepository(s.db).(*apiKeyRepository)
|
||||
entClient, _ := testEntSQLTx(s.T())
|
||||
s.client = entClient
|
||||
s.repo = NewApiKeyRepository(entClient).(*apiKeyRepository)
|
||||
}
|
||||
|
||||
func TestApiKeyRepoSuite(t *testing.T) {
|
||||
@@ -32,7 +33,7 @@ func TestApiKeyRepoSuite(t *testing.T) {
|
||||
// --- Create / GetByID / GetByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCreate() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
|
||||
user := s.mustCreateUser("create@test.com")
|
||||
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
@@ -56,16 +57,17 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbykey@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-key"})
|
||||
user := s.mustCreateUser("getbykey@test.com")
|
||||
group := s.mustCreateGroup("g-key")
|
||||
|
||||
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-getbykey",
|
||||
Name: "My Key",
|
||||
GroupID: &group.ID,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
s.Require().NoError(err, "GetByKey")
|
||||
@@ -84,13 +86,14 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
|
||||
// --- Update ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
|
||||
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
|
||||
user := s.mustCreateUser("update@test.com")
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update",
|
||||
Name: "Original",
|
||||
Status: service.StatusActive,
|
||||
}))
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
key.Name = "Renamed"
|
||||
key.Status = service.StatusDisabled
|
||||
@@ -106,14 +109,16 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargroup@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear"})
|
||||
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
|
||||
user := s.mustCreateUser("cleargroup@test.com")
|
||||
group := s.mustCreateGroup("g-clear")
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-clear-group",
|
||||
Name: "Group Key",
|
||||
GroupID: &group.ID,
|
||||
}))
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
key.GroupID = nil
|
||||
err := s.repo.Update(s.ctx, key)
|
||||
@@ -127,12 +132,14 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
// --- Delete ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestDelete() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
|
||||
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
|
||||
user := s.mustCreateUser("delete@test.com")
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-delete",
|
||||
Name: "Delete Me",
|
||||
})
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
err := s.repo.Delete(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
@@ -144,9 +151,9 @@ func (s *ApiKeyRepoSuite) TestDelete() {
|
||||
// --- ListByUserID / CountByUserID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
|
||||
user := s.mustCreateUser("listbyuser@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
@@ -155,13 +162,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "paging@test.com"})
|
||||
user := s.mustCreateUser("paging@test.com")
|
||||
for i := 0; i < 5; i++ {
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{
|
||||
UserID: user.ID,
|
||||
Key: "sk-page-" + string(rune('a'+i)),
|
||||
Name: "Key",
|
||||
})
|
||||
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
|
||||
}
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
|
||||
@@ -172,9 +175,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "count@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
|
||||
user := s.mustCreateUser("count@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
|
||||
|
||||
count, err := s.repo.CountByUserID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "CountByUserID")
|
||||
@@ -184,12 +187,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
|
||||
// --- ListByGroupID / CountByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbygroup@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
|
||||
user := s.mustCreateUser("listbygroup@test.com")
|
||||
group := s.mustCreateGroup("g-list")
|
||||
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-1", "K1", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-2", "K2", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-3", "K3", nil) // no group
|
||||
|
||||
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByGroupID")
|
||||
@@ -200,10 +203,9 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "countgroup@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
|
||||
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
|
||||
user := s.mustCreateUser("countgroup@test.com")
|
||||
group := s.mustCreateGroup("g-count")
|
||||
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
|
||||
|
||||
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
@@ -213,8 +215,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
|
||||
// --- ExistsByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestExistsByKey() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-exists", Name: "K"})
|
||||
user := s.mustCreateUser("exists@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
|
||||
|
||||
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
@@ -228,9 +230,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
|
||||
// --- SearchApiKeys ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "search@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
|
||||
user := s.mustCreateUser("search@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
@@ -239,9 +241,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnokw@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
|
||||
user := s.mustCreateUser("searchnokw@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
|
||||
s.Require().NoError(err)
|
||||
@@ -249,8 +251,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnouid@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
|
||||
user := s.mustCreateUser("searchnouid@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
|
||||
s.Require().NoError(err)
|
||||
@@ -260,12 +262,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
|
||||
// --- ClearGroupIDByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargrp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear-bulk"})
|
||||
user := s.mustCreateUser("cleargrp@test.com")
|
||||
group := s.mustCreateGroup("g-clear-bulk")
|
||||
|
||||
k1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
|
||||
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
|
||||
k1 := s.mustCreateApiKey(user.ID, "sk-clr-1", "K1", &group.ID)
|
||||
k2 := s.mustCreateApiKey(user.ID, "sk-clr-2", "K2", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-clr-3", "K3", nil) // no group
|
||||
|
||||
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||
@@ -283,16 +285,10 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "k@example.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-k"})
|
||||
|
||||
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
|
||||
UserID: user.ID,
|
||||
Key: "sk-test-1",
|
||||
Name: "My Key",
|
||||
GroupID: &group.ID,
|
||||
Status: service.StatusActive,
|
||||
}))
|
||||
user := s.mustCreateUser("k@example.com")
|
||||
group := s.mustCreateGroup("g-k")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
|
||||
key.GroupID = &group.ID
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
s.Require().NoError(err, "GetByKey")
|
||||
@@ -330,12 +326,8 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
s.Require().Equal(key.ID, found[0].ID)
|
||||
|
||||
// ClearGroupIDByGroupID
|
||||
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
|
||||
UserID: user.ID,
|
||||
Key: "sk-test-2",
|
||||
Name: "Group Key",
|
||||
GroupID: &group.ID,
|
||||
})
|
||||
k2 := s.mustCreateApiKey(user.ID, "sk-test-2", "Group Key", &group.ID)
|
||||
k2.GroupID = &group.ID
|
||||
|
||||
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
@@ -353,3 +345,41 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
s.Require().NoError(err, "CountByGroupID after clear")
|
||||
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
s.T().Helper()
|
||||
|
||||
u, err := s.client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash("test-password-hash").
|
||||
SetStatus(service.StatusActive).
|
||||
SetRole(service.RoleUser).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create user")
|
||||
return userEntityToService(u)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
s.T().Helper()
|
||||
|
||||
g, err := s.client.Group.Create().
|
||||
SetName(name).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create group")
|
||||
return groupEntityToService(g)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.ApiKey {
|
||||
s.T().Helper()
|
||||
|
||||
k := &service.ApiKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: name,
|
||||
GroupID: groupID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
|
||||
return k
|
||||
}
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// MaxExpiresAt is the maximum allowed expiration date for subscriptions (year 2099)
|
||||
// This prevents time.Time JSON serialization errors (RFC 3339 requires year <= 9999)
|
||||
var maxExpiresAt = time.Date(2099, 12, 31, 23, 59, 59, 0, time.UTC)
|
||||
|
||||
// AutoMigrate runs schema migrations for all repository persistence models.
|
||||
// Persistence models are defined within individual `*_repo.go` files.
|
||||
func AutoMigrate(db *gorm.DB) error {
|
||||
err := db.AutoMigrate(
|
||||
&userModel{},
|
||||
&apiKeyModel{},
|
||||
&groupModel{},
|
||||
&accountModel{},
|
||||
&accountGroupModel{},
|
||||
&proxyModel{},
|
||||
&redeemCodeModel{},
|
||||
&usageLogModel{},
|
||||
&settingModel{},
|
||||
&userSubscriptionModel{},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 修复无效的过期时间(年份超过 2099 会导致 JSON 序列化失败)
|
||||
return fixInvalidExpiresAt(db)
|
||||
}
|
||||
|
||||
// fixInvalidExpiresAt 修复 user_subscriptions 表中无效的过期时间
|
||||
func fixInvalidExpiresAt(db *gorm.DB) error {
|
||||
result := db.Model(&userSubscriptionModel{}).
|
||||
Where("expires_at > ?", maxExpiresAt).
|
||||
Update("expires_at", maxExpiresAt)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected > 0 {
|
||||
log.Printf("[AutoMigrate] Fixed %d subscriptions with invalid expires_at (year > 2099)", result.RowsAffected)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,38 +1,75 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"gorm.io/gorm"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// translatePersistenceError 将数据库层错误翻译为业务层错误。
|
||||
//
|
||||
// 这是 Repository 层的核心错误处理函数,确保数据库细节不会泄露到业务层。
|
||||
// 通过统一的错误翻译,业务层可以使用语义明确的错误类型(如 ErrUserNotFound)
|
||||
// 而不是依赖于特定数据库的错误(如 sql.ErrNoRows)。
|
||||
//
|
||||
// 参数:
|
||||
// - err: 原始数据库错误
|
||||
// - notFound: 当记录不存在时返回的业务错误(可为 nil 表示不处理)
|
||||
// - conflict: 当违反唯一约束时返回的业务错误(可为 nil 表示不处理)
|
||||
//
|
||||
// 返回:
|
||||
// - 翻译后的业务错误,或原始错误(如果不匹配任何规则)
|
||||
//
|
||||
// 示例:
|
||||
//
|
||||
// err := translatePersistenceError(dbErr, service.ErrUserNotFound, service.ErrEmailExists)
|
||||
func translatePersistenceError(err error, notFound, conflict *infraerrors.ApplicationError) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if notFound != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// 兼容 Ent ORM 和标准 database/sql 的 NotFound 行为。
|
||||
// Ent 使用自定义的 NotFoundError,而标准库使用 sql.ErrNoRows。
|
||||
// 这里同时处理两种情况,保持业务错误映射一致。
|
||||
if notFound != nil && (errors.Is(err, sql.ErrNoRows) || dbent.IsNotFound(err)) {
|
||||
return notFound.WithCause(err)
|
||||
}
|
||||
|
||||
// 处理唯一约束冲突(如邮箱已存在、名称重复等)
|
||||
if conflict != nil && isUniqueConstraintViolation(err) {
|
||||
return conflict.WithCause(err)
|
||||
}
|
||||
|
||||
// 未匹配任何规则,返回原始错误
|
||||
return err
|
||||
}
|
||||
|
||||
// isUniqueConstraintViolation 判断错误是否为唯一约束冲突。
|
||||
//
|
||||
// 支持多种检测方式:
|
||||
// 1. PostgreSQL 特定错误码 23505(唯一约束冲突)
|
||||
// 2. 错误消息中包含的通用关键词
|
||||
//
|
||||
// 这种多层次的检测确保了对不同数据库驱动和 ORM 的兼容性。
|
||||
func isUniqueConstraintViolation(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
return true
|
||||
// 优先检测 PostgreSQL 特定错误码(最精确)。
|
||||
// 错误码 23505 对应 unique_violation。
|
||||
// 参考:https://www.postgresql.org/docs/current/errcodes-appendix.html
|
||||
var pgErr *pq.Error
|
||||
if errors.As(err, &pgErr) {
|
||||
return pgErr.Code == "23505"
|
||||
}
|
||||
|
||||
// 回退到错误消息检测(兼容其他场景)。
|
||||
// 这些关键词覆盖了 PostgreSQL、MySQL 等主流数据库的错误消息。
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "duplicate key") ||
|
||||
strings.Contains(msg, "unique constraint") ||
|
||||
|
||||
@@ -3,17 +3,22 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func mustCreateUser(t *testing.T, db *gorm.DB, u *userModel) *userModel {
|
||||
func mustCreateUser(t *testing.T, client *dbent.Client, u *service.User) *service.User {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if u.Email == "" {
|
||||
u.Email = "user-" + time.Now().Format(time.RFC3339Nano) + "@example.com"
|
||||
}
|
||||
if u.PasswordHash == "" {
|
||||
u.PasswordHash = "test-password-hash"
|
||||
}
|
||||
@@ -26,18 +31,48 @@ func mustCreateUser(t *testing.T, db *gorm.DB, u *userModel) *userModel {
|
||||
if u.Concurrency == 0 {
|
||||
u.Concurrency = 5
|
||||
}
|
||||
if u.CreatedAt.IsZero() {
|
||||
u.CreatedAt = time.Now()
|
||||
|
||||
create := client.User.Create().
|
||||
SetEmail(u.Email).
|
||||
SetPasswordHash(u.PasswordHash).
|
||||
SetRole(u.Role).
|
||||
SetStatus(u.Status).
|
||||
SetBalance(u.Balance).
|
||||
SetConcurrency(u.Concurrency).
|
||||
SetUsername(u.Username).
|
||||
SetWechat(u.Wechat).
|
||||
SetNotes(u.Notes)
|
||||
if !u.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(u.CreatedAt)
|
||||
}
|
||||
if u.UpdatedAt.IsZero() {
|
||||
u.UpdatedAt = u.CreatedAt
|
||||
if !u.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(u.UpdatedAt)
|
||||
}
|
||||
require.NoError(t, db.Create(u).Error, "create user")
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create user")
|
||||
|
||||
u.ID = created.ID
|
||||
u.CreatedAt = created.CreatedAt
|
||||
u.UpdatedAt = created.UpdatedAt
|
||||
|
||||
if len(u.AllowedGroups) > 0 {
|
||||
for _, groupID := range u.AllowedGroups {
|
||||
_, err := client.UserAllowedGroup.Create().
|
||||
SetUserID(u.ID).
|
||||
SetGroupID(groupID).
|
||||
Save(ctx)
|
||||
require.NoError(t, err, "create user_allowed_groups row")
|
||||
}
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func mustCreateGroup(t *testing.T, db *gorm.DB, g *groupModel) *groupModel {
|
||||
func mustCreateGroup(t *testing.T, client *dbent.Client, g *service.Group) *service.Group {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if g.Platform == "" {
|
||||
g.Platform = service.PlatformAnthropic
|
||||
}
|
||||
@@ -47,18 +82,46 @@ func mustCreateGroup(t *testing.T, db *gorm.DB, g *groupModel) *groupModel {
|
||||
if g.SubscriptionType == "" {
|
||||
g.SubscriptionType = service.SubscriptionTypeStandard
|
||||
}
|
||||
if g.CreatedAt.IsZero() {
|
||||
g.CreatedAt = time.Now()
|
||||
|
||||
create := client.Group.Create().
|
||||
SetName(g.Name).
|
||||
SetPlatform(g.Platform).
|
||||
SetStatus(g.Status).
|
||||
SetSubscriptionType(g.SubscriptionType).
|
||||
SetRateMultiplier(g.RateMultiplier).
|
||||
SetIsExclusive(g.IsExclusive)
|
||||
if g.Description != "" {
|
||||
create.SetDescription(g.Description)
|
||||
}
|
||||
if g.UpdatedAt.IsZero() {
|
||||
g.UpdatedAt = g.CreatedAt
|
||||
if g.DailyLimitUSD != nil {
|
||||
create.SetDailyLimitUsd(*g.DailyLimitUSD)
|
||||
}
|
||||
require.NoError(t, db.Create(g).Error, "create group")
|
||||
if g.WeeklyLimitUSD != nil {
|
||||
create.SetWeeklyLimitUsd(*g.WeeklyLimitUSD)
|
||||
}
|
||||
if g.MonthlyLimitUSD != nil {
|
||||
create.SetMonthlyLimitUsd(*g.MonthlyLimitUSD)
|
||||
}
|
||||
if !g.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(g.CreatedAt)
|
||||
}
|
||||
if !g.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(g.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create group")
|
||||
|
||||
g.ID = created.ID
|
||||
g.CreatedAt = created.CreatedAt
|
||||
g.UpdatedAt = created.UpdatedAt
|
||||
return g
|
||||
}
|
||||
|
||||
func mustCreateProxy(t *testing.T, db *gorm.DB, p *proxyModel) *proxyModel {
|
||||
func mustCreateProxy(t *testing.T, client *dbent.Client, p *service.Proxy) *service.Proxy {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if p.Protocol == "" {
|
||||
p.Protocol = "http"
|
||||
}
|
||||
@@ -71,18 +134,39 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *proxyModel) *proxyModel {
|
||||
if p.Status == "" {
|
||||
p.Status = service.StatusActive
|
||||
}
|
||||
if p.CreatedAt.IsZero() {
|
||||
p.CreatedAt = time.Now()
|
||||
|
||||
create := client.Proxy.Create().
|
||||
SetName(p.Name).
|
||||
SetProtocol(p.Protocol).
|
||||
SetHost(p.Host).
|
||||
SetPort(p.Port).
|
||||
SetStatus(p.Status)
|
||||
if p.Username != "" {
|
||||
create.SetUsername(p.Username)
|
||||
}
|
||||
if p.UpdatedAt.IsZero() {
|
||||
p.UpdatedAt = p.CreatedAt
|
||||
if p.Password != "" {
|
||||
create.SetPassword(p.Password)
|
||||
}
|
||||
require.NoError(t, db.Create(p).Error, "create proxy")
|
||||
if !p.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(p.CreatedAt)
|
||||
}
|
||||
if !p.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(p.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create proxy")
|
||||
|
||||
p.ID = created.ID
|
||||
p.CreatedAt = created.CreatedAt
|
||||
p.UpdatedAt = created.UpdatedAt
|
||||
return p
|
||||
}
|
||||
|
||||
func mustCreateAccount(t *testing.T, db *gorm.DB, a *accountModel) *accountModel {
|
||||
func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) *service.Account {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if a.Platform == "" {
|
||||
a.Platform = service.PlatformAnthropic
|
||||
}
|
||||
@@ -92,57 +176,158 @@ func mustCreateAccount(t *testing.T, db *gorm.DB, a *accountModel) *accountModel
|
||||
if a.Status == "" {
|
||||
a.Status = service.StatusActive
|
||||
}
|
||||
if a.Concurrency == 0 {
|
||||
a.Concurrency = 3
|
||||
}
|
||||
if a.Priority == 0 {
|
||||
a.Priority = 50
|
||||
}
|
||||
if !a.Schedulable {
|
||||
a.Schedulable = true
|
||||
}
|
||||
if a.Credentials == nil {
|
||||
a.Credentials = datatypes.JSONMap{}
|
||||
a.Credentials = map[string]any{}
|
||||
}
|
||||
if a.Extra == nil {
|
||||
a.Extra = datatypes.JSONMap{}
|
||||
a.Extra = map[string]any{}
|
||||
}
|
||||
if a.CreatedAt.IsZero() {
|
||||
a.CreatedAt = time.Now()
|
||||
|
||||
create := client.Account.Create().
|
||||
SetName(a.Name).
|
||||
SetPlatform(a.Platform).
|
||||
SetType(a.Type).
|
||||
SetCredentials(a.Credentials).
|
||||
SetExtra(a.Extra).
|
||||
SetConcurrency(a.Concurrency).
|
||||
SetPriority(a.Priority).
|
||||
SetStatus(a.Status).
|
||||
SetSchedulable(a.Schedulable).
|
||||
SetErrorMessage(a.ErrorMessage)
|
||||
|
||||
if a.ProxyID != nil {
|
||||
create.SetProxyID(*a.ProxyID)
|
||||
}
|
||||
if a.UpdatedAt.IsZero() {
|
||||
a.UpdatedAt = a.CreatedAt
|
||||
if a.LastUsedAt != nil {
|
||||
create.SetLastUsedAt(*a.LastUsedAt)
|
||||
}
|
||||
require.NoError(t, db.Create(a).Error, "create account")
|
||||
if a.RateLimitedAt != nil {
|
||||
create.SetRateLimitedAt(*a.RateLimitedAt)
|
||||
}
|
||||
if a.RateLimitResetAt != nil {
|
||||
create.SetRateLimitResetAt(*a.RateLimitResetAt)
|
||||
}
|
||||
if a.OverloadUntil != nil {
|
||||
create.SetOverloadUntil(*a.OverloadUntil)
|
||||
}
|
||||
if a.SessionWindowStart != nil {
|
||||
create.SetSessionWindowStart(*a.SessionWindowStart)
|
||||
}
|
||||
if a.SessionWindowEnd != nil {
|
||||
create.SetSessionWindowEnd(*a.SessionWindowEnd)
|
||||
}
|
||||
if a.SessionWindowStatus != "" {
|
||||
create.SetSessionWindowStatus(a.SessionWindowStatus)
|
||||
}
|
||||
if !a.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(a.CreatedAt)
|
||||
}
|
||||
if !a.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(a.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create account")
|
||||
|
||||
a.ID = created.ID
|
||||
a.CreatedAt = created.CreatedAt
|
||||
a.UpdatedAt = created.UpdatedAt
|
||||
return a
|
||||
}
|
||||
|
||||
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *apiKeyModel) *apiKeyModel {
|
||||
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *service.ApiKey {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if k.Status == "" {
|
||||
k.Status = service.StatusActive
|
||||
}
|
||||
if k.CreatedAt.IsZero() {
|
||||
k.CreatedAt = time.Now()
|
||||
if k.Key == "" {
|
||||
k.Key = "sk-" + time.Now().Format("150405.000000")
|
||||
}
|
||||
if k.UpdatedAt.IsZero() {
|
||||
k.UpdatedAt = k.CreatedAt
|
||||
if k.Name == "" {
|
||||
k.Name = "default"
|
||||
}
|
||||
require.NoError(t, db.Create(k).Error, "create api key")
|
||||
|
||||
create := client.ApiKey.Create().
|
||||
SetUserID(k.UserID).
|
||||
SetKey(k.Key).
|
||||
SetName(k.Name).
|
||||
SetStatus(k.Status)
|
||||
if k.GroupID != nil {
|
||||
create.SetGroupID(*k.GroupID)
|
||||
}
|
||||
if !k.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(k.CreatedAt)
|
||||
}
|
||||
if !k.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(k.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create api key")
|
||||
|
||||
k.ID = created.ID
|
||||
k.CreatedAt = created.CreatedAt
|
||||
k.UpdatedAt = created.UpdatedAt
|
||||
return k
|
||||
}
|
||||
|
||||
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *redeemCodeModel) *redeemCodeModel {
|
||||
func mustCreateRedeemCode(t *testing.T, client *dbent.Client, c *service.RedeemCode) *service.RedeemCode {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if c.Status == "" {
|
||||
c.Status = service.StatusUnused
|
||||
}
|
||||
if c.Type == "" {
|
||||
c.Type = service.RedeemTypeBalance
|
||||
}
|
||||
if c.CreatedAt.IsZero() {
|
||||
c.CreatedAt = time.Now()
|
||||
if c.Code == "" {
|
||||
c.Code = "rc-" + time.Now().Format("150405.000000")
|
||||
}
|
||||
require.NoError(t, db.Create(c).Error, "create redeem code")
|
||||
|
||||
create := client.RedeemCode.Create().
|
||||
SetCode(c.Code).
|
||||
SetType(c.Type).
|
||||
SetValue(c.Value).
|
||||
SetStatus(c.Status).
|
||||
SetNotes(c.Notes).
|
||||
SetValidityDays(c.ValidityDays)
|
||||
if c.UsedBy != nil {
|
||||
create.SetUsedBy(*c.UsedBy)
|
||||
}
|
||||
if c.UsedAt != nil {
|
||||
create.SetUsedAt(*c.UsedAt)
|
||||
}
|
||||
if c.GroupID != nil {
|
||||
create.SetGroupID(*c.GroupID)
|
||||
}
|
||||
if !c.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(c.CreatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create redeem code")
|
||||
|
||||
c.ID = created.ID
|
||||
c.CreatedAt = created.CreatedAt
|
||||
return c
|
||||
}
|
||||
|
||||
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *userSubscriptionModel) *userSubscriptionModel {
|
||||
func mustCreateSubscription(t *testing.T, client *dbent.Client, s *service.UserSubscription) *service.UserSubscription {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if s.Status == "" {
|
||||
s.Status = service.SubscriptionStatusActive
|
||||
}
|
||||
@@ -162,16 +347,47 @@ func mustCreateSubscription(t *testing.T, db *gorm.DB, s *userSubscriptionModel)
|
||||
if s.UpdatedAt.IsZero() {
|
||||
s.UpdatedAt = now
|
||||
}
|
||||
require.NoError(t, db.Create(s).Error, "create user subscription")
|
||||
|
||||
create := client.UserSubscription.Create().
|
||||
SetUserID(s.UserID).
|
||||
SetGroupID(s.GroupID).
|
||||
SetStartsAt(s.StartsAt).
|
||||
SetExpiresAt(s.ExpiresAt).
|
||||
SetStatus(s.Status).
|
||||
SetAssignedAt(s.AssignedAt).
|
||||
SetNotes(s.Notes).
|
||||
SetDailyUsageUsd(s.DailyUsageUSD).
|
||||
SetWeeklyUsageUsd(s.WeeklyUsageUSD).
|
||||
SetMonthlyUsageUsd(s.MonthlyUsageUSD)
|
||||
|
||||
if s.AssignedBy != nil {
|
||||
create.SetAssignedBy(*s.AssignedBy)
|
||||
}
|
||||
if !s.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(s.CreatedAt)
|
||||
}
|
||||
if !s.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(s.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create user subscription")
|
||||
|
||||
s.ID = created.ID
|
||||
s.CreatedAt = created.CreatedAt
|
||||
s.UpdatedAt = created.UpdatedAt
|
||||
return s
|
||||
}
|
||||
|
||||
func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) {
|
||||
func mustBindAccountToGroup(t *testing.T, client *dbent.Client, accountID, groupID int64, priority int) {
|
||||
t.Helper()
|
||||
require.NoError(t, db.Create(&accountGroupModel{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Priority: priority,
|
||||
CreatedAt: time.Now(),
|
||||
}).Error, "create account_group")
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := client.AccountGroup.Create().
|
||||
SetAccountID(accountID).
|
||||
SetGroupID(groupID).
|
||||
SetPriority(priority).
|
||||
Save(ctx)
|
||||
require.NoError(t, err, "create account_group")
|
||||
}
|
||||
|
||||
|
||||
@@ -2,280 +2,370 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"database/sql"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
type sqlExecutor interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
}
|
||||
|
||||
type sqlBeginner interface {
|
||||
sqlExecutor
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||
}
|
||||
|
||||
type groupRepository struct {
|
||||
db *gorm.DB
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
begin sqlBeginner
|
||||
}
|
||||
|
||||
func NewGroupRepository(db *gorm.DB) service.GroupRepository {
|
||||
return &groupRepository{db: db}
|
||||
func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupRepository {
|
||||
return newGroupRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func (r *groupRepository) Create(ctx context.Context, group *service.Group) error {
|
||||
m := groupModelFromService(group)
|
||||
err := r.db.WithContext(ctx).Create(m).Error
|
||||
func newGroupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *groupRepository {
|
||||
var beginner sqlBeginner
|
||||
if b, ok := sqlq.(sqlBeginner); ok {
|
||||
beginner = b
|
||||
}
|
||||
return &groupRepository{client: client, sql: sqlq, begin: beginner}
|
||||
}
|
||||
|
||||
func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) error {
|
||||
builder := r.client.Group.Create().
|
||||
SetName(groupIn.Name).
|
||||
SetDescription(groupIn.Description).
|
||||
SetPlatform(groupIn.Platform).
|
||||
SetRateMultiplier(groupIn.RateMultiplier).
|
||||
SetIsExclusive(groupIn.IsExclusive).
|
||||
SetStatus(groupIn.Status).
|
||||
SetSubscriptionType(groupIn.SubscriptionType).
|
||||
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
|
||||
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
|
||||
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD)
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
applyGroupModelToService(group, m)
|
||||
groupIn.ID = created.ID
|
||||
groupIn.CreatedAt = created.CreatedAt
|
||||
groupIn.UpdatedAt = created.UpdatedAt
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrGroupExists)
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
|
||||
var m groupModel
|
||||
err := r.db.WithContext(ctx).First(&m, id).Error
|
||||
m, err := r.client.Group.Query().
|
||||
Where(group.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
group := groupModelToService(&m)
|
||||
count, _ := r.GetAccountCount(ctx, group.ID)
|
||||
group.AccountCount = count
|
||||
return group, nil
|
||||
|
||||
out := groupEntityToService(m)
|
||||
count, _ := r.GetAccountCount(ctx, out.ID)
|
||||
out.AccountCount = count
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) Update(ctx context.Context, group *service.Group) error {
|
||||
m := groupModelFromService(group)
|
||||
err := r.db.WithContext(ctx).Save(m).Error
|
||||
if err == nil {
|
||||
applyGroupModelToService(group, m)
|
||||
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
|
||||
updated, err := r.client.Group.UpdateOneID(groupIn.ID).
|
||||
SetName(groupIn.Name).
|
||||
SetDescription(groupIn.Description).
|
||||
SetPlatform(groupIn.Platform).
|
||||
SetRateMultiplier(groupIn.RateMultiplier).
|
||||
SetIsExclusive(groupIn.IsExclusive).
|
||||
SetStatus(groupIn.Status).
|
||||
SetSubscriptionType(groupIn.SubscriptionType).
|
||||
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
|
||||
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
|
||||
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
||||
}
|
||||
return err
|
||||
groupIn.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&groupModel{}, id).Error
|
||||
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", nil)
|
||||
}
|
||||
|
||||
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
|
||||
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
var groups []groupModel
|
||||
var total int64
|
||||
q := r.client.Group.Query()
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&groupModel{})
|
||||
|
||||
// Apply filters
|
||||
if platform != "" {
|
||||
db = db.Where("platform = ?", platform)
|
||||
q = q.Where(group.PlatformEQ(platform))
|
||||
}
|
||||
if status != "" {
|
||||
db = db.Where("status = ?", status)
|
||||
q = q.Where(group.StatusEQ(status))
|
||||
}
|
||||
if isExclusive != nil {
|
||||
db = db.Where("is_exclusive = ?", *isExclusive)
|
||||
q = q.Where(group.IsExclusiveEQ(*isExclusive))
|
||||
}
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id ASC").Find(&groups).Error; err != nil {
|
||||
groups, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
groupIDs := make([]int64, 0, len(groups))
|
||||
outGroups := make([]service.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *groupModelToService(&groups[i]))
|
||||
g := groupEntityToService(groups[i])
|
||||
outGroups = append(outGroups, *g)
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
|
||||
// 获取每个分组的账号数量
|
||||
for i := range outGroups {
|
||||
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
|
||||
outGroups[i].AccountCount = count
|
||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||
if err == nil {
|
||||
for i := range outGroups {
|
||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
return outGroups, paginationResultFromTotal(total, params), nil
|
||||
return outGroups, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
|
||||
var groups []groupModel
|
||||
err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Order("id ASC").Find(&groups).Error
|
||||
groups, err := r.client.Group.Query().
|
||||
Where(group.StatusEQ(service.StatusActive)).
|
||||
Order(dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupIDs := make([]int64, 0, len(groups))
|
||||
outGroups := make([]service.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *groupModelToService(&groups[i]))
|
||||
g := groupEntityToService(groups[i])
|
||||
outGroups = append(outGroups, *g)
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
// 获取每个分组的账号数量
|
||||
for i := range outGroups {
|
||||
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
|
||||
outGroups[i].AccountCount = count
|
||||
|
||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||
if err == nil {
|
||||
for i := range outGroups {
|
||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
return outGroups, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||
var groups []groupModel
|
||||
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", service.StatusActive, platform).Order("id ASC").Find(&groups).Error
|
||||
groups, err := r.client.Group.Query().
|
||||
Where(group.StatusEQ(service.StatusActive), group.PlatformEQ(platform)).
|
||||
Order(dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupIDs := make([]int64, 0, len(groups))
|
||||
outGroups := make([]service.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *groupModelToService(&groups[i]))
|
||||
g := groupEntityToService(groups[i])
|
||||
outGroups = append(outGroups, *g)
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
// 获取每个分组的账号数量
|
||||
for i := range outGroups {
|
||||
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
|
||||
outGroups[i].AccountCount = count
|
||||
|
||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||
if err == nil {
|
||||
for i := range outGroups {
|
||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
return outGroups, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&groupModel{}).Where("name = ?", name).Count(&count).Error
|
||||
return count > 0, err
|
||||
return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx)
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Table("account_groups").Where("group_id = ?", groupID).Count(&count).Error
|
||||
return count, err
|
||||
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", groupID).Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
|
||||
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Exec("DELETE FROM account_groups WHERE group_id = ?", groupID)
|
||||
return result.RowsAffected, result.Error
|
||||
res, err := r.sql.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", groupID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
return affected, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
group, err := r.GetByID(ctx, id)
|
||||
g, err := r.client.Group.Query().Where(group.IDEQ(id)).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
groupSvc := groupEntityToService(g)
|
||||
|
||||
exec := r.sql
|
||||
txClient := r.client
|
||||
var sqlTx *sql.Tx
|
||||
var txClientClose func() error
|
||||
|
||||
if r.begin != nil {
|
||||
sqlTx, err = r.begin.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exec = sqlTx
|
||||
txClient = entClientFromSQLTx(sqlTx)
|
||||
txClientClose = txClient.Close
|
||||
defer func() { _ = sqlTx.Rollback() }()
|
||||
}
|
||||
if txClientClose != nil {
|
||||
defer func() { _ = txClientClose() }()
|
||||
}
|
||||
|
||||
// Lock the group row to avoid concurrent writes while we cascade.
|
||||
var lockedID int64
|
||||
if err := exec.QueryRowContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id).Scan(&lockedID); err != nil {
|
||||
if errorsIsNoRows(err) {
|
||||
return nil, service.ErrGroupNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var affectedUserIDs []int64
|
||||
if group.IsSubscriptionType() {
|
||||
if err := r.db.WithContext(ctx).
|
||||
Table("user_subscriptions").
|
||||
Where("group_id = ?", id).
|
||||
Pluck("user_id", &affectedUserIDs).Error; err != nil {
|
||||
if groupSvc.IsSubscriptionType() {
|
||||
rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1", id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var userID int64
|
||||
if scanErr := rows.Scan(&userID); scanErr != nil {
|
||||
_ = rows.Close()
|
||||
return nil, scanErr
|
||||
}
|
||||
affectedUserIDs = append(affectedUserIDs, userID)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := exec.ExecContext(ctx, "DELETE FROM user_subscriptions WHERE group_id = $1", id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// 1. 删除订阅类型分组的订阅记录
|
||||
if group.IsSubscriptionType() {
|
||||
if err := tx.Exec("DELETE FROM user_subscriptions WHERE group_id = ?", id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
|
||||
if err := tx.Exec("UPDATE api_keys SET group_id = NULL WHERE group_id = ?", id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 3. 从 users.allowed_groups 数组中移除该分组 ID
|
||||
if err := tx.Exec(
|
||||
"UPDATE users SET allowed_groups = array_remove(allowed_groups, ?) WHERE ? = ANY(allowed_groups)",
|
||||
id, id,
|
||||
).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 4. 删除 account_groups 中间表的数据
|
||||
if err := tx.Exec("DELETE FROM account_groups WHERE group_id = ?", id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 5. 删除分组本身(带锁,避免并发写)
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&groupModel{}, id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
// 2. Clear group_id for api keys bound to this group.
|
||||
if _, err := txClient.ApiKey.Update().
|
||||
Where(apikey.GroupIDEQ(id)).
|
||||
ClearGroupID().
|
||||
Save(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. Remove the group id from users.allowed_groups array (legacy representation).
|
||||
// Phase 1 compatibility: also delete from user_allowed_groups join table when present.
|
||||
if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := exec.ExecContext(
|
||||
ctx,
|
||||
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1) WHERE $1 = ANY(allowed_groups)",
|
||||
id,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. Delete account_groups join rows.
|
||||
if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 5. Soft-delete group itself.
|
||||
if _, err := txClient.Group.Delete().Where(group.IDEQ(id)).Exec(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sqlTx != nil {
|
||||
if err := sqlTx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return affectedUserIDs, nil
|
||||
}
|
||||
|
||||
type groupModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
Name string `gorm:"uniqueIndex;size:100;not null"`
|
||||
Description string `gorm:"type:text"`
|
||||
Platform string `gorm:"size:50;default:anthropic;not null"`
|
||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null"`
|
||||
IsExclusive bool `gorm:"default:false;not null"`
|
||||
Status string `gorm:"size:20;default:active;not null"`
|
||||
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (map[int64]int64, error) {
|
||||
counts := make(map[int64]int64, len(groupIDs))
|
||||
if len(groupIDs) == 0 {
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
SubscriptionType string `gorm:"size:20;default:standard;not null"`
|
||||
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
|
||||
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
|
||||
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
|
||||
rows, err := r.sql.QueryContext(
|
||||
ctx,
|
||||
"SELECT group_id, COUNT(*) FROM account_groups WHERE group_id = ANY($1) GROUP BY group_id",
|
||||
pq.Array(groupIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
UpdatedAt time.Time `gorm:"not null"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
for rows.Next() {
|
||||
var groupID int64
|
||||
var count int64
|
||||
if err := rows.Scan(&groupID, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[groupID] = count
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (groupModel) TableName() string { return "groups" }
|
||||
|
||||
func groupModelToService(m *groupModel) *service.Group {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Group{
|
||||
ID: m.ID,
|
||||
Name: m.Name,
|
||||
Description: m.Description,
|
||||
Platform: m.Platform,
|
||||
RateMultiplier: m.RateMultiplier,
|
||||
IsExclusive: m.IsExclusive,
|
||||
Status: m.Status,
|
||||
SubscriptionType: m.SubscriptionType,
|
||||
DailyLimitUSD: m.DailyLimitUSD,
|
||||
WeeklyLimitUSD: m.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: m.MonthlyLimitUSD,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
func entClientFromSQLTx(tx *sql.Tx) *dbent.Client {
|
||||
drv := entsql.NewDriver(dialect.Postgres, entsql.Conn{ExecQuerier: tx})
|
||||
return dbent.NewClient(dbent.Driver(drv))
|
||||
}
|
||||
|
||||
func groupModelFromService(sg *service.Group) *groupModel {
|
||||
if sg == nil {
|
||||
return nil
|
||||
}
|
||||
return &groupModel{
|
||||
ID: sg.ID,
|
||||
Name: sg.Name,
|
||||
Description: sg.Description,
|
||||
Platform: sg.Platform,
|
||||
RateMultiplier: sg.RateMultiplier,
|
||||
IsExclusive: sg.IsExclusive,
|
||||
Status: sg.Status,
|
||||
SubscriptionType: sg.SubscriptionType,
|
||||
DailyLimitUSD: sg.DailyLimitUSD,
|
||||
WeeklyLimitUSD: sg.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: sg.MonthlyLimitUSD,
|
||||
CreatedAt: sg.CreatedAt,
|
||||
UpdatedAt: sg.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func applyGroupModelToService(group *service.Group, m *groupModel) {
|
||||
if group == nil || m == nil {
|
||||
return
|
||||
}
|
||||
group.ID = m.ID
|
||||
group.CreatedAt = m.CreatedAt
|
||||
group.UpdatedAt = m.UpdatedAt
|
||||
func errorsIsNoRows(err error) bool {
|
||||
return err == sql.ErrNoRows
|
||||
}
|
||||
|
||||
@@ -4,25 +4,26 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type GroupRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
tx *sql.Tx
|
||||
repo *groupRepository
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewGroupRepository(s.db).(*groupRepository)
|
||||
entClient, tx := testEntSQLTx(s.T())
|
||||
s.tx = tx
|
||||
s.repo = newGroupRepositoryWithSQL(entClient, tx)
|
||||
}
|
||||
|
||||
func TestGroupRepoSuite(t *testing.T) {
|
||||
@@ -33,9 +34,12 @@ func TestGroupRepoSuite(t *testing.T) {
|
||||
|
||||
func (s *GroupRepoSuite) TestCreate() {
|
||||
group := &service.Group{
|
||||
Name: "test-create",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Status: service.StatusActive,
|
||||
Name: "test-create",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, group)
|
||||
@@ -50,10 +54,19 @@ func (s *GroupRepoSuite) TestCreate() {
|
||||
func (s *GroupRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
s.Require().ErrorIs(err, service.ErrGroupNotFound)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestUpdate() {
|
||||
group := groupModelToService(mustCreateGroup(s.T(), s.db, &groupModel{Name: "original"}))
|
||||
group := &service.Group{
|
||||
Name: "original",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
group.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, group)
|
||||
@@ -65,20 +78,43 @@ func (s *GroupRepoSuite) TestUpdate() {
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestDelete() {
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "to-delete"})
|
||||
group := &service.Group{
|
||||
Name: "to-delete",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
err := s.repo.Delete(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
s.Require().ErrorIs(err, service.ErrGroupNotFound)
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *GroupRepoSuite) TestList() {
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
@@ -87,8 +123,22 @@ func (s *GroupRepoSuite) TestList() {
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI})
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformOpenAI,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
|
||||
s.Require().NoError(err)
|
||||
@@ -97,8 +147,22 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Status() {
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Status: service.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Status: service.StatusDisabled})
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusDisabled,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil)
|
||||
s.Require().NoError(err)
|
||||
@@ -107,8 +171,22 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() {
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", IsExclusive: false})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", IsExclusive: true})
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: true,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
isExclusive := true
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
|
||||
@@ -118,21 +196,35 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
g1 := mustCreateGroup(s.T(), s.db, &groupModel{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &groupModel{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Status: service.StatusActive,
|
||||
IsExclusive: true,
|
||||
})
|
||||
g1 := &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
g2 := &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: true,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g1))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g2))
|
||||
|
||||
a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
|
||||
mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1)
|
||||
var accountID int64
|
||||
s.Require().NoError(s.tx.QueryRowContext(
|
||||
s.ctx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
"acc1", service.PlatformAnthropic, service.AccountTypeOAuth,
|
||||
).Scan(&accountID))
|
||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g1.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g2.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
isExclusive := true
|
||||
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive)
|
||||
@@ -146,8 +238,22 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
// --- ListActive / ListActiveByPlatform ---
|
||||
|
||||
func (s *GroupRepoSuite) TestListActive() {
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "active1", Status: service.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "inactive1", Status: service.StatusDisabled})
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "active1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "inactive1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusDisabled,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
@@ -156,9 +262,30 @@ func (s *GroupRepoSuite) TestListActive() {
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListActiveByPlatform() {
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g3", Platform: service.PlatformAnthropic, Status: service.StatusDisabled})
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformOpenAI,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g3",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusDisabled,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic)
|
||||
s.Require().NoError(err, "ListActiveByPlatform")
|
||||
@@ -169,7 +296,14 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() {
|
||||
// --- ExistsByName ---
|
||||
|
||||
func (s *GroupRepoSuite) TestExistsByName() {
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "existing-group"})
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "existing-group",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
|
||||
s.Require().NoError(err, "ExistsByName")
|
||||
@@ -183,11 +317,33 @@ func (s *GroupRepoSuite) TestExistsByName() {
|
||||
// --- GetAccountCount ---
|
||||
|
||||
func (s *GroupRepoSuite) TestGetAccountCount() {
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
|
||||
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
|
||||
group := &service.Group{
|
||||
Name: "g-count",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
var a1 int64
|
||||
s.Require().NoError(s.tx.QueryRowContext(
|
||||
s.ctx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
"a1", service.PlatformAnthropic, service.AccountTypeOAuth,
|
||||
).Scan(&a1))
|
||||
var a2 int64
|
||||
s.Require().NoError(s.tx.QueryRowContext(
|
||||
s.ctx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
"a2", service.PlatformAnthropic, service.AccountTypeOAuth,
|
||||
).Scan(&a2))
|
||||
|
||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, group.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "GetAccountCount")
|
||||
@@ -195,7 +351,15 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
|
||||
group := &service.Group{
|
||||
Name: "g-empty",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||
s.Require().NoError(err)
|
||||
@@ -205,9 +369,23 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
|
||||
// --- DeleteAccountGroupsByGroupID ---
|
||||
|
||||
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
||||
g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
|
||||
a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
|
||||
mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1)
|
||||
g := &service.Group{
|
||||
Name: "g-del",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g))
|
||||
var accountID int64
|
||||
s.Require().NoError(s.tx.QueryRowContext(
|
||||
s.ctx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
"acc-del", service.PlatformAnthropic, service.AccountTypeOAuth,
|
||||
).Scan(&accountID))
|
||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
|
||||
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
|
||||
@@ -219,13 +397,34 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
|
||||
g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-multi"})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
|
||||
a3 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"})
|
||||
mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2)
|
||||
mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3)
|
||||
g := &service.Group{
|
||||
Name: "g-multi",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g))
|
||||
|
||||
insertAccount := func(name string) int64 {
|
||||
var id int64
|
||||
s.Require().NoError(s.tx.QueryRowContext(
|
||||
s.ctx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
name, service.PlatformAnthropic, service.AccountTypeOAuth,
|
||||
).Scan(&id))
|
||||
return id
|
||||
}
|
||||
a1 := insertAccount("a1")
|
||||
a2 := insertAccount("a2")
|
||||
a3 := insertAccount("a3")
|
||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, g.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, g.ID, 2)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a3, g.ID, 3)
|
||||
s.Require().NoError(err)
|
||||
|
||||
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
@@ -15,16 +15,19 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "github.com/lib/pq"
|
||||
redisclient "github.com/redis/go-redis/v9"
|
||||
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||
gormpostgres "gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,7 +36,7 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
integrationDB *gorm.DB
|
||||
integrationDB *sql.DB
|
||||
integrationRedis *redisclient.Client
|
||||
|
||||
redisNamespaceSeq uint64
|
||||
@@ -88,13 +91,13 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
integrationDB, err = openGormWithRetry(ctx, dsn, 30*time.Second)
|
||||
integrationDB, err = openSQLWithRetry(ctx, dsn, 30*time.Second)
|
||||
if err != nil {
|
||||
log.Printf("failed to open gorm db: %v", err)
|
||||
log.Printf("failed to open sql db: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := AutoMigrate(integrationDB); err != nil {
|
||||
log.Printf("failed to automigrate db: %v", err)
|
||||
if err := infrastructure.ApplyMigrations(ctx, integrationDB); err != nil {
|
||||
log.Printf("failed to apply db migrations: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -121,6 +124,7 @@ func TestMain(m *testing.M) {
|
||||
code := m.Run()
|
||||
|
||||
_ = integrationRedis.Close()
|
||||
_ = integrationDB.Close()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
@@ -147,29 +151,21 @@ func dockerImageExists(ctx context.Context, image string) bool {
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
|
||||
func openGormWithRetry(ctx context.Context, dsn string, timeout time.Duration) (*gorm.DB, error) {
|
||||
func openSQLWithRetry(ctx context.Context, dsn string, timeout time.Duration) (*sql.DB, error) {
|
||||
deadline := time.Now().Add(timeout)
|
||||
var lastErr error
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
db, err := gorm.Open(gormpostgres.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := pingWithTimeout(ctx, sqlDB, 2*time.Second); err != nil {
|
||||
if err := pingWithTimeout(ctx, db, 2*time.Second); err != nil {
|
||||
lastErr = err
|
||||
_ = db.Close()
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
@@ -186,17 +182,31 @@ func pingWithTimeout(ctx context.Context, db *sql.DB, timeout time.Duration) err
|
||||
return db.PingContext(pingCtx)
|
||||
}
|
||||
|
||||
func testTx(t *testing.T) *gorm.DB {
|
||||
func testTx(t *testing.T) *sql.Tx {
|
||||
t.Helper()
|
||||
|
||||
tx := integrationDB.Begin()
|
||||
require.NoError(t, tx.Error, "begin tx")
|
||||
tx, err := integrationDB.BeginTx(context.Background(), nil)
|
||||
require.NoError(t, err, "begin tx")
|
||||
t.Cleanup(func() {
|
||||
_ = tx.Rollback().Error
|
||||
_ = tx.Rollback()
|
||||
})
|
||||
return tx
|
||||
}
|
||||
|
||||
func testEntSQLTx(t *testing.T) (*dbent.Client, *sql.Tx) {
|
||||
t.Helper()
|
||||
|
||||
tx := testTx(t)
|
||||
drv := entsql.NewDriver(dialect.Postgres, entsql.Conn{ExecQuerier: tx})
|
||||
client := dbent.NewClient(dbent.Driver(drv))
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = client.Close()
|
||||
})
|
||||
|
||||
return client, tx
|
||||
}
|
||||
|
||||
func testRedis(t *testing.T) *redisclient.Client {
|
||||
t.Helper()
|
||||
|
||||
@@ -347,18 +357,19 @@ func (s *IntegrationRedisSuite) AssertTTLWithin(ttl, min, max time.Duration) {
|
||||
assertTTLWithin(s.T(), ttl, min, max)
|
||||
}
|
||||
|
||||
// IntegrationDBSuite provides a base suite for DB (Gorm) integration tests.
|
||||
// Embedding suites should call SetupTest to initialize ctx and db.
|
||||
// IntegrationDBSuite provides a base suite for DB integration tests.
|
||||
// Embedding suites should call SetupTest to initialize ctx and client.
|
||||
type IntegrationDBSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
tx *sql.Tx
|
||||
}
|
||||
|
||||
// SetupTest initializes ctx and db for each test method.
|
||||
// SetupTest initializes ctx and client for each test method.
|
||||
func (s *IntegrationDBSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.client, s.tx = testEntSQLTx(s.T())
|
||||
}
|
||||
|
||||
// RequireNoError is a convenience method wrapping require.NoError with s.T().
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
|
||||
// Re-apply migrations to verify idempotency (no errors, no duplicate rows).
|
||||
require.NoError(t, infrastructure.ApplyMigrations(context.Background(), integrationDB))
|
||||
|
||||
// schema_migrations should have at least the current migration set.
|
||||
var applied int
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM schema_migrations").Scan(&applied))
|
||||
require.GreaterOrEqual(t, applied, 7, "expected schema_migrations to contain applied migrations")
|
||||
|
||||
// users: columns required by repository queries
|
||||
requireColumn(t, tx, "users", "username", "character varying", 100, false)
|
||||
requireColumn(t, tx, "users", "wechat", "character varying", 100, false)
|
||||
requireColumn(t, tx, "users", "notes", "text", 0, false)
|
||||
|
||||
// accounts: schedulable and rate-limit fields
|
||||
requireColumn(t, tx, "accounts", "schedulable", "boolean", 0, false)
|
||||
requireColumn(t, tx, "accounts", "rate_limited_at", "timestamp with time zone", 0, true)
|
||||
requireColumn(t, tx, "accounts", "rate_limit_reset_at", "timestamp with time zone", 0, true)
|
||||
requireColumn(t, tx, "accounts", "overload_until", "timestamp with time zone", 0, true)
|
||||
requireColumn(t, tx, "accounts", "session_window_status", "character varying", 20, true)
|
||||
|
||||
// api_keys: key length should be 128
|
||||
requireColumn(t, tx, "api_keys", "key", "character varying", 128, false)
|
||||
|
||||
// redeem_codes: subscription fields
|
||||
requireColumn(t, tx, "redeem_codes", "group_id", "bigint", 0, true)
|
||||
requireColumn(t, tx, "redeem_codes", "validity_days", "integer", 0, false)
|
||||
|
||||
// usage_logs: billing_type used by filters/stats
|
||||
requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false)
|
||||
|
||||
// settings table should exist
|
||||
var settingsRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
||||
require.True(t, settingsRegclass.Valid, "expected settings table to exist")
|
||||
|
||||
// user_allowed_groups table should exist
|
||||
var uagRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass))
|
||||
require.True(t, uagRegclass.Valid, "expected user_allowed_groups table to exist")
|
||||
}
|
||||
|
||||
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
||||
t.Helper()
|
||||
|
||||
var row struct {
|
||||
DataType string
|
||||
MaxLen sql.NullInt64
|
||||
Nullable string
|
||||
}
|
||||
|
||||
err := tx.QueryRowContext(context.Background(), `
|
||||
SELECT
|
||||
data_type,
|
||||
character_maximum_length,
|
||||
is_nullable
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = $1
|
||||
AND column_name = $2
|
||||
`, table, column).Scan(&row.DataType, &row.MaxLen, &row.Nullable)
|
||||
require.NoError(t, err, "query information_schema.columns for %s.%s", table, column)
|
||||
require.Equal(t, dataType, row.DataType, "data_type mismatch for %s.%s", table, column)
|
||||
|
||||
if maxLen > 0 {
|
||||
require.True(t, row.MaxLen.Valid, "expected maxLen for %s.%s", table, column)
|
||||
require.Equal(t, int64(maxLen), row.MaxLen.Int64, "maxLen mismatch for %s.%s", table, column)
|
||||
}
|
||||
|
||||
if nullable {
|
||||
require.Equal(t, "YES", row.Nullable, "nullable mismatch for %s.%s", table, column)
|
||||
} else {
|
||||
require.Equal(t, "NO", row.Nullable, "nullable mismatch for %s.%s", table, column)
|
||||
}
|
||||
}
|
||||
@@ -2,52 +2,97 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
"database/sql"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type sqlQuerier interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
}
|
||||
|
||||
type proxyRepository struct {
|
||||
db *gorm.DB
|
||||
client *dbent.Client
|
||||
sql sqlQuerier
|
||||
}
|
||||
|
||||
func NewProxyRepository(db *gorm.DB) service.ProxyRepository {
|
||||
return &proxyRepository{db: db}
|
||||
func NewProxyRepository(client *dbent.Client, sqlDB *sql.DB) service.ProxyRepository {
|
||||
return newProxyRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func (r *proxyRepository) Create(ctx context.Context, proxy *service.Proxy) error {
|
||||
m := proxyModelFromService(proxy)
|
||||
err := r.db.WithContext(ctx).Create(m).Error
|
||||
func newProxyRepositoryWithSQL(client *dbent.Client, sqlq sqlQuerier) *proxyRepository {
|
||||
return &proxyRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *proxyRepository) Create(ctx context.Context, proxyIn *service.Proxy) error {
|
||||
builder := r.client.Proxy.Create().
|
||||
SetName(proxyIn.Name).
|
||||
SetProtocol(proxyIn.Protocol).
|
||||
SetHost(proxyIn.Host).
|
||||
SetPort(proxyIn.Port).
|
||||
SetStatus(proxyIn.Status)
|
||||
if proxyIn.Username != "" {
|
||||
builder.SetUsername(proxyIn.Username)
|
||||
}
|
||||
if proxyIn.Password != "" {
|
||||
builder.SetPassword(proxyIn.Password)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
applyProxyModelToService(proxy, m)
|
||||
applyProxyEntityToService(proxyIn, created)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy, error) {
|
||||
var m proxyModel
|
||||
err := r.db.WithContext(ctx).First(&m, id).Error
|
||||
m, err := r.client.Proxy.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil)
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrProxyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return proxyModelToService(&m), nil
|
||||
return proxyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) Update(ctx context.Context, proxy *service.Proxy) error {
|
||||
m := proxyModelFromService(proxy)
|
||||
err := r.db.WithContext(ctx).Save(m).Error
|
||||
func (r *proxyRepository) Update(ctx context.Context, proxyIn *service.Proxy) error {
|
||||
builder := r.client.Proxy.UpdateOneID(proxyIn.ID).
|
||||
SetName(proxyIn.Name).
|
||||
SetProtocol(proxyIn.Protocol).
|
||||
SetHost(proxyIn.Host).
|
||||
SetPort(proxyIn.Port).
|
||||
SetStatus(proxyIn.Status)
|
||||
if proxyIn.Username != "" {
|
||||
builder.SetUsername(proxyIn.Username)
|
||||
} else {
|
||||
builder.ClearUsername()
|
||||
}
|
||||
if proxyIn.Password != "" {
|
||||
builder.SetPassword(proxyIn.Password)
|
||||
} else {
|
||||
builder.ClearPassword()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
applyProxyModelToService(proxy, m)
|
||||
applyProxyEntityToService(proxyIn, updated)
|
||||
return nil
|
||||
}
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrProxyNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&proxyModel{}, id).Error
|
||||
_, err := r.client.Proxy.Delete().Where(proxy.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) {
|
||||
@@ -56,104 +101,111 @@ func (r *proxyRepository) List(ctx context.Context, params pagination.Pagination
|
||||
|
||||
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
|
||||
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) {
|
||||
var proxies []proxyModel
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&proxyModel{})
|
||||
|
||||
// Apply filters
|
||||
q := r.client.Proxy.Query()
|
||||
if protocol != "" {
|
||||
db = db.Where("protocol = ?", protocol)
|
||||
q = q.Where(proxy.ProtocolEQ(protocol))
|
||||
}
|
||||
if status != "" {
|
||||
db = db.Where("status = ?", status)
|
||||
q = q.Where(proxy.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
searchPattern := "%" + search + "%"
|
||||
db = db.Where("name ILIKE ?", searchPattern)
|
||||
q = q.Where(proxy.NameContainsFold(search))
|
||||
}
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&proxies).Error; err != nil {
|
||||
proxies, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(proxy.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outProxies := make([]service.Proxy, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
outProxies = append(outProxies, *proxyModelToService(&proxies[i]))
|
||||
outProxies = append(outProxies, *proxyEntityToService(proxies[i]))
|
||||
}
|
||||
|
||||
return outProxies, paginationResultFromTotal(total, params), nil
|
||||
return outProxies, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
|
||||
var proxies []proxyModel
|
||||
err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Find(&proxies).Error
|
||||
proxies, err := r.client.Proxy.Query().
|
||||
Where(proxy.StatusEQ(service.StatusActive)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outProxies := make([]service.Proxy, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
outProxies = append(outProxies, *proxyModelToService(&proxies[i]))
|
||||
outProxies = append(outProxies, *proxyEntityToService(proxies[i]))
|
||||
}
|
||||
return outProxies, nil
|
||||
}
|
||||
|
||||
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
|
||||
func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&proxyModel{}).
|
||||
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
q := r.client.Proxy.Query().
|
||||
Where(proxy.HostEQ(host), proxy.PortEQ(port))
|
||||
|
||||
if username == "" {
|
||||
q = q.Where(proxy.Or(proxy.UsernameIsNil(), proxy.UsernameEQ("")))
|
||||
} else {
|
||||
q = q.Where(proxy.UsernameEQ(username))
|
||||
}
|
||||
return count > 0, nil
|
||||
if password == "" {
|
||||
q = q.Where(proxy.Or(proxy.PasswordIsNil(), proxy.PasswordEQ("")))
|
||||
} else {
|
||||
q = q.Where(proxy.PasswordEQ(password))
|
||||
}
|
||||
|
||||
count, err := q.Count(ctx)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CountAccountsByProxyID returns the number of accounts using a specific proxy
|
||||
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||
row := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", proxyID)
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Table("accounts").
|
||||
Where("proxy_id = ?", proxyID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
if err := row.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
|
||||
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
|
||||
type result struct {
|
||||
ProxyID int64 `gorm:"column:proxy_id"`
|
||||
Count int64 `gorm:"column:count"`
|
||||
}
|
||||
var results []result
|
||||
err := r.db.WithContext(ctx).
|
||||
Table("accounts").
|
||||
Select("proxy_id, COUNT(*) as count").
|
||||
Where("proxy_id IS NOT NULL").
|
||||
Group("proxy_id").
|
||||
Scan(&results).Error
|
||||
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL GROUP BY proxy_id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
counts := make(map[int64]int64)
|
||||
for _, r := range results {
|
||||
counts[r.ProxyID] = r.Count
|
||||
for rows.Next() {
|
||||
var proxyID, count int64
|
||||
if err := rows.Scan(&proxyID, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[proxyID] = count
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
|
||||
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
|
||||
var proxies []proxyModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ?", service.StatusActive).
|
||||
Order("created_at DESC").
|
||||
Find(&proxies).Error
|
||||
proxies, err := r.client.Proxy.Query().
|
||||
Where(proxy.StatusEQ(service.StatusActive)).
|
||||
Order(dbent.Desc(proxy.FieldCreatedAt)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -167,76 +219,47 @@ func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]ser
|
||||
// Build result with account counts
|
||||
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
proxy := proxyModelToService(&proxies[i])
|
||||
if proxy == nil {
|
||||
proxyOut := proxyEntityToService(proxies[i])
|
||||
if proxyOut == nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, service.ProxyWithAccountCount{
|
||||
Proxy: *proxy,
|
||||
AccountCount: counts[proxy.ID],
|
||||
Proxy: *proxyOut,
|
||||
AccountCount: counts[proxyOut.ID],
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type proxyModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
Name string `gorm:"size:100;not null"`
|
||||
Protocol string `gorm:"size:20;not null"`
|
||||
Host string `gorm:"size:255;not null"`
|
||||
Port int `gorm:"not null"`
|
||||
Username string `gorm:"size:100"`
|
||||
Password string `gorm:"size:100"`
|
||||
Status string `gorm:"size:20;default:active;not null"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
UpdatedAt time.Time `gorm:"not null"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
func (proxyModel) TableName() string { return "proxies" }
|
||||
|
||||
func proxyModelToService(m *proxyModel) *service.Proxy {
|
||||
func proxyEntityToService(m *dbent.Proxy) *service.Proxy {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Proxy{
|
||||
out := &service.Proxy{
|
||||
ID: m.ID,
|
||||
Name: m.Name,
|
||||
Protocol: m.Protocol,
|
||||
Host: m.Host,
|
||||
Port: m.Port,
|
||||
Username: m.Username,
|
||||
Password: m.Password,
|
||||
Status: m.Status,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
if m.Username != nil {
|
||||
out.Username = *m.Username
|
||||
}
|
||||
if m.Password != nil {
|
||||
out.Password = *m.Password
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func proxyModelFromService(p *service.Proxy) *proxyModel {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return &proxyModel{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
Protocol: p.Protocol,
|
||||
Host: p.Host,
|
||||
Port: p.Port,
|
||||
Username: p.Username,
|
||||
Password: p.Password,
|
||||
Status: p.Status,
|
||||
CreatedAt: p.CreatedAt,
|
||||
UpdatedAt: p.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func applyProxyModelToService(proxy *service.Proxy, m *proxyModel) {
|
||||
if proxy == nil || m == nil {
|
||||
func applyProxyEntityToService(dst *service.Proxy, src *dbent.Proxy) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
proxy.ID = m.ID
|
||||
proxy.CreatedAt = m.CreatedAt
|
||||
proxy.UpdatedAt = m.UpdatedAt
|
||||
dst.ID = src.ID
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
|
||||
@@ -4,26 +4,27 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ProxyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *proxyRepository
|
||||
ctx context.Context
|
||||
sqlTx *sql.Tx
|
||||
repo *proxyRepository
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewProxyRepository(s.db).(*proxyRepository)
|
||||
entClient, sqlTx := testEntSQLTx(s.T())
|
||||
s.sqlTx = sqlTx
|
||||
s.repo = newProxyRepositoryWithSQL(entClient, sqlTx)
|
||||
}
|
||||
|
||||
func TestProxyRepoSuite(t *testing.T) {
|
||||
@@ -56,7 +57,14 @@ func (s *ProxyRepoSuite) TestGetByID_NotFound() {
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestUpdate() {
|
||||
proxy := proxyModelToService(mustCreateProxy(s.T(), s.db, &proxyModel{Name: "original"}))
|
||||
proxy := &service.Proxy{
|
||||
Name: "original",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, proxy))
|
||||
|
||||
proxy.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, proxy)
|
||||
@@ -68,7 +76,14 @@ func (s *ProxyRepoSuite) TestUpdate() {
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestDelete() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "to-delete"})
|
||||
proxy := &service.Proxy{
|
||||
Name: "to-delete",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, proxy))
|
||||
|
||||
err := s.repo.Delete(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
@@ -80,8 +95,8 @@ func (s *ProxyRepoSuite) TestDelete() {
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestList() {
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
|
||||
|
||||
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
@@ -90,8 +105,8 @@ func (s *ProxyRepoSuite) TestList() {
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Protocol: "http"})
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Protocol: "socks5"})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "socks5", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
|
||||
s.Require().NoError(err)
|
||||
@@ -100,8 +115,8 @@ func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Status() {
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Status: service.StatusActive})
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Status: service.StatusDisabled})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusDisabled})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "")
|
||||
s.Require().NoError(err)
|
||||
@@ -110,8 +125,8 @@ func (s *ProxyRepoSuite) TestListWithFilters_Status() {
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Search() {
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "production-proxy"})
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "dev-proxy"})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "production-proxy", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "dev-proxy", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
|
||||
s.Require().NoError(err)
|
||||
@@ -122,8 +137,8 @@ func (s *ProxyRepoSuite) TestListWithFilters_Search() {
|
||||
// --- ListActive ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestListActive() {
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "active1", Status: service.StatusActive})
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "inactive1", Status: service.StatusDisabled})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "active1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "inactive1", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusDisabled})
|
||||
|
||||
proxies, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
@@ -134,13 +149,14 @@ func (s *ProxyRepoSuite) TestListActive() {
|
||||
// --- ExistsByHostPortAuth ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{
|
||||
s.mustCreateProxy(&service.Proxy{
|
||||
Name: "p1",
|
||||
Protocol: "http",
|
||||
Host: "1.2.3.4",
|
||||
Port: 8080,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "user", "pass")
|
||||
@@ -153,13 +169,14 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{
|
||||
s.mustCreateProxy(&service.Proxy{
|
||||
Name: "p-noauth",
|
||||
Protocol: "http",
|
||||
Host: "5.6.7.8",
|
||||
Port: 8081,
|
||||
Username: "",
|
||||
Password: "",
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "5.6.7.8", 8081, "", "")
|
||||
@@ -170,10 +187,10 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
|
||||
// --- CountAccountsByProxyID ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-count"})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &proxy.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &proxy.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"}) // no proxy
|
||||
proxy := s.mustCreateProxy(&service.Proxy{Name: "p-count", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustInsertAccount("a1", &proxy.ID)
|
||||
s.mustInsertAccount("a2", &proxy.ID)
|
||||
s.mustInsertAccount("a3", nil) // no proxy
|
||||
|
||||
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "CountAccountsByProxyID")
|
||||
@@ -181,7 +198,7 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-zero"})
|
||||
proxy := s.mustCreateProxy(&service.Proxy{Name: "p-zero", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
|
||||
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err)
|
||||
@@ -191,12 +208,12 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
|
||||
// --- GetAccountCountsForProxies ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
|
||||
p1 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
|
||||
p2 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"})
|
||||
p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
|
||||
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
|
||||
s.mustInsertAccount("a1", &p1.ID)
|
||||
s.mustInsertAccount("a2", &p1.ID)
|
||||
s.mustInsertAccount("a3", &p2.ID)
|
||||
|
||||
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||
s.Require().NoError(err, "GetAccountCountsForProxies")
|
||||
@@ -215,24 +232,13 @@ func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
|
||||
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
|
||||
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
p1 := mustCreateProxy(s.T(), s.db, &proxyModel{
|
||||
Name: "p1",
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: base.Add(-1 * time.Hour),
|
||||
})
|
||||
p2 := mustCreateProxy(s.T(), s.db, &proxyModel{
|
||||
Name: "p2",
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: base,
|
||||
})
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{
|
||||
Name: "p3-inactive",
|
||||
Status: service.StatusDisabled,
|
||||
})
|
||||
p1 := s.mustCreateProxyWithTimes("p1", service.StatusActive, base.Add(-1*time.Hour))
|
||||
p2 := s.mustCreateProxyWithTimes("p2", service.StatusActive, base)
|
||||
s.mustCreateProxyWithTimes("p3-inactive", service.StatusDisabled, base.Add(1*time.Hour))
|
||||
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
|
||||
s.mustInsertAccount("a1", &p1.ID)
|
||||
s.mustInsertAccount("a2", &p1.ID)
|
||||
s.mustInsertAccount("a3", &p2.ID)
|
||||
|
||||
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
|
||||
s.Require().NoError(err, "ListActiveWithAccountCount")
|
||||
@@ -248,34 +254,16 @@ func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
|
||||
// --- Combined original test ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
|
||||
p1 := mustCreateProxy(s.T(), s.db, &proxyModel{
|
||||
Name: "p1",
|
||||
Protocol: "http",
|
||||
Host: "1.2.3.4",
|
||||
Port: 8080,
|
||||
Username: "u",
|
||||
Password: "p",
|
||||
CreatedAt: time.Now().Add(-1 * time.Hour),
|
||||
UpdatedAt: time.Now().Add(-1 * time.Hour),
|
||||
})
|
||||
p2 := mustCreateProxy(s.T(), s.db, &proxyModel{
|
||||
Name: "p2",
|
||||
Protocol: "http",
|
||||
Host: "5.6.7.8",
|
||||
Port: 8081,
|
||||
Username: "",
|
||||
Password: "",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
})
|
||||
p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "1.2.3.4", Port: 8080, Username: "u", Password: "p", Status: service.StatusActive})
|
||||
p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "5.6.7.8", Port: 8081, Username: "", Password: "", Status: service.StatusActive})
|
||||
|
||||
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "u", "p")
|
||||
s.Require().NoError(err, "ExistsByHostPortAuth")
|
||||
s.Require().True(exists, "expected proxy to exist")
|
||||
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
|
||||
s.mustInsertAccount("a1", &p1.ID)
|
||||
s.mustInsertAccount("a2", &p1.ID)
|
||||
s.mustInsertAccount("a3", &p2.ID)
|
||||
|
||||
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
|
||||
s.Require().NoError(err, "CountAccountsByProxyID")
|
||||
@@ -300,3 +288,42 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) mustCreateProxy(p *service.Proxy) *service.Proxy {
|
||||
s.T().Helper()
|
||||
s.Require().NoError(s.repo.Create(s.ctx, p), "create proxy")
|
||||
return p
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) mustCreateProxyWithTimes(name, status string, createdAt time.Time) *service.Proxy {
|
||||
s.T().Helper()
|
||||
|
||||
// Use the repository create for standard fields, then update timestamps via raw SQL to keep deterministic ordering.
|
||||
p := s.mustCreateProxy(&service.Proxy{
|
||||
Name: name,
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: status,
|
||||
})
|
||||
_, err := s.sqlTx.ExecContext(s.ctx, "UPDATE proxies SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, p.ID)
|
||||
s.Require().NoError(err, "update proxy timestamps")
|
||||
return p
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) mustInsertAccount(name string, proxyID *int64) {
|
||||
s.T().Helper()
|
||||
var pid any
|
||||
if proxyID != nil {
|
||||
pid = *proxyID
|
||||
}
|
||||
_, err := s.sqlTx.ExecContext(
|
||||
s.ctx,
|
||||
"INSERT INTO accounts (name, platform, type, proxy_id) VALUES ($1, $2, $3, $4)",
|
||||
name,
|
||||
service.PlatformAnthropic,
|
||||
service.AccountTypeOAuth,
|
||||
pid,
|
||||
)
|
||||
s.Require().NoError(err, "insert account")
|
||||
}
|
||||
|
||||
@@ -4,25 +4,35 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type redeemCodeRepository struct {
|
||||
db *gorm.DB
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository {
|
||||
return &redeemCodeRepository{db: db}
|
||||
func NewRedeemCodeRepository(client *dbent.Client) service.RedeemCodeRepository {
|
||||
return &redeemCodeRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Create(ctx context.Context, code *service.RedeemCode) error {
|
||||
m := redeemCodeModelFromService(code)
|
||||
err := r.db.WithContext(ctx).Create(m).Error
|
||||
created, err := r.client.RedeemCode.Create().
|
||||
SetCode(code.Code).
|
||||
SetType(code.Type).
|
||||
SetValue(code.Value).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes).
|
||||
SetValidityDays(code.ValidityDays).
|
||||
SetNillableUsedBy(code.UsedBy).
|
||||
SetNillableUsedAt(code.UsedAt).
|
||||
SetNillableGroupID(code.GroupID).
|
||||
Save(ctx)
|
||||
if err == nil {
|
||||
applyRedeemCodeModelToService(code, m)
|
||||
code.ID = created.ID
|
||||
code.CreatedAt = created.CreatedAt
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -31,36 +41,55 @@ func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []service.
|
||||
if len(codes) == 0 {
|
||||
return nil
|
||||
}
|
||||
models := make([]redeemCodeModel, 0, len(codes))
|
||||
|
||||
builders := make([]*dbent.RedeemCodeCreate, 0, len(codes))
|
||||
for i := range codes {
|
||||
m := redeemCodeModelFromService(&codes[i])
|
||||
if m != nil {
|
||||
models = append(models, *m)
|
||||
}
|
||||
c := &codes[i]
|
||||
b := r.client.RedeemCode.Create().
|
||||
SetCode(c.Code).
|
||||
SetType(c.Type).
|
||||
SetValue(c.Value).
|
||||
SetStatus(c.Status).
|
||||
SetNotes(c.Notes).
|
||||
SetValidityDays(c.ValidityDays).
|
||||
SetNillableUsedBy(c.UsedBy).
|
||||
SetNillableUsedAt(c.UsedAt).
|
||||
SetNillableGroupID(c.GroupID)
|
||||
builders = append(builders, b)
|
||||
}
|
||||
return r.db.WithContext(ctx).Create(&models).Error
|
||||
|
||||
return r.client.RedeemCode.CreateBulk(builders...).Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) {
|
||||
var m redeemCodeModel
|
||||
err := r.db.WithContext(ctx).First(&m, id).Error
|
||||
m, err := r.client.RedeemCode.Query().
|
||||
Where(redeemcode.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrRedeemCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return redeemCodeModelToService(&m), nil
|
||||
return redeemCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
|
||||
var m redeemCodeModel
|
||||
err := r.db.WithContext(ctx).Where("code = ?", code).First(&m).Error
|
||||
m, err := r.client.RedeemCode.Query().
|
||||
Where(redeemcode.CodeEQ(code)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrRedeemCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return redeemCodeModelToService(&m), nil
|
||||
return redeemCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&redeemCodeModel{}, id).Error
|
||||
_, err := r.client.RedeemCode.Delete().Where(redeemcode.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
@@ -68,61 +97,88 @@ func (r *redeemCodeRepository) List(ctx context.Context, params pagination.Pagin
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
var codes []redeemCodeModel
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&redeemCodeModel{})
|
||||
q := r.client.RedeemCode.Query()
|
||||
|
||||
if codeType != "" {
|
||||
db = db.Where("type = ?", codeType)
|
||||
q = q.Where(redeemcode.TypeEQ(codeType))
|
||||
}
|
||||
if status != "" {
|
||||
db = db.Where("status = ?", status)
|
||||
q = q.Where(redeemcode.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
searchPattern := "%" + search + "%"
|
||||
db = db.Where("code ILIKE ?", searchPattern)
|
||||
q = q.Where(redeemcode.CodeContainsFold(search))
|
||||
}
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := db.Preload("User").Preload("Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&codes).Error; err != nil {
|
||||
codes, err := q.
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(redeemcode.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outCodes := make([]service.RedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
outCodes = append(outCodes, *redeemCodeModelToService(&codes[i]))
|
||||
}
|
||||
outCodes := redeemCodeEntitiesToService(codes)
|
||||
|
||||
return outCodes, paginationResultFromTotal(total, params), nil
|
||||
return outCodes, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
|
||||
m := redeemCodeModelFromService(code)
|
||||
err := r.db.WithContext(ctx).Save(m).Error
|
||||
if err == nil {
|
||||
applyRedeemCodeModelToService(code, m)
|
||||
up := r.client.RedeemCode.UpdateOneID(code.ID).
|
||||
SetCode(code.Code).
|
||||
SetType(code.Type).
|
||||
SetValue(code.Value).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes).
|
||||
SetValidityDays(code.ValidityDays)
|
||||
|
||||
if code.UsedBy != nil {
|
||||
up.SetUsedBy(*code.UsedBy)
|
||||
} else {
|
||||
up.ClearUsedBy()
|
||||
}
|
||||
return err
|
||||
if code.UsedAt != nil {
|
||||
up.SetUsedAt(*code.UsedAt)
|
||||
} else {
|
||||
up.ClearUsedAt()
|
||||
}
|
||||
if code.GroupID != nil {
|
||||
up.SetGroupID(*code.GroupID)
|
||||
} else {
|
||||
up.ClearGroupID()
|
||||
}
|
||||
|
||||
updated, err := up.Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrRedeemCodeNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
code.CreatedAt = updated.CreatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
|
||||
now := time.Now()
|
||||
result := r.db.WithContext(ctx).Model(&redeemCodeModel{}).
|
||||
Where("id = ? AND status = ?", id, service.StatusUnused).
|
||||
Updates(map[string]any{
|
||||
"status": service.StatusUsed,
|
||||
"used_by": userID,
|
||||
"used_at": now,
|
||||
})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
affected, err := r.client.RedeemCode.Update().
|
||||
Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetUsedBy(userID).
|
||||
SetUsedAt(now).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return service.ErrRedeemCodeUsed.WithCause(gorm.ErrRecordNotFound)
|
||||
if affected == 0 {
|
||||
return service.ErrRedeemCodeUsed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -132,49 +188,24 @@ func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, lim
|
||||
limit = 10
|
||||
}
|
||||
|
||||
var codes []redeemCodeModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("used_by = ?", userID).
|
||||
Order("used_at DESC").
|
||||
codes, err := r.client.RedeemCode.Query().
|
||||
Where(redeemcode.UsedByEQ(userID)).
|
||||
WithGroup().
|
||||
Order(dbent.Desc(redeemcode.FieldUsedAt)).
|
||||
Limit(limit).
|
||||
Find(&codes).Error
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outCodes := make([]service.RedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
outCodes = append(outCodes, *redeemCodeModelToService(&codes[i]))
|
||||
}
|
||||
return outCodes, nil
|
||||
return redeemCodeEntitiesToService(codes), nil
|
||||
}
|
||||
|
||||
type redeemCodeModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
Code string `gorm:"uniqueIndex;size:32;not null"`
|
||||
Type string `gorm:"size:20;default:balance;not null"`
|
||||
Value float64 `gorm:"type:decimal(20,8);not null"`
|
||||
Status string `gorm:"size:20;default:unused;not null"`
|
||||
UsedBy *int64 `gorm:"index"`
|
||||
UsedAt *time.Time
|
||||
Notes string `gorm:"type:text"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
|
||||
GroupID *int64 `gorm:"index"`
|
||||
ValidityDays int `gorm:"default:30"`
|
||||
|
||||
User *userModel `gorm:"foreignKey:UsedBy"`
|
||||
Group *groupModel `gorm:"foreignKey:GroupID"`
|
||||
}
|
||||
|
||||
func (redeemCodeModel) TableName() string { return "redeem_codes" }
|
||||
|
||||
func redeemCodeModelToService(m *redeemCodeModel) *service.RedeemCode {
|
||||
func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.RedeemCode{
|
||||
out := &service.RedeemCode{
|
||||
ID: m.ID,
|
||||
Code: m.Code,
|
||||
Type: m.Type,
|
||||
@@ -182,38 +213,26 @@ func redeemCodeModelToService(m *redeemCodeModel) *service.RedeemCode {
|
||||
Status: m.Status,
|
||||
UsedBy: m.UsedBy,
|
||||
UsedAt: m.UsedAt,
|
||||
Notes: m.Notes,
|
||||
Notes: derefString(m.Notes),
|
||||
CreatedAt: m.CreatedAt,
|
||||
GroupID: m.GroupID,
|
||||
ValidityDays: m.ValidityDays,
|
||||
User: userModelToService(m.User),
|
||||
Group: groupModelToService(m.Group),
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
if m.Edges.Group != nil {
|
||||
out.Group = groupEntityToService(m.Edges.Group)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func redeemCodeModelFromService(r *service.RedeemCode) *redeemCodeModel {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return &redeemCodeModel{
|
||||
ID: r.ID,
|
||||
Code: r.Code,
|
||||
Type: r.Type,
|
||||
Value: r.Value,
|
||||
Status: r.Status,
|
||||
UsedBy: r.UsedBy,
|
||||
UsedAt: r.UsedAt,
|
||||
Notes: r.Notes,
|
||||
CreatedAt: r.CreatedAt,
|
||||
GroupID: r.GroupID,
|
||||
ValidityDays: r.ValidityDays,
|
||||
func redeemCodeEntitiesToService(models []*dbent.RedeemCode) []service.RedeemCode {
|
||||
out := make([]service.RedeemCode, 0, len(models))
|
||||
for i := range models {
|
||||
if s := redeemCodeEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func applyRedeemCodeModelToService(code *service.RedeemCode, m *redeemCodeModel) {
|
||||
if code == nil || m == nil {
|
||||
return
|
||||
}
|
||||
code.ID = m.ID
|
||||
code.CreatedAt = m.CreatedAt
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -7,29 +7,47 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type RedeemCodeRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *redeemCodeRepository
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *redeemCodeRepository
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewRedeemCodeRepository(s.db).(*redeemCodeRepository)
|
||||
entClient, _ := testEntSQLTx(s.T())
|
||||
s.client = entClient
|
||||
s.repo = NewRedeemCodeRepository(entClient).(*redeemCodeRepository)
|
||||
}
|
||||
|
||||
func TestRedeemCodeRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(RedeemCodeRepoSuite))
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) createUser(email string) *dbent.User {
|
||||
u, err := s.client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash("test-password-hash").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create user")
|
||||
return u
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) createGroup(name string) *dbent.Group {
|
||||
g, err := s.client.Group.Create().
|
||||
SetName(name).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create group")
|
||||
return g
|
||||
}
|
||||
|
||||
// --- Create / CreateBatch / GetByID / GetByCode ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestCreate() {
|
||||
@@ -70,10 +88,19 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch() {
|
||||
func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeNotFound)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestGetByCode() {
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "GET-BY-CODE", Type: service.RedeemTypeBalance})
|
||||
_, err := s.client.RedeemCode.Create().
|
||||
SetCode("GET-BY-CODE").
|
||||
SetType(service.RedeemTypeBalance).
|
||||
SetStatus(service.StatusUnused).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "seed redeem code")
|
||||
|
||||
got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE")
|
||||
s.Require().NoError(err, "GetByCode")
|
||||
@@ -83,25 +110,35 @@ func (s *RedeemCodeRepoSuite) TestGetByCode() {
|
||||
func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() {
|
||||
_, err := s.repo.GetByCode(s.ctx, "NON-EXISTENT")
|
||||
s.Require().Error(err, "expected error for non-existent code")
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeNotFound)
|
||||
}
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestDelete() {
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TO-DELETE", Type: service.RedeemTypeBalance})
|
||||
created, err := s.client.RedeemCode.Create().
|
||||
SetCode("TO-DELETE").
|
||||
SetType(service.RedeemTypeBalance).
|
||||
SetStatus(service.StatusUnused).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err := s.repo.Delete(s.ctx, code.ID)
|
||||
err = s.repo.Delete(s.ctx, created.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, code.ID)
|
||||
_, err = s.repo.GetByID(s.ctx, created.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeNotFound)
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestList() {
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-1", Type: service.RedeemTypeBalance})
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-2", Type: service.RedeemTypeBalance})
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "LIST-1", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "LIST-2", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
|
||||
codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
@@ -110,8 +147,8 @@ func (s *RedeemCodeRepoSuite) TestList() {
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-BAL", Type: service.RedeemTypeBalance})
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-SUB", Type: service.RedeemTypeSubscription})
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "TYPE-BAL", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "TYPE-SUB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused}))
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, "", "")
|
||||
s.Require().NoError(err)
|
||||
@@ -120,8 +157,8 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-UNUSED", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed})
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-UNUSED", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-USED", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUsed}))
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusUsed, "")
|
||||
s.Require().NoError(err)
|
||||
@@ -130,8 +167,8 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance})
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "BETA-CODE", Type: service.RedeemTypeBalance})
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "BETA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha")
|
||||
s.Require().NoError(err)
|
||||
@@ -140,12 +177,17 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"})
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
|
||||
Code: "WITH-GROUP",
|
||||
Type: service.RedeemTypeSubscription,
|
||||
GroupID: &group.ID,
|
||||
})
|
||||
group := s.createGroup(uniqueTestValue(s.T(), "g-preload"))
|
||||
_, err := s.client.RedeemCode.Create().
|
||||
SetCode("WITH-GROUP").
|
||||
SetType(service.RedeemTypeSubscription).
|
||||
SetStatus(service.StatusUnused).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
SetGroupID(group.ID).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "")
|
||||
s.Require().NoError(err)
|
||||
@@ -157,7 +199,13 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
|
||||
// --- Update ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUpdate() {
|
||||
code := redeemCodeModelToService(mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "UPDATE-ME", Type: service.RedeemTypeBalance, Value: 10}))
|
||||
code := &service.RedeemCode{
|
||||
Code: "UPDATE-ME",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 10,
|
||||
Status: service.StatusUnused,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, code))
|
||||
|
||||
code.Value = 50
|
||||
err := s.repo.Update(s.ctx, code)
|
||||
@@ -171,8 +219,9 @@ func (s *RedeemCodeRepoSuite) TestUpdate() {
|
||||
// --- Use ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "use@test.com"})
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "USE-ME", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
|
||||
user := s.createUser(uniqueTestValue(s.T(), "use") + "@example.com")
|
||||
code := &service.RedeemCode{Code: "USE-ME", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, code))
|
||||
|
||||
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().NoError(err, "Use")
|
||||
@@ -186,8 +235,9 @@ func (s *RedeemCodeRepoSuite) TestUse() {
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "idem@test.com"})
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "IDEM-CODE", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
|
||||
user := s.createUser(uniqueTestValue(s.T(), "idem") + "@example.com")
|
||||
code := &service.RedeemCode{Code: "IDEM-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, code))
|
||||
|
||||
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().NoError(err, "Use first time")
|
||||
@@ -199,8 +249,9 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "already@test.com"})
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALREADY-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed})
|
||||
user := s.createUser(uniqueTestValue(s.T(), "already") + "@example.com")
|
||||
code := &service.RedeemCode{Code: "ALREADY-USED", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUsed}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, code))
|
||||
|
||||
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().Error(err, "expected error for already used code")
|
||||
@@ -210,25 +261,34 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
|
||||
// --- ListByUser ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListByUser() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"})
|
||||
user := s.createUser(uniqueTestValue(s.T(), "listby") + "@example.com")
|
||||
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Create codes with explicit used_at for ordering
|
||||
c1 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
|
||||
Code: "USER-1",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Status: service.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
})
|
||||
s.db.Model(c1).Update("used_at", base)
|
||||
usedAt1 := base
|
||||
_, err := s.client.RedeemCode.Create().
|
||||
SetCode("USER-1").
|
||||
SetType(service.RedeemTypeBalance).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
SetUsedBy(user.ID).
|
||||
SetUsedAt(usedAt1).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
c2 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
|
||||
Code: "USER-2",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Status: service.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
})
|
||||
s.db.Model(c2).Update("used_at", base.Add(1*time.Hour))
|
||||
usedAt2 := base.Add(1 * time.Hour)
|
||||
_, err = s.client.RedeemCode.Create().
|
||||
SetCode("USER-2").
|
||||
SetType(service.RedeemTypeBalance).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
SetUsedBy(user.ID).
|
||||
SetUsedAt(usedAt2).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err, "ListByUser")
|
||||
@@ -239,17 +299,21 @@ func (s *RedeemCodeRepoSuite) TestListByUser() {
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "grp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listby"})
|
||||
user := s.createUser(uniqueTestValue(s.T(), "grp") + "@example.com")
|
||||
group := s.createGroup(uniqueTestValue(s.T(), "g-listby"))
|
||||
|
||||
c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
|
||||
Code: "WITH-GRP",
|
||||
Type: service.RedeemTypeSubscription,
|
||||
Status: service.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
GroupID: &group.ID,
|
||||
})
|
||||
s.db.Model(c).Update("used_at", time.Now())
|
||||
_, err := s.client.RedeemCode.Create().
|
||||
SetCode("WITH-GRP").
|
||||
SetType(service.RedeemTypeSubscription).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
SetUsedBy(user.ID).
|
||||
SetUsedAt(time.Now()).
|
||||
SetGroupID(group.ID).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err)
|
||||
@@ -259,14 +323,18 @@ func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "deflimit@test.com"})
|
||||
c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
|
||||
Code: "DEF-LIM",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Status: service.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
})
|
||||
s.db.Model(c).Update("used_at", time.Now())
|
||||
user := s.createUser(uniqueTestValue(s.T(), "deflimit") + "@example.com")
|
||||
_, err := s.client.RedeemCode.Create().
|
||||
SetCode("DEF-LIM").
|
||||
SetType(service.RedeemTypeBalance).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
SetUsedBy(user.ID).
|
||||
SetUsedAt(time.Now()).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// limit <= 0 should default to 10
|
||||
codes, err := s.repo.ListByUser(s.ctx, user.ID, 0)
|
||||
@@ -277,12 +345,13 @@ func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
|
||||
// --- Combined original test ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "rc@example.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-rc"})
|
||||
user := s.createUser(uniqueTestValue(s.T(), "rc") + "@example.com")
|
||||
group := s.createGroup(uniqueTestValue(s.T(), "g-rc"))
|
||||
groupID := group.ID
|
||||
|
||||
codes := []service.RedeemCode{
|
||||
{Code: "CODEA", Type: service.RedeemTypeBalance, Value: 1, Status: service.StatusUnused, CreatedAt: time.Now()},
|
||||
{Code: "CODEB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()},
|
||||
{Code: "CODEA", Type: service.RedeemTypeBalance, Value: 1, Status: service.StatusUnused, Notes: ""},
|
||||
{Code: "CODEB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused, Notes: "", GroupID: &groupID, ValidityDays: 7},
|
||||
}
|
||||
s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch")
|
||||
|
||||
@@ -303,10 +372,16 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser
|
||||
codeA, err := s.repo.GetByCode(s.ctx, "CODEA")
|
||||
s.Require().NoError(err, "GetByCode")
|
||||
|
||||
// Use fixed time instead of time.Sleep for deterministic ordering
|
||||
s.db.Model(&redeemCodeModel{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
// Use fixed time instead of time.Sleep for deterministic ordering.
|
||||
_, err = s.client.RedeemCode.UpdateOneID(codeB.ID).
|
||||
SetUsedAt(time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA")
|
||||
s.db.Model(&redeemCodeModel{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC))
|
||||
_, err = s.client.RedeemCode.UpdateOneID(codeA.ID).
|
||||
SetUsedAt(time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC)).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
used, err := s.repo.ListByUser(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err, "ListByUser")
|
||||
|
||||
@@ -4,27 +4,33 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type settingRepository struct {
|
||||
db *gorm.DB
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
func NewSettingRepository(db *gorm.DB) service.SettingRepository {
|
||||
return &settingRepository{db: db}
|
||||
func NewSettingRepository(client *ent.Client) service.SettingRepository {
|
||||
return &settingRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *settingRepository) Get(ctx context.Context, key string) (*service.Setting, error) {
|
||||
var m settingModel
|
||||
err := r.db.WithContext(ctx).Where("key = ?", key).First(&m).Error
|
||||
m, err := r.client.Setting.Query().Where(setting.KeyEQ(key)).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil)
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, service.ErrSettingNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return settingModelToService(&m), nil
|
||||
return &service.Setting{
|
||||
ID: m.ID,
|
||||
Key: m.Key,
|
||||
Value: m.Value,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) {
|
||||
@@ -36,21 +42,22 @@ func (r *settingRepository) GetValue(ctx context.Context, key string) (string, e
|
||||
}
|
||||
|
||||
func (r *settingRepository) Set(ctx context.Context, key, value string) error {
|
||||
m := &settingModel{
|
||||
Key: key,
|
||||
Value: value,
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "key"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
|
||||
}).Create(m).Error
|
||||
now := time.Now()
|
||||
return r.client.Setting.
|
||||
Create().
|
||||
SetKey(key).
|
||||
SetValue(value).
|
||||
SetUpdatedAt(now).
|
||||
OnConflictColumns(setting.FieldKey).
|
||||
UpdateNewValues().
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
var settings []settingModel
|
||||
err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error
|
||||
if len(keys) == 0 {
|
||||
return map[string]string{}, nil
|
||||
}
|
||||
settings, err := r.client.Setting.Query().Where(setting.KeyIn(keys...)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -63,27 +70,24 @@ func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map
|
||||
}
|
||||
|
||||
func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for key, value := range settings {
|
||||
m := &settingModel{
|
||||
Key: key,
|
||||
Value: value,
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
if err := tx.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "key"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
|
||||
}).Create(m).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(settings) == 0 {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
builders := make([]*ent.SettingCreate, 0, len(settings))
|
||||
for key, value := range settings {
|
||||
builders = append(builders, r.client.Setting.Create().SetKey(key).SetValue(value).SetUpdatedAt(now))
|
||||
}
|
||||
return r.client.Setting.
|
||||
CreateBulk(builders...).
|
||||
OnConflictColumns(setting.FieldKey).
|
||||
UpdateNewValues().
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
var settings []settingModel
|
||||
err := r.db.WithContext(ctx).Find(&settings).Error
|
||||
settings, err := r.client.Setting.Query().All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -96,26 +100,6 @@ func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, erro
|
||||
}
|
||||
|
||||
func (r *settingRepository) Delete(ctx context.Context, key string) error {
|
||||
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&settingModel{}).Error
|
||||
}
|
||||
|
||||
type settingModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
Key string `gorm:"uniqueIndex;size:100;not null"`
|
||||
Value string `gorm:"type:text;not null"`
|
||||
UpdatedAt time.Time `gorm:"not null"`
|
||||
}
|
||||
|
||||
func (settingModel) TableName() string { return "settings" }
|
||||
|
||||
func settingModelToService(m *settingModel) *service.Setting {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Setting{
|
||||
ID: m.ID,
|
||||
Key: m.Key,
|
||||
Value: m.Value,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
_, err := r.client.Setting.Delete().Where(setting.KeyEQ(key)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -8,20 +8,18 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type SettingRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *settingRepository
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewSettingRepository(s.db).(*settingRepository)
|
||||
entClient, _ := testEntSQLTx(s.T())
|
||||
s.repo = NewSettingRepository(entClient).(*settingRepository)
|
||||
}
|
||||
|
||||
func TestSettingRepoSuite(t *testing.T) {
|
||||
|
||||
110
backend/internal/repository/soft_delete_ent_integration_test.go
Normal file
110
backend/internal/repository/soft_delete_ent_integration_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func uniqueSoftDeleteValue(t *testing.T, prefix string) string {
|
||||
t.Helper()
|
||||
safeName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
|
||||
return fmt.Sprintf("%s-%s", prefix, safeName)
|
||||
}
|
||||
|
||||
func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, email string) *dbent.User {
|
||||
t.Helper()
|
||||
|
||||
u, err := client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash("test-password-hash").
|
||||
Save(ctx)
|
||||
require.NoError(t, err, "create ent user")
|
||||
return u
|
||||
}
|
||||
|
||||
func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client, _ := testEntSQLTx(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
|
||||
|
||||
repo := NewApiKeyRepository(client)
|
||||
key := &service.ApiKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
|
||||
Name: "soft-delete",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, key), "create api key")
|
||||
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
|
||||
|
||||
_, err := repo.GetByID(ctx, key.ID)
|
||||
require.ErrorIs(t, err, service.ErrApiKeyNotFound, "deleted rows should be hidden by default")
|
||||
|
||||
_, err = client.ApiKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
|
||||
require.Error(t, err, "default ent query should not see soft-deleted rows")
|
||||
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
|
||||
|
||||
got, err := client.ApiKey.Query().
|
||||
Where(apikey.IDEQ(key.ID)).
|
||||
Only(mixins.SkipSoftDelete(ctx))
|
||||
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
|
||||
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
|
||||
}
|
||||
|
||||
func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client, _ := testEntSQLTx(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
|
||||
|
||||
repo := NewApiKeyRepository(client)
|
||||
key := &service.ApiKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
|
||||
Name: "soft-delete2",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, key), "create api key")
|
||||
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "first delete")
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "second delete should be idempotent")
|
||||
}
|
||||
|
||||
func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client, _ := testEntSQLTx(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
|
||||
|
||||
repo := NewApiKeyRepository(client)
|
||||
key := &service.ApiKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),
|
||||
Name: "soft-delete3",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, key), "create api key")
|
||||
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
|
||||
|
||||
// Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at.
|
||||
_, err := client.ApiKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
|
||||
require.NoError(t, err, "hard delete")
|
||||
|
||||
_, err = client.ApiKey.Query().
|
||||
Where(apikey.IDEQ(key.ID)).
|
||||
Only(mixins.SkipSoftDelete(ctx))
|
||||
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,35 +4,39 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UsageLogRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *usageLogRepository
|
||||
ctx context.Context
|
||||
tx *sql.Tx
|
||||
client *dbent.Client
|
||||
repo *usageLogRepository
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewUsageLogRepository(s.db).(*usageLogRepository)
|
||||
client, tx := testEntSQLTx(s.T())
|
||||
s.client = client
|
||||
s.tx = tx
|
||||
s.repo = newUsageLogRepositoryWithSQL(client, tx)
|
||||
}
|
||||
|
||||
func TestUsageLogRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(UsageLogRepoSuite))
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) createUsageLog(user *userModel, apiKey *apiKeyModel, account *accountModel, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
|
||||
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.ApiKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
@@ -51,9 +55,9 @@ func (s *UsageLogRepoSuite) createUsageLog(user *userModel, apiKey *apiKeyModel,
|
||||
// --- Create / GetByID ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestCreate() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-create", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-create"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "create@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-create"})
|
||||
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
@@ -72,9 +76,9 @@ func (s *UsageLogRepoSuite) TestCreate() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbyid@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-getbyid"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid"})
|
||||
|
||||
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
@@ -92,9 +96,9 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
|
||||
// --- Delete ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestDelete() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-delete", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-delete"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "delete@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-delete"})
|
||||
|
||||
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
@@ -108,9 +112,9 @@ func (s *UsageLogRepoSuite) TestDelete() {
|
||||
// --- ListByUser ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByUser() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyuser"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyuser@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyuser"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
|
||||
@@ -124,9 +128,9 @@ func (s *UsageLogRepoSuite) TestListByUser() {
|
||||
// --- ListByApiKey ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByApiKey() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyapikey@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyapikey"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyapikey@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyapikey"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
|
||||
@@ -140,9 +144,9 @@ func (s *UsageLogRepoSuite) TestListByApiKey() {
|
||||
// --- ListByAccount ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByAccount() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyaccount@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyaccount"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyaccount@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyaccount"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
@@ -155,9 +159,9 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
|
||||
// --- GetUserStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "userstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userstats"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -175,9 +179,9 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
|
||||
// --- ListWithFilters ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filters"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filters"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
@@ -194,26 +198,26 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
now := time.Now()
|
||||
todayStart := timezone.Today()
|
||||
|
||||
userToday := mustCreateUser(s.T(), s.db, &userModel{
|
||||
userToday := mustCreateUser(s.T(), s.client, &service.User{
|
||||
Email: "today@example.com",
|
||||
CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
|
||||
UpdatedAt: now,
|
||||
})
|
||||
userOld := mustCreateUser(s.T(), s.db, &userModel{
|
||||
userOld := mustCreateUser(s.T(), s.client, &service.User{
|
||||
Email: "old@example.com",
|
||||
CreatedAt: todayStart.Add(-24 * time.Hour),
|
||||
UpdatedAt: todayStart.Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-ul"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-ul"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
|
||||
mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
|
||||
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
accNormal := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-normal", Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-error", Status: service.StatusError, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
|
||||
accNormal := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-normal", Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-error", Status: service.StatusError, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
|
||||
|
||||
d1, d2, d3 := 100, 200, 300
|
||||
logToday := &service.UsageLog{
|
||||
@@ -285,7 +289,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1")
|
||||
s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0")
|
||||
|
||||
wantRpm, wantTpm := s.repo.getPerformanceStats(s.ctx, 0)
|
||||
wantRpm, wantTpm, err := s.repo.getPerformanceStats(s.ctx, 0)
|
||||
s.Require().NoError(err, "getPerformanceStats")
|
||||
s.Require().Equal(wantRpm, stats.Rpm, "Rpm mismatch")
|
||||
s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch")
|
||||
}
|
||||
@@ -293,9 +298,9 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
// --- GetUserDashboardStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "userdash@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userdash", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userdash"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userdash@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userdash"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
@@ -308,9 +313,9 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
||||
// --- GetAccountTodayStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctoday@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-today"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctoday@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
@@ -323,11 +328,11 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
||||
// --- GetBatchUserUsageStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batch"})
|
||||
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batch"})
|
||||
|
||||
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||
@@ -348,10 +353,10 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
|
||||
// --- GetBatchApiKeyUsageStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "batchkey@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batchkey"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "batchkey@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batchkey"})
|
||||
|
||||
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||
@@ -370,9 +375,9 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
|
||||
// --- GetGlobalStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "global@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-global", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-global"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "global@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-global"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -395,9 +400,9 @@ func maxTime(a, b time.Time) time.Time {
|
||||
// --- ListByUserAndTimeRange ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "timerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-timerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-timerange"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "timerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-timerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -414,9 +419,9 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
|
||||
// --- ListByApiKeyAndTimeRange ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytimerange"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -433,9 +438,9 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
|
||||
// --- ListByAccountAndTimeRange ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-acctimerange"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-acctimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -452,9 +457,9 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
|
||||
// --- ListByModelAndTimeRange ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "modeltimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modeltimerange"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modeltimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modeltimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
@@ -508,9 +513,9 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
// --- GetAccountWindowStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "windowstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-windowstats"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "windowstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-windowstats"})
|
||||
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-10 * time.Minute)
|
||||
@@ -528,9 +533,9 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
|
||||
// --- GetUserUsageTrendByUserID ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrend"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrend"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -545,9 +550,9 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrendhourly@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrendhourly"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrendhourly@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrendhourly"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -564,9 +569,9 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
|
||||
// --- GetUserModelStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelstats"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
@@ -611,9 +616,9 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
// --- GetUsageTrendWithFilters ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -639,9 +644,9 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters-h@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters-h"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters-h@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters-h"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -658,9 +663,9 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
||||
// --- GetModelStatsWithFilters ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelfilters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelfilters"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelfilters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelfilters"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
@@ -712,9 +717,9 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
// --- GetAccountUsageStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "accstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-accstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-accstats"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "accstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-accstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
@@ -758,7 +763,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-emptystats"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-emptystats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
|
||||
startTime := base
|
||||
@@ -774,11 +779,11 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
|
||||
// --- GetUserUsageTrend ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrends"})
|
||||
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrends"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base)
|
||||
@@ -796,10 +801,10 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
|
||||
// --- GetApiKeyUsageTrend ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrend@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrends"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base)
|
||||
@@ -815,9 +820,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrendh@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrendh"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base)
|
||||
@@ -834,9 +839,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
|
||||
// --- ListWithFilters (additional filter tests) ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterskey@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterskey"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterskey@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterskey"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
@@ -848,9 +853,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterstime@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterstime"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterstime@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterstime"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -867,9 +872,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterscombined@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterscombined"})
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterscombined@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterscombined"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
|
||||
@@ -2,252 +2,412 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"database/sql"
|
||||
"sort"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type userRepository struct {
|
||||
db *gorm.DB
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
begin sqlBeginner
|
||||
}
|
||||
|
||||
func NewUserRepository(db *gorm.DB) service.UserRepository {
|
||||
return &userRepository{db: db}
|
||||
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
|
||||
return newUserRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(ctx context.Context, user *service.User) error {
|
||||
m := userModelFromService(user)
|
||||
err := r.db.WithContext(ctx).Create(m).Error
|
||||
if err == nil {
|
||||
applyUserModelToService(user, m)
|
||||
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
|
||||
var beginner sqlBeginner
|
||||
if b, ok := sqlq.(sqlBeginner); ok {
|
||||
beginner = b
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
return &userRepository{client: client, sql: sqlq, begin: beginner}
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
|
||||
if userIn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
exec := r.sql
|
||||
txClient := r.client
|
||||
var sqlTx *sql.Tx
|
||||
var txClientClose func() error
|
||||
|
||||
if r.begin != nil {
|
||||
var err error
|
||||
sqlTx, err = r.begin.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
exec = sqlTx
|
||||
txClient = entClientFromSQLTx(sqlTx)
|
||||
txClientClose = txClient.Close
|
||||
defer func() { _ = sqlTx.Rollback() }()
|
||||
}
|
||||
if txClientClose != nil {
|
||||
defer func() { _ = txClientClose() }()
|
||||
}
|
||||
|
||||
created, err := txClient.User.Create().
|
||||
SetEmail(userIn.Email).
|
||||
SetUsername(userIn.Username).
|
||||
SetWechat(userIn.Wechat).
|
||||
SetNotes(userIn.Notes).
|
||||
SetPasswordHash(userIn.PasswordHash).
|
||||
SetRole(userIn.Role).
|
||||
SetBalance(userIn.Balance).
|
||||
SetConcurrency(userIn.Concurrency).
|
||||
SetStatus(userIn.Status).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroups(ctx, txClient, exec, created.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sqlTx != nil {
|
||||
if err := sqlTx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
applyUserEntityToService(userIn, created)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
|
||||
var m userModel
|
||||
err := r.db.WithContext(ctx).First(&m, id).Error
|
||||
m, err := r.client.User.Query().Where(dbuser.IDEQ(id)).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return userModelToService(&m), nil
|
||||
|
||||
out := userEntityToService(m)
|
||||
groups, err := r.loadAllowedGroups(ctx, []int64{id})
|
||||
if err == nil {
|
||||
if v, ok := groups[id]; ok {
|
||||
out.AllowedGroups = v
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
|
||||
var m userModel
|
||||
err := r.db.WithContext(ctx).Where("email = ?", email).First(&m).Error
|
||||
m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return userModelToService(&m), nil
|
||||
|
||||
out := userEntityToService(m)
|
||||
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
|
||||
if err == nil {
|
||||
if v, ok := groups[m.ID]; ok {
|
||||
out.AllowedGroups = v
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) Update(ctx context.Context, user *service.User) error {
|
||||
m := userModelFromService(user)
|
||||
err := r.db.WithContext(ctx).Save(m).Error
|
||||
if err == nil {
|
||||
applyUserModelToService(user, m)
|
||||
func (r *userRepository) Update(ctx context.Context, userIn *service.User) error {
|
||||
if userIn == nil {
|
||||
return nil
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
|
||||
exec := r.sql
|
||||
txClient := r.client
|
||||
var sqlTx *sql.Tx
|
||||
var txClientClose func() error
|
||||
|
||||
if r.begin != nil {
|
||||
var err error
|
||||
sqlTx, err = r.begin.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
exec = sqlTx
|
||||
txClient = entClientFromSQLTx(sqlTx)
|
||||
txClientClose = txClient.Close
|
||||
defer func() { _ = sqlTx.Rollback() }()
|
||||
}
|
||||
if txClientClose != nil {
|
||||
defer func() { _ = txClientClose() }()
|
||||
}
|
||||
|
||||
updated, err := txClient.User.UpdateOneID(userIn.ID).
|
||||
SetEmail(userIn.Email).
|
||||
SetUsername(userIn.Username).
|
||||
SetWechat(userIn.Wechat).
|
||||
SetNotes(userIn.Notes).
|
||||
SetPasswordHash(userIn.PasswordHash).
|
||||
SetRole(userIn.Role).
|
||||
SetBalance(userIn.Balance).
|
||||
SetConcurrency(userIn.Concurrency).
|
||||
SetStatus(userIn.Status).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroups(ctx, txClient, exec, updated.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sqlTx != nil {
|
||||
if err := sqlTx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
userIn.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&userModel{}, id).Error
|
||||
_, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists users with optional filtering by status, role, and search query
|
||||
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) {
|
||||
var users []userModel
|
||||
var total int64
|
||||
q := r.client.User.Query()
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&userModel{})
|
||||
|
||||
// Apply filters
|
||||
if status != "" {
|
||||
db = db.Where("status = ?", status)
|
||||
q = q.Where(dbuser.StatusEQ(status))
|
||||
}
|
||||
if role != "" {
|
||||
db = db.Where("role = ?", role)
|
||||
q = q.Where(dbuser.RoleEQ(role))
|
||||
}
|
||||
if search != "" {
|
||||
searchPattern := "%" + search + "%"
|
||||
db = db.Where(
|
||||
"email ILIKE ? OR username ILIKE ? OR wechat ILIKE ?",
|
||||
searchPattern, searchPattern, searchPattern,
|
||||
q = q.Where(
|
||||
dbuser.Or(
|
||||
dbuser.EmailContainsFold(search),
|
||||
dbuser.UsernameContainsFold(search),
|
||||
dbuser.WechatContainsFold(search),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Query users with pagination (reuse the same db with filters applied)
|
||||
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&users).Error; err != nil {
|
||||
users, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(dbuser.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Batch load subscriptions for all users (avoid N+1)
|
||||
if len(users) > 0 {
|
||||
userIDs := make([]int64, len(users))
|
||||
userMap := make(map[int64]*service.User, len(users))
|
||||
outUsers := make([]service.User, 0, len(users))
|
||||
for i := range users {
|
||||
userIDs[i] = users[i].ID
|
||||
u := userModelToService(&users[i])
|
||||
outUsers = append(outUsers, *u)
|
||||
userMap[u.ID] = &outUsers[len(outUsers)-1]
|
||||
}
|
||||
|
||||
// Query active subscriptions with groups in one query
|
||||
var subscriptions []userSubscriptionModel
|
||||
if err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id IN ? AND status = ?", userIDs, service.SubscriptionStatusActive).
|
||||
Find(&subscriptions).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Associate subscriptions with users
|
||||
for i := range subscriptions {
|
||||
if user, ok := userMap[subscriptions[i].UserID]; ok {
|
||||
user.Subscriptions = append(user.Subscriptions, *userSubscriptionModelToService(&subscriptions[i]))
|
||||
}
|
||||
}
|
||||
|
||||
return outUsers, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
outUsers := make([]service.User, 0, len(users))
|
||||
for i := range users {
|
||||
outUsers = append(outUsers, *userModelToService(&users[i]))
|
||||
if len(users) == 0 {
|
||||
return outUsers, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
return outUsers, paginationResultFromTotal(total, params), nil
|
||||
userIDs := make([]int64, 0, len(users))
|
||||
userMap := make(map[int64]*service.User, len(users))
|
||||
for i := range users {
|
||||
userIDs = append(userIDs, users[i].ID)
|
||||
u := userEntityToService(users[i])
|
||||
outUsers = append(outUsers, *u)
|
||||
userMap[u.ID] = &outUsers[len(outUsers)-1]
|
||||
}
|
||||
|
||||
// Batch load active subscriptions with groups to avoid N+1.
|
||||
subs, err := r.client.UserSubscription.Query().
|
||||
Where(
|
||||
usersubscription.UserIDIn(userIDs...),
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
).
|
||||
WithGroup().
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for i := range subs {
|
||||
if u, ok := userMap[subs[i].UserID]; ok {
|
||||
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
|
||||
}
|
||||
}
|
||||
|
||||
allowedGroupsByUser, err := r.loadAllowedGroups(ctx, userIDs)
|
||||
if err == nil {
|
||||
for id, u := range userMap {
|
||||
if groups, ok := allowedGroupsByUser[id]; ok {
|
||||
u.AllowedGroups = groups
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return outUsers, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
|
||||
Update("balance", gorm.Expr("balance + ?", amount)).Error
|
||||
_, err := r.client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeductBalance 扣减用户余额,仅当余额充足时执行
|
||||
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
result := r.db.WithContext(ctx).Model(&userModel{}).
|
||||
Where("id = ? AND balance >= ?", id, amount).
|
||||
Update("balance", gorm.Expr("balance - ?", amount))
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
n, err := r.client.User.Update().
|
||||
Where(dbuser.IDEQ(id), dbuser.BalanceGTE(amount)).
|
||||
AddBalance(-amount).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
if n == 0 {
|
||||
return service.ErrInsufficientBalance
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
|
||||
return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
|
||||
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
|
||||
_, err := r.client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&userModel{}).Where("email = ?", email).Count(&count).Error
|
||||
return count > 0, err
|
||||
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
|
||||
}
|
||||
|
||||
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
|
||||
// 使用 PostgreSQL 的 array_remove 函数
|
||||
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Model(&userModel{}).
|
||||
Where("? = ANY(allowed_groups)", groupID).
|
||||
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
|
||||
return result.RowsAffected, result.Error
|
||||
if r.sql == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
joinAffected, err := r.client.UserAllowedGroup.Delete().
|
||||
Where(userallowedgroup.GroupIDEQ(groupID)).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
arrayRes, err := r.sql.ExecContext(
|
||||
ctx,
|
||||
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1), updated_at = NOW() WHERE $1 = ANY(allowed_groups)",
|
||||
groupID,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
arrayAffected, _ := arrayRes.RowsAffected()
|
||||
|
||||
if int64(joinAffected) > arrayAffected {
|
||||
return int64(joinAffected), nil
|
||||
}
|
||||
return arrayAffected, nil
|
||||
}
|
||||
|
||||
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
|
||||
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
|
||||
var m userModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("role = ? AND status = ?", service.RoleAdmin, service.StatusActive).
|
||||
Order("id ASC").
|
||||
First(&m).Error
|
||||
m, err := r.client.User.Query().
|
||||
Where(
|
||||
dbuser.RoleEQ(service.RoleAdmin),
|
||||
dbuser.StatusEQ(service.StatusActive),
|
||||
).
|
||||
Order(dbent.Asc(dbuser.FieldID)).
|
||||
First(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return userModelToService(&m), nil
|
||||
|
||||
out := userEntityToService(m)
|
||||
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
|
||||
if err == nil {
|
||||
if v, ok := groups[m.ID]; ok {
|
||||
out.AllowedGroups = v
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
type userModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
Email string `gorm:"uniqueIndex;size:255;not null"`
|
||||
Username string `gorm:"size:100;default:''"`
|
||||
Wechat string `gorm:"size:100;default:''"`
|
||||
Notes string `gorm:"type:text;default:''"`
|
||||
PasswordHash string `gorm:"size:255;not null"`
|
||||
Role string `gorm:"size:20;default:user;not null"`
|
||||
Balance float64 `gorm:"type:decimal(20,8);default:0;not null"`
|
||||
Concurrency int `gorm:"default:5;not null"`
|
||||
Status string `gorm:"size:20;default:active;not null"`
|
||||
AllowedGroups pq.Int64Array `gorm:"type:bigint[]"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
UpdatedAt time.Time `gorm:"not null"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64) (map[int64][]int64, error) {
|
||||
out := make(map[int64][]int64, len(userIDs))
|
||||
if len(userIDs) == 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
rows, err := r.client.UserAllowedGroup.Query().
|
||||
Where(userallowedgroup.UserIDIn(userIDs...)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range rows {
|
||||
out[rows[i].UserID] = append(out[rows[i].UserID], rows[i].GroupID)
|
||||
}
|
||||
|
||||
for userID := range out {
|
||||
sort.Slice(out[userID], func(i, j int) bool { return out[userID][i] < out[userID][j] })
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (userModel) TableName() string { return "users" }
|
||||
|
||||
func userModelToService(m *userModel) *service.User {
|
||||
if m == nil {
|
||||
func (r *userRepository) syncUserAllowedGroups(ctx context.Context, client *dbent.Client, exec sqlExecutor, userID int64, groupIDs []int64) error {
|
||||
if client == nil || exec == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.User{
|
||||
ID: m.ID,
|
||||
Email: m.Email,
|
||||
Username: m.Username,
|
||||
Wechat: m.Wechat,
|
||||
Notes: m.Notes,
|
||||
PasswordHash: m.PasswordHash,
|
||||
Role: m.Role,
|
||||
Balance: m.Balance,
|
||||
Concurrency: m.Concurrency,
|
||||
Status: m.Status,
|
||||
AllowedGroups: []int64(m.AllowedGroups),
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
|
||||
// Keep join table as the source of truth for reads.
|
||||
if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unique := make(map[int64]struct{}, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
unique[id] = struct{}{}
|
||||
}
|
||||
|
||||
legacyGroups := make([]int64, 0, len(unique))
|
||||
if len(unique) > 0 {
|
||||
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
|
||||
for groupID := range unique {
|
||||
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
|
||||
legacyGroups = append(legacyGroups, groupID)
|
||||
}
|
||||
if err := client.UserAllowedGroup.
|
||||
CreateBulk(creates...).
|
||||
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 1 compatibility: keep legacy users.allowed_groups array updated for existing raw SQL paths.
|
||||
var legacy any
|
||||
if len(legacyGroups) > 0 {
|
||||
sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] })
|
||||
legacy = pq.Array(legacyGroups)
|
||||
}
|
||||
if _, err := exec.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func userModelFromService(u *service.User) *userModel {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
return &userModel{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Wechat: u.Wechat,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
AllowedGroups: pq.Int64Array(u.AllowedGroups),
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func applyUserModelToService(dst *service.User, src *userModel) {
|
||||
func applyUserEntityToService(dst *service.User, src *dbent.User) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,46 +4,103 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *userRepository
|
||||
ctx context.Context
|
||||
tx *sql.Tx
|
||||
client *dbent.Client
|
||||
repo *userRepository
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewUserRepository(s.db).(*userRepository)
|
||||
entClient, tx := testEntSQLTx(s.T())
|
||||
s.tx = tx
|
||||
s.client = entClient
|
||||
s.repo = newUserRepositoryWithSQL(entClient, tx)
|
||||
}
|
||||
|
||||
func TestUserRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(UserRepoSuite))
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) mustCreateUser(u *service.User) *service.User {
|
||||
s.T().Helper()
|
||||
|
||||
if u.Email == "" {
|
||||
u.Email = "user-" + time.Now().Format(time.RFC3339Nano) + "@example.com"
|
||||
}
|
||||
if u.PasswordHash == "" {
|
||||
u.PasswordHash = "test-password-hash"
|
||||
}
|
||||
if u.Role == "" {
|
||||
u.Role = service.RoleUser
|
||||
}
|
||||
if u.Status == "" {
|
||||
u.Status = service.StatusActive
|
||||
}
|
||||
if u.Concurrency == 0 {
|
||||
u.Concurrency = 5
|
||||
}
|
||||
|
||||
s.Require().NoError(s.repo.Create(s.ctx, u), "create user")
|
||||
return u
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
s.T().Helper()
|
||||
|
||||
g, err := s.client.Group.Create().
|
||||
SetName(name).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create group")
|
||||
return groupEntityToService(g)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) mustCreateSubscription(userID, groupID int64, mutate func(*dbent.UserSubscriptionCreate)) *dbent.UserSubscription {
|
||||
s.T().Helper()
|
||||
|
||||
now := time.Now()
|
||||
create := s.client.UserSubscription.Create().
|
||||
SetUserID(userID).
|
||||
SetGroupID(groupID).
|
||||
SetStartsAt(now.Add(-1*time.Hour)).
|
||||
SetExpiresAt(now.Add(24*time.Hour)).
|
||||
SetStatus(service.SubscriptionStatusActive).
|
||||
SetAssignedAt(now).
|
||||
SetNotes("")
|
||||
|
||||
if mutate != nil {
|
||||
mutate(create)
|
||||
}
|
||||
|
||||
sub, err := create.Save(s.ctx)
|
||||
s.Require().NoError(err, "create subscription")
|
||||
return sub
|
||||
}
|
||||
|
||||
// --- Create / GetByID / GetByEmail / Update / Delete ---
|
||||
|
||||
func (s *UserRepoSuite) TestCreate() {
|
||||
user := &service.User{
|
||||
user := s.mustCreateUser(&service.User{
|
||||
Email: "create@test.com",
|
||||
Username: "testuser",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
})
|
||||
|
||||
err := s.repo.Create(s.ctx, user)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(user.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
@@ -57,7 +114,7 @@ func (s *UserRepoSuite) TestGetByID_NotFound() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetByEmail() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "byemail@test.com"})
|
||||
user := s.mustCreateUser(&service.User{Email: "byemail@test.com"})
|
||||
|
||||
got, err := s.repo.GetByEmail(s.ctx, user.Email)
|
||||
s.Require().NoError(err, "GetByEmail")
|
||||
@@ -70,19 +127,20 @@ func (s *UserRepoSuite) TestGetByEmail_NotFound() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdate() {
|
||||
user := userModelToService(mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com", Username: "original"}))
|
||||
|
||||
user.Username = "updated"
|
||||
err := s.repo.Update(s.ctx, user)
|
||||
s.Require().NoError(err, "Update")
|
||||
user := s.mustCreateUser(&service.User{Email: "update@test.com", Username: "original"})
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
got.Username = "updated"
|
||||
s.Require().NoError(s.repo.Update(s.ctx, got), "Update")
|
||||
|
||||
updated, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated", got.Username)
|
||||
s.Require().Equal("updated", updated.Username)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDelete() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
|
||||
user := s.mustCreateUser(&service.User{Email: "delete@test.com"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
@@ -94,8 +152,8 @@ func (s *UserRepoSuite) TestDelete() {
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *UserRepoSuite) TestList() {
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "list1@test.com"})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "list2@test.com"})
|
||||
s.mustCreateUser(&service.User{Email: "list1@test.com"})
|
||||
s.mustCreateUser(&service.User{Email: "list2@test.com"})
|
||||
|
||||
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
@@ -104,8 +162,8 @@ func (s *UserRepoSuite) TestList() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_Status() {
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com", Status: service.StatusActive})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "disabled@test.com", Status: service.StatusDisabled})
|
||||
s.mustCreateUser(&service.User{Email: "active@test.com", Status: service.StatusActive})
|
||||
s.mustCreateUser(&service.User{Email: "disabled@test.com", Status: service.StatusDisabled})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, "", "")
|
||||
s.Require().NoError(err)
|
||||
@@ -114,8 +172,8 @@ func (s *UserRepoSuite) TestListWithFilters_Status() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_Role() {
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "user@test.com", Role: service.RoleUser})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin})
|
||||
s.mustCreateUser(&service.User{Email: "user@test.com", Role: service.RoleUser})
|
||||
s.mustCreateUser(&service.User{Email: "admin@test.com", Role: service.RoleAdmin})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.RoleAdmin, "")
|
||||
s.Require().NoError(err)
|
||||
@@ -124,8 +182,8 @@ func (s *UserRepoSuite) TestListWithFilters_Role() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_Search() {
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "alice@test.com", Username: "Alice"})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "bob@test.com", Username: "Bob"})
|
||||
s.mustCreateUser(&service.User{Email: "alice@test.com", Username: "Alice"})
|
||||
s.mustCreateUser(&service.User{Email: "bob@test.com", Username: "Bob"})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice")
|
||||
s.Require().NoError(err)
|
||||
@@ -134,8 +192,8 @@ func (s *UserRepoSuite) TestListWithFilters_Search() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com", Username: "JohnDoe"})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com", Username: "JaneSmith"})
|
||||
s.mustCreateUser(&service.User{Email: "u1@test.com", Username: "JohnDoe"})
|
||||
s.mustCreateUser(&service.User{Email: "u2@test.com", Username: "JaneSmith"})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john")
|
||||
s.Require().NoError(err)
|
||||
@@ -144,8 +202,8 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "w1@test.com", Wechat: "wx_hello"})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "w2@test.com", Wechat: "wx_world"})
|
||||
s.mustCreateUser(&service.User{Email: "w1@test.com", Wechat: "wx_hello"})
|
||||
s.mustCreateUser(&service.User{Email: "w2@test.com", Wechat: "wx_world"})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
|
||||
s.Require().NoError(err)
|
||||
@@ -154,20 +212,17 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub@test.com", Status: service.StatusActive})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sub"})
|
||||
user := s.mustCreateUser(&service.User{Email: "sub@test.com", Status: service.StatusActive})
|
||||
groupActive := s.mustCreateGroup("g-sub-active")
|
||||
groupExpired := s.mustCreateGroup("g-sub-expired")
|
||||
|
||||
_ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
_ = s.mustCreateSubscription(user.ID, groupActive.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetStatus(service.SubscriptionStatusActive)
|
||||
c.SetExpiresAt(time.Now().Add(1 * time.Hour))
|
||||
})
|
||||
_ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour),
|
||||
_ = s.mustCreateSubscription(user.ID, groupExpired.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetStatus(service.SubscriptionStatusExpired)
|
||||
c.SetExpiresAt(time.Now().Add(-1 * time.Hour))
|
||||
})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "sub@")
|
||||
@@ -175,11 +230,11 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
|
||||
s.Require().Len(users, 1, "expected 1 user")
|
||||
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
|
||||
s.Require().NotNil(users[0].Subscriptions[0].Group, "expected subscription group preload")
|
||||
s.Require().Equal(group.ID, users[0].Subscriptions[0].Group.ID, "group ID mismatch")
|
||||
s.Require().Equal(groupActive.ID, users[0].Subscriptions[0].Group.ID, "group ID mismatch")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
mustCreateUser(s.T(), s.db, &userModel{
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "a@example.com",
|
||||
Username: "Alice",
|
||||
Wechat: "wx_a",
|
||||
@@ -187,7 +242,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
})
|
||||
target := mustCreateUser(s.T(), s.db, &userModel{
|
||||
target := s.mustCreateUser(&service.User{
|
||||
Email: "b@example.com",
|
||||
Username: "Bob",
|
||||
Wechat: "wx_b",
|
||||
@@ -195,7 +250,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
Status: service.StatusActive,
|
||||
Balance: 1,
|
||||
})
|
||||
mustCreateUser(s.T(), s.db, &userModel{
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "c@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusDisabled,
|
||||
@@ -211,40 +266,40 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
// --- Balance operations ---
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateBalance() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "bal@test.com", Balance: 10})
|
||||
user := s.mustCreateUser(&service.User{Email: "bal@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
|
||||
s.Require().NoError(err, "UpdateBalance")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(12.5, got.Balance)
|
||||
s.Require().InDelta(12.5, got.Balance, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateBalance_Negative() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "balneg@test.com", Balance: 10})
|
||||
user := s.mustCreateUser(&service.User{Email: "balneg@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
|
||||
s.Require().NoError(err, "UpdateBalance with negative")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(7.0, got.Balance)
|
||||
s.Require().InDelta(7.0, got.Balance, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "deduct@test.com", Balance: 10})
|
||||
user := s.mustCreateUser(&service.User{Email: "deduct@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 5)
|
||||
s.Require().NoError(err, "DeductBalance")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(5.0, got.Balance)
|
||||
s.Require().InDelta(5.0, got.Balance, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "insuf@test.com", Balance: 5})
|
||||
user := s.mustCreateUser(&service.User{Email: "insuf@test.com", Balance: 5})
|
||||
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
|
||||
s.Require().Error(err, "expected error for insufficient balance")
|
||||
@@ -252,20 +307,20 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exact@test.com", Balance: 10})
|
||||
user := s.mustCreateUser(&service.User{Email: "exact@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err, "DeductBalance exact amount")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(got.Balance)
|
||||
s.Require().InDelta(0.0, got.Balance, 1e-6)
|
||||
}
|
||||
|
||||
// --- Concurrency ---
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateConcurrency() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "conc@test.com", Concurrency: 5})
|
||||
user := s.mustCreateUser(&service.User{Email: "conc@test.com", Concurrency: 5})
|
||||
|
||||
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
|
||||
s.Require().NoError(err, "UpdateConcurrency")
|
||||
@@ -276,7 +331,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "concneg@test.com", Concurrency: 5})
|
||||
user := s.mustCreateUser(&service.User{Email: "concneg@test.com", Concurrency: 5})
|
||||
|
||||
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
|
||||
s.Require().NoError(err, "UpdateConcurrency negative")
|
||||
@@ -289,7 +344,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
|
||||
// --- ExistsByEmail ---
|
||||
|
||||
func (s *UserRepoSuite) TestExistsByEmail() {
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
|
||||
s.mustCreateUser(&service.User{Email: "exists@test.com"})
|
||||
|
||||
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
|
||||
s.Require().NoError(err, "ExistsByEmail")
|
||||
@@ -303,34 +358,38 @@ func (s *UserRepoSuite) TestExistsByEmail() {
|
||||
// --- RemoveGroupFromAllowedGroups ---
|
||||
|
||||
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
|
||||
groupID := int64(42)
|
||||
userA := mustCreateUser(s.T(), s.db, &userModel{
|
||||
target := s.mustCreateGroup("target-42")
|
||||
other := s.mustCreateGroup("other-7")
|
||||
|
||||
userA := s.mustCreateUser(&service.User{
|
||||
Email: "a1@example.com",
|
||||
AllowedGroups: pq.Int64Array{groupID, 7},
|
||||
AllowedGroups: []int64{target.ID, other.ID},
|
||||
})
|
||||
mustCreateUser(s.T(), s.db, &userModel{
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "a2@example.com",
|
||||
AllowedGroups: pq.Int64Array{7},
|
||||
AllowedGroups: []int64{other.ID},
|
||||
})
|
||||
|
||||
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, groupID)
|
||||
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, target.ID)
|
||||
s.Require().NoError(err, "RemoveGroupFromAllowedGroups")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, userA.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
for _, id := range got.AllowedGroups {
|
||||
s.Require().NotEqual(groupID, id, "expected groupID to be removed from allowed_groups")
|
||||
}
|
||||
s.Require().NotContains(got.AllowedGroups, target.ID)
|
||||
s.Require().Contains(got.AllowedGroups, other.ID)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
|
||||
mustCreateUser(s.T(), s.db, &userModel{
|
||||
groupA := s.mustCreateGroup("nomatch-a")
|
||||
groupB := s.mustCreateGroup("nomatch-b")
|
||||
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "nomatch@test.com",
|
||||
AllowedGroups: pq.Int64Array{1, 2, 3},
|
||||
AllowedGroups: []int64{groupA.ID, groupB.ID},
|
||||
})
|
||||
|
||||
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, 999)
|
||||
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, 999999)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(affected, "expected no affected rows")
|
||||
}
|
||||
@@ -338,12 +397,12 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
|
||||
// --- GetFirstAdmin ---
|
||||
|
||||
func (s *UserRepoSuite) TestGetFirstAdmin() {
|
||||
admin1 := mustCreateUser(s.T(), s.db, &userModel{
|
||||
admin1 := s.mustCreateUser(&service.User{
|
||||
Email: "admin1@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
mustCreateUser(s.T(), s.db, &userModel{
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "admin2@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
@@ -355,7 +414,7 @@ func (s *UserRepoSuite) TestGetFirstAdmin() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
|
||||
mustCreateUser(s.T(), s.db, &userModel{
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "user@example.com",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
@@ -366,12 +425,12 @@ func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
|
||||
mustCreateUser(s.T(), s.db, &userModel{
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "disabled@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusDisabled,
|
||||
})
|
||||
activeAdmin := mustCreateUser(s.T(), s.db, &userModel{
|
||||
activeAdmin := s.mustCreateUser(&service.User{
|
||||
Email: "active@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
@@ -382,10 +441,10 @@ func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
|
||||
s.Require().Equal(activeAdmin.ID, got.ID, "should return only active admin")
|
||||
}
|
||||
|
||||
// --- Combined original test ---
|
||||
// --- Combined ---
|
||||
|
||||
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &userModel{
|
||||
user1 := s.mustCreateUser(&service.User{
|
||||
Email: "a@example.com",
|
||||
Username: "Alice",
|
||||
Wechat: "wx_a",
|
||||
@@ -393,7 +452,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
})
|
||||
user2 := mustCreateUser(s.T(), s.db, &userModel{
|
||||
user2 := s.mustCreateUser(&service.User{
|
||||
Email: "b@example.com",
|
||||
Username: "Bob",
|
||||
Wechat: "wx_b",
|
||||
@@ -401,7 +460,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
|
||||
Status: service.StatusActive,
|
||||
Balance: 1,
|
||||
})
|
||||
_ = mustCreateUser(s.T(), s.db, &userModel{
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "c@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusDisabled,
|
||||
@@ -424,12 +483,12 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
|
||||
s.Require().NoError(s.repo.UpdateBalance(s.ctx, user1.ID, 2.5), "UpdateBalance")
|
||||
got3, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after UpdateBalance")
|
||||
s.Require().Equal(12.5, got3.Balance, "UpdateBalance mismatch")
|
||||
s.Require().InDelta(12.5, got3.Balance, 1e-6)
|
||||
|
||||
s.Require().NoError(s.repo.DeductBalance(s.ctx, user1.ID, 5), "DeductBalance")
|
||||
got4, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after DeductBalance")
|
||||
s.Require().Equal(7.5, got4.Balance, "DeductBalance mismatch")
|
||||
s.Require().InDelta(7.5, got4.Balance, 1e-6)
|
||||
|
||||
err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
|
||||
s.Require().Error(err, "DeductBalance expected error for insufficient balance")
|
||||
@@ -438,7 +497,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
|
||||
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
|
||||
got5, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after UpdateConcurrency")
|
||||
s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch")
|
||||
s.Require().Equal(user1.Concurrency+3, got5.Concurrency)
|
||||
|
||||
params := pagination.PaginationParams{Page: 1, PageSize: 10}
|
||||
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.StatusActive, service.RoleAdmin, "b@")
|
||||
@@ -447,3 +506,4 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
|
||||
s.Require().Len(users, 1, "ListWithFilters len mismatch")
|
||||
s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch")
|
||||
}
|
||||
|
||||
|
||||
@@ -4,333 +4,336 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type userSubscriptionRepository struct {
|
||||
db *gorm.DB
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository {
|
||||
return &userSubscriptionRepository{db: db}
|
||||
func NewUserSubscriptionRepository(client *dbent.Client) service.UserSubscriptionRepository {
|
||||
return &userSubscriptionRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||||
m := userSubscriptionModelFromService(sub)
|
||||
err := r.db.WithContext(ctx).Create(m).Error
|
||||
if sub == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
builder := r.client.UserSubscription.Create().
|
||||
SetUserID(sub.UserID).
|
||||
SetGroupID(sub.GroupID).
|
||||
SetExpiresAt(sub.ExpiresAt).
|
||||
SetNillableDailyWindowStart(sub.DailyWindowStart).
|
||||
SetNillableWeeklyWindowStart(sub.WeeklyWindowStart).
|
||||
SetNillableMonthlyWindowStart(sub.MonthlyWindowStart).
|
||||
SetDailyUsageUsd(sub.DailyUsageUSD).
|
||||
SetWeeklyUsageUsd(sub.WeeklyUsageUSD).
|
||||
SetMonthlyUsageUsd(sub.MonthlyUsageUSD).
|
||||
SetNillableAssignedBy(sub.AssignedBy)
|
||||
|
||||
if sub.StartsAt.IsZero() {
|
||||
builder.SetStartsAt(time.Now())
|
||||
} else {
|
||||
builder.SetStartsAt(sub.StartsAt)
|
||||
}
|
||||
if sub.Status != "" {
|
||||
builder.SetStatus(sub.Status)
|
||||
}
|
||||
if !sub.AssignedAt.IsZero() {
|
||||
builder.SetAssignedAt(sub.AssignedAt)
|
||||
}
|
||||
// Keep compatibility with historical behavior: always store notes as a string value.
|
||||
builder.SetNotes(sub.Notes)
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
applyUserSubscriptionModelToService(sub, m)
|
||||
applyUserSubscriptionEntityToService(sub, created)
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
|
||||
var m userSubscriptionModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("User").
|
||||
Preload("Group").
|
||||
Preload("AssignedByUser").
|
||||
First(&m, id).Error
|
||||
m, err := r.client.UserSubscription.Query().
|
||||
Where(usersubscription.IDEQ(id)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
WithAssignedByUser().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
return userSubscriptionModelToService(&m), nil
|
||||
return userSubscriptionEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
var m userSubscriptionModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id = ? AND group_id = ?", userID, groupID).
|
||||
First(&m).Error
|
||||
m, err := r.client.UserSubscription.Query().
|
||||
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
|
||||
WithGroup().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
return userSubscriptionModelToService(&m), nil
|
||||
return userSubscriptionEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
var m userSubscriptionModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id = ? AND group_id = ? AND status = ? AND expires_at > ?",
|
||||
userID, groupID, service.SubscriptionStatusActive, time.Now()).
|
||||
First(&m).Error
|
||||
m, err := r.client.UserSubscription.Query().
|
||||
Where(
|
||||
usersubscription.UserIDEQ(userID),
|
||||
usersubscription.GroupIDEQ(groupID),
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
usersubscription.ExpiresAtGT(time.Now()),
|
||||
).
|
||||
WithGroup().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
return userSubscriptionModelToService(&m), nil
|
||||
return userSubscriptionEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error {
|
||||
sub.UpdatedAt = time.Now()
|
||||
m := userSubscriptionModelFromService(sub)
|
||||
err := r.db.WithContext(ctx).Save(m).Error
|
||||
if err == nil {
|
||||
applyUserSubscriptionModelToService(sub, m)
|
||||
if sub == nil {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
|
||||
builder := r.client.UserSubscription.UpdateOneID(sub.ID).
|
||||
SetUserID(sub.UserID).
|
||||
SetGroupID(sub.GroupID).
|
||||
SetStartsAt(sub.StartsAt).
|
||||
SetExpiresAt(sub.ExpiresAt).
|
||||
SetStatus(sub.Status).
|
||||
SetNillableDailyWindowStart(sub.DailyWindowStart).
|
||||
SetNillableWeeklyWindowStart(sub.WeeklyWindowStart).
|
||||
SetNillableMonthlyWindowStart(sub.MonthlyWindowStart).
|
||||
SetDailyUsageUsd(sub.DailyUsageUSD).
|
||||
SetWeeklyUsageUsd(sub.WeeklyUsageUSD).
|
||||
SetMonthlyUsageUsd(sub.MonthlyUsageUSD).
|
||||
SetNillableAssignedBy(sub.AssignedBy).
|
||||
SetAssignedAt(sub.AssignedAt).
|
||||
SetNotes(sub.Notes)
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
applyUserSubscriptionEntityToService(sub, updated)
|
||||
return nil
|
||||
}
|
||||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, service.ErrSubscriptionAlreadyExists)
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&userSubscriptionModel{}, id).Error
|
||||
// Match GORM semantics: deleting a missing row is not an error.
|
||||
_, err := r.client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
var subs []userSubscriptionModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC").
|
||||
Find(&subs).Error
|
||||
subs, err := r.client.UserSubscription.Query().
|
||||
Where(usersubscription.UserIDEQ(userID)).
|
||||
WithGroup().
|
||||
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return userSubscriptionModelsToService(subs), nil
|
||||
return userSubscriptionEntitiesToService(subs), nil
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
var subs []userSubscriptionModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id = ? AND status = ? AND expires_at > ?",
|
||||
userID, service.SubscriptionStatusActive, time.Now()).
|
||||
Order("created_at DESC").
|
||||
Find(&subs).Error
|
||||
subs, err := r.client.UserSubscription.Query().
|
||||
Where(
|
||||
usersubscription.UserIDEQ(userID),
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
usersubscription.ExpiresAtGT(time.Now()),
|
||||
).
|
||||
WithGroup().
|
||||
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return userSubscriptionModelsToService(subs), nil
|
||||
return userSubscriptionEntitiesToService(subs), nil
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
var subs []userSubscriptionModel
|
||||
var total int64
|
||||
q := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID))
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).Where("group_id = ?", groupID)
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err := query.
|
||||
Preload("User").
|
||||
Preload("Group").
|
||||
Order("created_at DESC").
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Find(&subs).Error
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
|
||||
subs, err := q.
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
var subs []userSubscriptionModel
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&userSubscriptionModel{})
|
||||
q := r.client.UserSubscription.Query()
|
||||
if userID != nil {
|
||||
query = query.Where("user_id = ?", *userID)
|
||||
q = q.Where(usersubscription.UserIDEQ(*userID))
|
||||
}
|
||||
if groupID != nil {
|
||||
query = query.Where("group_id = ?", *groupID)
|
||||
q = q.Where(usersubscription.GroupIDEQ(*groupID))
|
||||
}
|
||||
if status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
q = q.Where(usersubscription.StatusEQ(status))
|
||||
}
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err := query.
|
||||
Preload("User").
|
||||
Preload("Group").
|
||||
Preload("AssignedByUser").
|
||||
Order("created_at DESC").
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Find(&subs).Error
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
|
||||
subs, err := q.
|
||||
WithUser().
|
||||
WithGroup().
|
||||
WithAssignedByUser().
|
||||
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("user_id = ? AND group_id = ?", userID, groupID).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
return r.client.UserSubscription.Query().
|
||||
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
|
||||
Exist(ctx)
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("id = ?", subscriptionID).
|
||||
Updates(map[string]any{
|
||||
"expires_at": newExpiresAt,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
|
||||
SetExpiresAt(newExpiresAt).
|
||||
Save(ctx)
|
||||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
|
||||
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("id = ?", subscriptionID).
|
||||
Updates(map[string]any{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
|
||||
SetStatus(status).
|
||||
Save(ctx)
|
||||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
|
||||
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("id = ?", subscriptionID).
|
||||
Updates(map[string]any{
|
||||
"notes": notes,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
|
||||
SetNotes(notes).
|
||||
Save(ctx)
|
||||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"daily_window_start": start,
|
||||
"weekly_window_start": start,
|
||||
"monthly_window_start": start,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
_, err := r.client.UserSubscription.UpdateOneID(id).
|
||||
SetDailyWindowStart(start).
|
||||
SetWeeklyWindowStart(start).
|
||||
SetMonthlyWindowStart(start).
|
||||
Save(ctx)
|
||||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"daily_usage_usd": 0,
|
||||
"daily_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
_, err := r.client.UserSubscription.UpdateOneID(id).
|
||||
SetDailyUsageUsd(0).
|
||||
SetDailyWindowStart(newWindowStart).
|
||||
Save(ctx)
|
||||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"weekly_usage_usd": 0,
|
||||
"weekly_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
_, err := r.client.UserSubscription.UpdateOneID(id).
|
||||
SetWeeklyUsageUsd(0).
|
||||
SetWeeklyWindowStart(newWindowStart).
|
||||
Save(ctx)
|
||||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"monthly_usage_usd": 0,
|
||||
"monthly_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
_, err := r.client.UserSubscription.UpdateOneID(id).
|
||||
SetMonthlyUsageUsd(0).
|
||||
SetMonthlyWindowStart(newWindowStart).
|
||||
Save(ctx)
|
||||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
|
||||
"weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD),
|
||||
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
_, err := r.client.UserSubscription.UpdateOneID(id).
|
||||
AddDailyUsageUsd(costUSD).
|
||||
AddWeeklyUsageUsd(costUSD).
|
||||
AddMonthlyUsageUsd(costUSD).
|
||||
Save(ctx)
|
||||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
|
||||
Updates(map[string]any{
|
||||
"status": service.SubscriptionStatusExpired,
|
||||
"updated_at": time.Now(),
|
||||
})
|
||||
return result.RowsAffected, result.Error
|
||||
n, err := r.client.UserSubscription.Update().
|
||||
Where(
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
usersubscription.ExpiresAtLTE(time.Now()),
|
||||
).
|
||||
SetStatus(service.SubscriptionStatusExpired).
|
||||
Save(ctx)
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// Extra repository helpers (currently used only by integration tests).
|
||||
|
||||
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) {
|
||||
var subs []userSubscriptionModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
|
||||
Find(&subs).Error
|
||||
subs, err := r.client.UserSubscription.Query().
|
||||
Where(
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
usersubscription.ExpiresAtLTE(time.Now()),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return userSubscriptionModelsToService(subs), nil
|
||||
return userSubscriptionEntitiesToService(subs), nil
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("group_id = ?", groupID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
count, err := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("group_id = ? AND status = ? AND expires_at > ?",
|
||||
groupID, service.SubscriptionStatusActive, time.Now()).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
count, err := r.client.UserSubscription.Query().
|
||||
Where(
|
||||
usersubscription.GroupIDEQ(groupID),
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
usersubscription.ExpiresAtGT(time.Now()),
|
||||
).
|
||||
Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&userSubscriptionModel{})
|
||||
return result.RowsAffected, result.Error
|
||||
n, err := r.client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx)
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
type userSubscriptionModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
UserID int64 `gorm:"index;not null"`
|
||||
GroupID int64 `gorm:"index;not null"`
|
||||
|
||||
StartsAt time.Time `gorm:"not null"`
|
||||
ExpiresAt time.Time `gorm:"not null"`
|
||||
Status string `gorm:"size:20;default:active;not null"`
|
||||
|
||||
DailyWindowStart *time.Time
|
||||
WeeklyWindowStart *time.Time
|
||||
MonthlyWindowStart *time.Time
|
||||
|
||||
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
|
||||
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
|
||||
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
|
||||
|
||||
AssignedBy *int64 `gorm:"index"`
|
||||
AssignedAt time.Time `gorm:"not null"`
|
||||
Notes string `gorm:"type:text"`
|
||||
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
UpdatedAt time.Time `gorm:"not null"`
|
||||
|
||||
User *userModel `gorm:"foreignKey:UserID"`
|
||||
Group *groupModel `gorm:"foreignKey:GroupID"`
|
||||
AssignedByUser *userModel `gorm:"foreignKey:AssignedBy"`
|
||||
}
|
||||
|
||||
func (userSubscriptionModel) TableName() string { return "user_subscriptions" }
|
||||
|
||||
func userSubscriptionModelToService(m *userSubscriptionModel) *service.UserSubscription {
|
||||
func userSubscriptionEntityToService(m *dbent.UserSubscription) *service.UserSubscription {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.UserSubscription{
|
||||
out := &service.UserSubscription{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
GroupID: m.GroupID,
|
||||
@@ -340,60 +343,42 @@ func userSubscriptionModelToService(m *userSubscriptionModel) *service.UserSubsc
|
||||
DailyWindowStart: m.DailyWindowStart,
|
||||
WeeklyWindowStart: m.WeeklyWindowStart,
|
||||
MonthlyWindowStart: m.MonthlyWindowStart,
|
||||
DailyUsageUSD: m.DailyUsageUSD,
|
||||
WeeklyUsageUSD: m.WeeklyUsageUSD,
|
||||
MonthlyUsageUSD: m.MonthlyUsageUSD,
|
||||
DailyUsageUSD: m.DailyUsageUsd,
|
||||
WeeklyUsageUSD: m.WeeklyUsageUsd,
|
||||
MonthlyUsageUSD: m.MonthlyUsageUsd,
|
||||
AssignedBy: m.AssignedBy,
|
||||
AssignedAt: m.AssignedAt,
|
||||
Notes: m.Notes,
|
||||
Notes: derefString(m.Notes),
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
User: userModelToService(m.User),
|
||||
Group: groupModelToService(m.Group),
|
||||
AssignedByUser: userModelToService(m.AssignedByUser),
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
if m.Edges.Group != nil {
|
||||
out.Group = groupEntityToService(m.Edges.Group)
|
||||
}
|
||||
if m.Edges.AssignedByUser != nil {
|
||||
out.AssignedByUser = userEntityToService(m.Edges.AssignedByUser)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func userSubscriptionModelsToService(models []userSubscriptionModel) []service.UserSubscription {
|
||||
func userSubscriptionEntitiesToService(models []*dbent.UserSubscription) []service.UserSubscription {
|
||||
out := make([]service.UserSubscription, 0, len(models))
|
||||
for i := range models {
|
||||
if s := userSubscriptionModelToService(&models[i]); s != nil {
|
||||
if s := userSubscriptionEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func userSubscriptionModelFromService(s *service.UserSubscription) *userSubscriptionModel {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &userSubscriptionModel{
|
||||
ID: s.ID,
|
||||
UserID: s.UserID,
|
||||
GroupID: s.GroupID,
|
||||
StartsAt: s.StartsAt,
|
||||
ExpiresAt: s.ExpiresAt,
|
||||
Status: s.Status,
|
||||
DailyWindowStart: s.DailyWindowStart,
|
||||
WeeklyWindowStart: s.WeeklyWindowStart,
|
||||
MonthlyWindowStart: s.MonthlyWindowStart,
|
||||
DailyUsageUSD: s.DailyUsageUSD,
|
||||
WeeklyUsageUSD: s.WeeklyUsageUSD,
|
||||
MonthlyUsageUSD: s.MonthlyUsageUSD,
|
||||
AssignedBy: s.AssignedBy,
|
||||
AssignedAt: s.AssignedAt,
|
||||
Notes: s.Notes,
|
||||
CreatedAt: s.CreatedAt,
|
||||
UpdatedAt: s.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func applyUserSubscriptionModelToService(sub *service.UserSubscription, m *userSubscriptionModel) {
|
||||
if sub == nil || m == nil {
|
||||
func applyUserSubscriptionEntityToService(dst *service.UserSubscription, src *dbent.UserSubscription) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
sub.ID = m.ID
|
||||
sub.CreatedAt = m.CreatedAt
|
||||
sub.UpdatedAt = m.UpdatedAt
|
||||
dst.ID = src.ID
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
|
||||
@@ -7,34 +7,85 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserSubscriptionRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *userSubscriptionRepository
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *userSubscriptionRepository
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewUserSubscriptionRepository(s.db).(*userSubscriptionRepository)
|
||||
client, _ := testEntSQLTx(s.T())
|
||||
s.client = client
|
||||
s.repo = NewUserSubscriptionRepository(s.client).(*userSubscriptionRepository)
|
||||
}
|
||||
|
||||
func TestUserSubscriptionRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(UserSubscriptionRepoSuite))
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) mustCreateUser(email string, role string) *service.User {
|
||||
s.T().Helper()
|
||||
|
||||
if role == "" {
|
||||
role = service.RoleUser
|
||||
}
|
||||
|
||||
u, err := s.client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash("test-password-hash").
|
||||
SetStatus(service.StatusActive).
|
||||
SetRole(role).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create user")
|
||||
return userEntityToService(u)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
s.T().Helper()
|
||||
|
||||
g, err := s.client.Group.Create().
|
||||
SetName(name).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create group")
|
||||
return groupEntityToService(g)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) mustCreateSubscription(userID, groupID int64, mutate func(*dbent.UserSubscriptionCreate)) *dbent.UserSubscription {
|
||||
s.T().Helper()
|
||||
|
||||
now := time.Now()
|
||||
create := s.client.UserSubscription.Create().
|
||||
SetUserID(userID).
|
||||
SetGroupID(groupID).
|
||||
SetStartsAt(now.Add(-1*time.Hour)).
|
||||
SetExpiresAt(now.Add(24*time.Hour)).
|
||||
SetStatus(service.SubscriptionStatusActive).
|
||||
SetAssignedAt(now).
|
||||
SetNotes("")
|
||||
|
||||
if mutate != nil {
|
||||
mutate(create)
|
||||
}
|
||||
|
||||
sub, err := create.Save(s.ctx)
|
||||
s.Require().NoError(err, "create user subscription")
|
||||
return sub
|
||||
}
|
||||
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestCreate() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub-create@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-create"})
|
||||
user := s.mustCreateUser("sub-create@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-create")
|
||||
|
||||
sub := &service.UserSubscription{
|
||||
UserID: user.ID,
|
||||
@@ -54,16 +105,12 @@ func (s *UserSubscriptionRepoSuite) TestCreate() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "preload@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"})
|
||||
admin := mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin})
|
||||
user := s.mustCreateUser("preload@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-preload")
|
||||
admin := s.mustCreateUser("admin@test.com", service.RoleAdmin)
|
||||
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
AssignedBy: &admin.ID,
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetAssignedBy(admin.ID)
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
@@ -82,18 +129,15 @@ func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestUpdate() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-update"})
|
||||
sub := userSubscriptionModelToService(mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}))
|
||||
user := s.mustCreateUser("update@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-update")
|
||||
created := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
sub, err := s.repo.GetByID(s.ctx, created.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
|
||||
sub.Notes = "updated notes"
|
||||
err := s.repo.Update(s.ctx, sub)
|
||||
s.Require().NoError(err, "Update")
|
||||
s.Require().NoError(s.repo.Update(s.ctx, sub), "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
@@ -101,14 +145,9 @@ func (s *UserSubscriptionRepoSuite) TestUpdate() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestDelete() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delete"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
user := s.mustCreateUser("delete@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-delete")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
err := s.repo.Delete(s.ctx, sub.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
@@ -117,17 +156,16 @@ func (s *UserSubscriptionRepoSuite) TestDelete() {
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestDelete_Idempotent() {
|
||||
s.Require().NoError(s.repo.Delete(s.ctx, 42424242), "Delete should be idempotent")
|
||||
}
|
||||
|
||||
// --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "byuser@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-byuser"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
user := s.mustCreateUser("byuser@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-byuser")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
got, err := s.repo.GetByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||
s.Require().NoError(err, "GetByUserIDAndGroupID")
|
||||
@@ -141,15 +179,11 @@ func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-active"})
|
||||
user := s.mustCreateUser("active@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-active")
|
||||
|
||||
// Create active subscription (future expiry)
|
||||
active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(2 * time.Hour),
|
||||
active := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(2 * time.Hour))
|
||||
})
|
||||
|
||||
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||
@@ -158,15 +192,11 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "expired@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-expired"})
|
||||
user := s.mustCreateUser("expired@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-expired")
|
||||
|
||||
// Create expired subscription (past expiry but active status)
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(-2 * time.Hour),
|
||||
s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(-2 * time.Hour))
|
||||
})
|
||||
|
||||
_, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||
@@ -176,21 +206,14 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnor
|
||||
// --- ListByUserID / ListActiveByUserID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestListByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"})
|
||||
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list2"})
|
||||
user := s.mustCreateUser("listby@test.com", service.RoleUser)
|
||||
g1 := s.mustCreateGroup("g-list1")
|
||||
g2 := s.mustCreateGroup("g-list2")
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: g1.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: g2.ID,
|
||||
Status: service.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
s.mustCreateSubscription(user.ID, g1.ID, nil)
|
||||
s.mustCreateSubscription(user.ID, g2.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetStatus(service.SubscriptionStatusExpired)
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
subs, err := s.repo.ListByUserID(s.ctx, user.ID)
|
||||
@@ -202,21 +225,16 @@ func (s *UserSubscriptionRepoSuite) TestListByUserID() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listactive@test.com"})
|
||||
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act2"})
|
||||
user := s.mustCreateUser("listactive@test.com", service.RoleUser)
|
||||
g1 := s.mustCreateGroup("g-act1")
|
||||
g2 := s.mustCreateGroup("g-act2")
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: g1.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
s.mustCreateSubscription(user.ID, g1.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: g2.ID,
|
||||
Status: service.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
s.mustCreateSubscription(user.ID, g2.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetStatus(service.SubscriptionStatusExpired)
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID)
|
||||
@@ -228,22 +246,12 @@ func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
|
||||
// --- ListByGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listgrp"})
|
||||
user1 := s.mustCreateUser("u1@test.com", service.RoleUser)
|
||||
user2 := s.mustCreateUser("u2@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-listgrp")
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user1.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user2.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
||||
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
||||
|
||||
subs, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByGroupID")
|
||||
@@ -258,15 +266,9 @@ func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
|
||||
// --- List with filters ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "list@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
user := s.mustCreateUser("list@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-list")
|
||||
s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
|
||||
s.Require().NoError(err, "List")
|
||||
@@ -275,22 +277,12 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter2@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-filter"})
|
||||
user1 := s.mustCreateUser("filter1@test.com", service.RoleUser)
|
||||
user2 := s.mustCreateUser("filter2@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-filter")
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user1.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user2.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
||||
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
|
||||
s.Require().NoError(err)
|
||||
@@ -299,22 +291,12 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "grpfilter@test.com"})
|
||||
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f2"})
|
||||
user := s.mustCreateUser("grpfilter@test.com", service.RoleUser)
|
||||
g1 := s.mustCreateGroup("g-f1")
|
||||
g2 := s.mustCreateGroup("g-f2")
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: g1.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: g2.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
s.mustCreateSubscription(user.ID, g1.ID, nil)
|
||||
s.mustCreateSubscription(user.ID, g2.ID, nil)
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
|
||||
s.Require().NoError(err)
|
||||
@@ -323,20 +305,18 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "statfilter@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-stat"})
|
||||
user1 := s.mustCreateUser("statfilter1@test.com", service.RoleUser)
|
||||
user2 := s.mustCreateUser("statfilter2@test.com", service.RoleUser)
|
||||
group1 := s.mustCreateGroup("g-stat-1")
|
||||
group2 := s.mustCreateGroup("g-stat-2")
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
s.mustCreateSubscription(user1.ID, group1.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetStatus(service.SubscriptionStatusActive)
|
||||
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
s.mustCreateSubscription(user2.ID, group2.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetStatus(service.SubscriptionStatusExpired)
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
|
||||
@@ -348,52 +328,37 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
|
||||
// --- Usage tracking ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "usage@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-usage"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
user := s.mustCreateUser("usage@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-usage")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
err := s.repo.IncrementUsage(s.ctx, sub.ID, 1.25)
|
||||
s.Require().NoError(err, "IncrementUsage")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(1.25, got.DailyUsageUSD)
|
||||
s.Require().Equal(1.25, got.WeeklyUsageUSD)
|
||||
s.Require().Equal(1.25, got.MonthlyUsageUSD)
|
||||
s.Require().InDelta(1.25, got.DailyUsageUSD, 1e-6)
|
||||
s.Require().InDelta(1.25, got.WeeklyUsageUSD, 1e-6)
|
||||
s.Require().InDelta(1.25, got.MonthlyUsageUSD, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "accum@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-accum"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
user := s.mustCreateUser("accum@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-accum")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 1.0))
|
||||
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 2.5))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(3.5, got.DailyUsageUSD)
|
||||
s.Require().InDelta(3.5, got.DailyUsageUSD, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "activate@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-activate"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
user := s.mustCreateUser("activate@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-activate")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
activateAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
err := s.repo.ActivateWindows(s.ctx, sub.ID, activateAt)
|
||||
@@ -404,19 +369,15 @@ func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
|
||||
s.Require().NotNil(got.DailyWindowStart)
|
||||
s.Require().NotNil(got.WeeklyWindowStart)
|
||||
s.Require().NotNil(got.MonthlyWindowStart)
|
||||
s.Require().True(got.DailyWindowStart.Equal(activateAt))
|
||||
s.Require().WithinDuration(activateAt, *got.DailyWindowStart, time.Microsecond)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetd@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetd"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
DailyUsageUSD: 10.0,
|
||||
WeeklyUsageUSD: 20.0,
|
||||
user := s.mustCreateUser("resetd@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-resetd")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetDailyUsageUsd(10.0)
|
||||
c.SetWeeklyUsageUsd(20.0)
|
||||
})
|
||||
|
||||
resetAt := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)
|
||||
@@ -425,21 +386,18 @@ func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(got.DailyUsageUSD)
|
||||
s.Require().Equal(20.0, got.WeeklyUsageUSD, "weekly should remain unchanged")
|
||||
s.Require().True(got.DailyWindowStart.Equal(resetAt))
|
||||
s.Require().InDelta(0.0, got.DailyUsageUSD, 1e-6)
|
||||
s.Require().InDelta(20.0, got.WeeklyUsageUSD, 1e-6)
|
||||
s.Require().NotNil(got.DailyWindowStart)
|
||||
s.Require().WithinDuration(resetAt, *got.DailyWindowStart, time.Microsecond)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetw@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetw"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
WeeklyUsageUSD: 15.0,
|
||||
MonthlyUsageUSD: 30.0,
|
||||
user := s.mustCreateUser("resetw@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-resetw")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetWeeklyUsageUsd(15.0)
|
||||
c.SetMonthlyUsageUsd(30.0)
|
||||
})
|
||||
|
||||
resetAt := time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)
|
||||
@@ -448,20 +406,17 @@ func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(got.WeeklyUsageUSD)
|
||||
s.Require().Equal(30.0, got.MonthlyUsageUSD, "monthly should remain unchanged")
|
||||
s.Require().True(got.WeeklyWindowStart.Equal(resetAt))
|
||||
s.Require().InDelta(0.0, got.WeeklyUsageUSD, 1e-6)
|
||||
s.Require().InDelta(30.0, got.MonthlyUsageUSD, 1e-6)
|
||||
s.Require().NotNil(got.WeeklyWindowStart)
|
||||
s.Require().WithinDuration(resetAt, *got.WeeklyWindowStart, time.Microsecond)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetm@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetm"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
MonthlyUsageUSD: 100.0,
|
||||
user := s.mustCreateUser("resetm@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-resetm")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetMonthlyUsageUsd(25.0)
|
||||
})
|
||||
|
||||
resetAt := time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC)
|
||||
@@ -470,21 +425,17 @@ func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(got.MonthlyUsageUSD)
|
||||
s.Require().True(got.MonthlyWindowStart.Equal(resetAt))
|
||||
s.Require().InDelta(0.0, got.MonthlyUsageUSD, 1e-6)
|
||||
s.Require().NotNil(got.MonthlyWindowStart)
|
||||
s.Require().WithinDuration(resetAt, *got.MonthlyWindowStart, time.Microsecond)
|
||||
}
|
||||
|
||||
// --- UpdateStatus / ExtendExpiry / UpdateNotes ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "status@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-status"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
user := s.mustCreateUser("status@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-status")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
err := s.repo.UpdateStatus(s.ctx, sub.ID, service.SubscriptionStatusExpired)
|
||||
s.Require().NoError(err, "UpdateStatus")
|
||||
@@ -495,14 +446,9 @@ func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "extend@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-extend"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
user := s.mustCreateUser("extend@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-extend")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
newExpiry := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
err := s.repo.ExtendExpiry(s.ctx, sub.ID, newExpiry)
|
||||
@@ -510,18 +456,13 @@ func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(got.ExpiresAt.Equal(newExpiry))
|
||||
s.Require().WithinDuration(newExpiry, got.ExpiresAt, time.Microsecond)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "notes@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-notes"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
user := s.mustCreateUser("notes@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-notes")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
err := s.repo.UpdateNotes(s.ctx, sub.ID, "VIP user")
|
||||
s.Require().NoError(err, "UpdateNotes")
|
||||
@@ -534,20 +475,15 @@ func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
|
||||
// --- ListExpired / BatchUpdateExpiredStatus ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestListExpired() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listexp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listexp"})
|
||||
user := s.mustCreateUser("listexp@test.com", service.RoleUser)
|
||||
groupActive := s.mustCreateGroup("g-listexp-active")
|
||||
groupExpired := s.mustCreateGroup("g-listexp-expired")
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
s.mustCreateSubscription(user.ID, groupActive.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
s.mustCreateSubscription(user.ID, groupExpired.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
expired, err := s.repo.ListExpired(s.ctx)
|
||||
@@ -556,20 +492,15 @@ func (s *UserSubscriptionRepoSuite) TestListExpired() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "batch@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-batch"})
|
||||
user := s.mustCreateUser("batch@test.com", service.RoleUser)
|
||||
groupFuture := s.mustCreateGroup("g-batch-future")
|
||||
groupPast := s.mustCreateGroup("g-batch-past")
|
||||
|
||||
active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
active := s.mustCreateSubscription(user.ID, groupFuture.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
|
||||
})
|
||||
expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
expiredActive := s.mustCreateSubscription(user.ID, groupPast.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
|
||||
@@ -586,15 +517,10 @@ func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
|
||||
// --- ExistsByUserIDAndGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-exists"})
|
||||
user := s.mustCreateUser("exists@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-exists")
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
exists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||
s.Require().NoError(err, "ExistsByUserIDAndGroupID")
|
||||
@@ -608,21 +534,14 @@ func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
|
||||
// --- CountByGroupID / CountActiveByGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt2@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
|
||||
user1 := s.mustCreateUser("cnt1@test.com", service.RoleUser)
|
||||
user2 := s.mustCreateUser("cnt2@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-count")
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user1.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user2.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
||||
s.mustCreateSubscription(user2.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetStatus(service.SubscriptionStatusExpired)
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
@@ -631,21 +550,15 @@ func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact2@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-cntact"})
|
||||
user1 := s.mustCreateUser("cntact1@test.com", service.RoleUser)
|
||||
user2 := s.mustCreateUser("cntact2@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-cntact")
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user1.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
s.mustCreateSubscription(user1.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user2.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time
|
||||
s.mustCreateSubscription(user2.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour)) // expired by time
|
||||
})
|
||||
|
||||
count, err := s.repo.CountActiveByGroupID(s.ctx, group.ID)
|
||||
@@ -656,21 +569,12 @@ func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
|
||||
// --- DeleteByGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delgrp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delgrp"})
|
||||
user1 := s.mustCreateUser("delgrp1@test.com", service.RoleUser)
|
||||
user2 := s.mustCreateUser("delgrp2@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-delgrp")
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
})
|
||||
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
||||
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
||||
|
||||
affected, err := s.repo.DeleteByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "DeleteByGroupID")
|
||||
@@ -680,26 +584,21 @@ func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- Combined original test ---
|
||||
// --- Combined scenario ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() {
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "subr@example.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-subr"})
|
||||
user := s.mustCreateUser("subr@example.com", service.RoleUser)
|
||||
groupActive := s.mustCreateGroup("g-subr-active")
|
||||
groupExpired := s.mustCreateGroup("g-subr-expired")
|
||||
|
||||
active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(2 * time.Hour),
|
||||
active := s.mustCreateSubscription(user.ID, groupActive.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(2 * time.Hour))
|
||||
})
|
||||
expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(-2 * time.Hour),
|
||||
expiredActive := s.mustCreateSubscription(user.ID, groupExpired.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(-2 * time.Hour))
|
||||
})
|
||||
|
||||
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, groupActive.ID)
|
||||
s.Require().NoError(err, "GetActiveByUserIDAndGroupID")
|
||||
s.Require().Equal(active.ID, got.ID, "expected active subscription")
|
||||
|
||||
@@ -709,9 +608,9 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
|
||||
|
||||
after, err := s.repo.GetByID(s.ctx, active.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal(1.25, after.DailyUsageUSD, "DailyUsageUSD mismatch")
|
||||
s.Require().Equal(1.25, after.WeeklyUsageUSD, "WeeklyUsageUSD mismatch")
|
||||
s.Require().Equal(1.25, after.MonthlyUsageUSD, "MonthlyUsageUSD mismatch")
|
||||
s.Require().InDelta(1.25, after.DailyUsageUSD, 1e-6)
|
||||
s.Require().InDelta(1.25, after.WeeklyUsageUSD, 1e-6)
|
||||
s.Require().InDelta(1.25, after.MonthlyUsageUSD, 1e-6)
|
||||
s.Require().NotNil(after.DailyWindowStart, "expected DailyWindowStart activated")
|
||||
s.Require().NotNil(after.WeeklyWindowStart, "expected WeeklyWindowStart activated")
|
||||
s.Require().NotNil(after.MonthlyWindowStart, "expected MonthlyWindowStart activated")
|
||||
@@ -720,14 +619,16 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
|
||||
s.Require().NoError(s.repo.ResetDailyUsage(s.ctx, active.ID, resetAt), "ResetDailyUsage")
|
||||
afterReset, err := s.repo.GetByID(s.ctx, active.ID)
|
||||
s.Require().NoError(err, "GetByID after reset")
|
||||
s.Require().Equal(0.0, afterReset.DailyUsageUSD, "expected daily usage reset to 0")
|
||||
s.Require().NotNil(afterReset.DailyWindowStart, "expected DailyWindowStart not nil")
|
||||
s.Require().True(afterReset.DailyWindowStart.Equal(resetAt), "expected daily window start updated")
|
||||
s.Require().InDelta(0.0, afterReset.DailyUsageUSD, 1e-6)
|
||||
s.Require().NotNil(afterReset.DailyWindowStart)
|
||||
s.Require().WithinDuration(resetAt, *afterReset.DailyWindowStart, time.Microsecond)
|
||||
|
||||
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
|
||||
s.Require().NoError(err, "BatchUpdateExpiredStatus")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
updated, err := s.repo.GetByID(s.ctx, expiredActive.ID)
|
||||
s.Require().NoError(err, "GetByID expired")
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
|
||||
}
|
||||
|
||||
|
||||
214
backend/internal/server/middleware/api_key_auth_google_test.go
Normal file
214
backend/internal/server/middleware/api_key_auth_google_test.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeApiKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
|
||||
}
|
||||
|
||||
func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error { return errors.New("not implemented") }
|
||||
func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
if f.getByKey == nil {
|
||||
return nil, errors.New("unexpected call")
|
||||
}
|
||||
return f.getByKey(ctx, key)
|
||||
}
|
||||
func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error { return errors.New("not implemented") }
|
||||
func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error { return errors.New("not implemented") }
|
||||
func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { return 0, errors.New("not implemented") }
|
||||
func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { return false, errors.New("not implemented") }
|
||||
func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type googleErrorResponse struct {
|
||||
Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService {
|
||||
return service.NewApiKeyService(
|
||||
repo,
|
||||
nil, // userRepo (unused in GetByKey)
|
||||
nil, // groupRepo
|
||||
nil, // userSubRepo
|
||||
nil, // cache
|
||||
&config.Config{},
|
||||
)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return nil, errors.New("should not be called")
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
|
||||
require.Equal(t, "API key is required", resp.Error.Message)
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
|
||||
require.Equal(t, "Invalid API key", resp.Error.Message)
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return nil, errors.New("db down")
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer any")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusInternalServerError, resp.Error.Code)
|
||||
require.Equal(t, "Failed to validate API key", resp.Error.Message)
|
||||
require.Equal(t, "INTERNAL", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return &service.ApiKey{
|
||||
ID: 1,
|
||||
Key: key,
|
||||
Status: service.StatusDisabled,
|
||||
User: &service.User{
|
||||
ID: 123,
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer disabled")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
|
||||
require.Equal(t, "API key is disabled", resp.Error.Message)
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return &service.ApiKey{
|
||||
ID: 1,
|
||||
Key: key,
|
||||
Status: service.StatusActive,
|
||||
User: &service.User{
|
||||
ID: 123,
|
||||
Status: service.StatusActive,
|
||||
Balance: 0,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer ok")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusForbidden, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusForbidden, resp.Error.Code)
|
||||
require.Equal(t, "Insufficient account balance", resp.Error.Message)
|
||||
require.Equal(t, "PERMISSION_DENIED", resp.Error.Status)
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package setup
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -10,13 +11,12 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gopkg.in/yaml.v3"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Config paths
|
||||
@@ -92,20 +92,16 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
||||
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.SSLMode,
|
||||
)
|
||||
|
||||
db, err := gorm.Open(postgres.Open(defaultDSN), &gorm.Config{})
|
||||
db, err := sql.Open("postgres", defaultDSN)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to PostgreSQL: %w", err)
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get db instance: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if sqlDB == nil {
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
}
|
||||
}()
|
||||
@@ -113,22 +109,23 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := sqlDB.PingContext(ctx); err != nil {
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
return fmt.Errorf("ping failed: %w", err)
|
||||
}
|
||||
|
||||
// Check if target database exists
|
||||
var exists bool
|
||||
row := sqlDB.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)", cfg.DBName)
|
||||
row := db.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)", cfg.DBName)
|
||||
if err := row.Scan(&exists); err != nil {
|
||||
return fmt.Errorf("failed to check database existence: %w", err)
|
||||
}
|
||||
|
||||
// Create database if not exists
|
||||
if !exists {
|
||||
// 注意:数据库名不能参数化,依赖前置输入校验保障安全。
|
||||
// Note: Database names cannot be parameterized, but we've already validated cfg.DBName
|
||||
// in the handler using validateDBName() which only allows [a-zA-Z][a-zA-Z0-9_]*
|
||||
_, err := sqlDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", cfg.DBName))
|
||||
_, err := db.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", cfg.DBName))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create database '%s': %w", cfg.DBName, err)
|
||||
}
|
||||
@@ -136,27 +133,23 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
||||
}
|
||||
|
||||
// Now connect to the target database to verify
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
}
|
||||
sqlDB = nil
|
||||
db = nil
|
||||
|
||||
targetDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName, cfg.SSLMode,
|
||||
)
|
||||
|
||||
targetDB, err := gorm.Open(postgres.Open(targetDSN), &gorm.Config{})
|
||||
targetDB, err := sql.Open("postgres", targetDSN)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database '%s': %w", cfg.DBName, err)
|
||||
}
|
||||
|
||||
targetSqlDB, err := targetDB.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get target db instance: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := targetSqlDB.Close(); err != nil {
|
||||
if err := targetDB.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
}
|
||||
}()
|
||||
@@ -164,7 +157,7 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
if err := targetSqlDB.PingContext(ctx2); err != nil {
|
||||
if err := targetDB.PingContext(ctx2); err != nil {
|
||||
return fmt.Errorf("ping target database failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -256,22 +249,18 @@ func initializeDatabase(cfg *SetupConfig) error {
|
||||
cfg.Database.Password, cfg.Database.DBName, cfg.Database.SSLMode,
|
||||
)
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return repository.AutoMigrate(db)
|
||||
return infrastructure.ApplyMigrations(context.Background(), db)
|
||||
}
|
||||
|
||||
func createAdminUser(cfg *SetupConfig) error {
|
||||
@@ -281,24 +270,24 @@ func createAdminUser(cfg *SetupConfig) error {
|
||||
cfg.Database.Password, cfg.Database.DBName, cfg.Database.SSLMode,
|
||||
)
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Printf("failed to close postgres connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 使用超时上下文避免安装流程因数据库异常而长时间阻塞。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Check if admin already exists
|
||||
var count int64
|
||||
if err := db.Table("users").Where("role = ?", service.RoleAdmin).Count(&count).Error; err != nil {
|
||||
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM users WHERE role = $1", service.RoleAdmin).Scan(&count); err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
@@ -319,7 +308,20 @@ func createAdminUser(cfg *SetupConfig) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return repository.NewUserRepository(db).Create(context.Background(), admin)
|
||||
_, err = db.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO users (email, password_hash, role, balance, concurrency, status, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
|
||||
admin.Email,
|
||||
admin.PasswordHash,
|
||||
admin.Role,
|
||||
admin.Balance,
|
||||
admin.Concurrency,
|
||||
admin.Status,
|
||||
admin.CreatedAt,
|
||||
admin.UpdatedAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func writeConfigFile(cfg *SetupConfig) error {
|
||||
@@ -339,7 +341,10 @@ func writeConfigFile(cfg *SetupConfig) error {
|
||||
ExpireHour int `yaml:"expire_hour"`
|
||||
} `yaml:"jwt"`
|
||||
Default struct {
|
||||
GroupID uint `yaml:"group_id"`
|
||||
UserConcurrency int `yaml:"user_concurrency"`
|
||||
UserBalance float64 `yaml:"user_balance"`
|
||||
ApiKeyPrefix string `yaml:"api_key_prefix"`
|
||||
RateMultiplier float64 `yaml:"rate_multiplier"`
|
||||
} `yaml:"default"`
|
||||
RateLimit struct {
|
||||
RequestsPerMinute int `yaml:"requests_per_minute"`
|
||||
@@ -358,9 +363,15 @@ func writeConfigFile(cfg *SetupConfig) error {
|
||||
ExpireHour: cfg.JWT.ExpireHour,
|
||||
},
|
||||
Default: struct {
|
||||
GroupID uint `yaml:"group_id"`
|
||||
UserConcurrency int `yaml:"user_concurrency"`
|
||||
UserBalance float64 `yaml:"user_balance"`
|
||||
ApiKeyPrefix string `yaml:"api_key_prefix"`
|
||||
RateMultiplier float64 `yaml:"rate_multiplier"`
|
||||
}{
|
||||
GroupID: 1,
|
||||
UserConcurrency: 5,
|
||||
UserBalance: 0,
|
||||
ApiKeyPrefix: "sk-",
|
||||
RateMultiplier: 1.0,
|
||||
},
|
||||
RateLimit: struct {
|
||||
RequestsPerMinute int `yaml:"requests_per_minute"`
|
||||
|
||||
Reference in New Issue
Block a user