refactor(数据库): 迁移持久层到 Ent 并清理 GORM

将仓储层/基础设施改为 Ent + 原生 SQL 执行路径,并移除 AutoMigrate 与 GORM 依赖。
重构内容包括:
- 仓储层改用 Ent/SQL(含 usage_log/account 等复杂查询),统一错误映射
- 基础设施与 setup 初始化切换为 Ent + SQL migrations
- 集成测试与 fixtures 迁移到 Ent 事务模型
- 清理遗留 GORM 模型/依赖,补充迁移与文档说明
- 增加根目录 Makefile 便于前后端编译

测试:
- go test -tags unit ./...
- go test -tags integration ./...
This commit is contained in:
yangjianbo
2025-12-29 10:03:27 +08:00
parent fd51ff6970
commit 3d617de577
149 changed files with 62892 additions and 3212 deletions

View File

@@ -195,8 +195,10 @@ func setDefaults() {
viper.SetDefault("jwt.expire_hour", 24)
// Default
viper.SetDefault("default.admin_email", "admin@sub2api.com")
viper.SetDefault("default.admin_password", "admin123")
// Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
// Do not ship fixed defaults here to avoid insecure "known credentials" in production.
viper.SetDefault("default.admin_email", "")
viper.SetDefault("default.admin_password", "")
viper.SetDefault("default.user_concurrency", 5)
viper.SetDefault("default.user_balance", 0)
viper.SetDefault("default.api_key_prefix", "sk-")

View File

@@ -1,38 +0,0 @@
package infrastructure
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/repository"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// InitDB 初始化数据库连接
func InitDB(cfg *config.Config) (*gorm.DB, error) {
// 初始化时区(在数据库连接之前,确保时区设置正确)
if err := timezone.Init(cfg.Timezone); err != nil {
return nil, err
}
gormConfig := &gorm.Config{}
if cfg.Server.Mode == "debug" {
gormConfig.Logger = logger.Default.LogMode(logger.Info)
}
// 使用带时区的 DSN 连接数据库
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
if err != nil {
return nil, err
}
// 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := repository.AutoMigrate(db); err != nil {
return nil, err
}
return db, nil
}

View File

@@ -0,0 +1,65 @@
// Package infrastructure 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package infrastructure
import (
"context"
"database/sql"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/migrations"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "github.com/lib/pq" // PostgreSQL 驱动,通过副作用导入注册驱动
)
// InitEnt 初始化 Ent ORM 客户端并返回客户端实例和底层的 *sql.DB。
//
// 该函数执行以下操作:
// 1. 初始化全局时区设置,确保时间处理一致性
// 2. 建立 PostgreSQL 数据库连接
// 3. 自动执行数据库迁移,确保 schema 与代码同步
// 4. 创建并返回 Ent 客户端实例
//
// 重要提示:调用者必须负责关闭返回的 ent.Client关闭时会自动关闭底层的 driver/db
//
// 参数:
// - cfg: 应用程序配置,包含数据库连接信息和时区设置
//
// 返回:
// - *ent.Client: Ent ORM 客户端,用于执行数据库操作
// - *sql.DB: 底层的 SQL 数据库连接,可用于直接执行原生 SQL
// - error: 初始化过程中的错误
func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
// 优先初始化时区设置,确保所有时间操作使用统一的时区。
// 这对于跨时区部署和日志时间戳的一致性至关重要。
if err := timezone.Init(cfg.Timezone); err != nil {
return nil, nil, err
}
// 构建包含时区信息的数据库连接字符串 (DSN)。
// 时区信息会传递给 PostgreSQL确保数据库层面的时间处理正确。
dsn := cfg.Database.DSNWithTimezone(cfg.Timezone)
// 使用 Ent 的 SQL 驱动打开 PostgreSQL 连接。
// dialect.Postgres 指定使用 PostgreSQL 方言进行 SQL 生成。
drv, err := entsql.Open(dialect.Postgres, dsn)
if err != nil {
return nil, nil, err
}
// 确保数据库 schema 已准备就绪。
// SQL 迁移文件是 schema 的权威来源source of truth
// 这种方式比 Ent 的自动迁移更可控,支持复杂的迁移场景。
if err := applyMigrationsFS(context.Background(), drv.DB(), migrations.FS); err != nil {
_ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露
return nil, nil, err
}
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
client := ent.NewClient(ent.Driver(drv))
return client, drv.DB(), nil
}

View File

@@ -0,0 +1,184 @@
package infrastructure
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"io/fs"
"sort"
"strings"
"github.com/Wei-Shaw/sub2api/migrations"
)
// schemaMigrationsTableDDL 定义迁移记录表的 DDL。
// 该表用于跟踪已应用的迁移文件及其校验和。
// - filename: 迁移文件名,作为主键唯一标识每个迁移
// - checksum: 文件内容的 SHA256 哈希值,用于检测迁移文件是否被篡改
// - applied_at: 迁移应用时间戳
const schemaMigrationsTableDDL = `
CREATE TABLE IF NOT EXISTS schema_migrations (
filename TEXT PRIMARY KEY,
checksum TEXT NOT NULL,
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
`
// migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。
// 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
const migrationsAdvisoryLockID int64 = 694208311321144027
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
//
// 该函数可以在每次应用启动时安全调用:
// - 已应用的迁移会被自动跳过(通过校验 filename 判断)
// - 如果迁移文件内容被修改checksum 不匹配),会返回错误
// - 使用 PostgreSQL Advisory Lock 确保多实例并发安全
//
// 参数:
// - ctx: 上下文,用于超时控制和取消
// - db: 数据库连接
//
// 返回:
// - error: 迁移过程中的任何错误
func ApplyMigrations(ctx context.Context, db *sql.DB) error {
if db == nil {
return errors.New("nil sql db")
}
return applyMigrationsFS(ctx, db, migrations.FS)
}
// applyMigrationsFS 是迁移执行的核心实现。
// 它从指定的文件系统读取 SQL 迁移文件并按顺序应用。
//
// 迁移执行流程:
// 1. 获取 PostgreSQL Advisory Lock防止多实例并发迁移
// 2. 确保 schema_migrations 表存在
// 3. 按文件名排序读取所有 .sql 文件
// 4. 对于每个迁移文件:
// - 计算文件内容的 SHA256 校验和
// - 检查该迁移是否已应用(通过 filename 查询)
// - 如果已应用,验证校验和是否匹配
// - 如果未应用,在事务中执行迁移并记录
// 5. 释放 Advisory Lock
//
// 参数:
// - ctx: 上下文
// - db: 数据库连接
// - fsys: 包含迁移文件的文件系统(通常是 embed.FS
func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if db == nil {
return errors.New("nil sql db")
}
// 获取分布式锁,确保多实例部署时只有一个实例执行迁移。
// 这是 PostgreSQL 特有的 Advisory Lock 机制。
if err := pgAdvisoryLock(ctx, db); err != nil {
return err
}
defer func() {
// 无论迁移是否成功,都要释放锁。
// 使用 context.Background() 确保即使原 ctx 已取消也能释放锁。
_ = pgAdvisoryUnlock(context.Background(), db)
}()
// 创建迁移记录表(如果不存在)。
// 该表记录所有已应用的迁移及其校验和。
if _, err := db.ExecContext(ctx, schemaMigrationsTableDDL); err != nil {
return fmt.Errorf("create schema_migrations: %w", err)
}
// 获取所有 .sql 迁移文件并按文件名排序。
// 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql
files, err := fs.Glob(fsys, "*.sql")
if err != nil {
return fmt.Errorf("list migrations: %w", err)
}
sort.Strings(files) // 确保按文件名顺序执行迁移
for _, name := range files {
// 读取迁移文件内容
contentBytes, err := fs.ReadFile(fsys, name)
if err != nil {
return fmt.Errorf("read migration %s: %w", name, err)
}
content := strings.TrimSpace(string(contentBytes))
if content == "" {
continue // 跳过空文件
}
// 计算文件内容的 SHA256 校验和,用于检测文件是否被修改。
// 这是一种防篡改机制:如果有人修改了已应用的迁移文件,系统会拒绝启动。
sum := sha256.Sum256([]byte(content))
checksum := hex.EncodeToString(sum[:])
// 检查该迁移是否已经应用
var existing string
rowErr := db.QueryRowContext(ctx, "SELECT checksum FROM schema_migrations WHERE filename = $1", name).Scan(&existing)
if rowErr == nil {
// 迁移已应用,验证校验和是否匹配
if existing != checksum {
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf("migration %s checksum mismatch (db=%s file=%s)", name, existing, checksum)
}
continue // 迁移已应用且校验和匹配,跳过
}
if !errors.Is(rowErr, sql.ErrNoRows) {
return fmt.Errorf("check migration %s: %w", name, rowErr)
}
// 迁移未应用,在事务中执行。
// 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin migration %s: %w", name, err)
}
// 执行迁移 SQL
if _, err := tx.ExecContext(ctx, content); err != nil {
_ = tx.Rollback()
return fmt.Errorf("apply migration %s: %w", name, err)
}
// 记录迁移已完成,保存文件名和校验和
if _, err := tx.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil {
_ = tx.Rollback()
return fmt.Errorf("record migration %s: %w", name, err)
}
// 提交事务
if err := tx.Commit(); err != nil {
_ = tx.Rollback()
return fmt.Errorf("commit migration %s: %w", name, err)
}
}
return nil
}
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
func pgAdvisoryLock(ctx context.Context, db *sql.DB) error {
_, err := db.ExecContext(ctx, "SELECT pg_advisory_lock($1)", migrationsAdvisoryLockID)
if err != nil {
return fmt.Errorf("acquire migrations lock: %w", err)
}
return nil
}
// pgAdvisoryUnlock 释放 PostgreSQL Advisory Lock。
// 必须在获取锁后确保释放,否则会阻塞其他实例的迁移操作。
func pgAdvisoryUnlock(ctx context.Context, db *sql.DB) error {
_, err := db.ExecContext(ctx, "SELECT pg_advisory_unlock($1)", migrationsAdvisoryLockID)
if err != nil {
return fmt.Errorf("release migrations lock: %w", err)
}
return nil
}

View File

@@ -1,25 +1,79 @@
package infrastructure
import (
"database/sql"
"errors"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
entsql "entgo.io/ent/dialect/sql"
)
// ProviderSet 提供基础设施层的依赖
// ProviderSet 基础设施层的 Wire 依赖提供者集合。
//
// Wire 是 Google 开发的编译时依赖注入工具。ProviderSet 将相关的依赖提供函数
// 组织在一起,便于在应用程序启动时自动组装依赖关系。
//
// 包含的提供者:
// - ProvideEnt: 提供 Ent ORM 客户端
// - ProvideSQLDB: 提供底层 SQL 数据库连接
// - ProvideRedis: 提供 Redis 客户端
var ProviderSet = wire.NewSet(
ProvideDB,
ProvideEnt,
ProvideSQLDB,
ProvideRedis,
)
// ProvideDB 提供数据库连接
func ProvideDB(cfg *config.Config) (*gorm.DB, error) {
return InitDB(cfg)
// ProvideEnt 为依赖注入提供 Ent 客户端。
//
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
//
// 依赖config.Config
// 提供:*ent.Client
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
client, _, err := InitEnt(cfg)
return client, err
}
// ProvideRedis 提供 Redis 客户端
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
//
// 某些 Repository 需要直接执行原生 SQL如复杂的批量更新、聚合查询
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
//
// 设计说明:
// - Ent 底层使用 sql.DB通过 Driver 接口可以访问
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
//
// 依赖:*ent.Client
// 提供:*sql.DB
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
if client == nil {
return nil, errors.New("nil ent client")
}
// 从 Ent 客户端获取底层驱动
drv, ok := client.Driver().(*entsql.Driver)
if !ok {
return nil, errors.New("ent driver does not expose *sql.DB")
}
// 返回驱动持有的 sql.DB 实例
return drv.DB(), nil
}
// ProvideRedis 为依赖注入提供 Redis 客户端。
//
// Redis 用于:
// - 分布式锁(如并发控制)
// - 缓存如用户会话、API 响应缓存)
// - 速率限制
// - 实时统计数据
//
// 依赖config.Config
// 提供:*redis.Client
func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg)
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,27 +4,31 @@ package repository
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/datatypes"
"gorm.io/gorm"
)
type AccountRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
tx *sql.Tx
client *dbent.Client
repo *accountRepository
}
func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewAccountRepository(s.db).(*accountRepository)
client, tx := testEntSQLTx(s.T())
s.client = client
s.tx = tx
s.repo = newAccountRepositoryWithSQL(client, tx)
}
func TestAccountRepoSuite(t *testing.T) {
@@ -61,7 +65,7 @@ func (s *AccountRepoSuite) TestGetByID_NotFound() {
}
func (s *AccountRepoSuite) TestUpdate() {
account := accountModelToService(mustCreateAccount(s.T(), s.db, &accountModel{Name: "original"}))
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "original"})
account.Name = "updated"
err := s.repo.Update(s.ctx, account)
@@ -73,7 +77,7 @@ func (s *AccountRepoSuite) TestUpdate() {
}
func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "to-delete"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete")
@@ -83,23 +87,23 @@ func (s *AccountRepoSuite) TestDelete() {
}
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-del"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete should cascade remove bindings")
var count int64
s.db.Model(&accountGroupModel{}).Where("account_id = ?", account.ID).Count(&count)
count, err := s.client.AccountGroup.Query().Where(accountgroup.AccountIDEQ(account.ID)).Count(s.ctx)
s.Require().NoError(err)
s.Require().Zero(count, "expected bindings to be removed")
}
// --- List / ListWithFilters ---
func (s *AccountRepoSuite) TestList() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc2"})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc1"})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc2"})
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
@@ -110,7 +114,7 @@ func (s *AccountRepoSuite) TestList() {
func (s *AccountRepoSuite) TestListWithFilters() {
tests := []struct {
name string
setup func(db *gorm.DB)
setup func(client *dbent.Client)
platform string
accType string
status string
@@ -120,9 +124,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
}{
{
name: "filter_by_platform",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic})
mustCreateAccount(s.T(), db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI})
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic})
mustCreateAccount(s.T(), client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI})
},
platform: service.PlatformOpenAI,
wantCount: 1,
@@ -132,9 +136,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
},
{
name: "filter_by_type",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &accountModel{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), db, &accountModel{Name: "t2", Type: service.AccountTypeApiKey})
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeApiKey})
},
accType: service.AccountTypeApiKey,
wantCount: 1,
@@ -144,9 +148,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
},
{
name: "filter_by_status",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &accountModel{Name: "s1", Status: service.StatusActive})
mustCreateAccount(s.T(), db, &accountModel{Name: "s2", Status: service.StatusDisabled})
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "s1", Status: service.StatusActive})
mustCreateAccount(s.T(), client, &service.Account{Name: "s2", Status: service.StatusDisabled})
},
status: service.StatusDisabled,
wantCount: 1,
@@ -156,9 +160,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
},
{
name: "filter_by_search",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &accountModel{Name: "alpha-account"})
mustCreateAccount(s.T(), db, &accountModel{Name: "beta-account"})
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "alpha-account"})
mustCreateAccount(s.T(), client, &service.Account{Name: "beta-account"})
},
search: "alpha",
wantCount: 1,
@@ -171,11 +175,11 @@ func (s *AccountRepoSuite) TestListWithFilters() {
for _, tt := range tests {
s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源
db := testTx(s.T())
repo := NewAccountRepository(db).(*accountRepository)
client, tx := testEntSQLTx(s.T())
repo := newAccountRepositoryWithSQL(client, tx)
ctx := context.Background()
tt.setup(db)
tt.setup(client)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
s.Require().NoError(err)
@@ -190,11 +194,11 @@ func (s *AccountRepoSuite) TestListWithFilters() {
// --- ListByGroup / ListActive / ListByPlatform ---
func (s *AccountRepoSuite) TestListByGroup() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
acc1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Status: service.StatusActive})
acc2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Status: service.StatusActive})
mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2)
mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1)
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-list"})
acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Status: service.StatusActive})
acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Status: service.StatusActive})
mustBindAccountToGroup(s.T(), s.client, acc1.ID, group.ID, 2)
mustBindAccountToGroup(s.T(), s.client, acc2.ID, group.ID, 1)
accounts, err := s.repo.ListByGroup(s.ctx, group.ID)
s.Require().NoError(err, "ListByGroup")
@@ -204,8 +208,8 @@ func (s *AccountRepoSuite) TestListByGroup() {
}
func (s *AccountRepoSuite) TestListActive() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "active1", Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "inactive1", Status: service.StatusDisabled})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "active1", Status: service.StatusActive})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "inactive1", Status: service.StatusDisabled})
accounts, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
@@ -214,8 +218,8 @@ func (s *AccountRepoSuite) TestListActive() {
}
func (s *AccountRepoSuite) TestListByPlatform() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
accounts, err := s.repo.ListByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListByPlatform")
@@ -226,14 +230,14 @@ func (s *AccountRepoSuite) TestListByPlatform() {
// --- Preload and VirtualFields ---
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
proxy := mustCreateProxy(s.T(), s.client, &service.Proxy{Name: "p1"})
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g1"})
account := mustCreateAccount(s.T(), s.db, &accountModel{
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc1",
ProxyID: &proxy.ID,
})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
@@ -257,9 +261,9 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc"})
g1 := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g1"})
g2 := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc"})
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
groups, err := s.repo.GetGroups(s.ctx, account.ID)
@@ -279,9 +283,9 @@ func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
}
func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-empty"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-empty"})
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-empty"})
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
@@ -294,14 +298,14 @@ func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
func (s *AccountRepoSuite) TestListSchedulable() {
now := time.Now()
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
okAcc := mustCreateAccount(s.T(), s.client, &service.Account{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.client, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
overloaded := mustCreateAccount(s.T(), s.client, &service.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.client, overloaded.ID, group.ID, 1)
sched, err := s.repo.ListSchedulable(s.ctx)
s.Require().NoError(err, "ListSchedulable")
@@ -312,17 +316,17 @@ func (s *AccountRepoSuite) TestListSchedulable() {
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
now := time.Now()
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
okAcc := mustCreateAccount(s.T(), s.client, &service.Account{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.client, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
overloaded := mustCreateAccount(s.T(), s.client, &service.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.client, overloaded.ID, group.ID, 1)
rateLimited := mustCreateAccount(s.T(), s.db, &accountModel{Name: "rl", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1)
rateLimited := mustCreateAccount(s.T(), s.client, &service.Account{Name: "rl", Schedulable: true})
mustBindAccountToGroup(s.T(), s.client, rateLimited.ID, group.ID, 1)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
s.Require().NoError(s.repo.SetError(s.ctx, overloaded.ID, "boom"), "SetError")
@@ -339,8 +343,8 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_Statu
}
func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err)
@@ -349,11 +353,11 @@ func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
}
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sp"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sp"})
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
a2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
mustBindAccountToGroup(s.T(), s.client, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.client, a2.ID, group.ID, 2)
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, service.PlatformAnthropic)
s.Require().NoError(err)
@@ -362,7 +366,7 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
}
func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-sched", Schedulable: true})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
@@ -374,7 +378,7 @@ func (s *AccountRepoSuite) TestSetSchedulable() {
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
func (s *AccountRepoSuite) TestSetOverloaded() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-over"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-over"})
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
@@ -386,7 +390,7 @@ func (s *AccountRepoSuite) TestSetOverloaded() {
}
func (s *AccountRepoSuite) TestSetRateLimited() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-rl"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-rl"})
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
@@ -399,7 +403,7 @@ func (s *AccountRepoSuite) TestSetRateLimited() {
}
func (s *AccountRepoSuite) TestClearRateLimit() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-clear"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-clear"})
until := time.Now().Add(1 * time.Hour)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
@@ -416,7 +420,7 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
// --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-used"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-used"})
s.Require().Nil(account.LastUsedAt)
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
@@ -429,7 +433,7 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() {
// --- SetError ---
func (s *AccountRepoSuite) TestSetError() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-err", Status: service.StatusActive})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-err", Status: service.StatusActive})
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
@@ -442,7 +446,7 @@ func (s *AccountRepoSuite) TestSetError() {
// --- UpdateSessionWindow ---
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-win"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-win"})
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
@@ -458,9 +462,9 @@ func (s *AccountRepoSuite) TestUpdateSessionWindow() {
// --- UpdateExtra ---
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
account := mustCreateAccount(s.T(), s.db, &accountModel{
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-extra",
Extra: datatypes.JSONMap{"a": "1"},
Extra: map[string]any{"a": "1"},
})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
@@ -471,12 +475,12 @@ func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
}
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-extra-empty"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-extra-empty"})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
}
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-nil-extra", Extra: nil})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-nil-extra", Extra: nil})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
got, err := s.repo.GetByID(s.ctx, account.ID)
@@ -488,9 +492,9 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
crsID := "crs-12345"
mustCreateAccount(s.T(), s.db, &accountModel{
mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-crs",
Extra: datatypes.JSONMap{"crs_account_id": crsID},
Extra: map[string]any{"crs_account_id": crsID},
})
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
@@ -514,8 +518,8 @@ func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
// --- BulkUpdate ---
func (s *AccountRepoSuite) TestBulkUpdate() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk1", Priority: 1})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk2", Priority: 1})
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk1", Priority: 1})
a2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk2", Priority: 1})
newPriority := 99
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
@@ -531,13 +535,13 @@ func (s *AccountRepoSuite) TestBulkUpdate() {
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{
a1 := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "bulk-cred",
Credentials: datatypes.JSONMap{"existing": "value"},
Credentials: map[string]any{"existing": "value"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Credentials: datatypes.JSONMap{"new_key": "new_value"},
Credentials: map[string]any{"new_key": "new_value"},
})
s.Require().NoError(err)
@@ -547,13 +551,13 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{
a1 := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "bulk-extra",
Extra: datatypes.JSONMap{"existing": "val"},
Extra: map[string]any{"existing": "val"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Extra: datatypes.JSONMap{"new_key": "new_val"},
Extra: map[string]any{"new_key": "new_val"},
})
s.Require().NoError(err)
@@ -569,7 +573,7 @@ func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
}
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk-empty"})
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-empty"})
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
s.Require().NoError(err)

View File

@@ -0,0 +1,144 @@
//go:build integration
package repository
import (
"context"
"fmt"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func uniqueTestValue(t *testing.T, prefix string) string {
t.Helper()
safeName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
return fmt.Sprintf("%s-%s", prefix, safeName)
}
func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) {
ctx := context.Background()
entClient, sqlTx := testEntSQLTx(t)
targetGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "target-group")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
otherGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "other-group")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
repo := newUserRepositoryWithSQL(entClient, sqlTx)
u1 := &service.User{
Email: uniqueTestValue(t, "u1") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
}
require.NoError(t, repo.Create(ctx, u1))
u2 := &service.User{
Email: uniqueTestValue(t, "u2") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{targetGroup.ID},
}
require.NoError(t, repo.Create(ctx, u2))
u3 := &service.User{
Email: uniqueTestValue(t, "u3") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{otherGroup.ID},
}
require.NoError(t, repo.Create(ctx, u3))
affected, err := repo.RemoveGroupFromAllowedGroups(ctx, targetGroup.ID)
require.NoError(t, err)
require.Equal(t, int64(2), affected)
u1After, err := repo.GetByID(ctx, u1.ID)
require.NoError(t, err)
require.NotContains(t, u1After.AllowedGroups, targetGroup.ID)
require.Contains(t, u1After.AllowedGroups, otherGroup.ID)
u2After, err := repo.GetByID(ctx, u2.ID)
require.NoError(t, err)
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
}
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
ctx := context.Background()
entClient, sqlTx := testEntSQLTx(t)
targetGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "delete-cascade-target")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
otherGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "delete-cascade-other")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
userRepo := newUserRepositoryWithSQL(entClient, sqlTx)
groupRepo := newGroupRepositoryWithSQL(entClient, sqlTx)
apiKeyRepo := NewApiKeyRepository(entClient)
u := &service.User{
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
}
require.NoError(t, userRepo.Create(ctx, u))
key := &service.ApiKey{
UserID: u.ID,
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
Name: "test key",
GroupID: &targetGroup.ID,
Status: service.StatusActive,
}
require.NoError(t, apiKeyRepo.Create(ctx, key))
_, err = groupRepo.DeleteCascade(ctx, targetGroup.ID)
require.NoError(t, err)
// Deleted group should be hidden by default queries (soft-delete semantics).
_, err = groupRepo.GetByID(ctx, targetGroup.ID)
require.ErrorIs(t, err, service.ErrGroupNotFound)
activeGroups, err := groupRepo.ListActive(ctx)
require.NoError(t, err)
for _, g := range activeGroups {
require.NotEqual(t, targetGroup.ID, g.ID)
}
// User.allowed_groups should no longer include the deleted group.
uAfter, err := userRepo.GetByID(ctx, u.ID)
require.NoError(t, err)
require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID)
require.Contains(t, uAfter.AllowedGroups, otherGroup.ID)
// API keys bound to the deleted group should have group_id cleared.
keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID)
require.NoError(t, err)
require.Nil(t, keyAfter.GroupID)
}

View File

@@ -2,83 +2,118 @@ package repository
import (
"context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
type apiKeyRepository struct {
db *gorm.DB
client *dbent.Client
}
func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
return &apiKeyRepository{db: db}
func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
return &apiKeyRepository{client: client}
}
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Create(m).Error
created, err := r.client.ApiKey.Create().
SetUserID(key.UserID).
SetKey(key.Key).
SetName(key.Name).
SetStatus(key.Status).
SetNillableGroupID(key.GroupID).
Save(ctx)
if err == nil {
applyApiKeyModelToService(key, m)
key.ID = created.ID
key.CreatedAt = created.CreatedAt
key.UpdatedAt = created.UpdatedAt
}
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
}
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&m, id).Error
m, err := r.client.ApiKey.Query().
Where(apikey.IDEQ(id)).
WithUser().
WithGroup().
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
if dbent.IsNotFound(err) {
return nil, service.ErrApiKeyNotFound
}
return nil, err
}
return apiKeyModelToService(&m), nil
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&m).Error
m, err := r.client.ApiKey.Query().
Where(apikey.KeyEQ(key)).
WithUser().
WithGroup().
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
if dbent.IsNotFound(err) {
return nil, service.ErrApiKeyNotFound
}
return nil, err
}
return apiKeyModelToService(&m), nil
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Model(m).Select("name", "group_id", "status", "updated_at").Updates(m).Error
builder := r.client.ApiKey.UpdateOneID(key.ID).
SetName(key.Name).
SetStatus(key.Status)
if key.GroupID != nil {
builder.SetGroupID(*key.GroupID)
} else {
builder.ClearGroupID()
}
updated, err := builder.Save(ctx)
if err == nil {
applyApiKeyModelToService(key, m)
key.UpdatedAt = updated.UpdatedAt
return nil
}
if dbent.IsNotFound(err) {
return service.ErrApiKeyNotFound
}
return err
}
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&apiKeyModel{}, id).Error
_, err := r.client.ApiKey.Delete().Where(apikey.IDEQ(id)).Exec(ctx)
return err
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []apiKeyModel
var total int64
q := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID))
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID)
if err := db.Count(&total).Error; err != nil {
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
if err := db.Preload("Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&keys).Error; err != nil {
keys, err := q.
WithGroup().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(apikey.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
return outKeys, paginationResultFromTotal(total, params), nil
return outKeys, paginationResultFromTotal(int64(total), params), nil
}
func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
@@ -86,11 +121,9 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
return []int64{}, nil
}
ids := make([]int64, 0, len(apiKeyIDs))
err := r.db.WithContext(ctx).
Model(&apiKeyModel{}).
Where("user_id = ? AND id IN ?", userID, apiKeyIDs).
Pluck("id", &ids).Error
ids, err := r.client.ApiKey.Query().
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...)).
IDs(ctx)
if err != nil {
return nil, err
}
@@ -98,136 +131,146 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
}
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error
return count, err
count, err := r.client.ApiKey.Query().Where(apikey.UserIDEQ(userID)).Count(ctx)
return int64(count), err
}
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("key = ?", key).Count(&count).Error
count, err := r.client.ApiKey.Query().Where(apikey.KeyEQ(key)).Count(ctx)
return count > 0, err
}
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []apiKeyModel
var total int64
q := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID))
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID)
if err := db.Count(&total).Error; err != nil {
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
if err := db.Preload("User").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&keys).Error; err != nil {
keys, err := q.
WithUser().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(apikey.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
return outKeys, paginationResultFromTotal(total, params), nil
return outKeys, paginationResultFromTotal(int64(total), params), nil
}
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
var keys []apiKeyModel
db := r.db.WithContext(ctx).Model(&apiKeyModel{})
q := r.client.ApiKey.Query()
if userID > 0 {
db = db.Where("user_id = ?", userID)
q = q.Where(apikey.UserIDEQ(userID))
}
if keyword != "" {
searchPattern := "%" + keyword + "%"
db = db.Where("name ILIKE ?", searchPattern)
q = q.Where(apikey.NameContainsFold(keyword))
}
if err := db.Limit(limit).Order("id DESC").Find(&keys).Error; err != nil {
keys, err := q.Limit(limit).Order(dbent.Desc(apikey.FieldID)).All(ctx)
if err != nil {
return nil, err
}
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
return outKeys, nil
}
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("group_id = ?", groupID).
Update("group_id", nil)
return result.RowsAffected, result.Error
n, err := r.client.ApiKey.Update().
Where(apikey.GroupIDEQ(groupID)).
ClearGroupID().
Save(ctx)
return int64(n), err
}
// CountByGroupID 获取分组的 API Key 数量
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err
count, err := r.client.ApiKey.Query().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
return int64(count), err
}
type apiKeyModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
Key string `gorm:"uniqueIndex;size:128;not null"`
Name string `gorm:"size:100;not null"`
GroupID *int64 `gorm:"index"`
Status string `gorm:"size:20;default:active;not null"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (apiKeyModel) TableName() string { return "api_keys" }
func apiKeyModelToService(m *apiKeyModel) *service.ApiKey {
func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey {
if m == nil {
return nil
}
return &service.ApiKey{
out := &service.ApiKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
GroupID: m.GroupID,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
GroupID: m.GroupID,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
}
if m.Edges.Group != nil {
out.Group = groupEntityToService(m.Edges.Group)
}
return out
}
func apiKeyModelFromService(k *service.ApiKey) *apiKeyModel {
if k == nil {
func userEntityToService(u *dbent.User) *service.User {
if u == nil {
return nil
}
return &apiKeyModel{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
return &service.User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
func applyApiKeyModelToService(key *service.ApiKey, m *apiKeyModel) {
if key == nil || m == nil {
return
func groupEntityToService(g *dbent.Group) *service.Group {
if g == nil {
return nil
}
return &service.Group{
ID: g.ID,
Name: g.Name,
Description: derefString(g.Description),
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
key.ID = m.ID
key.CreatedAt = m.CreatedAt
key.UpdatedAt = m.UpdatedAt
}
func derefString(s *string) string {
if s == nil {
return ""
}
return *s
}

View File

@@ -6,23 +6,24 @@ import (
"context"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type ApiKeyRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *apiKeyRepository
ctx context.Context
client *dbent.Client
repo *apiKeyRepository
}
func (s *ApiKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewApiKeyRepository(s.db).(*apiKeyRepository)
entClient, _ := testEntSQLTx(s.T())
s.client = entClient
s.repo = NewApiKeyRepository(entClient).(*apiKeyRepository)
}
func TestApiKeyRepoSuite(t *testing.T) {
@@ -32,7 +33,7 @@ func TestApiKeyRepoSuite(t *testing.T) {
// --- Create / GetByID / GetByKey ---
func (s *ApiKeyRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
user := s.mustCreateUser("create@test.com")
key := &service.ApiKey{
UserID: user.ID,
@@ -56,16 +57,17 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
}
func (s *ApiKeyRepoSuite) TestGetByKey() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbykey@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-key"})
user := s.mustCreateUser("getbykey@test.com")
group := s.mustCreateGroup("g-key")
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-getbykey",
Name: "My Key",
GroupID: &group.ID,
Status: service.StatusActive,
})
}
s.Require().NoError(s.repo.Create(s.ctx, key))
got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey")
@@ -84,13 +86,14 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
// --- Update ---
func (s *ApiKeyRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
user := s.mustCreateUser("update@test.com")
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-update",
Name: "Original",
Status: service.StatusActive,
}))
}
s.Require().NoError(s.repo.Create(s.ctx, key))
key.Name = "Renamed"
key.Status = service.StatusDisabled
@@ -106,14 +109,16 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
}
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear"})
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
user := s.mustCreateUser("cleargroup@test.com")
group := s.mustCreateGroup("g-clear")
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-clear-group",
Name: "Group Key",
GroupID: &group.ID,
}))
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, key))
key.GroupID = nil
err := s.repo.Update(s.ctx, key)
@@ -127,12 +132,14 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
// --- Delete ---
func (s *ApiKeyRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
user := s.mustCreateUser("delete@test.com")
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-delete",
Name: "Delete Me",
})
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, key))
err := s.repo.Delete(s.ctx, key.ID)
s.Require().NoError(err, "Delete")
@@ -144,9 +151,9 @@ func (s *ApiKeyRepoSuite) TestDelete() {
// --- ListByUserID / CountByUserID ---
func (s *ApiKeyRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
user := s.mustCreateUser("listbyuser@test.com")
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
@@ -155,13 +162,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
}
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "paging@test.com"})
user := s.mustCreateUser("paging@test.com")
for i := 0; i < 5; i++ {
mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-page-" + string(rune('a'+i)),
Name: "Key",
})
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
}
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
@@ -172,9 +175,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
}
func (s *ApiKeyRepoSuite) TestCountByUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "count@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
user := s.mustCreateUser("count@test.com")
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID")
@@ -184,12 +187,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
// --- ListByGroupID / CountByGroupID ---
func (s *ApiKeyRepoSuite) TestListByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbygroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
user := s.mustCreateUser("listbygroup@test.com")
group := s.mustCreateGroup("g-list")
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
s.mustCreateApiKey(user.ID, "sk-grp-1", "K1", &group.ID)
s.mustCreateApiKey(user.ID, "sk-grp-2", "K2", &group.ID)
s.mustCreateApiKey(user.ID, "sk-grp-3", "K3", nil) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
@@ -200,10 +203,9 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
}
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "countgroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
user := s.mustCreateUser("countgroup@test.com")
group := s.mustCreateGroup("g-count")
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
@@ -213,8 +215,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
// --- ExistsByKey ---
func (s *ApiKeyRepoSuite) TestExistsByKey() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-exists", Name: "K"})
user := s.mustCreateUser("exists@test.com")
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey")
@@ -228,9 +230,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
// --- SearchApiKeys ---
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "search@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
user := s.mustCreateUser("search@test.com")
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys")
@@ -239,9 +241,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnokw@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
user := s.mustCreateUser("searchnokw@test.com")
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err)
@@ -249,8 +251,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnouid@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
user := s.mustCreateUser("searchnouid@test.com")
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err)
@@ -260,12 +262,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
// --- ClearGroupIDByGroupID ---
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear-bulk"})
user := s.mustCreateUser("cleargrp@test.com")
group := s.mustCreateGroup("g-clear-bulk")
k1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
k1 := s.mustCreateApiKey(user.ID, "sk-clr-1", "K1", &group.ID)
k2 := s.mustCreateApiKey(user.ID, "sk-clr-2", "K2", &group.ID)
s.mustCreateApiKey(user.ID, "sk-clr-3", "K3", nil) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
@@ -283,16 +285,10 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "k@example.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-k"})
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-test-1",
Name: "My Key",
GroupID: &group.ID,
Status: service.StatusActive,
}))
user := s.mustCreateUser("k@example.com")
group := s.mustCreateGroup("g-k")
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
key.GroupID = &group.ID
got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey")
@@ -330,12 +326,8 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-test-2",
Name: "Group Key",
GroupID: &group.ID,
})
k2 := s.mustCreateApiKey(user.ID, "sk-test-2", "Group Key", &group.ID)
k2.GroupID = &group.ID
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
@@ -353,3 +345,41 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().NoError(err, "CountByGroupID after clear")
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
}
func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
s.T().Helper()
u, err := s.client.User.Create().
SetEmail(email).
SetPasswordHash("test-password-hash").
SetStatus(service.StatusActive).
SetRole(service.RoleUser).
Save(s.ctx)
s.Require().NoError(err, "create user")
return userEntityToService(u)
}
func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
s.T().Helper()
g, err := s.client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
Save(s.ctx)
s.Require().NoError(err, "create group")
return groupEntityToService(g)
}
func (s *ApiKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.ApiKey {
s.T().Helper()
k := &service.ApiKey{
UserID: userID,
Key: key,
Name: name,
GroupID: groupID,
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
return k
}

View File

@@ -1,49 +0,0 @@
package repository
import (
"log"
"time"
"gorm.io/gorm"
)
// MaxExpiresAt is the maximum allowed expiration date for subscriptions (year 2099)
// This prevents time.Time JSON serialization errors (RFC 3339 requires year <= 9999)
var maxExpiresAt = time.Date(2099, 12, 31, 23, 59, 59, 0, time.UTC)
// AutoMigrate runs schema migrations for all repository persistence models.
// Persistence models are defined within individual `*_repo.go` files.
func AutoMigrate(db *gorm.DB) error {
err := db.AutoMigrate(
&userModel{},
&apiKeyModel{},
&groupModel{},
&accountModel{},
&accountGroupModel{},
&proxyModel{},
&redeemCodeModel{},
&usageLogModel{},
&settingModel{},
&userSubscriptionModel{},
)
if err != nil {
return err
}
// 修复无效的过期时间(年份超过 2099 会导致 JSON 序列化失败)
return fixInvalidExpiresAt(db)
}
// fixInvalidExpiresAt 修复 user_subscriptions 表中无效的过期时间
func fixInvalidExpiresAt(db *gorm.DB) error {
result := db.Model(&userSubscriptionModel{}).
Where("expires_at > ?", maxExpiresAt).
Update("expires_at", maxExpiresAt)
if result.Error != nil {
return result.Error
}
if result.RowsAffected > 0 {
log.Printf("[AutoMigrate] Fixed %d subscriptions with invalid expires_at (year > 2099)", result.RowsAffected)
}
return nil
}

View File

@@ -1,38 +1,75 @@
package repository
import (
"database/sql"
"errors"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"gorm.io/gorm"
"github.com/lib/pq"
)
// translatePersistenceError 将数据库层错误翻译为业务层错误。
//
// 这是 Repository 层的核心错误处理函数,确保数据库细节不会泄露到业务层。
// 通过统一的错误翻译,业务层可以使用语义明确的错误类型(如 ErrUserNotFound
// 而不是依赖于特定数据库的错误(如 sql.ErrNoRows
//
// 参数:
// - err: 原始数据库错误
// - notFound: 当记录不存在时返回的业务错误(可为 nil 表示不处理)
// - conflict: 当违反唯一约束时返回的业务错误(可为 nil 表示不处理)
//
// 返回:
// - 翻译后的业务错误,或原始错误(如果不匹配任何规则)
//
// 示例:
//
// err := translatePersistenceError(dbErr, service.ErrUserNotFound, service.ErrEmailExists)
func translatePersistenceError(err error, notFound, conflict *infraerrors.ApplicationError) error {
if err == nil {
return nil
}
if notFound != nil && errors.Is(err, gorm.ErrRecordNotFound) {
// 兼容 Ent ORM 和标准 database/sql 的 NotFound 行为。
// Ent 使用自定义的 NotFoundError而标准库使用 sql.ErrNoRows。
// 这里同时处理两种情况,保持业务错误映射一致。
if notFound != nil && (errors.Is(err, sql.ErrNoRows) || dbent.IsNotFound(err)) {
return notFound.WithCause(err)
}
// 处理唯一约束冲突(如邮箱已存在、名称重复等)
if conflict != nil && isUniqueConstraintViolation(err) {
return conflict.WithCause(err)
}
// 未匹配任何规则,返回原始错误
return err
}
// isUniqueConstraintViolation 判断错误是否为唯一约束冲突。
//
// 支持多种检测方式:
// 1. PostgreSQL 特定错误码 23505唯一约束冲突
// 2. 错误消息中包含的通用关键词
//
// 这种多层次的检测确保了对不同数据库驱动和 ORM 的兼容性。
func isUniqueConstraintViolation(err error) bool {
if err == nil {
return false
}
if errors.Is(err, gorm.ErrDuplicatedKey) {
return true
// 优先检测 PostgreSQL 特定错误码(最精确)。
// 错误码 23505 对应 unique_violation。
// 参考https://www.postgresql.org/docs/current/errcodes-appendix.html
var pgErr *pq.Error
if errors.As(err, &pgErr) {
return pgErr.Code == "23505"
}
// 回退到错误消息检测(兼容其他场景)。
// 这些关键词覆盖了 PostgreSQL、MySQL 等主流数据库的错误消息。
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "duplicate key") ||
strings.Contains(msg, "unique constraint") ||

View File

@@ -3,17 +3,22 @@
package repository
import (
"context"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"gorm.io/datatypes"
"gorm.io/gorm"
)
func mustCreateUser(t *testing.T, db *gorm.DB, u *userModel) *userModel {
func mustCreateUser(t *testing.T, client *dbent.Client, u *service.User) *service.User {
t.Helper()
ctx := context.Background()
if u.Email == "" {
u.Email = "user-" + time.Now().Format(time.RFC3339Nano) + "@example.com"
}
if u.PasswordHash == "" {
u.PasswordHash = "test-password-hash"
}
@@ -26,18 +31,48 @@ func mustCreateUser(t *testing.T, db *gorm.DB, u *userModel) *userModel {
if u.Concurrency == 0 {
u.Concurrency = 5
}
if u.CreatedAt.IsZero() {
u.CreatedAt = time.Now()
create := client.User.Create().
SetEmail(u.Email).
SetPasswordHash(u.PasswordHash).
SetRole(u.Role).
SetStatus(u.Status).
SetBalance(u.Balance).
SetConcurrency(u.Concurrency).
SetUsername(u.Username).
SetWechat(u.Wechat).
SetNotes(u.Notes)
if !u.CreatedAt.IsZero() {
create.SetCreatedAt(u.CreatedAt)
}
if u.UpdatedAt.IsZero() {
u.UpdatedAt = u.CreatedAt
if !u.UpdatedAt.IsZero() {
create.SetUpdatedAt(u.UpdatedAt)
}
require.NoError(t, db.Create(u).Error, "create user")
created, err := create.Save(ctx)
require.NoError(t, err, "create user")
u.ID = created.ID
u.CreatedAt = created.CreatedAt
u.UpdatedAt = created.UpdatedAt
if len(u.AllowedGroups) > 0 {
for _, groupID := range u.AllowedGroups {
_, err := client.UserAllowedGroup.Create().
SetUserID(u.ID).
SetGroupID(groupID).
Save(ctx)
require.NoError(t, err, "create user_allowed_groups row")
}
}
return u
}
func mustCreateGroup(t *testing.T, db *gorm.DB, g *groupModel) *groupModel {
func mustCreateGroup(t *testing.T, client *dbent.Client, g *service.Group) *service.Group {
t.Helper()
ctx := context.Background()
if g.Platform == "" {
g.Platform = service.PlatformAnthropic
}
@@ -47,18 +82,46 @@ func mustCreateGroup(t *testing.T, db *gorm.DB, g *groupModel) *groupModel {
if g.SubscriptionType == "" {
g.SubscriptionType = service.SubscriptionTypeStandard
}
if g.CreatedAt.IsZero() {
g.CreatedAt = time.Now()
create := client.Group.Create().
SetName(g.Name).
SetPlatform(g.Platform).
SetStatus(g.Status).
SetSubscriptionType(g.SubscriptionType).
SetRateMultiplier(g.RateMultiplier).
SetIsExclusive(g.IsExclusive)
if g.Description != "" {
create.SetDescription(g.Description)
}
if g.UpdatedAt.IsZero() {
g.UpdatedAt = g.CreatedAt
if g.DailyLimitUSD != nil {
create.SetDailyLimitUsd(*g.DailyLimitUSD)
}
require.NoError(t, db.Create(g).Error, "create group")
if g.WeeklyLimitUSD != nil {
create.SetWeeklyLimitUsd(*g.WeeklyLimitUSD)
}
if g.MonthlyLimitUSD != nil {
create.SetMonthlyLimitUsd(*g.MonthlyLimitUSD)
}
if !g.CreatedAt.IsZero() {
create.SetCreatedAt(g.CreatedAt)
}
if !g.UpdatedAt.IsZero() {
create.SetUpdatedAt(g.UpdatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create group")
g.ID = created.ID
g.CreatedAt = created.CreatedAt
g.UpdatedAt = created.UpdatedAt
return g
}
func mustCreateProxy(t *testing.T, db *gorm.DB, p *proxyModel) *proxyModel {
func mustCreateProxy(t *testing.T, client *dbent.Client, p *service.Proxy) *service.Proxy {
t.Helper()
ctx := context.Background()
if p.Protocol == "" {
p.Protocol = "http"
}
@@ -71,18 +134,39 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *proxyModel) *proxyModel {
if p.Status == "" {
p.Status = service.StatusActive
}
if p.CreatedAt.IsZero() {
p.CreatedAt = time.Now()
create := client.Proxy.Create().
SetName(p.Name).
SetProtocol(p.Protocol).
SetHost(p.Host).
SetPort(p.Port).
SetStatus(p.Status)
if p.Username != "" {
create.SetUsername(p.Username)
}
if p.UpdatedAt.IsZero() {
p.UpdatedAt = p.CreatedAt
if p.Password != "" {
create.SetPassword(p.Password)
}
require.NoError(t, db.Create(p).Error, "create proxy")
if !p.CreatedAt.IsZero() {
create.SetCreatedAt(p.CreatedAt)
}
if !p.UpdatedAt.IsZero() {
create.SetUpdatedAt(p.UpdatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create proxy")
p.ID = created.ID
p.CreatedAt = created.CreatedAt
p.UpdatedAt = created.UpdatedAt
return p
}
func mustCreateAccount(t *testing.T, db *gorm.DB, a *accountModel) *accountModel {
func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) *service.Account {
t.Helper()
ctx := context.Background()
if a.Platform == "" {
a.Platform = service.PlatformAnthropic
}
@@ -92,57 +176,158 @@ func mustCreateAccount(t *testing.T, db *gorm.DB, a *accountModel) *accountModel
if a.Status == "" {
a.Status = service.StatusActive
}
if a.Concurrency == 0 {
a.Concurrency = 3
}
if a.Priority == 0 {
a.Priority = 50
}
if !a.Schedulable {
a.Schedulable = true
}
if a.Credentials == nil {
a.Credentials = datatypes.JSONMap{}
a.Credentials = map[string]any{}
}
if a.Extra == nil {
a.Extra = datatypes.JSONMap{}
a.Extra = map[string]any{}
}
if a.CreatedAt.IsZero() {
a.CreatedAt = time.Now()
create := client.Account.Create().
SetName(a.Name).
SetPlatform(a.Platform).
SetType(a.Type).
SetCredentials(a.Credentials).
SetExtra(a.Extra).
SetConcurrency(a.Concurrency).
SetPriority(a.Priority).
SetStatus(a.Status).
SetSchedulable(a.Schedulable).
SetErrorMessage(a.ErrorMessage)
if a.ProxyID != nil {
create.SetProxyID(*a.ProxyID)
}
if a.UpdatedAt.IsZero() {
a.UpdatedAt = a.CreatedAt
if a.LastUsedAt != nil {
create.SetLastUsedAt(*a.LastUsedAt)
}
require.NoError(t, db.Create(a).Error, "create account")
if a.RateLimitedAt != nil {
create.SetRateLimitedAt(*a.RateLimitedAt)
}
if a.RateLimitResetAt != nil {
create.SetRateLimitResetAt(*a.RateLimitResetAt)
}
if a.OverloadUntil != nil {
create.SetOverloadUntil(*a.OverloadUntil)
}
if a.SessionWindowStart != nil {
create.SetSessionWindowStart(*a.SessionWindowStart)
}
if a.SessionWindowEnd != nil {
create.SetSessionWindowEnd(*a.SessionWindowEnd)
}
if a.SessionWindowStatus != "" {
create.SetSessionWindowStatus(a.SessionWindowStatus)
}
if !a.CreatedAt.IsZero() {
create.SetCreatedAt(a.CreatedAt)
}
if !a.UpdatedAt.IsZero() {
create.SetUpdatedAt(a.UpdatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create account")
a.ID = created.ID
a.CreatedAt = created.CreatedAt
a.UpdatedAt = created.UpdatedAt
return a
}
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *apiKeyModel) *apiKeyModel {
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *service.ApiKey {
t.Helper()
ctx := context.Background()
if k.Status == "" {
k.Status = service.StatusActive
}
if k.CreatedAt.IsZero() {
k.CreatedAt = time.Now()
if k.Key == "" {
k.Key = "sk-" + time.Now().Format("150405.000000")
}
if k.UpdatedAt.IsZero() {
k.UpdatedAt = k.CreatedAt
if k.Name == "" {
k.Name = "default"
}
require.NoError(t, db.Create(k).Error, "create api key")
create := client.ApiKey.Create().
SetUserID(k.UserID).
SetKey(k.Key).
SetName(k.Name).
SetStatus(k.Status)
if k.GroupID != nil {
create.SetGroupID(*k.GroupID)
}
if !k.CreatedAt.IsZero() {
create.SetCreatedAt(k.CreatedAt)
}
if !k.UpdatedAt.IsZero() {
create.SetUpdatedAt(k.UpdatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create api key")
k.ID = created.ID
k.CreatedAt = created.CreatedAt
k.UpdatedAt = created.UpdatedAt
return k
}
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *redeemCodeModel) *redeemCodeModel {
func mustCreateRedeemCode(t *testing.T, client *dbent.Client, c *service.RedeemCode) *service.RedeemCode {
t.Helper()
ctx := context.Background()
if c.Status == "" {
c.Status = service.StatusUnused
}
if c.Type == "" {
c.Type = service.RedeemTypeBalance
}
if c.CreatedAt.IsZero() {
c.CreatedAt = time.Now()
if c.Code == "" {
c.Code = "rc-" + time.Now().Format("150405.000000")
}
require.NoError(t, db.Create(c).Error, "create redeem code")
create := client.RedeemCode.Create().
SetCode(c.Code).
SetType(c.Type).
SetValue(c.Value).
SetStatus(c.Status).
SetNotes(c.Notes).
SetValidityDays(c.ValidityDays)
if c.UsedBy != nil {
create.SetUsedBy(*c.UsedBy)
}
if c.UsedAt != nil {
create.SetUsedAt(*c.UsedAt)
}
if c.GroupID != nil {
create.SetGroupID(*c.GroupID)
}
if !c.CreatedAt.IsZero() {
create.SetCreatedAt(c.CreatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create redeem code")
c.ID = created.ID
c.CreatedAt = created.CreatedAt
return c
}
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *userSubscriptionModel) *userSubscriptionModel {
func mustCreateSubscription(t *testing.T, client *dbent.Client, s *service.UserSubscription) *service.UserSubscription {
t.Helper()
ctx := context.Background()
if s.Status == "" {
s.Status = service.SubscriptionStatusActive
}
@@ -162,16 +347,47 @@ func mustCreateSubscription(t *testing.T, db *gorm.DB, s *userSubscriptionModel)
if s.UpdatedAt.IsZero() {
s.UpdatedAt = now
}
require.NoError(t, db.Create(s).Error, "create user subscription")
create := client.UserSubscription.Create().
SetUserID(s.UserID).
SetGroupID(s.GroupID).
SetStartsAt(s.StartsAt).
SetExpiresAt(s.ExpiresAt).
SetStatus(s.Status).
SetAssignedAt(s.AssignedAt).
SetNotes(s.Notes).
SetDailyUsageUsd(s.DailyUsageUSD).
SetWeeklyUsageUsd(s.WeeklyUsageUSD).
SetMonthlyUsageUsd(s.MonthlyUsageUSD)
if s.AssignedBy != nil {
create.SetAssignedBy(*s.AssignedBy)
}
if !s.CreatedAt.IsZero() {
create.SetCreatedAt(s.CreatedAt)
}
if !s.UpdatedAt.IsZero() {
create.SetUpdatedAt(s.UpdatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create user subscription")
s.ID = created.ID
s.CreatedAt = created.CreatedAt
s.UpdatedAt = created.UpdatedAt
return s
}
func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) {
func mustBindAccountToGroup(t *testing.T, client *dbent.Client, accountID, groupID int64, priority int) {
t.Helper()
require.NoError(t, db.Create(&accountGroupModel{
AccountID: accountID,
GroupID: groupID,
Priority: priority,
CreatedAt: time.Now(),
}).Error, "create account_group")
ctx := context.Background()
_, err := client.AccountGroup.Create().
SetAccountID(accountID).
SetGroupID(groupID).
SetPriority(priority).
Save(ctx)
require.NoError(t, err, "create account_group")
}

View File

@@ -2,280 +2,370 @@ package repository
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"database/sql"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
)
type sqlExecutor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
type sqlBeginner interface {
sqlExecutor
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
type groupRepository struct {
db *gorm.DB
client *dbent.Client
sql sqlExecutor
begin sqlBeginner
}
func NewGroupRepository(db *gorm.DB) service.GroupRepository {
return &groupRepository{db: db}
func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupRepository {
return newGroupRepositoryWithSQL(client, sqlDB)
}
func (r *groupRepository) Create(ctx context.Context, group *service.Group) error {
m := groupModelFromService(group)
err := r.db.WithContext(ctx).Create(m).Error
func newGroupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *groupRepository {
var beginner sqlBeginner
if b, ok := sqlq.(sqlBeginner); ok {
beginner = b
}
return &groupRepository{client: client, sql: sqlq, begin: beginner}
}
func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) error {
builder := r.client.Group.Create().
SetName(groupIn.Name).
SetDescription(groupIn.Description).
SetPlatform(groupIn.Platform).
SetRateMultiplier(groupIn.RateMultiplier).
SetIsExclusive(groupIn.IsExclusive).
SetStatus(groupIn.Status).
SetSubscriptionType(groupIn.SubscriptionType).
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD)
created, err := builder.Save(ctx)
if err == nil {
applyGroupModelToService(group, m)
groupIn.ID = created.ID
groupIn.CreatedAt = created.CreatedAt
groupIn.UpdatedAt = created.UpdatedAt
}
return translatePersistenceError(err, nil, service.ErrGroupExists)
}
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
var m groupModel
err := r.db.WithContext(ctx).First(&m, id).Error
m, err := r.client.Group.Query().
Where(group.IDEQ(id)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
group := groupModelToService(&m)
count, _ := r.GetAccountCount(ctx, group.ID)
group.AccountCount = count
return group, nil
out := groupEntityToService(m)
count, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count
return out, nil
}
func (r *groupRepository) Update(ctx context.Context, group *service.Group) error {
m := groupModelFromService(group)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyGroupModelToService(group, m)
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
updated, err := r.client.Group.UpdateOneID(groupIn.ID).
SetName(groupIn.Name).
SetDescription(groupIn.Description).
SetPlatform(groupIn.Platform).
SetRateMultiplier(groupIn.RateMultiplier).
SetIsExclusive(groupIn.IsExclusive).
SetStatus(groupIn.Status).
SetSubscriptionType(groupIn.SubscriptionType).
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
}
return err
groupIn.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&groupModel{}, id).Error
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
return err
}
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil)
}
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
var groups []groupModel
var total int64
q := r.client.Group.Query()
db := r.db.WithContext(ctx).Model(&groupModel{})
// Apply filters
if platform != "" {
db = db.Where("platform = ?", platform)
q = q.Where(group.PlatformEQ(platform))
}
if status != "" {
db = db.Where("status = ?", status)
q = q.Where(group.StatusEQ(status))
}
if isExclusive != nil {
db = db.Where("is_exclusive = ?", *isExclusive)
q = q.Where(group.IsExclusiveEQ(*isExclusive))
}
if err := db.Count(&total).Error; err != nil {
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id ASC").Find(&groups).Error; err != nil {
groups, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Asc(group.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
groupIDs := make([]int64, 0, len(groups))
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *groupModelToService(&groups[i]))
g := groupEntityToService(groups[i])
outGroups = append(outGroups, *g)
groupIDs = append(groupIDs, g.ID)
}
// 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil {
for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID]
}
}
return outGroups, paginationResultFromTotal(total, params), nil
return outGroups, paginationResultFromTotal(int64(total), params), nil
}
func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
var groups []groupModel
err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Order("id ASC").Find(&groups).Error
groups, err := r.client.Group.Query().
Where(group.StatusEQ(service.StatusActive)).
Order(dbent.Asc(group.FieldID)).
All(ctx)
if err != nil {
return nil, err
}
groupIDs := make([]int64, 0, len(groups))
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *groupModelToService(&groups[i]))
g := groupEntityToService(groups[i])
outGroups = append(outGroups, *g)
groupIDs = append(groupIDs, g.ID)
}
// 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil {
for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID]
}
}
return outGroups, nil
}
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
var groups []groupModel
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", service.StatusActive, platform).Order("id ASC").Find(&groups).Error
groups, err := r.client.Group.Query().
Where(group.StatusEQ(service.StatusActive), group.PlatformEQ(platform)).
Order(dbent.Asc(group.FieldID)).
All(ctx)
if err != nil {
return nil, err
}
groupIDs := make([]int64, 0, len(groups))
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *groupModelToService(&groups[i]))
g := groupEntityToService(groups[i])
outGroups = append(outGroups, *g)
groupIDs = append(groupIDs, g.ID)
}
// 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil {
for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID]
}
}
return outGroups, nil
}
func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&groupModel{}).Where("name = ?", name).Count(&count).Error
return count > 0, err
return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx)
}
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Table("account_groups").Where("group_id = ?", groupID).Count(&count).Error
return count, err
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", groupID).Scan(&count); err != nil {
return 0, err
}
return count, nil
}
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Exec("DELETE FROM account_groups WHERE group_id = ?", groupID)
return result.RowsAffected, result.Error
res, err := r.sql.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", groupID)
if err != nil {
return 0, err
}
affected, _ := res.RowsAffected()
return affected, nil
}
func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
group, err := r.GetByID(ctx, id)
g, err := r.client.Group.Query().Where(group.IDEQ(id)).Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
groupSvc := groupEntityToService(g)
exec := r.sql
txClient := r.client
var sqlTx *sql.Tx
var txClientClose func() error
if r.begin != nil {
sqlTx, err = r.begin.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
exec = sqlTx
txClient = entClientFromSQLTx(sqlTx)
txClientClose = txClient.Close
defer func() { _ = sqlTx.Rollback() }()
}
if txClientClose != nil {
defer func() { _ = txClientClose() }()
}
// Lock the group row to avoid concurrent writes while we cascade.
var lockedID int64
if err := exec.QueryRowContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id).Scan(&lockedID); err != nil {
if errorsIsNoRows(err) {
return nil, service.ErrGroupNotFound
}
return nil, err
}
var affectedUserIDs []int64
if group.IsSubscriptionType() {
if err := r.db.WithContext(ctx).
Table("user_subscriptions").
Where("group_id = ?", id).
Pluck("user_id", &affectedUserIDs).Error; err != nil {
if groupSvc.IsSubscriptionType() {
rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1", id)
if err != nil {
return nil, err
}
for rows.Next() {
var userID int64
if scanErr := rows.Scan(&userID); scanErr != nil {
_ = rows.Close()
return nil, scanErr
}
affectedUserIDs = append(affectedUserIDs, userID)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
if _, err := exec.ExecContext(ctx, "DELETE FROM user_subscriptions WHERE group_id = $1", id); err != nil {
return nil, err
}
}
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 删除订阅类型分组的订阅记录
if group.IsSubscriptionType() {
if err := tx.Exec("DELETE FROM user_subscriptions WHERE group_id = ?", id).Error; err != nil {
return err
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
if err := tx.Exec("UPDATE api_keys SET group_id = NULL WHERE group_id = ?", id).Error; err != nil {
return err
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Exec(
"UPDATE users SET allowed_groups = array_remove(allowed_groups, ?) WHERE ? = ANY(allowed_groups)",
id, id,
).Error; err != nil {
return err
}
// 4. 删除 account_groups 中间表的数据
if err := tx.Exec("DELETE FROM account_groups WHERE group_id = ?", id).Error; err != nil {
return err
}
// 5. 删除分组本身(带锁,避免并发写)
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&groupModel{}, id).Error; err != nil {
return err
}
return nil
})
if err != nil {
// 2. Clear group_id for api keys bound to this group.
if _, err := txClient.ApiKey.Update().
Where(apikey.GroupIDEQ(id)).
ClearGroupID().
Save(ctx); err != nil {
return nil, err
}
// 3. Remove the group id from users.allowed_groups array (legacy representation).
// Phase 1 compatibility: also delete from user_allowed_groups join table when present.
if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil {
return nil, err
}
if _, err := exec.ExecContext(
ctx,
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1) WHERE $1 = ANY(allowed_groups)",
id,
); err != nil {
return nil, err
}
// 4. Delete account_groups join rows.
if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil {
return nil, err
}
// 5. Soft-delete group itself.
if _, err := txClient.Group.Delete().Where(group.IDEQ(id)).Exec(ctx); err != nil {
return nil, err
}
if sqlTx != nil {
if err := sqlTx.Commit(); err != nil {
return nil, err
}
}
return affectedUserIDs, nil
}
type groupModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"uniqueIndex;size:100;not null"`
Description string `gorm:"type:text"`
Platform string `gorm:"size:50;default:anthropic;not null"`
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null"`
IsExclusive bool `gorm:"default:false;not null"`
Status string `gorm:"size:20;default:active;not null"`
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (map[int64]int64, error) {
counts := make(map[int64]int64, len(groupIDs))
if len(groupIDs) == 0 {
return counts, nil
}
SubscriptionType string `gorm:"size:20;default:standard;not null"`
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
rows, err := r.sql.QueryContext(
ctx,
"SELECT group_id, COUNT(*) FROM account_groups WHERE group_id = ANY($1) GROUP BY group_id",
pq.Array(groupIDs),
)
if err != nil {
return nil, err
}
defer rows.Close()
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
for rows.Next() {
var groupID int64
var count int64
if err := rows.Scan(&groupID, &count); err != nil {
return nil, err
}
counts[groupID] = count
}
if err := rows.Err(); err != nil {
return nil, err
}
return counts, nil
}
func (groupModel) TableName() string { return "groups" }
func groupModelToService(m *groupModel) *service.Group {
if m == nil {
return nil
}
return &service.Group{
ID: m.ID,
Name: m.Name,
Description: m.Description,
Platform: m.Platform,
RateMultiplier: m.RateMultiplier,
IsExclusive: m.IsExclusive,
Status: m.Status,
SubscriptionType: m.SubscriptionType,
DailyLimitUSD: m.DailyLimitUSD,
WeeklyLimitUSD: m.WeeklyLimitUSD,
MonthlyLimitUSD: m.MonthlyLimitUSD,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
func entClientFromSQLTx(tx *sql.Tx) *dbent.Client {
drv := entsql.NewDriver(dialect.Postgres, entsql.Conn{ExecQuerier: tx})
return dbent.NewClient(dbent.Driver(drv))
}
func groupModelFromService(sg *service.Group) *groupModel {
if sg == nil {
return nil
}
return &groupModel{
ID: sg.ID,
Name: sg.Name,
Description: sg.Description,
Platform: sg.Platform,
RateMultiplier: sg.RateMultiplier,
IsExclusive: sg.IsExclusive,
Status: sg.Status,
SubscriptionType: sg.SubscriptionType,
DailyLimitUSD: sg.DailyLimitUSD,
WeeklyLimitUSD: sg.WeeklyLimitUSD,
MonthlyLimitUSD: sg.MonthlyLimitUSD,
CreatedAt: sg.CreatedAt,
UpdatedAt: sg.UpdatedAt,
}
}
func applyGroupModelToService(group *service.Group, m *groupModel) {
if group == nil || m == nil {
return
}
group.ID = m.ID
group.CreatedAt = m.CreatedAt
group.UpdatedAt = m.UpdatedAt
func errorsIsNoRows(err error) bool {
return err == sql.ErrNoRows
}

View File

@@ -4,25 +4,26 @@ package repository
import (
"context"
"database/sql"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type GroupRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
tx *sql.Tx
repo *groupRepository
}
func (s *GroupRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewGroupRepository(s.db).(*groupRepository)
entClient, tx := testEntSQLTx(s.T())
s.tx = tx
s.repo = newGroupRepositoryWithSQL(entClient, tx)
}
func TestGroupRepoSuite(t *testing.T) {
@@ -33,9 +34,12 @@ func TestGroupRepoSuite(t *testing.T) {
func (s *GroupRepoSuite) TestCreate() {
group := &service.Group{
Name: "test-create",
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
Name: "test-create",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
err := s.repo.Create(s.ctx, group)
@@ -50,10 +54,19 @@ func (s *GroupRepoSuite) TestCreate() {
func (s *GroupRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
func (s *GroupRepoSuite) TestUpdate() {
group := groupModelToService(mustCreateGroup(s.T(), s.db, &groupModel{Name: "original"}))
group := &service.Group{
Name: "original",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
group.Name = "updated"
err := s.repo.Update(s.ctx, group)
@@ -65,20 +78,43 @@ func (s *GroupRepoSuite) TestUpdate() {
}
func (s *GroupRepoSuite) TestDelete() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "to-delete"})
group := &service.Group{
Name: "to-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err, "expected error after delete")
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
// --- List / ListWithFilters ---
func (s *GroupRepoSuite) TestList() {
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g2",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
@@ -87,8 +123,22 @@ func (s *GroupRepoSuite) TestList() {
}
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI})
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g2",
Platform: service.PlatformOpenAI,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
s.Require().NoError(err)
@@ -97,8 +147,22 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
}
func (s *GroupRepoSuite) TestListWithFilters_Status() {
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Status: service.StatusDisabled})
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g2",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusDisabled,
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil)
s.Require().NoError(err)
@@ -107,8 +171,22 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() {
}
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", IsExclusive: false})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", IsExclusive: true})
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g2",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: true,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
isExclusive := true
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
@@ -118,21 +196,35 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
}
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := mustCreateGroup(s.T(), s.db, &groupModel{
Name: "g1",
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{
Name: "g2",
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
IsExclusive: true,
})
g1 := &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
g2 := &service.Group{
Name: "g2",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: true,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g1))
s.Require().NoError(s.repo.Create(s.ctx, g2))
a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1)
var accountID int64
s.Require().NoError(s.tx.QueryRowContext(
s.ctx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
"acc1", service.PlatformAnthropic, service.AccountTypeOAuth,
).Scan(&accountID))
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g1.ID, 1)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g2.ID, 1)
s.Require().NoError(err)
isExclusive := true
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive)
@@ -146,8 +238,22 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
// --- ListActive / ListActiveByPlatform ---
func (s *GroupRepoSuite) TestListActive() {
mustCreateGroup(s.T(), s.db, &groupModel{Name: "active1", Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "inactive1", Status: service.StatusDisabled})
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "active1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "inactive1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusDisabled,
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
@@ -156,9 +262,30 @@ func (s *GroupRepoSuite) TestListActive() {
}
func (s *GroupRepoSuite) TestListActiveByPlatform() {
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g3", Platform: service.PlatformAnthropic, Status: service.StatusDisabled})
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g2",
Platform: service.PlatformOpenAI,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g3",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusDisabled,
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListActiveByPlatform")
@@ -169,7 +296,14 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() {
// --- ExistsByName ---
func (s *GroupRepoSuite) TestExistsByName() {
mustCreateGroup(s.T(), s.db, &groupModel{Name: "existing-group"})
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "existing-group",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
s.Require().NoError(err, "ExistsByName")
@@ -183,11 +317,33 @@ func (s *GroupRepoSuite) TestExistsByName() {
// --- GetAccountCount ---
func (s *GroupRepoSuite) TestGetAccountCount() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
group := &service.Group{
Name: "g-count",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
var a1 int64
s.Require().NoError(s.tx.QueryRowContext(
s.ctx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
"a1", service.PlatformAnthropic, service.AccountTypeOAuth,
).Scan(&a1))
var a2 int64
s.Require().NoError(s.tx.QueryRowContext(
s.ctx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
"a2", service.PlatformAnthropic, service.AccountTypeOAuth,
).Scan(&a2))
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, group.ID, 1)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2)
s.Require().NoError(err)
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err, "GetAccountCount")
@@ -195,7 +351,15 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
}
func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
group := &service.Group{
Name: "g-empty",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err)
@@ -205,9 +369,23 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
// --- DeleteAccountGroupsByGroupID ---
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1)
g := &service.Group{
Name: "g-del",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g))
var accountID int64
s.Require().NoError(s.tx.QueryRowContext(
s.ctx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
"acc-del", service.PlatformAnthropic, service.AccountTypeOAuth,
).Scan(&accountID))
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g.ID, 1)
s.Require().NoError(err)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
@@ -219,13 +397,34 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
}
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-multi"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
a3 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2)
mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3)
g := &service.Group{
Name: "g-multi",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g))
insertAccount := func(name string) int64 {
var id int64
s.Require().NoError(s.tx.QueryRowContext(
s.ctx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
name, service.PlatformAnthropic, service.AccountTypeOAuth,
).Scan(&id))
return id
}
a1 := insertAccount("a1")
a2 := insertAccount("a2")
a3 := insertAccount("a3")
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, g.ID, 1)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, g.ID, 2)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a3, g.ID, 3)
s.Require().NoError(err)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
s.Require().NoError(err)

View File

@@ -15,16 +15,19 @@ import (
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "github.com/lib/pq"
redisclient "github.com/redis/go-redis/v9"
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
gormpostgres "gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
const (
@@ -33,7 +36,7 @@ const (
)
var (
integrationDB *gorm.DB
integrationDB *sql.DB
integrationRedis *redisclient.Client
redisNamespaceSeq uint64
@@ -88,13 +91,13 @@ func TestMain(m *testing.M) {
os.Exit(1)
}
integrationDB, err = openGormWithRetry(ctx, dsn, 30*time.Second)
integrationDB, err = openSQLWithRetry(ctx, dsn, 30*time.Second)
if err != nil {
log.Printf("failed to open gorm db: %v", err)
log.Printf("failed to open sql db: %v", err)
os.Exit(1)
}
if err := AutoMigrate(integrationDB); err != nil {
log.Printf("failed to automigrate db: %v", err)
if err := infrastructure.ApplyMigrations(ctx, integrationDB); err != nil {
log.Printf("failed to apply db migrations: %v", err)
os.Exit(1)
}
@@ -121,6 +124,7 @@ func TestMain(m *testing.M) {
code := m.Run()
_ = integrationRedis.Close()
_ = integrationDB.Close()
os.Exit(code)
}
@@ -147,29 +151,21 @@ func dockerImageExists(ctx context.Context, image string) bool {
return cmd.Run() == nil
}
func openGormWithRetry(ctx context.Context, dsn string, timeout time.Duration) (*gorm.DB, error) {
func openSQLWithRetry(ctx context.Context, dsn string, timeout time.Duration) (*sql.DB, error) {
deadline := time.Now().Add(timeout)
var lastErr error
for time.Now().Before(deadline) {
db, err := gorm.Open(gormpostgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
db, err := sql.Open("postgres", dsn)
if err != nil {
lastErr = err
time.Sleep(250 * time.Millisecond)
continue
}
sqlDB, err := db.DB()
if err != nil {
lastErr = err
time.Sleep(250 * time.Millisecond)
continue
}
if err := pingWithTimeout(ctx, sqlDB, 2*time.Second); err != nil {
if err := pingWithTimeout(ctx, db, 2*time.Second); err != nil {
lastErr = err
_ = db.Close()
time.Sleep(250 * time.Millisecond)
continue
}
@@ -186,17 +182,31 @@ func pingWithTimeout(ctx context.Context, db *sql.DB, timeout time.Duration) err
return db.PingContext(pingCtx)
}
func testTx(t *testing.T) *gorm.DB {
func testTx(t *testing.T) *sql.Tx {
t.Helper()
tx := integrationDB.Begin()
require.NoError(t, tx.Error, "begin tx")
tx, err := integrationDB.BeginTx(context.Background(), nil)
require.NoError(t, err, "begin tx")
t.Cleanup(func() {
_ = tx.Rollback().Error
_ = tx.Rollback()
})
return tx
}
func testEntSQLTx(t *testing.T) (*dbent.Client, *sql.Tx) {
t.Helper()
tx := testTx(t)
drv := entsql.NewDriver(dialect.Postgres, entsql.Conn{ExecQuerier: tx})
client := dbent.NewClient(dbent.Driver(drv))
t.Cleanup(func() {
_ = client.Close()
})
return client, tx
}
func testRedis(t *testing.T) *redisclient.Client {
t.Helper()
@@ -347,18 +357,19 @@ func (s *IntegrationRedisSuite) AssertTTLWithin(ttl, min, max time.Duration) {
assertTTLWithin(s.T(), ttl, min, max)
}
// IntegrationDBSuite provides a base suite for DB (Gorm) integration tests.
// Embedding suites should call SetupTest to initialize ctx and db.
// IntegrationDBSuite provides a base suite for DB integration tests.
// Embedding suites should call SetupTest to initialize ctx and client.
type IntegrationDBSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
ctx context.Context
client *dbent.Client
tx *sql.Tx
}
// SetupTest initializes ctx and db for each test method.
// SetupTest initializes ctx and client for each test method.
func (s *IntegrationDBSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.client, s.tx = testEntSQLTx(s.T())
}
// RequireNoError is a convenience method wrapping require.NoError with s.T().

View File

@@ -0,0 +1,90 @@
//go:build integration
package repository
import (
"context"
"database/sql"
"testing"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/stretchr/testify/require"
)
func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
tx := testTx(t)
// Re-apply migrations to verify idempotency (no errors, no duplicate rows).
require.NoError(t, infrastructure.ApplyMigrations(context.Background(), integrationDB))
// schema_migrations should have at least the current migration set.
var applied int
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM schema_migrations").Scan(&applied))
require.GreaterOrEqual(t, applied, 7, "expected schema_migrations to contain applied migrations")
// users: columns required by repository queries
requireColumn(t, tx, "users", "username", "character varying", 100, false)
requireColumn(t, tx, "users", "wechat", "character varying", 100, false)
requireColumn(t, tx, "users", "notes", "text", 0, false)
// accounts: schedulable and rate-limit fields
requireColumn(t, tx, "accounts", "schedulable", "boolean", 0, false)
requireColumn(t, tx, "accounts", "rate_limited_at", "timestamp with time zone", 0, true)
requireColumn(t, tx, "accounts", "rate_limit_reset_at", "timestamp with time zone", 0, true)
requireColumn(t, tx, "accounts", "overload_until", "timestamp with time zone", 0, true)
requireColumn(t, tx, "accounts", "session_window_status", "character varying", 20, true)
// api_keys: key length should be 128
requireColumn(t, tx, "api_keys", "key", "character varying", 128, false)
// redeem_codes: subscription fields
requireColumn(t, tx, "redeem_codes", "group_id", "bigint", 0, true)
requireColumn(t, tx, "redeem_codes", "validity_days", "integer", 0, false)
// usage_logs: billing_type used by filters/stats
requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false)
// settings table should exist
var settingsRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
require.True(t, settingsRegclass.Valid, "expected settings table to exist")
// user_allowed_groups table should exist
var uagRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass))
require.True(t, uagRegclass.Valid, "expected user_allowed_groups table to exist")
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
t.Helper()
var row struct {
DataType string
MaxLen sql.NullInt64
Nullable string
}
err := tx.QueryRowContext(context.Background(), `
SELECT
data_type,
character_maximum_length,
is_nullable
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = $1
AND column_name = $2
`, table, column).Scan(&row.DataType, &row.MaxLen, &row.Nullable)
require.NoError(t, err, "query information_schema.columns for %s.%s", table, column)
require.Equal(t, dataType, row.DataType, "data_type mismatch for %s.%s", table, column)
if maxLen > 0 {
require.True(t, row.MaxLen.Valid, "expected maxLen for %s.%s", table, column)
require.Equal(t, int64(maxLen), row.MaxLen.Int64, "maxLen mismatch for %s.%s", table, column)
}
if nullable {
require.Equal(t, "YES", row.Nullable, "nullable mismatch for %s.%s", table, column)
} else {
require.Equal(t, "NO", row.Nullable, "nullable mismatch for %s.%s", table, column)
}
}

View File

@@ -2,52 +2,97 @@ package repository
import (
"context"
"time"
"database/sql"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
type sqlQuerier interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
type proxyRepository struct {
db *gorm.DB
client *dbent.Client
sql sqlQuerier
}
func NewProxyRepository(db *gorm.DB) service.ProxyRepository {
return &proxyRepository{db: db}
func NewProxyRepository(client *dbent.Client, sqlDB *sql.DB) service.ProxyRepository {
return newProxyRepositoryWithSQL(client, sqlDB)
}
func (r *proxyRepository) Create(ctx context.Context, proxy *service.Proxy) error {
m := proxyModelFromService(proxy)
err := r.db.WithContext(ctx).Create(m).Error
func newProxyRepositoryWithSQL(client *dbent.Client, sqlq sqlQuerier) *proxyRepository {
return &proxyRepository{client: client, sql: sqlq}
}
func (r *proxyRepository) Create(ctx context.Context, proxyIn *service.Proxy) error {
builder := r.client.Proxy.Create().
SetName(proxyIn.Name).
SetProtocol(proxyIn.Protocol).
SetHost(proxyIn.Host).
SetPort(proxyIn.Port).
SetStatus(proxyIn.Status)
if proxyIn.Username != "" {
builder.SetUsername(proxyIn.Username)
}
if proxyIn.Password != "" {
builder.SetPassword(proxyIn.Password)
}
created, err := builder.Save(ctx)
if err == nil {
applyProxyModelToService(proxy, m)
applyProxyEntityToService(proxyIn, created)
}
return err
}
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy, error) {
var m proxyModel
err := r.db.WithContext(ctx).First(&m, id).Error
m, err := r.client.Proxy.Get(ctx, id)
if err != nil {
return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil)
if dbent.IsNotFound(err) {
return nil, service.ErrProxyNotFound
}
return nil, err
}
return proxyModelToService(&m), nil
return proxyEntityToService(m), nil
}
func (r *proxyRepository) Update(ctx context.Context, proxy *service.Proxy) error {
m := proxyModelFromService(proxy)
err := r.db.WithContext(ctx).Save(m).Error
func (r *proxyRepository) Update(ctx context.Context, proxyIn *service.Proxy) error {
builder := r.client.Proxy.UpdateOneID(proxyIn.ID).
SetName(proxyIn.Name).
SetProtocol(proxyIn.Protocol).
SetHost(proxyIn.Host).
SetPort(proxyIn.Port).
SetStatus(proxyIn.Status)
if proxyIn.Username != "" {
builder.SetUsername(proxyIn.Username)
} else {
builder.ClearUsername()
}
if proxyIn.Password != "" {
builder.SetPassword(proxyIn.Password)
} else {
builder.ClearPassword()
}
updated, err := builder.Save(ctx)
if err == nil {
applyProxyModelToService(proxy, m)
applyProxyEntityToService(proxyIn, updated)
return nil
}
if dbent.IsNotFound(err) {
return service.ErrProxyNotFound
}
return err
}
func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&proxyModel{}, id).Error
_, err := r.client.Proxy.Delete().Where(proxy.IDEQ(id)).Exec(ctx)
return err
}
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) {
@@ -56,104 +101,111 @@ func (r *proxyRepository) List(ctx context.Context, params pagination.Pagination
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) {
var proxies []proxyModel
var total int64
db := r.db.WithContext(ctx).Model(&proxyModel{})
// Apply filters
q := r.client.Proxy.Query()
if protocol != "" {
db = db.Where("protocol = ?", protocol)
q = q.Where(proxy.ProtocolEQ(protocol))
}
if status != "" {
db = db.Where("status = ?", status)
q = q.Where(proxy.StatusEQ(status))
}
if search != "" {
searchPattern := "%" + search + "%"
db = db.Where("name ILIKE ?", searchPattern)
q = q.Where(proxy.NameContainsFold(search))
}
if err := db.Count(&total).Error; err != nil {
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&proxies).Error; err != nil {
proxies, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(proxy.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outProxies := make([]service.Proxy, 0, len(proxies))
for i := range proxies {
outProxies = append(outProxies, *proxyModelToService(&proxies[i]))
outProxies = append(outProxies, *proxyEntityToService(proxies[i]))
}
return outProxies, paginationResultFromTotal(total, params), nil
return outProxies, paginationResultFromTotal(int64(total), params), nil
}
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
var proxies []proxyModel
err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Find(&proxies).Error
proxies, err := r.client.Proxy.Query().
Where(proxy.StatusEQ(service.StatusActive)).
All(ctx)
if err != nil {
return nil, err
}
outProxies := make([]service.Proxy, 0, len(proxies))
for i := range proxies {
outProxies = append(outProxies, *proxyModelToService(&proxies[i]))
outProxies = append(outProxies, *proxyEntityToService(proxies[i]))
}
return outProxies, nil
}
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&proxyModel{}).
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
Count(&count).Error
if err != nil {
return false, err
q := r.client.Proxy.Query().
Where(proxy.HostEQ(host), proxy.PortEQ(port))
if username == "" {
q = q.Where(proxy.Or(proxy.UsernameIsNil(), proxy.UsernameEQ("")))
} else {
q = q.Where(proxy.UsernameEQ(username))
}
return count > 0, nil
if password == "" {
q = q.Where(proxy.Or(proxy.PasswordIsNil(), proxy.PasswordEQ("")))
} else {
q = q.Where(proxy.PasswordEQ(password))
}
count, err := q.Count(ctx)
return count > 0, err
}
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
row := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", proxyID)
var count int64
err := r.db.WithContext(ctx).Table("accounts").
Where("proxy_id = ?", proxyID).
Count(&count).Error
return count, err
if err := row.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
type result struct {
ProxyID int64 `gorm:"column:proxy_id"`
Count int64 `gorm:"column:count"`
}
var results []result
err := r.db.WithContext(ctx).
Table("accounts").
Select("proxy_id, COUNT(*) as count").
Where("proxy_id IS NOT NULL").
Group("proxy_id").
Scan(&results).Error
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL GROUP BY proxy_id")
if err != nil {
return nil, err
}
defer rows.Close()
counts := make(map[int64]int64)
for _, r := range results {
counts[r.ProxyID] = r.Count
for rows.Next() {
var proxyID, count int64
if err := rows.Scan(&proxyID, &count); err != nil {
return nil, err
}
counts[proxyID] = count
}
if err := rows.Err(); err != nil {
return nil, err
}
return counts, nil
}
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
var proxies []proxyModel
err := r.db.WithContext(ctx).
Where("status = ?", service.StatusActive).
Order("created_at DESC").
Find(&proxies).Error
proxies, err := r.client.Proxy.Query().
Where(proxy.StatusEQ(service.StatusActive)).
Order(dbent.Desc(proxy.FieldCreatedAt)).
All(ctx)
if err != nil {
return nil, err
}
@@ -167,76 +219,47 @@ func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]ser
// Build result with account counts
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
proxy := proxyModelToService(&proxies[i])
if proxy == nil {
proxyOut := proxyEntityToService(proxies[i])
if proxyOut == nil {
continue
}
result = append(result, service.ProxyWithAccountCount{
Proxy: *proxy,
AccountCount: counts[proxy.ID],
Proxy: *proxyOut,
AccountCount: counts[proxyOut.ID],
})
}
return result, nil
}
type proxyModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"size:100;not null"`
Protocol string `gorm:"size:20;not null"`
Host string `gorm:"size:255;not null"`
Port int `gorm:"not null"`
Username string `gorm:"size:100"`
Password string `gorm:"size:100"`
Status string `gorm:"size:20;default:active;not null"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (proxyModel) TableName() string { return "proxies" }
func proxyModelToService(m *proxyModel) *service.Proxy {
func proxyEntityToService(m *dbent.Proxy) *service.Proxy {
if m == nil {
return nil
}
return &service.Proxy{
out := &service.Proxy{
ID: m.ID,
Name: m.Name,
Protocol: m.Protocol,
Host: m.Host,
Port: m.Port,
Username: m.Username,
Password: m.Password,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
if m.Username != nil {
out.Username = *m.Username
}
if m.Password != nil {
out.Password = *m.Password
}
return out
}
func proxyModelFromService(p *service.Proxy) *proxyModel {
if p == nil {
return nil
}
return &proxyModel{
ID: p.ID,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
func applyProxyModelToService(proxy *service.Proxy, m *proxyModel) {
if proxy == nil || m == nil {
func applyProxyEntityToService(dst *service.Proxy, src *dbent.Proxy) {
if dst == nil || src == nil {
return
}
proxy.ID = m.ID
proxy.CreatedAt = m.CreatedAt
proxy.UpdatedAt = m.UpdatedAt
dst.ID = src.ID
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}

View File

@@ -4,26 +4,27 @@ package repository
import (
"context"
"database/sql"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type ProxyRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *proxyRepository
ctx context.Context
sqlTx *sql.Tx
repo *proxyRepository
}
func (s *ProxyRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewProxyRepository(s.db).(*proxyRepository)
entClient, sqlTx := testEntSQLTx(s.T())
s.sqlTx = sqlTx
s.repo = newProxyRepositoryWithSQL(entClient, sqlTx)
}
func TestProxyRepoSuite(t *testing.T) {
@@ -56,7 +57,14 @@ func (s *ProxyRepoSuite) TestGetByID_NotFound() {
}
func (s *ProxyRepoSuite) TestUpdate() {
proxy := proxyModelToService(mustCreateProxy(s.T(), s.db, &proxyModel{Name: "original"}))
proxy := &service.Proxy{
Name: "original",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, proxy))
proxy.Name = "updated"
err := s.repo.Update(s.ctx, proxy)
@@ -68,7 +76,14 @@ func (s *ProxyRepoSuite) TestUpdate() {
}
func (s *ProxyRepoSuite) TestDelete() {
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "to-delete"})
proxy := &service.Proxy{
Name: "to-delete",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, proxy))
err := s.repo.Delete(s.ctx, proxy.ID)
s.Require().NoError(err, "Delete")
@@ -80,8 +95,8 @@ func (s *ProxyRepoSuite) TestDelete() {
// --- List / ListWithFilters ---
func (s *ProxyRepoSuite) TestList() {
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"})
s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
@@ -90,8 +105,8 @@ func (s *ProxyRepoSuite) TestList() {
}
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Protocol: "http"})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Protocol: "socks5"})
s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "socks5", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
s.Require().NoError(err)
@@ -100,8 +115,8 @@ func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
}
func (s *ProxyRepoSuite) TestListWithFilters_Status() {
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Status: service.StatusActive})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Status: service.StatusDisabled})
s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusDisabled})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "")
s.Require().NoError(err)
@@ -110,8 +125,8 @@ func (s *ProxyRepoSuite) TestListWithFilters_Status() {
}
func (s *ProxyRepoSuite) TestListWithFilters_Search() {
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "production-proxy"})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "dev-proxy"})
s.mustCreateProxy(&service.Proxy{Name: "production-proxy", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
s.mustCreateProxy(&service.Proxy{Name: "dev-proxy", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
s.Require().NoError(err)
@@ -122,8 +137,8 @@ func (s *ProxyRepoSuite) TestListWithFilters_Search() {
// --- ListActive ---
func (s *ProxyRepoSuite) TestListActive() {
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "active1", Status: service.StatusActive})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "inactive1", Status: service.StatusDisabled})
s.mustCreateProxy(&service.Proxy{Name: "active1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
s.mustCreateProxy(&service.Proxy{Name: "inactive1", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusDisabled})
proxies, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
@@ -134,13 +149,14 @@ func (s *ProxyRepoSuite) TestListActive() {
// --- ExistsByHostPortAuth ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
mustCreateProxy(s.T(), s.db, &proxyModel{
s.mustCreateProxy(&service.Proxy{
Name: "p1",
Protocol: "http",
Host: "1.2.3.4",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "user", "pass")
@@ -153,13 +169,14 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
}
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
mustCreateProxy(s.T(), s.db, &proxyModel{
s.mustCreateProxy(&service.Proxy{
Name: "p-noauth",
Protocol: "http",
Host: "5.6.7.8",
Port: 8081,
Username: "",
Password: "",
Status: service.StatusActive,
})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "5.6.7.8", 8081, "", "")
@@ -170,10 +187,10 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
// --- CountAccountsByProxyID ---
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-count"})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"}) // no proxy
proxy := s.mustCreateProxy(&service.Proxy{Name: "p-count", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
s.mustInsertAccount("a1", &proxy.ID)
s.mustInsertAccount("a2", &proxy.ID)
s.mustInsertAccount("a3", nil) // no proxy
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err, "CountAccountsByProxyID")
@@ -181,7 +198,7 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
}
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-zero"})
proxy := s.mustCreateProxy(&service.Proxy{Name: "p-zero", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err)
@@ -191,12 +208,12 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
// --- GetAccountCountsForProxies ---
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
p1 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
p2 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"})
p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
s.mustInsertAccount("a1", &p1.ID)
s.mustInsertAccount("a2", &p1.ID)
s.mustInsertAccount("a3", &p2.ID)
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err, "GetAccountCountsForProxies")
@@ -215,24 +232,13 @@ func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
p1 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1",
Status: service.StatusActive,
CreatedAt: base.Add(-1 * time.Hour),
})
p2 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p2",
Status: service.StatusActive,
CreatedAt: base,
})
mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p3-inactive",
Status: service.StatusDisabled,
})
p1 := s.mustCreateProxyWithTimes("p1", service.StatusActive, base.Add(-1*time.Hour))
p2 := s.mustCreateProxyWithTimes("p2", service.StatusActive, base)
s.mustCreateProxyWithTimes("p3-inactive", service.StatusDisabled, base.Add(1*time.Hour))
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
s.mustInsertAccount("a1", &p1.ID)
s.mustInsertAccount("a2", &p1.ID)
s.mustInsertAccount("a3", &p2.ID)
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
s.Require().NoError(err, "ListActiveWithAccountCount")
@@ -248,34 +254,16 @@ func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
// --- Combined original test ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
p1 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1",
Protocol: "http",
Host: "1.2.3.4",
Port: 8080,
Username: "u",
Password: "p",
CreatedAt: time.Now().Add(-1 * time.Hour),
UpdatedAt: time.Now().Add(-1 * time.Hour),
})
p2 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p2",
Protocol: "http",
Host: "5.6.7.8",
Port: 8081,
Username: "",
Password: "",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
})
p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "1.2.3.4", Port: 8080, Username: "u", Password: "p", Status: service.StatusActive})
p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "5.6.7.8", Port: 8081, Username: "", Password: "", Status: service.StatusActive})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "u", "p")
s.Require().NoError(err, "ExistsByHostPortAuth")
s.Require().True(exists, "expected proxy to exist")
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
s.mustInsertAccount("a1", &p1.ID)
s.mustInsertAccount("a2", &p1.ID)
s.mustInsertAccount("a3", &p2.ID)
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
s.Require().NoError(err, "CountAccountsByProxyID")
@@ -300,3 +288,42 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
}
}
}
func (s *ProxyRepoSuite) mustCreateProxy(p *service.Proxy) *service.Proxy {
s.T().Helper()
s.Require().NoError(s.repo.Create(s.ctx, p), "create proxy")
return p
}
func (s *ProxyRepoSuite) mustCreateProxyWithTimes(name, status string, createdAt time.Time) *service.Proxy {
s.T().Helper()
// Use the repository create for standard fields, then update timestamps via raw SQL to keep deterministic ordering.
p := s.mustCreateProxy(&service.Proxy{
Name: name,
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: status,
})
_, err := s.sqlTx.ExecContext(s.ctx, "UPDATE proxies SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, p.ID)
s.Require().NoError(err, "update proxy timestamps")
return p
}
func (s *ProxyRepoSuite) mustInsertAccount(name string, proxyID *int64) {
s.T().Helper()
var pid any
if proxyID != nil {
pid = *proxyID
}
_, err := s.sqlTx.ExecContext(
s.ctx,
"INSERT INTO accounts (name, platform, type, proxy_id) VALUES ($1, $2, $3, $4)",
name,
service.PlatformAnthropic,
service.AccountTypeOAuth,
pid,
)
s.Require().NoError(err, "insert account")
}

View File

@@ -4,25 +4,35 @@ import (
"context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm"
)
type redeemCodeRepository struct {
db *gorm.DB
client *dbent.Client
}
func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository {
return &redeemCodeRepository{db: db}
func NewRedeemCodeRepository(client *dbent.Client) service.RedeemCodeRepository {
return &redeemCodeRepository{client: client}
}
func (r *redeemCodeRepository) Create(ctx context.Context, code *service.RedeemCode) error {
m := redeemCodeModelFromService(code)
err := r.db.WithContext(ctx).Create(m).Error
created, err := r.client.RedeemCode.Create().
SetCode(code.Code).
SetType(code.Type).
SetValue(code.Value).
SetStatus(code.Status).
SetNotes(code.Notes).
SetValidityDays(code.ValidityDays).
SetNillableUsedBy(code.UsedBy).
SetNillableUsedAt(code.UsedAt).
SetNillableGroupID(code.GroupID).
Save(ctx)
if err == nil {
applyRedeemCodeModelToService(code, m)
code.ID = created.ID
code.CreatedAt = created.CreatedAt
}
return err
}
@@ -31,36 +41,55 @@ func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []service.
if len(codes) == 0 {
return nil
}
models := make([]redeemCodeModel, 0, len(codes))
builders := make([]*dbent.RedeemCodeCreate, 0, len(codes))
for i := range codes {
m := redeemCodeModelFromService(&codes[i])
if m != nil {
models = append(models, *m)
}
c := &codes[i]
b := r.client.RedeemCode.Create().
SetCode(c.Code).
SetType(c.Type).
SetValue(c.Value).
SetStatus(c.Status).
SetNotes(c.Notes).
SetValidityDays(c.ValidityDays).
SetNillableUsedBy(c.UsedBy).
SetNillableUsedAt(c.UsedAt).
SetNillableGroupID(c.GroupID)
builders = append(builders, b)
}
return r.db.WithContext(ctx).Create(&models).Error
return r.client.RedeemCode.CreateBulk(builders...).Exec(ctx)
}
func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) {
var m redeemCodeModel
err := r.db.WithContext(ctx).First(&m, id).Error
m, err := r.client.RedeemCode.Query().
Where(redeemcode.IDEQ(id)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
if dbent.IsNotFound(err) {
return nil, service.ErrRedeemCodeNotFound
}
return nil, err
}
return redeemCodeModelToService(&m), nil
return redeemCodeEntityToService(m), nil
}
func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
var m redeemCodeModel
err := r.db.WithContext(ctx).Where("code = ?", code).First(&m).Error
m, err := r.client.RedeemCode.Query().
Where(redeemcode.CodeEQ(code)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
if dbent.IsNotFound(err) {
return nil, service.ErrRedeemCodeNotFound
}
return nil, err
}
return redeemCodeModelToService(&m), nil
return redeemCodeEntityToService(m), nil
}
func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&redeemCodeModel{}, id).Error
_, err := r.client.RedeemCode.Delete().Where(redeemcode.IDEQ(id)).Exec(ctx)
return err
}
func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
@@ -68,61 +97,88 @@ func (r *redeemCodeRepository) List(ctx context.Context, params pagination.Pagin
}
func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
var codes []redeemCodeModel
var total int64
db := r.db.WithContext(ctx).Model(&redeemCodeModel{})
q := r.client.RedeemCode.Query()
if codeType != "" {
db = db.Where("type = ?", codeType)
q = q.Where(redeemcode.TypeEQ(codeType))
}
if status != "" {
db = db.Where("status = ?", status)
q = q.Where(redeemcode.StatusEQ(status))
}
if search != "" {
searchPattern := "%" + search + "%"
db = db.Where("code ILIKE ?", searchPattern)
q = q.Where(redeemcode.CodeContainsFold(search))
}
if err := db.Count(&total).Error; err != nil {
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
if err := db.Preload("User").Preload("Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&codes).Error; err != nil {
codes, err := q.
WithUser().
WithGroup().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(redeemcode.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outCodes := make([]service.RedeemCode, 0, len(codes))
for i := range codes {
outCodes = append(outCodes, *redeemCodeModelToService(&codes[i]))
}
outCodes := redeemCodeEntitiesToService(codes)
return outCodes, paginationResultFromTotal(total, params), nil
return outCodes, paginationResultFromTotal(int64(total), params), nil
}
func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
m := redeemCodeModelFromService(code)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyRedeemCodeModelToService(code, m)
up := r.client.RedeemCode.UpdateOneID(code.ID).
SetCode(code.Code).
SetType(code.Type).
SetValue(code.Value).
SetStatus(code.Status).
SetNotes(code.Notes).
SetValidityDays(code.ValidityDays)
if code.UsedBy != nil {
up.SetUsedBy(*code.UsedBy)
} else {
up.ClearUsedBy()
}
return err
if code.UsedAt != nil {
up.SetUsedAt(*code.UsedAt)
} else {
up.ClearUsedAt()
}
if code.GroupID != nil {
up.SetGroupID(*code.GroupID)
} else {
up.ClearGroupID()
}
updated, err := up.Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrRedeemCodeNotFound
}
return err
}
code.CreatedAt = updated.CreatedAt
return nil
}
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now()
result := r.db.WithContext(ctx).Model(&redeemCodeModel{}).
Where("id = ? AND status = ?", id, service.StatusUnused).
Updates(map[string]any{
"status": service.StatusUsed,
"used_by": userID,
"used_at": now,
})
if result.Error != nil {
return result.Error
affected, err := r.client.RedeemCode.Update().
Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)).
SetStatus(service.StatusUsed).
SetUsedBy(userID).
SetUsedAt(now).
Save(ctx)
if err != nil {
return err
}
if result.RowsAffected == 0 {
return service.ErrRedeemCodeUsed.WithCause(gorm.ErrRecordNotFound)
if affected == 0 {
return service.ErrRedeemCodeUsed
}
return nil
}
@@ -132,49 +188,24 @@ func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, lim
limit = 10
}
var codes []redeemCodeModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("used_by = ?", userID).
Order("used_at DESC").
codes, err := r.client.RedeemCode.Query().
Where(redeemcode.UsedByEQ(userID)).
WithGroup().
Order(dbent.Desc(redeemcode.FieldUsedAt)).
Limit(limit).
Find(&codes).Error
All(ctx)
if err != nil {
return nil, err
}
outCodes := make([]service.RedeemCode, 0, len(codes))
for i := range codes {
outCodes = append(outCodes, *redeemCodeModelToService(&codes[i]))
}
return outCodes, nil
return redeemCodeEntitiesToService(codes), nil
}
type redeemCodeModel struct {
ID int64 `gorm:"primaryKey"`
Code string `gorm:"uniqueIndex;size:32;not null"`
Type string `gorm:"size:20;default:balance;not null"`
Value float64 `gorm:"type:decimal(20,8);not null"`
Status string `gorm:"size:20;default:unused;not null"`
UsedBy *int64 `gorm:"index"`
UsedAt *time.Time
Notes string `gorm:"type:text"`
CreatedAt time.Time `gorm:"not null"`
GroupID *int64 `gorm:"index"`
ValidityDays int `gorm:"default:30"`
User *userModel `gorm:"foreignKey:UsedBy"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (redeemCodeModel) TableName() string { return "redeem_codes" }
func redeemCodeModelToService(m *redeemCodeModel) *service.RedeemCode {
func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode {
if m == nil {
return nil
}
return &service.RedeemCode{
out := &service.RedeemCode{
ID: m.ID,
Code: m.Code,
Type: m.Type,
@@ -182,38 +213,26 @@ func redeemCodeModelToService(m *redeemCodeModel) *service.RedeemCode {
Status: m.Status,
UsedBy: m.UsedBy,
UsedAt: m.UsedAt,
Notes: m.Notes,
Notes: derefString(m.Notes),
CreatedAt: m.CreatedAt,
GroupID: m.GroupID,
ValidityDays: m.ValidityDays,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
}
if m.Edges.Group != nil {
out.Group = groupEntityToService(m.Edges.Group)
}
return out
}
func redeemCodeModelFromService(r *service.RedeemCode) *redeemCodeModel {
if r == nil {
return nil
}
return &redeemCodeModel{
ID: r.ID,
Code: r.Code,
Type: r.Type,
Value: r.Value,
Status: r.Status,
UsedBy: r.UsedBy,
UsedAt: r.UsedAt,
Notes: r.Notes,
CreatedAt: r.CreatedAt,
GroupID: r.GroupID,
ValidityDays: r.ValidityDays,
func redeemCodeEntitiesToService(models []*dbent.RedeemCode) []service.RedeemCode {
out := make([]service.RedeemCode, 0, len(models))
for i := range models {
if s := redeemCodeEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
}
func applyRedeemCodeModelToService(code *service.RedeemCode, m *redeemCodeModel) {
if code == nil || m == nil {
return
}
code.ID = m.ID
code.CreatedAt = m.CreatedAt
return out
}

View File

@@ -7,29 +7,47 @@ import (
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type RedeemCodeRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *redeemCodeRepository
ctx context.Context
client *dbent.Client
repo *redeemCodeRepository
}
func (s *RedeemCodeRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewRedeemCodeRepository(s.db).(*redeemCodeRepository)
entClient, _ := testEntSQLTx(s.T())
s.client = entClient
s.repo = NewRedeemCodeRepository(entClient).(*redeemCodeRepository)
}
func TestRedeemCodeRepoSuite(t *testing.T) {
suite.Run(t, new(RedeemCodeRepoSuite))
}
func (s *RedeemCodeRepoSuite) createUser(email string) *dbent.User {
u, err := s.client.User.Create().
SetEmail(email).
SetPasswordHash("test-password-hash").
Save(s.ctx)
s.Require().NoError(err, "create user")
return u
}
func (s *RedeemCodeRepoSuite) createGroup(name string) *dbent.Group {
g, err := s.client.Group.Create().
SetName(name).
Save(s.ctx)
s.Require().NoError(err, "create group")
return g
}
// --- Create / CreateBatch / GetByID / GetByCode ---
func (s *RedeemCodeRepoSuite) TestCreate() {
@@ -70,10 +88,19 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch() {
func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
s.Require().ErrorIs(err, service.ErrRedeemCodeNotFound)
}
func (s *RedeemCodeRepoSuite) TestGetByCode() {
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "GET-BY-CODE", Type: service.RedeemTypeBalance})
_, err := s.client.RedeemCode.Create().
SetCode("GET-BY-CODE").
SetType(service.RedeemTypeBalance).
SetStatus(service.StatusUnused).
SetValue(0).
SetNotes("").
SetValidityDays(30).
Save(s.ctx)
s.Require().NoError(err, "seed redeem code")
got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE")
s.Require().NoError(err, "GetByCode")
@@ -83,25 +110,35 @@ func (s *RedeemCodeRepoSuite) TestGetByCode() {
func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() {
_, err := s.repo.GetByCode(s.ctx, "NON-EXISTENT")
s.Require().Error(err, "expected error for non-existent code")
s.Require().ErrorIs(err, service.ErrRedeemCodeNotFound)
}
// --- Delete ---
func (s *RedeemCodeRepoSuite) TestDelete() {
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TO-DELETE", Type: service.RedeemTypeBalance})
created, err := s.client.RedeemCode.Create().
SetCode("TO-DELETE").
SetType(service.RedeemTypeBalance).
SetStatus(service.StatusUnused).
SetValue(0).
SetNotes("").
SetValidityDays(30).
Save(s.ctx)
s.Require().NoError(err)
err := s.repo.Delete(s.ctx, code.ID)
err = s.repo.Delete(s.ctx, created.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, code.ID)
_, err = s.repo.GetByID(s.ctx, created.ID)
s.Require().Error(err, "expected error after delete")
s.Require().ErrorIs(err, service.ErrRedeemCodeNotFound)
}
// --- List / ListWithFilters ---
func (s *RedeemCodeRepoSuite) TestList() {
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-1", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-2", Type: service.RedeemTypeBalance})
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "LIST-1", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "LIST-2", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
@@ -110,8 +147,8 @@ func (s *RedeemCodeRepoSuite) TestList() {
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-BAL", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-SUB", Type: service.RedeemTypeSubscription})
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "TYPE-BAL", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "TYPE-SUB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused}))
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, "", "")
s.Require().NoError(err)
@@ -120,8 +157,8 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-UNUSED", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed})
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-UNUSED", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-USED", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUsed}))
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusUsed, "")
s.Require().NoError(err)
@@ -130,8 +167,8 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "BETA-CODE", Type: service.RedeemTypeBalance})
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "BETA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha")
s.Require().NoError(err)
@@ -140,12 +177,17 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "WITH-GROUP",
Type: service.RedeemTypeSubscription,
GroupID: &group.ID,
})
group := s.createGroup(uniqueTestValue(s.T(), "g-preload"))
_, err := s.client.RedeemCode.Create().
SetCode("WITH-GROUP").
SetType(service.RedeemTypeSubscription).
SetStatus(service.StatusUnused).
SetValue(0).
SetNotes("").
SetValidityDays(30).
SetGroupID(group.ID).
Save(s.ctx)
s.Require().NoError(err)
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "")
s.Require().NoError(err)
@@ -157,7 +199,13 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
// --- Update ---
func (s *RedeemCodeRepoSuite) TestUpdate() {
code := redeemCodeModelToService(mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "UPDATE-ME", Type: service.RedeemTypeBalance, Value: 10}))
code := &service.RedeemCode{
Code: "UPDATE-ME",
Type: service.RedeemTypeBalance,
Value: 10,
Status: service.StatusUnused,
}
s.Require().NoError(s.repo.Create(s.ctx, code))
code.Value = 50
err := s.repo.Update(s.ctx, code)
@@ -171,8 +219,9 @@ func (s *RedeemCodeRepoSuite) TestUpdate() {
// --- Use ---
func (s *RedeemCodeRepoSuite) TestUse() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "use@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "USE-ME", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
user := s.createUser(uniqueTestValue(s.T(), "use") + "@example.com")
code := &service.RedeemCode{Code: "USE-ME", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}
s.Require().NoError(s.repo.Create(s.ctx, code))
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use")
@@ -186,8 +235,9 @@ func (s *RedeemCodeRepoSuite) TestUse() {
}
func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "idem@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "IDEM-CODE", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
user := s.createUser(uniqueTestValue(s.T(), "idem") + "@example.com")
code := &service.RedeemCode{Code: "IDEM-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}
s.Require().NoError(s.repo.Create(s.ctx, code))
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use first time")
@@ -199,8 +249,9 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
}
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "already@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALREADY-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed})
user := s.createUser(uniqueTestValue(s.T(), "already") + "@example.com")
code := &service.RedeemCode{Code: "ALREADY-USED", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUsed}
s.Require().NoError(s.repo.Create(s.ctx, code))
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "expected error for already used code")
@@ -210,25 +261,34 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
// --- ListByUser ---
func (s *RedeemCodeRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"})
user := s.createUser(uniqueTestValue(s.T(), "listby") + "@example.com")
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
// Create codes with explicit used_at for ordering
c1 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "USER-1",
Type: service.RedeemTypeBalance,
Status: service.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c1).Update("used_at", base)
usedAt1 := base
_, err := s.client.RedeemCode.Create().
SetCode("USER-1").
SetType(service.RedeemTypeBalance).
SetStatus(service.StatusUsed).
SetValue(0).
SetNotes("").
SetValidityDays(30).
SetUsedBy(user.ID).
SetUsedAt(usedAt1).
Save(s.ctx)
s.Require().NoError(err)
c2 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "USER-2",
Type: service.RedeemTypeBalance,
Status: service.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c2).Update("used_at", base.Add(1*time.Hour))
usedAt2 := base.Add(1 * time.Hour)
_, err = s.client.RedeemCode.Create().
SetCode("USER-2").
SetType(service.RedeemTypeBalance).
SetStatus(service.StatusUsed).
SetValue(0).
SetNotes("").
SetValidityDays(30).
SetUsedBy(user.ID).
SetUsedAt(usedAt2).
Save(s.ctx)
s.Require().NoError(err)
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err, "ListByUser")
@@ -239,17 +299,21 @@ func (s *RedeemCodeRepoSuite) TestListByUser() {
}
func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "grp@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listby"})
user := s.createUser(uniqueTestValue(s.T(), "grp") + "@example.com")
group := s.createGroup(uniqueTestValue(s.T(), "g-listby"))
c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "WITH-GRP",
Type: service.RedeemTypeSubscription,
Status: service.StatusUsed,
UsedBy: &user.ID,
GroupID: &group.ID,
})
s.db.Model(c).Update("used_at", time.Now())
_, err := s.client.RedeemCode.Create().
SetCode("WITH-GRP").
SetType(service.RedeemTypeSubscription).
SetStatus(service.StatusUsed).
SetValue(0).
SetNotes("").
SetValidityDays(30).
SetUsedBy(user.ID).
SetUsedAt(time.Now()).
SetGroupID(group.ID).
Save(s.ctx)
s.Require().NoError(err)
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err)
@@ -259,14 +323,18 @@ func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
}
func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "deflimit@test.com"})
c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "DEF-LIM",
Type: service.RedeemTypeBalance,
Status: service.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c).Update("used_at", time.Now())
user := s.createUser(uniqueTestValue(s.T(), "deflimit") + "@example.com")
_, err := s.client.RedeemCode.Create().
SetCode("DEF-LIM").
SetType(service.RedeemTypeBalance).
SetStatus(service.StatusUsed).
SetValue(0).
SetNotes("").
SetValidityDays(30).
SetUsedBy(user.ID).
SetUsedAt(time.Now()).
Save(s.ctx)
s.Require().NoError(err)
// limit <= 0 should default to 10
codes, err := s.repo.ListByUser(s.ctx, user.ID, 0)
@@ -277,12 +345,13 @@ func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
// --- Combined original test ---
func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "rc@example.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-rc"})
user := s.createUser(uniqueTestValue(s.T(), "rc") + "@example.com")
group := s.createGroup(uniqueTestValue(s.T(), "g-rc"))
groupID := group.ID
codes := []service.RedeemCode{
{Code: "CODEA", Type: service.RedeemTypeBalance, Value: 1, Status: service.StatusUnused, CreatedAt: time.Now()},
{Code: "CODEB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()},
{Code: "CODEA", Type: service.RedeemTypeBalance, Value: 1, Status: service.StatusUnused, Notes: ""},
{Code: "CODEB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused, Notes: "", GroupID: &groupID, ValidityDays: 7},
}
s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch")
@@ -303,10 +372,16 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser
codeA, err := s.repo.GetByCode(s.ctx, "CODEA")
s.Require().NoError(err, "GetByCode")
// Use fixed time instead of time.Sleep for deterministic ordering
s.db.Model(&redeemCodeModel{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC))
// Use fixed time instead of time.Sleep for deterministic ordering.
_, err = s.client.RedeemCode.UpdateOneID(codeB.ID).
SetUsedAt(time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)).
Save(s.ctx)
s.Require().NoError(err)
s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA")
s.db.Model(&redeemCodeModel{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC))
_, err = s.client.RedeemCode.UpdateOneID(codeA.ID).
SetUsedAt(time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC)).
Save(s.ctx)
s.Require().NoError(err)
used, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err, "ListByUser")

View File

@@ -4,27 +4,33 @@ import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type settingRepository struct {
db *gorm.DB
client *ent.Client
}
func NewSettingRepository(db *gorm.DB) service.SettingRepository {
return &settingRepository{db: db}
func NewSettingRepository(client *ent.Client) service.SettingRepository {
return &settingRepository{client: client}
}
func (r *settingRepository) Get(ctx context.Context, key string) (*service.Setting, error) {
var m settingModel
err := r.db.WithContext(ctx).Where("key = ?", key).First(&m).Error
m, err := r.client.Setting.Query().Where(setting.KeyEQ(key)).Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil)
if ent.IsNotFound(err) {
return nil, service.ErrSettingNotFound
}
return nil, err
}
return settingModelToService(&m), nil
return &service.Setting{
ID: m.ID,
Key: m.Key,
Value: m.Value,
UpdatedAt: m.UpdatedAt,
}, nil
}
func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) {
@@ -36,21 +42,22 @@ func (r *settingRepository) GetValue(ctx context.Context, key string) (string, e
}
func (r *settingRepository) Set(ctx context.Context, key, value string) error {
m := &settingModel{
Key: key,
Value: value,
UpdatedAt: time.Now(),
}
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
}).Create(m).Error
now := time.Now()
return r.client.Setting.
Create().
SetKey(key).
SetValue(value).
SetUpdatedAt(now).
OnConflictColumns(setting.FieldKey).
UpdateNewValues().
Exec(ctx)
}
func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
var settings []settingModel
err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error
if len(keys) == 0 {
return map[string]string{}, nil
}
settings, err := r.client.Setting.Query().Where(setting.KeyIn(keys...)).All(ctx)
if err != nil {
return nil, err
}
@@ -63,27 +70,24 @@ func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map
}
func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for key, value := range settings {
m := &settingModel{
Key: key,
Value: value,
UpdatedAt: time.Now(),
}
if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
}).Create(m).Error; err != nil {
return err
}
}
if len(settings) == 0 {
return nil
})
}
now := time.Now()
builders := make([]*ent.SettingCreate, 0, len(settings))
for key, value := range settings {
builders = append(builders, r.client.Setting.Create().SetKey(key).SetValue(value).SetUpdatedAt(now))
}
return r.client.Setting.
CreateBulk(builders...).
OnConflictColumns(setting.FieldKey).
UpdateNewValues().
Exec(ctx)
}
func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) {
var settings []settingModel
err := r.db.WithContext(ctx).Find(&settings).Error
settings, err := r.client.Setting.Query().All(ctx)
if err != nil {
return nil, err
}
@@ -96,26 +100,6 @@ func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, erro
}
func (r *settingRepository) Delete(ctx context.Context, key string) error {
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&settingModel{}).Error
}
type settingModel struct {
ID int64 `gorm:"primaryKey"`
Key string `gorm:"uniqueIndex;size:100;not null"`
Value string `gorm:"type:text;not null"`
UpdatedAt time.Time `gorm:"not null"`
}
func (settingModel) TableName() string { return "settings" }
func settingModelToService(m *settingModel) *service.Setting {
if m == nil {
return nil
}
return &service.Setting{
ID: m.ID,
Key: m.Key,
Value: m.Value,
UpdatedAt: m.UpdatedAt,
}
_, err := r.client.Setting.Delete().Where(setting.KeyEQ(key)).Exec(ctx)
return err
}

View File

@@ -8,20 +8,18 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type SettingRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *settingRepository
}
func (s *SettingRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewSettingRepository(s.db).(*settingRepository)
entClient, _ := testEntSQLTx(s.T())
s.repo = NewSettingRepository(entClient).(*settingRepository)
}
func TestSettingRepoSuite(t *testing.T) {

View File

@@ -0,0 +1,110 @@
//go:build integration
package repository
import (
"context"
"fmt"
"strings"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func uniqueSoftDeleteValue(t *testing.T, prefix string) string {
t.Helper()
safeName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
return fmt.Sprintf("%s-%s", prefix, safeName)
}
func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, email string) *dbent.User {
t.Helper()
u, err := client.User.Create().
SetEmail(email).
SetPasswordHash("test-password-hash").
Save(ctx)
require.NoError(t, err, "create ent user")
return u
}
func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
ctx := context.Background()
client, _ := testEntSQLTx(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
Name: "soft-delete",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key), "create api key")
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
_, err := repo.GetByID(ctx, key.ID)
require.ErrorIs(t, err, service.ErrApiKeyNotFound, "deleted rows should be hidden by default")
_, err = client.ApiKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
require.Error(t, err, "default ent query should not see soft-deleted rows")
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
got, err := client.ApiKey.Query().
Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
}
func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
ctx := context.Background()
client, _ := testEntSQLTx(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
Name: "soft-delete2",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key), "create api key")
require.NoError(t, repo.Delete(ctx, key.ID), "first delete")
require.NoError(t, repo.Delete(ctx, key.ID), "second delete should be idempotent")
}
func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
ctx := context.Background()
client, _ := testEntSQLTx(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),
Name: "soft-delete3",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key), "create api key")
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
// Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at.
_, err := client.ApiKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "hard delete")
_, err = client.ApiKey.Query().
Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,35 +4,39 @@ package repository
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type UsageLogRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *usageLogRepository
ctx context.Context
tx *sql.Tx
client *dbent.Client
repo *usageLogRepository
}
func (s *UsageLogRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUsageLogRepository(s.db).(*usageLogRepository)
client, tx := testEntSQLTx(s.T())
s.client = client
s.tx = tx
s.repo = newUsageLogRepositoryWithSQL(client, tx)
}
func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite))
}
func (s *UsageLogRepoSuite) createUsageLog(user *userModel, apiKey *apiKeyModel, account *accountModel, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.ApiKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
@@ -51,9 +55,9 @@ func (s *UsageLogRepoSuite) createUsageLog(user *userModel, apiKey *apiKeyModel,
// --- Create / GetByID ---
func (s *UsageLogRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-create"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-create"})
log := &service.UsageLog{
UserID: user.ID,
@@ -72,9 +76,9 @@ func (s *UsageLogRepoSuite) TestCreate() {
}
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-getbyid"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -92,9 +96,9 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
// --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-delete"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-delete"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -108,9 +112,9 @@ func (s *UsageLogRepoSuite) TestDelete() {
// --- ListByUser ---
func (s *UsageLogRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyuser"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyuser"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
@@ -124,9 +128,9 @@ func (s *UsageLogRepoSuite) TestListByUser() {
// --- ListByApiKey ---
func (s *UsageLogRepoSuite) TestListByApiKey() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyapikey"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyapikey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
@@ -140,9 +144,9 @@ func (s *UsageLogRepoSuite) TestListByApiKey() {
// --- ListByAccount ---
func (s *UsageLogRepoSuite) TestListByAccount() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyaccount"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyaccount"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -155,9 +159,9 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
// --- GetUserStats ---
func (s *UsageLogRepoSuite) TestGetUserStats() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userstats"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -175,9 +179,9 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
// --- ListWithFilters ---
func (s *UsageLogRepoSuite) TestListWithFilters() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filters"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filters"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -194,26 +198,26 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
now := time.Now()
todayStart := timezone.Today()
userToday := mustCreateUser(s.T(), s.db, &userModel{
userToday := mustCreateUser(s.T(), s.client, &service.User{
Email: "today@example.com",
CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
UpdatedAt: now,
})
userOld := mustCreateUser(s.T(), s.db, &userModel{
userOld := mustCreateUser(s.T(), s.client, &service.User{
Email: "old@example.com",
CreatedAt: todayStart.Add(-24 * time.Hour),
UpdatedAt: todayStart.Add(-24 * time.Hour),
})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
resetAt := now.Add(10 * time.Minute)
accNormal := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-normal", Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-error", Status: service.StatusError, Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
accNormal := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-normal", Schedulable: true})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-error", Status: service.StatusError, Schedulable: true})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
d1, d2, d3 := 100, 200, 300
logToday := &service.UsageLog{
@@ -285,7 +289,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1")
s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0")
wantRpm, wantTpm := s.repo.getPerformanceStats(s.ctx, 0)
wantRpm, wantTpm, err := s.repo.getPerformanceStats(s.ctx, 0)
s.Require().NoError(err, "getPerformanceStats")
s.Require().Equal(wantRpm, stats.Rpm, "Rpm mismatch")
s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch")
}
@@ -293,9 +298,9 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
// --- GetUserDashboardStats ---
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userdash"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userdash"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -308,9 +313,9 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
// --- GetAccountTodayStats ---
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-today"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -323,11 +328,11 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
// --- GetBatchUserUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batch"})
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batch"})
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
@@ -348,10 +353,10 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
// --- GetBatchApiKeyUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batchkey"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batchkey"})
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
@@ -370,9 +375,9 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
// --- GetGlobalStats ---
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-global"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-global"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -395,9 +400,9 @@ func maxTime(a, b time.Time) time.Time {
// --- ListByUserAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-timerange"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-timerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -414,9 +419,9 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
// --- ListByApiKeyAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytimerange"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -433,9 +438,9 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
// --- ListByAccountAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-acctimerange"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-acctimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -452,9 +457,9 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
// --- ListByModelAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modeltimerange"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modeltimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -508,9 +513,9 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
// --- GetAccountWindowStats ---
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-windowstats"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-windowstats"})
now := time.Now()
windowStart := now.Add(-10 * time.Minute)
@@ -528,9 +533,9 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
// --- GetUserUsageTrendByUserID ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrend"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrend"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -545,9 +550,9 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
}
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrendhourly"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrendhourly"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -564,9 +569,9 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
// --- GetUserModelStats ---
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelstats"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -611,9 +616,9 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
// --- GetUsageTrendWithFilters ---
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -639,9 +644,9 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
}
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters-h"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters-h"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -658,9 +663,9 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
// --- GetModelStatsWithFilters ---
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelfilters"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -712,9 +717,9 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
// --- GetAccountUsageStats ---
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-accstats"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-accstats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
@@ -758,7 +763,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
}
func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-emptystats"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-emptystats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
startTime := base
@@ -774,11 +779,11 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
// --- GetUserUsageTrend ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrends"})
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base)
@@ -796,10 +801,10 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
// --- GetApiKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrends"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base)
@@ -815,9 +820,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
}
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrendh"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base)
@@ -834,9 +839,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
// --- ListWithFilters (additional filter tests) ---
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterskey"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterskey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -848,9 +853,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
}
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterstime"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterstime"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -867,9 +872,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
}
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterscombined"})
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterscombined"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)

View File

@@ -2,252 +2,412 @@ package repository
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"database/sql"
"sort"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
"gorm.io/gorm"
)
type userRepository struct {
db *gorm.DB
client *dbent.Client
sql sqlExecutor
begin sqlBeginner
}
func NewUserRepository(db *gorm.DB) service.UserRepository {
return &userRepository{db: db}
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
return newUserRepositoryWithSQL(client, sqlDB)
}
func (r *userRepository) Create(ctx context.Context, user *service.User) error {
m := userModelFromService(user)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyUserModelToService(user, m)
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
var beginner sqlBeginner
if b, ok := sqlq.(sqlBeginner); ok {
beginner = b
}
return translatePersistenceError(err, nil, service.ErrEmailExists)
return &userRepository{client: client, sql: sqlq, begin: beginner}
}
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
if userIn == nil {
return nil
}
exec := r.sql
txClient := r.client
var sqlTx *sql.Tx
var txClientClose func() error
if r.begin != nil {
var err error
sqlTx, err = r.begin.BeginTx(ctx, nil)
if err != nil {
return err
}
exec = sqlTx
txClient = entClientFromSQLTx(sqlTx)
txClientClose = txClient.Close
defer func() { _ = sqlTx.Rollback() }()
}
if txClientClose != nil {
defer func() { _ = txClientClose() }()
}
created, err := txClient.User.Create().
SetEmail(userIn.Email).
SetUsername(userIn.Username).
SetWechat(userIn.Wechat).
SetNotes(userIn.Notes).
SetPasswordHash(userIn.PasswordHash).
SetRole(userIn.Role).
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
if err := r.syncUserAllowedGroups(ctx, txClient, exec, created.ID, userIn.AllowedGroups); err != nil {
return err
}
if sqlTx != nil {
if err := sqlTx.Commit(); err != nil {
return err
}
}
applyUserEntityToService(userIn, created)
return nil
}
func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
var m userModel
err := r.db.WithContext(ctx).First(&m, id).Error
m, err := r.client.User.Query().Where(dbuser.IDEQ(id)).Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return userModelToService(&m), nil
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{id})
if err == nil {
if v, ok := groups[id]; ok {
out.AllowedGroups = v
}
}
return out, nil
}
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
var m userModel
err := r.db.WithContext(ctx).Where("email = ?", email).First(&m).Error
m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return userModelToService(&m), nil
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err == nil {
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
}
return out, nil
}
func (r *userRepository) Update(ctx context.Context, user *service.User) error {
m := userModelFromService(user)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyUserModelToService(user, m)
func (r *userRepository) Update(ctx context.Context, userIn *service.User) error {
if userIn == nil {
return nil
}
return translatePersistenceError(err, nil, service.ErrEmailExists)
exec := r.sql
txClient := r.client
var sqlTx *sql.Tx
var txClientClose func() error
if r.begin != nil {
var err error
sqlTx, err = r.begin.BeginTx(ctx, nil)
if err != nil {
return err
}
exec = sqlTx
txClient = entClientFromSQLTx(sqlTx)
txClientClose = txClient.Close
defer func() { _ = sqlTx.Rollback() }()
}
if txClientClose != nil {
defer func() { _ = txClientClose() }()
}
updated, err := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email).
SetUsername(userIn.Username).
SetWechat(userIn.Wechat).
SetNotes(userIn.Notes).
SetPasswordHash(userIn.PasswordHash).
SetRole(userIn.Role).
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
}
if err := r.syncUserAllowedGroups(ctx, txClient, exec, updated.ID, userIn.AllowedGroups); err != nil {
return err
}
if sqlTx != nil {
if err := sqlTx.Commit(); err != nil {
return err
}
}
userIn.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *userRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&userModel{}, id).Error
_, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
return err
}
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists users with optional filtering by status, role, and search query
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) {
var users []userModel
var total int64
q := r.client.User.Query()
db := r.db.WithContext(ctx).Model(&userModel{})
// Apply filters
if status != "" {
db = db.Where("status = ?", status)
q = q.Where(dbuser.StatusEQ(status))
}
if role != "" {
db = db.Where("role = ?", role)
q = q.Where(dbuser.RoleEQ(role))
}
if search != "" {
searchPattern := "%" + search + "%"
db = db.Where(
"email ILIKE ? OR username ILIKE ? OR wechat ILIKE ?",
searchPattern, searchPattern, searchPattern,
q = q.Where(
dbuser.Or(
dbuser.EmailContainsFold(search),
dbuser.UsernameContainsFold(search),
dbuser.WechatContainsFold(search),
),
)
}
if err := db.Count(&total).Error; err != nil {
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
// Query users with pagination (reuse the same db with filters applied)
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&users).Error; err != nil {
users, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(dbuser.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
// Batch load subscriptions for all users (avoid N+1)
if len(users) > 0 {
userIDs := make([]int64, len(users))
userMap := make(map[int64]*service.User, len(users))
outUsers := make([]service.User, 0, len(users))
for i := range users {
userIDs[i] = users[i].ID
u := userModelToService(&users[i])
outUsers = append(outUsers, *u)
userMap[u.ID] = &outUsers[len(outUsers)-1]
}
// Query active subscriptions with groups in one query
var subscriptions []userSubscriptionModel
if err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id IN ? AND status = ?", userIDs, service.SubscriptionStatusActive).
Find(&subscriptions).Error; err != nil {
return nil, nil, err
}
// Associate subscriptions with users
for i := range subscriptions {
if user, ok := userMap[subscriptions[i].UserID]; ok {
user.Subscriptions = append(user.Subscriptions, *userSubscriptionModelToService(&subscriptions[i]))
}
}
return outUsers, paginationResultFromTotal(total, params), nil
}
outUsers := make([]service.User, 0, len(users))
for i := range users {
outUsers = append(outUsers, *userModelToService(&users[i]))
if len(users) == 0 {
return outUsers, paginationResultFromTotal(int64(total), params), nil
}
return outUsers, paginationResultFromTotal(total, params), nil
userIDs := make([]int64, 0, len(users))
userMap := make(map[int64]*service.User, len(users))
for i := range users {
userIDs = append(userIDs, users[i].ID)
u := userEntityToService(users[i])
outUsers = append(outUsers, *u)
userMap[u.ID] = &outUsers[len(outUsers)-1]
}
// Batch load active subscriptions with groups to avoid N+1.
subs, err := r.client.UserSubscription.Query().
Where(
usersubscription.UserIDIn(userIDs...),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
).
WithGroup().
All(ctx)
if err != nil {
return nil, nil, err
}
for i := range subs {
if u, ok := userMap[subs[i].UserID]; ok {
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
}
}
allowedGroupsByUser, err := r.loadAllowedGroups(ctx, userIDs)
if err == nil {
for id, u := range userMap {
if groups, ok := allowedGroupsByUser[id]; ok {
u.AllowedGroups = groups
}
}
}
return outUsers, paginationResultFromTotal(int64(total), params), nil
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
Update("balance", gorm.Expr("balance + ?", amount)).Error
_, err := r.client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
return err
}
// DeductBalance 扣减用户余额,仅当余额充足时执行
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
result := r.db.WithContext(ctx).Model(&userModel{}).
Where("id = ? AND balance >= ?", id, amount).
Update("balance", gorm.Expr("balance - ?", amount))
if result.Error != nil {
return result.Error
n, err := r.client.User.Update().
Where(dbuser.IDEQ(id), dbuser.BalanceGTE(amount)).
AddBalance(-amount).
Save(ctx)
if err != nil {
return err
}
if result.RowsAffected == 0 {
if n == 0 {
return service.ErrInsufficientBalance
}
return nil
}
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
_, err := r.client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx)
return err
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&userModel{}).Where("email = ?", email).Count(&count).Error
return count > 0, err
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
}
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
// 使用 PostgreSQL 的 array_remove 函数
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&userModel{}).
Where("? = ANY(allowed_groups)", groupID).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
return result.RowsAffected, result.Error
if r.sql == nil {
return 0, nil
}
joinAffected, err := r.client.UserAllowedGroup.Delete().
Where(userallowedgroup.GroupIDEQ(groupID)).
Exec(ctx)
if err != nil {
return 0, err
}
arrayRes, err := r.sql.ExecContext(
ctx,
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1), updated_at = NOW() WHERE $1 = ANY(allowed_groups)",
groupID,
)
if err != nil {
return 0, err
}
arrayAffected, _ := arrayRes.RowsAffected()
if int64(joinAffected) > arrayAffected {
return int64(joinAffected), nil
}
return arrayAffected, nil
}
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
var m userModel
err := r.db.WithContext(ctx).
Where("role = ? AND status = ?", service.RoleAdmin, service.StatusActive).
Order("id ASC").
First(&m).Error
m, err := r.client.User.Query().
Where(
dbuser.RoleEQ(service.RoleAdmin),
dbuser.StatusEQ(service.StatusActive),
).
Order(dbent.Asc(dbuser.FieldID)).
First(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return userModelToService(&m), nil
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err == nil {
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
}
return out, nil
}
type userModel struct {
ID int64 `gorm:"primaryKey"`
Email string `gorm:"uniqueIndex;size:255;not null"`
Username string `gorm:"size:100;default:''"`
Wechat string `gorm:"size:100;default:''"`
Notes string `gorm:"type:text;default:''"`
PasswordHash string `gorm:"size:255;not null"`
Role string `gorm:"size:20;default:user;not null"`
Balance float64 `gorm:"type:decimal(20,8);default:0;not null"`
Concurrency int `gorm:"default:5;not null"`
Status string `gorm:"size:20;default:active;not null"`
AllowedGroups pq.Int64Array `gorm:"type:bigint[]"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64) (map[int64][]int64, error) {
out := make(map[int64][]int64, len(userIDs))
if len(userIDs) == 0 {
return out, nil
}
rows, err := r.client.UserAllowedGroup.Query().
Where(userallowedgroup.UserIDIn(userIDs...)).
All(ctx)
if err != nil {
return nil, err
}
for i := range rows {
out[rows[i].UserID] = append(out[rows[i].UserID], rows[i].GroupID)
}
for userID := range out {
sort.Slice(out[userID], func(i, j int) bool { return out[userID][i] < out[userID][j] })
}
return out, nil
}
func (userModel) TableName() string { return "users" }
func userModelToService(m *userModel) *service.User {
if m == nil {
func (r *userRepository) syncUserAllowedGroups(ctx context.Context, client *dbent.Client, exec sqlExecutor, userID int64, groupIDs []int64) error {
if client == nil || exec == nil {
return nil
}
return &service.User{
ID: m.ID,
Email: m.Email,
Username: m.Username,
Wechat: m.Wechat,
Notes: m.Notes,
PasswordHash: m.PasswordHash,
Role: m.Role,
Balance: m.Balance,
Concurrency: m.Concurrency,
Status: m.Status,
AllowedGroups: []int64(m.AllowedGroups),
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
// Keep join table as the source of truth for reads.
if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil {
return err
}
unique := make(map[int64]struct{}, len(groupIDs))
for _, id := range groupIDs {
if id <= 0 {
continue
}
unique[id] = struct{}{}
}
legacyGroups := make([]int64, 0, len(unique))
if len(unique) > 0 {
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
for groupID := range unique {
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
legacyGroups = append(legacyGroups, groupID)
}
if err := client.UserAllowedGroup.
CreateBulk(creates...).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx); err != nil {
return err
}
}
// Phase 1 compatibility: keep legacy users.allowed_groups array updated for existing raw SQL paths.
var legacy any
if len(legacyGroups) > 0 {
sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] })
legacy = pq.Array(legacyGroups)
}
if _, err := exec.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil {
return err
}
return nil
}
func userModelFromService(u *service.User) *userModel {
if u == nil {
return nil
}
return &userModel{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: pq.Int64Array(u.AllowedGroups),
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
func applyUserModelToService(dst *service.User, src *userModel) {
func applyUserEntityToService(dst *service.User, src *dbent.User) {
if dst == nil || src == nil {
return
}

View File

@@ -4,46 +4,103 @@ package repository
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type UserRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *userRepository
ctx context.Context
tx *sql.Tx
client *dbent.Client
repo *userRepository
}
func (s *UserRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUserRepository(s.db).(*userRepository)
entClient, tx := testEntSQLTx(s.T())
s.tx = tx
s.client = entClient
s.repo = newUserRepositoryWithSQL(entClient, tx)
}
func TestUserRepoSuite(t *testing.T) {
suite.Run(t, new(UserRepoSuite))
}
func (s *UserRepoSuite) mustCreateUser(u *service.User) *service.User {
s.T().Helper()
if u.Email == "" {
u.Email = "user-" + time.Now().Format(time.RFC3339Nano) + "@example.com"
}
if u.PasswordHash == "" {
u.PasswordHash = "test-password-hash"
}
if u.Role == "" {
u.Role = service.RoleUser
}
if u.Status == "" {
u.Status = service.StatusActive
}
if u.Concurrency == 0 {
u.Concurrency = 5
}
s.Require().NoError(s.repo.Create(s.ctx, u), "create user")
return u
}
func (s *UserRepoSuite) mustCreateGroup(name string) *service.Group {
s.T().Helper()
g, err := s.client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
Save(s.ctx)
s.Require().NoError(err, "create group")
return groupEntityToService(g)
}
func (s *UserRepoSuite) mustCreateSubscription(userID, groupID int64, mutate func(*dbent.UserSubscriptionCreate)) *dbent.UserSubscription {
s.T().Helper()
now := time.Now()
create := s.client.UserSubscription.Create().
SetUserID(userID).
SetGroupID(groupID).
SetStartsAt(now.Add(-1*time.Hour)).
SetExpiresAt(now.Add(24*time.Hour)).
SetStatus(service.SubscriptionStatusActive).
SetAssignedAt(now).
SetNotes("")
if mutate != nil {
mutate(create)
}
sub, err := create.Save(s.ctx)
s.Require().NoError(err, "create subscription")
return sub
}
// --- Create / GetByID / GetByEmail / Update / Delete ---
func (s *UserRepoSuite) TestCreate() {
user := &service.User{
user := s.mustCreateUser(&service.User{
Email: "create@test.com",
Username: "testuser",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
}
})
err := s.repo.Create(s.ctx, user)
s.Require().NoError(err, "Create")
s.Require().NotZero(user.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, user.ID)
@@ -57,7 +114,7 @@ func (s *UserRepoSuite) TestGetByID_NotFound() {
}
func (s *UserRepoSuite) TestGetByEmail() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "byemail@test.com"})
user := s.mustCreateUser(&service.User{Email: "byemail@test.com"})
got, err := s.repo.GetByEmail(s.ctx, user.Email)
s.Require().NoError(err, "GetByEmail")
@@ -70,19 +127,20 @@ func (s *UserRepoSuite) TestGetByEmail_NotFound() {
}
func (s *UserRepoSuite) TestUpdate() {
user := userModelToService(mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com", Username: "original"}))
user.Username = "updated"
err := s.repo.Update(s.ctx, user)
s.Require().NoError(err, "Update")
user := s.mustCreateUser(&service.User{Email: "update@test.com", Username: "original"})
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
got.Username = "updated"
s.Require().NoError(s.repo.Update(s.ctx, got), "Update")
updated, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Username)
s.Require().Equal("updated", updated.Username)
}
func (s *UserRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
user := s.mustCreateUser(&service.User{Email: "delete@test.com"})
err := s.repo.Delete(s.ctx, user.ID)
s.Require().NoError(err, "Delete")
@@ -94,8 +152,8 @@ func (s *UserRepoSuite) TestDelete() {
// --- List / ListWithFilters ---
func (s *UserRepoSuite) TestList() {
mustCreateUser(s.T(), s.db, &userModel{Email: "list1@test.com"})
mustCreateUser(s.T(), s.db, &userModel{Email: "list2@test.com"})
s.mustCreateUser(&service.User{Email: "list1@test.com"})
s.mustCreateUser(&service.User{Email: "list2@test.com"})
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
@@ -104,8 +162,8 @@ func (s *UserRepoSuite) TestList() {
}
func (s *UserRepoSuite) TestListWithFilters_Status() {
mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com", Status: service.StatusActive})
mustCreateUser(s.T(), s.db, &userModel{Email: "disabled@test.com", Status: service.StatusDisabled})
s.mustCreateUser(&service.User{Email: "active@test.com", Status: service.StatusActive})
s.mustCreateUser(&service.User{Email: "disabled@test.com", Status: service.StatusDisabled})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, "", "")
s.Require().NoError(err)
@@ -114,8 +172,8 @@ func (s *UserRepoSuite) TestListWithFilters_Status() {
}
func (s *UserRepoSuite) TestListWithFilters_Role() {
mustCreateUser(s.T(), s.db, &userModel{Email: "user@test.com", Role: service.RoleUser})
mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin})
s.mustCreateUser(&service.User{Email: "user@test.com", Role: service.RoleUser})
s.mustCreateUser(&service.User{Email: "admin@test.com", Role: service.RoleAdmin})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.RoleAdmin, "")
s.Require().NoError(err)
@@ -124,8 +182,8 @@ func (s *UserRepoSuite) TestListWithFilters_Role() {
}
func (s *UserRepoSuite) TestListWithFilters_Search() {
mustCreateUser(s.T(), s.db, &userModel{Email: "alice@test.com", Username: "Alice"})
mustCreateUser(s.T(), s.db, &userModel{Email: "bob@test.com", Username: "Bob"})
s.mustCreateUser(&service.User{Email: "alice@test.com", Username: "Alice"})
s.mustCreateUser(&service.User{Email: "bob@test.com", Username: "Bob"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice")
s.Require().NoError(err)
@@ -134,8 +192,8 @@ func (s *UserRepoSuite) TestListWithFilters_Search() {
}
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com", Username: "JohnDoe"})
mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com", Username: "JaneSmith"})
s.mustCreateUser(&service.User{Email: "u1@test.com", Username: "JohnDoe"})
s.mustCreateUser(&service.User{Email: "u2@test.com", Username: "JaneSmith"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john")
s.Require().NoError(err)
@@ -144,8 +202,8 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
}
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
mustCreateUser(s.T(), s.db, &userModel{Email: "w1@test.com", Wechat: "wx_hello"})
mustCreateUser(s.T(), s.db, &userModel{Email: "w2@test.com", Wechat: "wx_world"})
s.mustCreateUser(&service.User{Email: "w1@test.com", Wechat: "wx_hello"})
s.mustCreateUser(&service.User{Email: "w2@test.com", Wechat: "wx_world"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
s.Require().NoError(err)
@@ -154,20 +212,17 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
}
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub@test.com", Status: service.StatusActive})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sub"})
user := s.mustCreateUser(&service.User{Email: "sub@test.com", Status: service.StatusActive})
groupActive := s.mustCreateGroup("g-sub-active")
groupExpired := s.mustCreateGroup("g-sub-expired")
_ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(1 * time.Hour),
_ = s.mustCreateSubscription(user.ID, groupActive.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetStatus(service.SubscriptionStatusActive)
c.SetExpiresAt(time.Now().Add(1 * time.Hour))
})
_ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-1 * time.Hour),
_ = s.mustCreateSubscription(user.ID, groupExpired.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetStatus(service.SubscriptionStatusExpired)
c.SetExpiresAt(time.Now().Add(-1 * time.Hour))
})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "sub@")
@@ -175,11 +230,11 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
s.Require().Len(users, 1, "expected 1 user")
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
s.Require().NotNil(users[0].Subscriptions[0].Group, "expected subscription group preload")
s.Require().Equal(group.ID, users[0].Subscriptions[0].Group.ID, "group ID mismatch")
s.Require().Equal(groupActive.ID, users[0].Subscriptions[0].Group.ID, "group ID mismatch")
}
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
mustCreateUser(s.T(), s.db, &userModel{
s.mustCreateUser(&service.User{
Email: "a@example.com",
Username: "Alice",
Wechat: "wx_a",
@@ -187,7 +242,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
Status: service.StatusActive,
Balance: 10,
})
target := mustCreateUser(s.T(), s.db, &userModel{
target := s.mustCreateUser(&service.User{
Email: "b@example.com",
Username: "Bob",
Wechat: "wx_b",
@@ -195,7 +250,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
Status: service.StatusActive,
Balance: 1,
})
mustCreateUser(s.T(), s.db, &userModel{
s.mustCreateUser(&service.User{
Email: "c@example.com",
Role: service.RoleAdmin,
Status: service.StatusDisabled,
@@ -211,40 +266,40 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
// --- Balance operations ---
func (s *UserRepoSuite) TestUpdateBalance() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "bal@test.com", Balance: 10})
user := s.mustCreateUser(&service.User{Email: "bal@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
s.Require().NoError(err, "UpdateBalance")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(12.5, got.Balance)
s.Require().InDelta(12.5, got.Balance, 1e-6)
}
func (s *UserRepoSuite) TestUpdateBalance_Negative() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "balneg@test.com", Balance: 10})
user := s.mustCreateUser(&service.User{Email: "balneg@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
s.Require().NoError(err, "UpdateBalance with negative")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(7.0, got.Balance)
s.Require().InDelta(7.0, got.Balance, 1e-6)
}
func (s *UserRepoSuite) TestDeductBalance() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "deduct@test.com", Balance: 10})
user := s.mustCreateUser(&service.User{Email: "deduct@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 5)
s.Require().NoError(err, "DeductBalance")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(5.0, got.Balance)
s.Require().InDelta(5.0, got.Balance, 1e-6)
}
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "insuf@test.com", Balance: 5})
user := s.mustCreateUser(&service.User{Email: "insuf@test.com", Balance: 5})
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance")
@@ -252,20 +307,20 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
}
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exact@test.com", Balance: 10})
user := s.mustCreateUser(&service.User{Email: "exact@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 10)
s.Require().NoError(err, "DeductBalance exact amount")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Zero(got.Balance)
s.Require().InDelta(0.0, got.Balance, 1e-6)
}
// --- Concurrency ---
func (s *UserRepoSuite) TestUpdateConcurrency() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "conc@test.com", Concurrency: 5})
user := s.mustCreateUser(&service.User{Email: "conc@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
s.Require().NoError(err, "UpdateConcurrency")
@@ -276,7 +331,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency() {
}
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "concneg@test.com", Concurrency: 5})
user := s.mustCreateUser(&service.User{Email: "concneg@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
s.Require().NoError(err, "UpdateConcurrency negative")
@@ -289,7 +344,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
// --- ExistsByEmail ---
func (s *UserRepoSuite) TestExistsByEmail() {
mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
s.mustCreateUser(&service.User{Email: "exists@test.com"})
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
s.Require().NoError(err, "ExistsByEmail")
@@ -303,34 +358,38 @@ func (s *UserRepoSuite) TestExistsByEmail() {
// --- RemoveGroupFromAllowedGroups ---
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
groupID := int64(42)
userA := mustCreateUser(s.T(), s.db, &userModel{
target := s.mustCreateGroup("target-42")
other := s.mustCreateGroup("other-7")
userA := s.mustCreateUser(&service.User{
Email: "a1@example.com",
AllowedGroups: pq.Int64Array{groupID, 7},
AllowedGroups: []int64{target.ID, other.ID},
})
mustCreateUser(s.T(), s.db, &userModel{
s.mustCreateUser(&service.User{
Email: "a2@example.com",
AllowedGroups: pq.Int64Array{7},
AllowedGroups: []int64{other.ID},
})
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, groupID)
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, target.ID)
s.Require().NoError(err, "RemoveGroupFromAllowedGroups")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
got, err := s.repo.GetByID(s.ctx, userA.ID)
s.Require().NoError(err, "GetByID")
for _, id := range got.AllowedGroups {
s.Require().NotEqual(groupID, id, "expected groupID to be removed from allowed_groups")
}
s.Require().NotContains(got.AllowedGroups, target.ID)
s.Require().Contains(got.AllowedGroups, other.ID)
}
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
mustCreateUser(s.T(), s.db, &userModel{
groupA := s.mustCreateGroup("nomatch-a")
groupB := s.mustCreateGroup("nomatch-b")
s.mustCreateUser(&service.User{
Email: "nomatch@test.com",
AllowedGroups: pq.Int64Array{1, 2, 3},
AllowedGroups: []int64{groupA.ID, groupB.ID},
})
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, 999)
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, 999999)
s.Require().NoError(err)
s.Require().Zero(affected, "expected no affected rows")
}
@@ -338,12 +397,12 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
// --- GetFirstAdmin ---
func (s *UserRepoSuite) TestGetFirstAdmin() {
admin1 := mustCreateUser(s.T(), s.db, &userModel{
admin1 := s.mustCreateUser(&service.User{
Email: "admin1@example.com",
Role: service.RoleAdmin,
Status: service.StatusActive,
})
mustCreateUser(s.T(), s.db, &userModel{
s.mustCreateUser(&service.User{
Email: "admin2@example.com",
Role: service.RoleAdmin,
Status: service.StatusActive,
@@ -355,7 +414,7 @@ func (s *UserRepoSuite) TestGetFirstAdmin() {
}
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
mustCreateUser(s.T(), s.db, &userModel{
s.mustCreateUser(&service.User{
Email: "user@example.com",
Role: service.RoleUser,
Status: service.StatusActive,
@@ -366,12 +425,12 @@ func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
}
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
mustCreateUser(s.T(), s.db, &userModel{
s.mustCreateUser(&service.User{
Email: "disabled@example.com",
Role: service.RoleAdmin,
Status: service.StatusDisabled,
})
activeAdmin := mustCreateUser(s.T(), s.db, &userModel{
activeAdmin := s.mustCreateUser(&service.User{
Email: "active@example.com",
Role: service.RoleAdmin,
Status: service.StatusActive,
@@ -382,10 +441,10 @@ func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
s.Require().Equal(activeAdmin.ID, got.ID, "should return only active admin")
}
// --- Combined original test ---
// --- Combined ---
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user1 := mustCreateUser(s.T(), s.db, &userModel{
user1 := s.mustCreateUser(&service.User{
Email: "a@example.com",
Username: "Alice",
Wechat: "wx_a",
@@ -393,7 +452,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
Status: service.StatusActive,
Balance: 10,
})
user2 := mustCreateUser(s.T(), s.db, &userModel{
user2 := s.mustCreateUser(&service.User{
Email: "b@example.com",
Username: "Bob",
Wechat: "wx_b",
@@ -401,7 +460,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
Status: service.StatusActive,
Balance: 1,
})
_ = mustCreateUser(s.T(), s.db, &userModel{
s.mustCreateUser(&service.User{
Email: "c@example.com",
Role: service.RoleAdmin,
Status: service.StatusDisabled,
@@ -424,12 +483,12 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().NoError(s.repo.UpdateBalance(s.ctx, user1.ID, 2.5), "UpdateBalance")
got3, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after UpdateBalance")
s.Require().Equal(12.5, got3.Balance, "UpdateBalance mismatch")
s.Require().InDelta(12.5, got3.Balance, 1e-6)
s.Require().NoError(s.repo.DeductBalance(s.ctx, user1.ID, 5), "DeductBalance")
got4, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after DeductBalance")
s.Require().Equal(7.5, got4.Balance, "DeductBalance mismatch")
s.Require().InDelta(7.5, got4.Balance, 1e-6)
err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
s.Require().Error(err, "DeductBalance expected error for insufficient balance")
@@ -438,7 +497,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
got5, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after UpdateConcurrency")
s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch")
s.Require().Equal(user1.Concurrency+3, got5.Concurrency)
params := pagination.PaginationParams{Page: 1, PageSize: 10}
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.StatusActive, service.RoleAdmin, "b@")
@@ -447,3 +506,4 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().Len(users, 1, "ListWithFilters len mismatch")
s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch")
}

View File

@@ -4,333 +4,336 @@ import (
"context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm"
)
type userSubscriptionRepository struct {
db *gorm.DB
client *dbent.Client
}
func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository {
return &userSubscriptionRepository{db: db}
func NewUserSubscriptionRepository(client *dbent.Client) service.UserSubscriptionRepository {
return &userSubscriptionRepository{client: client}
}
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error {
m := userSubscriptionModelFromService(sub)
err := r.db.WithContext(ctx).Create(m).Error
if sub == nil {
return nil
}
builder := r.client.UserSubscription.Create().
SetUserID(sub.UserID).
SetGroupID(sub.GroupID).
SetExpiresAt(sub.ExpiresAt).
SetNillableDailyWindowStart(sub.DailyWindowStart).
SetNillableWeeklyWindowStart(sub.WeeklyWindowStart).
SetNillableMonthlyWindowStart(sub.MonthlyWindowStart).
SetDailyUsageUsd(sub.DailyUsageUSD).
SetWeeklyUsageUsd(sub.WeeklyUsageUSD).
SetMonthlyUsageUsd(sub.MonthlyUsageUSD).
SetNillableAssignedBy(sub.AssignedBy)
if sub.StartsAt.IsZero() {
builder.SetStartsAt(time.Now())
} else {
builder.SetStartsAt(sub.StartsAt)
}
if sub.Status != "" {
builder.SetStatus(sub.Status)
}
if !sub.AssignedAt.IsZero() {
builder.SetAssignedAt(sub.AssignedAt)
}
// Keep compatibility with historical behavior: always store notes as a string value.
builder.SetNotes(sub.Notes)
created, err := builder.Save(ctx)
if err == nil {
applyUserSubscriptionModelToService(sub, m)
applyUserSubscriptionEntityToService(sub, created)
}
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
}
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
var m userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("User").
Preload("Group").
Preload("AssignedByUser").
First(&m, id).Error
m, err := r.client.UserSubscription.Query().
Where(usersubscription.IDEQ(id)).
WithUser().
WithGroup().
WithAssignedByUser().
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return userSubscriptionModelToService(&m), nil
return userSubscriptionEntityToService(m), nil
}
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
var m userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND group_id = ?", userID, groupID).
First(&m).Error
m, err := r.client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
WithGroup().
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return userSubscriptionModelToService(&m), nil
return userSubscriptionEntityToService(m), nil
}
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
var m userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND group_id = ? AND status = ? AND expires_at > ?",
userID, groupID, service.SubscriptionStatusActive, time.Now()).
First(&m).Error
m, err := r.client.UserSubscription.Query().
Where(
usersubscription.UserIDEQ(userID),
usersubscription.GroupIDEQ(groupID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtGT(time.Now()),
).
WithGroup().
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return userSubscriptionModelToService(&m), nil
return userSubscriptionEntityToService(m), nil
}
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error {
sub.UpdatedAt = time.Now()
m := userSubscriptionModelFromService(sub)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyUserSubscriptionModelToService(sub, m)
if sub == nil {
return nil
}
return err
builder := r.client.UserSubscription.UpdateOneID(sub.ID).
SetUserID(sub.UserID).
SetGroupID(sub.GroupID).
SetStartsAt(sub.StartsAt).
SetExpiresAt(sub.ExpiresAt).
SetStatus(sub.Status).
SetNillableDailyWindowStart(sub.DailyWindowStart).
SetNillableWeeklyWindowStart(sub.WeeklyWindowStart).
SetNillableMonthlyWindowStart(sub.MonthlyWindowStart).
SetDailyUsageUsd(sub.DailyUsageUSD).
SetWeeklyUsageUsd(sub.WeeklyUsageUSD).
SetMonthlyUsageUsd(sub.MonthlyUsageUSD).
SetNillableAssignedBy(sub.AssignedBy).
SetAssignedAt(sub.AssignedAt).
SetNotes(sub.Notes)
updated, err := builder.Save(ctx)
if err == nil {
applyUserSubscriptionEntityToService(sub, updated)
return nil
}
return translatePersistenceError(err, service.ErrSubscriptionNotFound, service.ErrSubscriptionAlreadyExists)
}
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&userSubscriptionModel{}, id).Error
// Match GORM semantics: deleting a missing row is not an error.
_, err := r.client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx)
return err
}
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
var subs []userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ?", userID).
Order("created_at DESC").
Find(&subs).Error
subs, err := r.client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID)).
WithGroup().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
All(ctx)
if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
return userSubscriptionEntitiesToService(subs), nil
}
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
var subs []userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND status = ? AND expires_at > ?",
userID, service.SubscriptionStatusActive, time.Now()).
Order("created_at DESC").
Find(&subs).Error
subs, err := r.client.UserSubscription.Query().
Where(
usersubscription.UserIDEQ(userID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtGT(time.Now()),
).
WithGroup().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
All(ctx)
if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
return userSubscriptionEntitiesToService(subs), nil
}
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
var subs []userSubscriptionModel
var total int64
q := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID))
query := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).Where("group_id = ?", groupID)
if err := query.Count(&total).Error; err != nil {
return nil, nil, err
}
err := query.
Preload("User").
Preload("Group").
Order("created_at DESC").
Offset(params.Offset()).
Limit(params.Limit()).
Find(&subs).Error
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
subs, err := q.
WithUser().
WithGroup().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
Offset(params.Offset()).
Limit(params.Limit()).
All(ctx)
if err != nil {
return nil, nil, err
}
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
}
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
var subs []userSubscriptionModel
var total int64
query := r.db.WithContext(ctx).Model(&userSubscriptionModel{})
q := r.client.UserSubscription.Query()
if userID != nil {
query = query.Where("user_id = ?", *userID)
q = q.Where(usersubscription.UserIDEQ(*userID))
}
if groupID != nil {
query = query.Where("group_id = ?", *groupID)
q = q.Where(usersubscription.GroupIDEQ(*groupID))
}
if status != "" {
query = query.Where("status = ?", status)
q = q.Where(usersubscription.StatusEQ(status))
}
if err := query.Count(&total).Error; err != nil {
return nil, nil, err
}
err := query.
Preload("User").
Preload("Group").
Preload("AssignedByUser").
Order("created_at DESC").
Offset(params.Offset()).
Limit(params.Limit()).
Find(&subs).Error
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
subs, err := q.
WithUser().
WithGroup().
WithAssignedByUser().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
Offset(params.Offset()).
Limit(params.Limit()).
All(ctx)
if err != nil {
return nil, nil, err
}
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
}
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("user_id = ? AND group_id = ?", userID, groupID).
Count(&count).Error
return count > 0, err
return r.client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
Exist(ctx)
}
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", subscriptionID).
Updates(map[string]any{
"expires_at": newExpiresAt,
"updated_at": time.Now(),
}).Error
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
SetExpiresAt(newExpiresAt).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", subscriptionID).
Updates(map[string]any{
"status": status,
"updated_at": time.Now(),
}).Error
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
SetStatus(status).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", subscriptionID).
Updates(map[string]any{
"notes": notes,
"updated_at": time.Now(),
}).Error
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
SetNotes(notes).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_window_start": start,
"weekly_window_start": start,
"monthly_window_start": start,
"updated_at": time.Now(),
}).Error
_, err := r.client.UserSubscription.UpdateOneID(id).
SetDailyWindowStart(start).
SetWeeklyWindowStart(start).
SetMonthlyWindowStart(start).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_usage_usd": 0,
"daily_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
_, err := r.client.UserSubscription.UpdateOneID(id).
SetDailyUsageUsd(0).
SetDailyWindowStart(newWindowStart).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"weekly_usage_usd": 0,
"weekly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
_, err := r.client.UserSubscription.UpdateOneID(id).
SetWeeklyUsageUsd(0).
SetWeeklyWindowStart(newWindowStart).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"monthly_usage_usd": 0,
"monthly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
_, err := r.client.UserSubscription.UpdateOneID(id).
SetMonthlyUsageUsd(0).
SetMonthlyWindowStart(newWindowStart).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
"weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD),
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
"updated_at": time.Now(),
}).Error
_, err := r.client.UserSubscription.UpdateOneID(id).
AddDailyUsageUsd(costUSD).
AddWeeklyUsageUsd(costUSD).
AddMonthlyUsageUsd(costUSD).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
result := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
Updates(map[string]any{
"status": service.SubscriptionStatusExpired,
"updated_at": time.Now(),
})
return result.RowsAffected, result.Error
n, err := r.client.UserSubscription.Update().
Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(time.Now()),
).
SetStatus(service.SubscriptionStatusExpired).
Save(ctx)
return int64(n), err
}
// Extra repository helpers (currently used only by integration tests).
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) {
var subs []userSubscriptionModel
err := r.db.WithContext(ctx).
Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
Find(&subs).Error
subs, err := r.client.UserSubscription.Query().
Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(time.Now()),
).
All(ctx)
if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
return userSubscriptionEntitiesToService(subs), nil
}
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("group_id = ?", groupID).
Count(&count).Error
return count, err
count, err := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx)
return int64(count), err
}
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("group_id = ? AND status = ? AND expires_at > ?",
groupID, service.SubscriptionStatusActive, time.Now()).
Count(&count).Error
return count, err
count, err := r.client.UserSubscription.Query().
Where(
usersubscription.GroupIDEQ(groupID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtGT(time.Now()),
).
Count(ctx)
return int64(count), err
}
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&userSubscriptionModel{})
return result.RowsAffected, result.Error
n, err := r.client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx)
return int64(n), err
}
type userSubscriptionModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
GroupID int64 `gorm:"index;not null"`
StartsAt time.Time `gorm:"not null"`
ExpiresAt time.Time `gorm:"not null"`
Status string `gorm:"size:20;default:active;not null"`
DailyWindowStart *time.Time
WeeklyWindowStart *time.Time
MonthlyWindowStart *time.Time
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
AssignedBy *int64 `gorm:"index"`
AssignedAt time.Time `gorm:"not null"`
Notes string `gorm:"type:text"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
AssignedByUser *userModel `gorm:"foreignKey:AssignedBy"`
}
func (userSubscriptionModel) TableName() string { return "user_subscriptions" }
func userSubscriptionModelToService(m *userSubscriptionModel) *service.UserSubscription {
func userSubscriptionEntityToService(m *dbent.UserSubscription) *service.UserSubscription {
if m == nil {
return nil
}
return &service.UserSubscription{
out := &service.UserSubscription{
ID: m.ID,
UserID: m.UserID,
GroupID: m.GroupID,
@@ -340,60 +343,42 @@ func userSubscriptionModelToService(m *userSubscriptionModel) *service.UserSubsc
DailyWindowStart: m.DailyWindowStart,
WeeklyWindowStart: m.WeeklyWindowStart,
MonthlyWindowStart: m.MonthlyWindowStart,
DailyUsageUSD: m.DailyUsageUSD,
WeeklyUsageUSD: m.WeeklyUsageUSD,
MonthlyUsageUSD: m.MonthlyUsageUSD,
DailyUsageUSD: m.DailyUsageUsd,
WeeklyUsageUSD: m.WeeklyUsageUsd,
MonthlyUsageUSD: m.MonthlyUsageUsd,
AssignedBy: m.AssignedBy,
AssignedAt: m.AssignedAt,
Notes: m.Notes,
Notes: derefString(m.Notes),
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
AssignedByUser: userModelToService(m.AssignedByUser),
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
}
if m.Edges.Group != nil {
out.Group = groupEntityToService(m.Edges.Group)
}
if m.Edges.AssignedByUser != nil {
out.AssignedByUser = userEntityToService(m.Edges.AssignedByUser)
}
return out
}
func userSubscriptionModelsToService(models []userSubscriptionModel) []service.UserSubscription {
func userSubscriptionEntitiesToService(models []*dbent.UserSubscription) []service.UserSubscription {
out := make([]service.UserSubscription, 0, len(models))
for i := range models {
if s := userSubscriptionModelToService(&models[i]); s != nil {
if s := userSubscriptionEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func userSubscriptionModelFromService(s *service.UserSubscription) *userSubscriptionModel {
if s == nil {
return nil
}
return &userSubscriptionModel{
ID: s.ID,
UserID: s.UserID,
GroupID: s.GroupID,
StartsAt: s.StartsAt,
ExpiresAt: s.ExpiresAt,
Status: s.Status,
DailyWindowStart: s.DailyWindowStart,
WeeklyWindowStart: s.WeeklyWindowStart,
MonthlyWindowStart: s.MonthlyWindowStart,
DailyUsageUSD: s.DailyUsageUSD,
WeeklyUsageUSD: s.WeeklyUsageUSD,
MonthlyUsageUSD: s.MonthlyUsageUSD,
AssignedBy: s.AssignedBy,
AssignedAt: s.AssignedAt,
Notes: s.Notes,
CreatedAt: s.CreatedAt,
UpdatedAt: s.UpdatedAt,
}
}
func applyUserSubscriptionModelToService(sub *service.UserSubscription, m *userSubscriptionModel) {
if sub == nil || m == nil {
func applyUserSubscriptionEntityToService(dst *service.UserSubscription, src *dbent.UserSubscription) {
if dst == nil || src == nil {
return
}
sub.ID = m.ID
sub.CreatedAt = m.CreatedAt
sub.UpdatedAt = m.UpdatedAt
dst.ID = src.ID
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}

View File

@@ -7,34 +7,85 @@ import (
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type UserSubscriptionRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *userSubscriptionRepository
ctx context.Context
client *dbent.Client
repo *userSubscriptionRepository
}
func (s *UserSubscriptionRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUserSubscriptionRepository(s.db).(*userSubscriptionRepository)
client, _ := testEntSQLTx(s.T())
s.client = client
s.repo = NewUserSubscriptionRepository(s.client).(*userSubscriptionRepository)
}
func TestUserSubscriptionRepoSuite(t *testing.T) {
suite.Run(t, new(UserSubscriptionRepoSuite))
}
func (s *UserSubscriptionRepoSuite) mustCreateUser(email string, role string) *service.User {
s.T().Helper()
if role == "" {
role = service.RoleUser
}
u, err := s.client.User.Create().
SetEmail(email).
SetPasswordHash("test-password-hash").
SetStatus(service.StatusActive).
SetRole(role).
Save(s.ctx)
s.Require().NoError(err, "create user")
return userEntityToService(u)
}
func (s *UserSubscriptionRepoSuite) mustCreateGroup(name string) *service.Group {
s.T().Helper()
g, err := s.client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
Save(s.ctx)
s.Require().NoError(err, "create group")
return groupEntityToService(g)
}
func (s *UserSubscriptionRepoSuite) mustCreateSubscription(userID, groupID int64, mutate func(*dbent.UserSubscriptionCreate)) *dbent.UserSubscription {
s.T().Helper()
now := time.Now()
create := s.client.UserSubscription.Create().
SetUserID(userID).
SetGroupID(groupID).
SetStartsAt(now.Add(-1*time.Hour)).
SetExpiresAt(now.Add(24*time.Hour)).
SetStatus(service.SubscriptionStatusActive).
SetAssignedAt(now).
SetNotes("")
if mutate != nil {
mutate(create)
}
sub, err := create.Save(s.ctx)
s.Require().NoError(err, "create user subscription")
return sub
}
// --- Create / GetByID / Update / Delete ---
func (s *UserSubscriptionRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub-create@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-create"})
user := s.mustCreateUser("sub-create@test.com", service.RoleUser)
group := s.mustCreateGroup("g-create")
sub := &service.UserSubscription{
UserID: user.ID,
@@ -54,16 +105,12 @@ func (s *UserSubscriptionRepoSuite) TestCreate() {
}
func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "preload@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"})
admin := mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin})
user := s.mustCreateUser("preload@test.com", service.RoleUser)
group := s.mustCreateGroup("g-preload")
admin := s.mustCreateUser("admin@test.com", service.RoleAdmin)
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
AssignedBy: &admin.ID,
sub := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetAssignedBy(admin.ID)
})
got, err := s.repo.GetByID(s.ctx, sub.ID)
@@ -82,18 +129,15 @@ func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() {
}
func (s *UserSubscriptionRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-update"})
sub := userSubscriptionModelToService(mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}))
user := s.mustCreateUser("update@test.com", service.RoleUser)
group := s.mustCreateGroup("g-update")
created := s.mustCreateSubscription(user.ID, group.ID, nil)
sub, err := s.repo.GetByID(s.ctx, created.ID)
s.Require().NoError(err, "GetByID")
sub.Notes = "updated notes"
err := s.repo.Update(s.ctx, sub)
s.Require().NoError(err, "Update")
s.Require().NoError(s.repo.Update(s.ctx, sub), "Update")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err, "GetByID after update")
@@ -101,14 +145,9 @@ func (s *UserSubscriptionRepoSuite) TestUpdate() {
}
func (s *UserSubscriptionRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delete"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
user := s.mustCreateUser("delete@test.com", service.RoleUser)
group := s.mustCreateGroup("g-delete")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
err := s.repo.Delete(s.ctx, sub.ID)
s.Require().NoError(err, "Delete")
@@ -117,17 +156,16 @@ func (s *UserSubscriptionRepoSuite) TestDelete() {
s.Require().Error(err, "expected error after delete")
}
func (s *UserSubscriptionRepoSuite) TestDelete_Idempotent() {
s.Require().NoError(s.repo.Delete(s.ctx, 42424242), "Delete should be idempotent")
}
// --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "byuser@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-byuser"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
user := s.mustCreateUser("byuser@test.com", service.RoleUser)
group := s.mustCreateGroup("g-byuser")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
got, err := s.repo.GetByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "GetByUserIDAndGroupID")
@@ -141,15 +179,11 @@ func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() {
}
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-active"})
user := s.mustCreateUser("active@test.com", service.RoleUser)
group := s.mustCreateGroup("g-active")
// Create active subscription (future expiry)
active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour),
active := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(2 * time.Hour))
})
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
@@ -158,15 +192,11 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
}
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "expired@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-expired"})
user := s.mustCreateUser("expired@test.com", service.RoleUser)
group := s.mustCreateGroup("g-expired")
// Create expired subscription (past expiry but active status)
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour),
s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(-2 * time.Hour))
})
_, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
@@ -176,21 +206,14 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnor
// --- ListByUserID / ListActiveByUserID ---
func (s *UserSubscriptionRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list1"})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list2"})
user := s.mustCreateUser("listby@test.com", service.RoleUser)
g1 := s.mustCreateGroup("g-list1")
g2 := s.mustCreateGroup("g-list2")
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g1.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g2.ID,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetStatus(service.SubscriptionStatusExpired)
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
subs, err := s.repo.ListByUserID(s.ctx, user.ID)
@@ -202,21 +225,16 @@ func (s *UserSubscriptionRepoSuite) TestListByUserID() {
}
func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listactive@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act1"})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act2"})
user := s.mustCreateUser("listactive@test.com", service.RoleUser)
g1 := s.mustCreateGroup("g-act1")
g2 := s.mustCreateGroup("g-act2")
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g1.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
s.mustCreateSubscription(user.ID, g1.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
})
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g2.ID,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
s.mustCreateSubscription(user.ID, g2.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetStatus(service.SubscriptionStatusExpired)
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID)
@@ -228,22 +246,12 @@ func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
// --- ListByGroupID ---
func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listgrp"})
user1 := s.mustCreateUser("u1@test.com", service.RoleUser)
user2 := s.mustCreateUser("u2@test.com", service.RoleUser)
group := s.mustCreateGroup("g-listgrp")
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil)
subs, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
@@ -258,15 +266,9 @@ func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
// --- List with filters ---
func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "list@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
user := s.mustCreateUser("list@test.com", service.RoleUser)
group := s.mustCreateGroup("g-list")
s.mustCreateSubscription(user.ID, group.ID, nil)
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
s.Require().NoError(err, "List")
@@ -275,22 +277,12 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter2@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-filter"})
user1 := s.mustCreateUser("filter1@test.com", service.RoleUser)
user2 := s.mustCreateUser("filter2@test.com", service.RoleUser)
group := s.mustCreateGroup("g-filter")
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
s.Require().NoError(err)
@@ -299,22 +291,12 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "grpfilter@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f1"})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f2"})
user := s.mustCreateUser("grpfilter@test.com", service.RoleUser)
g1 := s.mustCreateGroup("g-f1")
g2 := s.mustCreateGroup("g-f2")
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g1.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g2.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
s.Require().NoError(err)
@@ -323,20 +305,18 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "statfilter@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-stat"})
user1 := s.mustCreateUser("statfilter1@test.com", service.RoleUser)
user2 := s.mustCreateUser("statfilter2@test.com", service.RoleUser)
group1 := s.mustCreateGroup("g-stat-1")
group2 := s.mustCreateGroup("g-stat-2")
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
s.mustCreateSubscription(user1.ID, group1.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetStatus(service.SubscriptionStatusActive)
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
})
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
s.mustCreateSubscription(user2.ID, group2.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetStatus(service.SubscriptionStatusExpired)
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
@@ -348,52 +328,37 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
// --- Usage tracking ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "usage@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-usage"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
user := s.mustCreateUser("usage@test.com", service.RoleUser)
group := s.mustCreateGroup("g-usage")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
err := s.repo.IncrementUsage(s.ctx, sub.ID, 1.25)
s.Require().NoError(err, "IncrementUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal(1.25, got.DailyUsageUSD)
s.Require().Equal(1.25, got.WeeklyUsageUSD)
s.Require().Equal(1.25, got.MonthlyUsageUSD)
s.Require().InDelta(1.25, got.DailyUsageUSD, 1e-6)
s.Require().InDelta(1.25, got.WeeklyUsageUSD, 1e-6)
s.Require().InDelta(1.25, got.MonthlyUsageUSD, 1e-6)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "accum@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-accum"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
user := s.mustCreateUser("accum@test.com", service.RoleUser)
group := s.mustCreateGroup("g-accum")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 1.0))
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 2.5))
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal(3.5, got.DailyUsageUSD)
s.Require().InDelta(3.5, got.DailyUsageUSD, 1e-6)
}
func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "activate@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-activate"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
user := s.mustCreateUser("activate@test.com", service.RoleUser)
group := s.mustCreateGroup("g-activate")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
activateAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
err := s.repo.ActivateWindows(s.ctx, sub.ID, activateAt)
@@ -404,19 +369,15 @@ func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
s.Require().NotNil(got.DailyWindowStart)
s.Require().NotNil(got.WeeklyWindowStart)
s.Require().NotNil(got.MonthlyWindowStart)
s.Require().True(got.DailyWindowStart.Equal(activateAt))
s.Require().WithinDuration(activateAt, *got.DailyWindowStart, time.Microsecond)
}
func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetd@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetd"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
DailyUsageUSD: 10.0,
WeeklyUsageUSD: 20.0,
user := s.mustCreateUser("resetd@test.com", service.RoleUser)
group := s.mustCreateGroup("g-resetd")
sub := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetDailyUsageUsd(10.0)
c.SetWeeklyUsageUsd(20.0)
})
resetAt := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)
@@ -425,21 +386,18 @@ func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Zero(got.DailyUsageUSD)
s.Require().Equal(20.0, got.WeeklyUsageUSD, "weekly should remain unchanged")
s.Require().True(got.DailyWindowStart.Equal(resetAt))
s.Require().InDelta(0.0, got.DailyUsageUSD, 1e-6)
s.Require().InDelta(20.0, got.WeeklyUsageUSD, 1e-6)
s.Require().NotNil(got.DailyWindowStart)
s.Require().WithinDuration(resetAt, *got.DailyWindowStart, time.Microsecond)
}
func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetw@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetw"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
WeeklyUsageUSD: 15.0,
MonthlyUsageUSD: 30.0,
user := s.mustCreateUser("resetw@test.com", service.RoleUser)
group := s.mustCreateGroup("g-resetw")
sub := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetWeeklyUsageUsd(15.0)
c.SetMonthlyUsageUsd(30.0)
})
resetAt := time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)
@@ -448,20 +406,17 @@ func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Zero(got.WeeklyUsageUSD)
s.Require().Equal(30.0, got.MonthlyUsageUSD, "monthly should remain unchanged")
s.Require().True(got.WeeklyWindowStart.Equal(resetAt))
s.Require().InDelta(0.0, got.WeeklyUsageUSD, 1e-6)
s.Require().InDelta(30.0, got.MonthlyUsageUSD, 1e-6)
s.Require().NotNil(got.WeeklyWindowStart)
s.Require().WithinDuration(resetAt, *got.WeeklyWindowStart, time.Microsecond)
}
func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetm@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetm"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
MonthlyUsageUSD: 100.0,
user := s.mustCreateUser("resetm@test.com", service.RoleUser)
group := s.mustCreateGroup("g-resetm")
sub := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetMonthlyUsageUsd(25.0)
})
resetAt := time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC)
@@ -470,21 +425,17 @@ func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Zero(got.MonthlyUsageUSD)
s.Require().True(got.MonthlyWindowStart.Equal(resetAt))
s.Require().InDelta(0.0, got.MonthlyUsageUSD, 1e-6)
s.Require().NotNil(got.MonthlyWindowStart)
s.Require().WithinDuration(resetAt, *got.MonthlyWindowStart, time.Microsecond)
}
// --- UpdateStatus / ExtendExpiry / UpdateNotes ---
func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "status@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-status"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
user := s.mustCreateUser("status@test.com", service.RoleUser)
group := s.mustCreateGroup("g-status")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
err := s.repo.UpdateStatus(s.ctx, sub.ID, service.SubscriptionStatusExpired)
s.Require().NoError(err, "UpdateStatus")
@@ -495,14 +446,9 @@ func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
}
func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "extend@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-extend"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
user := s.mustCreateUser("extend@test.com", service.RoleUser)
group := s.mustCreateGroup("g-extend")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
newExpiry := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
err := s.repo.ExtendExpiry(s.ctx, sub.ID, newExpiry)
@@ -510,18 +456,13 @@ func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().True(got.ExpiresAt.Equal(newExpiry))
s.Require().WithinDuration(newExpiry, got.ExpiresAt, time.Microsecond)
}
func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "notes@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-notes"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
user := s.mustCreateUser("notes@test.com", service.RoleUser)
group := s.mustCreateGroup("g-notes")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
err := s.repo.UpdateNotes(s.ctx, sub.ID, "VIP user")
s.Require().NoError(err, "UpdateNotes")
@@ -534,20 +475,15 @@ func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
// --- ListExpired / BatchUpdateExpiredStatus ---
func (s *UserSubscriptionRepoSuite) TestListExpired() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listexp@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listexp"})
user := s.mustCreateUser("listexp@test.com", service.RoleUser)
groupActive := s.mustCreateGroup("g-listexp-active")
groupExpired := s.mustCreateGroup("g-listexp-expired")
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
s.mustCreateSubscription(user.ID, groupActive.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
})
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour),
s.mustCreateSubscription(user.ID, groupExpired.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
expired, err := s.repo.ListExpired(s.ctx)
@@ -556,20 +492,15 @@ func (s *UserSubscriptionRepoSuite) TestListExpired() {
}
func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "batch@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-batch"})
user := s.mustCreateUser("batch@test.com", service.RoleUser)
groupFuture := s.mustCreateGroup("g-batch-future")
groupPast := s.mustCreateGroup("g-batch-past")
active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
active := s.mustCreateSubscription(user.ID, groupFuture.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
})
expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour),
expiredActive := s.mustCreateSubscription(user.ID, groupPast.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
@@ -586,15 +517,10 @@ func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
// --- ExistsByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-exists"})
user := s.mustCreateUser("exists@test.com", service.RoleUser)
group := s.mustCreateGroup("g-exists")
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
s.mustCreateSubscription(user.ID, group.ID, nil)
exists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "ExistsByUserIDAndGroupID")
@@ -608,21 +534,14 @@ func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
// --- CountByGroupID / CountActiveByGroupID ---
func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt2@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
user1 := s.mustCreateUser("cnt1@test.com", service.RoleUser)
user2 := s.mustCreateUser("cnt2@test.com", service.RoleUser)
group := s.mustCreateGroup("g-count")
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetStatus(service.SubscriptionStatusExpired)
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
@@ -631,21 +550,15 @@ func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
}
func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact2@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-cntact"})
user1 := s.mustCreateUser("cntact1@test.com", service.RoleUser)
user2 := s.mustCreateUser("cntact2@test.com", service.RoleUser)
group := s.mustCreateGroup("g-cntact")
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
s.mustCreateSubscription(user1.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
})
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time
s.mustCreateSubscription(user2.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour)) // expired by time
})
count, err := s.repo.CountActiveByGroupID(s.ctx, group.ID)
@@ -656,21 +569,12 @@ func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
// --- DeleteByGroupID ---
func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delgrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delgrp"})
user1 := s.mustCreateUser("delgrp1@test.com", service.RoleUser)
user2 := s.mustCreateUser("delgrp2@test.com", service.RoleUser)
group := s.mustCreateGroup("g-delgrp")
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil)
affected, err := s.repo.DeleteByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "DeleteByGroupID")
@@ -680,26 +584,21 @@ func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
s.Require().Zero(count)
}
// --- Combined original test ---
// --- Combined scenario ---
func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &userModel{Email: "subr@example.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-subr"})
user := s.mustCreateUser("subr@example.com", service.RoleUser)
groupActive := s.mustCreateGroup("g-subr-active")
groupExpired := s.mustCreateGroup("g-subr-expired")
active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour),
active := s.mustCreateSubscription(user.ID, groupActive.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(2 * time.Hour))
})
expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour),
expiredActive := s.mustCreateSubscription(user.ID, groupExpired.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(-2 * time.Hour))
})
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, groupActive.ID)
s.Require().NoError(err, "GetActiveByUserIDAndGroupID")
s.Require().Equal(active.ID, got.ID, "expected active subscription")
@@ -709,9 +608,9 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
after, err := s.repo.GetByID(s.ctx, active.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(1.25, after.DailyUsageUSD, "DailyUsageUSD mismatch")
s.Require().Equal(1.25, after.WeeklyUsageUSD, "WeeklyUsageUSD mismatch")
s.Require().Equal(1.25, after.MonthlyUsageUSD, "MonthlyUsageUSD mismatch")
s.Require().InDelta(1.25, after.DailyUsageUSD, 1e-6)
s.Require().InDelta(1.25, after.WeeklyUsageUSD, 1e-6)
s.Require().InDelta(1.25, after.MonthlyUsageUSD, 1e-6)
s.Require().NotNil(after.DailyWindowStart, "expected DailyWindowStart activated")
s.Require().NotNil(after.WeeklyWindowStart, "expected WeeklyWindowStart activated")
s.Require().NotNil(after.MonthlyWindowStart, "expected MonthlyWindowStart activated")
@@ -720,14 +619,16 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
s.Require().NoError(s.repo.ResetDailyUsage(s.ctx, active.ID, resetAt), "ResetDailyUsage")
afterReset, err := s.repo.GetByID(s.ctx, active.ID)
s.Require().NoError(err, "GetByID after reset")
s.Require().Equal(0.0, afterReset.DailyUsageUSD, "expected daily usage reset to 0")
s.Require().NotNil(afterReset.DailyWindowStart, "expected DailyWindowStart not nil")
s.Require().True(afterReset.DailyWindowStart.Equal(resetAt), "expected daily window start updated")
s.Require().InDelta(0.0, afterReset.DailyUsageUSD, 1e-6)
s.Require().NotNil(afterReset.DailyWindowStart)
s.Require().WithinDuration(resetAt, *afterReset.DailyWindowStart, time.Microsecond)
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
s.Require().NoError(err, "BatchUpdateExpiredStatus")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
updated, err := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().NoError(err, "GetByID expired")
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
}

View File

@@ -0,0 +1,214 @@
package middleware
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type fakeApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
}
func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error { return errors.New("not implemented") }
func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
if f.getByKey == nil {
return nil, errors.New("unexpected call")
}
return f.getByKey(ctx, key)
}
func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error { return errors.New("not implemented") }
func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error { return errors.New("not implemented") }
func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { return 0, errors.New("not implemented") }
func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { return false, errors.New("not implemented") }
func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
type googleErrorResponse struct {
Error struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
} `json:"error"`
}
func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService {
return service.NewApiKeyService(
repo,
nil, // userRepo (unused in GetByKey)
nil, // groupRepo
nil, // userSubRepo
nil, // cache
&config.Config{},
)
}
func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, errors.New("should not be called")
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
require.Equal(t, "API key is required", resp.Error.Message)
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, service.ErrApiKeyNotFound
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("Authorization", "Bearer invalid")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
require.Equal(t, "Invalid API key", resp.Error.Message)
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, errors.New("db down")
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("Authorization", "Bearer any")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusInternalServerError, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusInternalServerError, resp.Error.Code)
require.Equal(t, "Failed to validate API key", resp.Error.Message)
require.Equal(t, "INTERNAL", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.ApiKey{
ID: 1,
Key: key,
Status: service.StatusDisabled,
User: &service.User{
ID: 123,
Status: service.StatusActive,
},
}, nil
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("Authorization", "Bearer disabled")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
require.Equal(t, "API key is disabled", resp.Error.Message)
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.ApiKey{
ID: 1,
Key: key,
Status: service.StatusActive,
User: &service.User{
ID: 123,
Status: service.StatusActive,
Balance: 0,
},
}, nil
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("Authorization", "Bearer ok")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusForbidden, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusForbidden, resp.Error.Code)
require.Equal(t, "Insufficient account balance", resp.Error.Message)
require.Equal(t, "PERMISSION_DENIED", resp.Error.Status)
}

View File

@@ -3,6 +3,7 @@ package setup
import (
"context"
"crypto/rand"
"database/sql"
"encoding/hex"
"fmt"
"log"
@@ -10,13 +11,12 @@ import (
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/service"
_ "github.com/lib/pq"
"github.com/redis/go-redis/v9"
"gopkg.in/yaml.v3"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
// Config paths
@@ -92,20 +92,16 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.SSLMode,
)
db, err := gorm.Open(postgres.Open(defaultDSN), &gorm.Config{})
db, err := sql.Open("postgres", defaultDSN)
if err != nil {
return fmt.Errorf("failed to connect to PostgreSQL: %w", err)
}
sqlDB, err := db.DB()
if err != nil {
return fmt.Errorf("failed to get db instance: %w", err)
}
defer func() {
if sqlDB == nil {
if db == nil {
return
}
if err := sqlDB.Close(); err != nil {
if err := db.Close(); err != nil {
log.Printf("failed to close postgres connection: %v", err)
}
}()
@@ -113,22 +109,23 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := sqlDB.PingContext(ctx); err != nil {
if err := db.PingContext(ctx); err != nil {
return fmt.Errorf("ping failed: %w", err)
}
// Check if target database exists
var exists bool
row := sqlDB.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)", cfg.DBName)
row := db.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)", cfg.DBName)
if err := row.Scan(&exists); err != nil {
return fmt.Errorf("failed to check database existence: %w", err)
}
// Create database if not exists
if !exists {
// 注意:数据库名不能参数化,依赖前置输入校验保障安全。
// Note: Database names cannot be parameterized, but we've already validated cfg.DBName
// in the handler using validateDBName() which only allows [a-zA-Z][a-zA-Z0-9_]*
_, err := sqlDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", cfg.DBName))
_, err := db.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", cfg.DBName))
if err != nil {
return fmt.Errorf("failed to create database '%s': %w", cfg.DBName, err)
}
@@ -136,27 +133,23 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
}
// Now connect to the target database to verify
if err := sqlDB.Close(); err != nil {
if err := db.Close(); err != nil {
log.Printf("failed to close postgres connection: %v", err)
}
sqlDB = nil
db = nil
targetDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName, cfg.SSLMode,
)
targetDB, err := gorm.Open(postgres.Open(targetDSN), &gorm.Config{})
targetDB, err := sql.Open("postgres", targetDSN)
if err != nil {
return fmt.Errorf("failed to connect to database '%s': %w", cfg.DBName, err)
}
targetSqlDB, err := targetDB.DB()
if err != nil {
return fmt.Errorf("failed to get target db instance: %w", err)
}
defer func() {
if err := targetSqlDB.Close(); err != nil {
if err := targetDB.Close(); err != nil {
log.Printf("failed to close postgres connection: %v", err)
}
}()
@@ -164,7 +157,7 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel2()
if err := targetSqlDB.PingContext(ctx2); err != nil {
if err := targetDB.PingContext(ctx2); err != nil {
return fmt.Errorf("ping target database failed: %w", err)
}
@@ -256,22 +249,18 @@ func initializeDatabase(cfg *SetupConfig) error {
cfg.Database.Password, cfg.Database.DBName, cfg.Database.SSLMode,
)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
db, err := sql.Open("postgres", dsn)
if err != nil {
return err
}
sqlDB, err := db.DB()
if err != nil {
return err
}
defer func() {
if err := sqlDB.Close(); err != nil {
if err := db.Close(); err != nil {
log.Printf("failed to close postgres connection: %v", err)
}
}()
return repository.AutoMigrate(db)
return infrastructure.ApplyMigrations(context.Background(), db)
}
func createAdminUser(cfg *SetupConfig) error {
@@ -281,24 +270,24 @@ func createAdminUser(cfg *SetupConfig) error {
cfg.Database.Password, cfg.Database.DBName, cfg.Database.SSLMode,
)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
db, err := sql.Open("postgres", dsn)
if err != nil {
return err
}
sqlDB, err := db.DB()
if err != nil {
return err
}
defer func() {
if err := sqlDB.Close(); err != nil {
if err := db.Close(); err != nil {
log.Printf("failed to close postgres connection: %v", err)
}
}()
// 使用超时上下文避免安装流程因数据库异常而长时间阻塞。
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Check if admin already exists
var count int64
if err := db.Table("users").Where("role = ?", service.RoleAdmin).Count(&count).Error; err != nil {
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM users WHERE role = $1", service.RoleAdmin).Scan(&count); err != nil {
return err
}
if count > 0 {
@@ -319,7 +308,20 @@ func createAdminUser(cfg *SetupConfig) error {
return err
}
return repository.NewUserRepository(db).Create(context.Background(), admin)
_, err = db.ExecContext(
ctx,
`INSERT INTO users (email, password_hash, role, balance, concurrency, status, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
admin.Email,
admin.PasswordHash,
admin.Role,
admin.Balance,
admin.Concurrency,
admin.Status,
admin.CreatedAt,
admin.UpdatedAt,
)
return err
}
func writeConfigFile(cfg *SetupConfig) error {
@@ -339,7 +341,10 @@ func writeConfigFile(cfg *SetupConfig) error {
ExpireHour int `yaml:"expire_hour"`
} `yaml:"jwt"`
Default struct {
GroupID uint `yaml:"group_id"`
UserConcurrency int `yaml:"user_concurrency"`
UserBalance float64 `yaml:"user_balance"`
ApiKeyPrefix string `yaml:"api_key_prefix"`
RateMultiplier float64 `yaml:"rate_multiplier"`
} `yaml:"default"`
RateLimit struct {
RequestsPerMinute int `yaml:"requests_per_minute"`
@@ -358,9 +363,15 @@ func writeConfigFile(cfg *SetupConfig) error {
ExpireHour: cfg.JWT.ExpireHour,
},
Default: struct {
GroupID uint `yaml:"group_id"`
UserConcurrency int `yaml:"user_concurrency"`
UserBalance float64 `yaml:"user_balance"`
ApiKeyPrefix string `yaml:"api_key_prefix"`
RateMultiplier float64 `yaml:"rate_multiplier"`
}{
GroupID: 1,
UserConcurrency: 5,
UserBalance: 0,
ApiKeyPrefix: "sk-",
RateMultiplier: 1.0,
},
RateLimit: struct {
RequestsPerMinute int `yaml:"requests_per_minute"`