refactor: 调整项目结构为单向依赖

This commit is contained in:
Forest
2025-12-26 15:40:24 +08:00
parent b3463769dc
commit e5a77853b0
98 changed files with 5503 additions and 3352 deletions

View File

@@ -5,10 +5,10 @@ import (
"errors"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
@@ -21,69 +21,66 @@ func NewAccountRepository(db *gorm.DB) service.AccountRepository {
return &accountRepository{db: db}
}
func (r *accountRepository) Create(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Create(account).Error
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
m := accountModelFromService(account)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyAccountModelToService(account, m)
}
return err
}
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
var account model.Account
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Account, error) {
var m accountModel
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
// 填充 GroupIDs 和 Groups 虚拟字段
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
account.Groups = make([]*model.Group, 0, len(account.AccountGroups))
for _, ag := range account.AccountGroups {
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
}
}
return &account, nil
return accountModelToService(&m), nil
}
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
if crsAccountID == "" {
return nil, nil
}
var account model.Account
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&account).Error
var m accountModel
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&m).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &account, nil
return accountModelToService(&m), nil
}
func (r *accountRepository) Update(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Save(account).Error
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
m := accountModelFromService(account)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyAccountModelToService(account, m)
}
return err
}
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
// 先删除账号与分组的绑定关系
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&accountGroupModel{}).Error; err != nil {
return err
}
// 再删除账号
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
return r.db.WithContext(ctx).Delete(&accountModel{}, id).Error
}
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "")
}
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
var accounts []model.Account
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
var accounts []accountModel
var total int64
db := r.db.WithContext(ctx).Model(&model.Account{})
db := r.db.WithContext(ctx).Model(&accountModel{})
// Apply filters
if platform != "" {
db = db.Where("platform = ?", platform)
}
@@ -106,67 +103,84 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
return nil, nil, err
}
// 填充每个 Account 的虚拟字段GroupIDs 和 Groups
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups))
accounts[i].Groups = make([]*model.Group, 0, len(accounts[i].AccountGroups))
for _, ag := range accounts[i].AccountGroups {
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
if ag.Group != nil {
accounts[i].Groups = append(accounts[i].Groups, ag.Group)
}
}
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return accounts, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outAccounts, paginationResultFromTotal(total, params), nil
}
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, model.StatusActive).
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, service.StatusActive).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListActive(ctx context.Context) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive).
Where("status = ?", service.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, service.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Update("last_used_at", now).Error
}
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{
"status": model.StatusError,
"status": service.StatusError,
"error_message": errorMsg,
}).Error
}
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &model.AccountGroup{
ag := &accountGroupModel{
AccountID: accountID,
GroupID: groupID,
Priority: priority,
@@ -176,131 +190,148 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
Delete(&model.AccountGroup{}).Error
Delete(&accountGroupModel{}).Error
}
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
var groups []model.Group
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
var groups []groupModel
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.group_id = groups.id").
Where("account_groups.account_id = ?", accountID).
Find(&groups).Error
return groups, err
}
if err != nil {
return nil, err
}
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, model.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *groupModelToService(&groups[i]))
}
return outGroups, nil
}
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
// 删除现有绑定
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&accountGroupModel{}).Error; err != nil {
return err
}
// 添加新绑定
if len(groupIDs) > 0 {
accountGroups := make([]model.AccountGroup, 0, len(groupIDs))
for i, groupID := range groupIDs {
accountGroups = append(accountGroups, model.AccountGroup{
AccountID: accountID,
GroupID: groupID,
Priority: i + 1, // 使用索引作为优先级
})
}
return r.db.WithContext(ctx).Create(&accountGroups).Error
if len(groupIDs) == 0 {
return nil
}
return nil
accountGroups := make([]accountGroupModel, 0, len(groupIDs))
for i, groupID := range groupIDs {
accountGroups = append(accountGroups, accountGroupModel{
AccountID: accountID,
GroupID: groupID,
Priority: i + 1,
})
}
return r.db.WithContext(ctx).Create(&accountGroups).Error
}
// ListSchedulable 获取所有可调度的账号
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Where("status = ? AND schedulable = ?", model.StatusActive, true).
Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
// ListSchedulableByGroupID 按组获取可调度的账号
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
// ListSchedulableByPlatform 按平台获取可调度的账号
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Where("platform = ?", platform).
Where("status = ? AND schedulable = ?", model.StatusActive, true).
Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.platform = ?", platform).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
// SetRateLimited 标记账号为限流状态(429)
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{
"rate_limited_at": now,
"rate_limit_reset_at": resetAt,
}).Error
}
// SetOverloaded 标记账号为过载状态(529)
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("overload_until", until).Error
}
// ClearRateLimit 清除账号的限流状态
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{
"rate_limited_at": nil,
"rate_limit_reset_at": nil,
@@ -308,7 +339,6 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
}).Error
}
// UpdateSessionWindow 更新账号的5小时时间窗口信息
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]any{
"session_window_status": status,
@@ -319,45 +349,35 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
if end != nil {
updates["session_window_end"] = end
}
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Updates(updates).Error
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Updates(updates).Error
}
// SetSchedulable 设置账号的调度开关
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("schedulable", schedulable).Error
}
// UpdateExtra updates specific fields in account's Extra JSONB field
// It merges the updates into existing Extra data without overwriting other fields
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 {
return nil
}
// Get current account to preserve existing Extra data
var account model.Account
var account accountModel
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil {
return err
}
// Initialize Extra if nil
if account.Extra == nil {
account.Extra = make(model.JSONB)
account.Extra = datatypes.JSONMap{}
}
// Merge updates into existing Extra
for k, v := range updates {
account.Extra[k] = v
}
// Save updated Extra
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("extra", account.Extra).Error
}
// BulkUpdate updates multiple accounts with the provided fields.
// It merges credentials/extra JSONB fields instead of overwriting them.
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 {
return 0, nil
@@ -381,10 +401,10 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
updateMap["status"] = *updates.Status
}
if len(updates.Credentials) > 0 {
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", updates.Credentials)
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", datatypes.JSONMap(updates.Credentials))
}
if len(updates.Extra) > 0 {
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", updates.Extra)
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", datatypes.JSONMap(updates.Extra))
}
if len(updateMap) == 0 {
@@ -392,10 +412,178 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
}
result := r.db.WithContext(ctx).
Model(&model.Account{}).
Model(&accountModel{}).
Where("id IN ?", ids).
Clauses(clause.Returning{}).
Updates(updateMap)
return result.RowsAffected, result.Error
}
type accountModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"size:100;not null"`
Platform string `gorm:"size:50;not null"`
Type string `gorm:"size:20;not null"`
Credentials datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"`
Extra datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"`
ProxyID *int64 `gorm:"index"`
Concurrency int `gorm:"default:3;not null"`
Priority int `gorm:"default:50;not null"`
Status string `gorm:"size:20;default:active;not null"`
ErrorMessage string `gorm:"type:text"`
LastUsedAt *time.Time `gorm:"index"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
Schedulable bool `gorm:"default:true;not null"`
RateLimitedAt *time.Time `gorm:"index"`
RateLimitResetAt *time.Time `gorm:"index"`
OverloadUntil *time.Time `gorm:"index"`
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string `gorm:"size:20"`
Proxy *proxyModel `gorm:"foreignKey:ProxyID"`
AccountGroups []accountGroupModel `gorm:"foreignKey:AccountID"`
}
func (accountModel) TableName() string { return "accounts" }
type accountGroupModel struct {
AccountID int64 `gorm:"primaryKey"`
GroupID int64 `gorm:"primaryKey"`
Priority int `gorm:"default:50;not null"`
CreatedAt time.Time `gorm:"not null"`
Account *accountModel `gorm:"foreignKey:AccountID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (accountGroupModel) TableName() string { return "account_groups" }
func accountGroupModelToService(m *accountGroupModel) *service.AccountGroup {
if m == nil {
return nil
}
return &service.AccountGroup{
AccountID: m.AccountID,
GroupID: m.GroupID,
Priority: m.Priority,
CreatedAt: m.CreatedAt,
Account: accountModelToService(m.Account),
Group: groupModelToService(m.Group),
}
}
func accountModelToService(m *accountModel) *service.Account {
if m == nil {
return nil
}
var credentials map[string]any
if m.Credentials != nil {
credentials = map[string]any(m.Credentials)
}
var extra map[string]any
if m.Extra != nil {
extra = map[string]any(m.Extra)
}
account := &service.Account{
ID: m.ID,
Name: m.Name,
Platform: m.Platform,
Type: m.Type,
Credentials: credentials,
Extra: extra,
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
Status: m.Status,
ErrorMessage: m.ErrorMessage,
LastUsedAt: m.LastUsedAt,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable,
RateLimitedAt: m.RateLimitedAt,
RateLimitResetAt: m.RateLimitResetAt,
OverloadUntil: m.OverloadUntil,
SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: m.SessionWindowStatus,
Proxy: proxyModelToService(m.Proxy),
}
if len(m.AccountGroups) > 0 {
account.AccountGroups = make([]service.AccountGroup, 0, len(m.AccountGroups))
account.GroupIDs = make([]int64, 0, len(m.AccountGroups))
account.Groups = make([]*service.Group, 0, len(m.AccountGroups))
for i := range m.AccountGroups {
ag := accountGroupModelToService(&m.AccountGroups[i])
if ag == nil {
continue
}
account.AccountGroups = append(account.AccountGroups, *ag)
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
}
}
}
return account
}
func accountModelFromService(a *service.Account) *accountModel {
if a == nil {
return nil
}
var credentials datatypes.JSONMap
if a.Credentials != nil {
credentials = datatypes.JSONMap(a.Credentials)
}
var extra datatypes.JSONMap
if a.Extra != nil {
extra = datatypes.JSONMap(a.Extra)
}
return &accountModel{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: credentials,
Extra: extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
}
}
func applyAccountModelToService(account *service.Account, m *accountModel) {
if account == nil || m == nil {
return
}
account.ID = m.ID
account.CreatedAt = m.CreatedAt
account.UpdatedAt = m.UpdatedAt
}

View File

@@ -7,10 +7,10 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/datatypes"
"gorm.io/gorm"
)
@@ -34,11 +34,16 @@ func TestAccountRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete ---
func (s *AccountRepoSuite) TestCreate() {
account := &model.Account{
Name: "test-create",
Platform: model.PlatformAnthropic,
Type: model.AccountTypeOAuth,
Status: model.StatusActive,
account := &service.Account{
Name: "test-create",
Platform: service.PlatformAnthropic,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Credentials: map[string]any{},
Extra: map[string]any{},
Concurrency: 3,
Priority: 50,
Schedulable: true,
}
err := s.repo.Create(s.ctx, account)
@@ -56,7 +61,7 @@ func (s *AccountRepoSuite) TestGetByID_NotFound() {
}
func (s *AccountRepoSuite) TestUpdate() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "original"})
account := accountModelToService(mustCreateAccount(s.T(), s.db, &accountModel{Name: "original"}))
account.Name = "updated"
err := s.repo.Update(s.ctx, account)
@@ -68,7 +73,7 @@ func (s *AccountRepoSuite) TestUpdate() {
}
func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "to-delete"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete")
@@ -78,23 +83,23 @@ func (s *AccountRepoSuite) TestDelete() {
}
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete should cascade remove bindings")
var count int64
s.db.Model(&model.AccountGroup{}).Where("account_id = ?", account.ID).Count(&count)
s.db.Model(&accountGroupModel{}).Where("account_id = ?", account.ID).Count(&count)
s.Require().Zero(count, "expected bindings to be removed")
}
// --- List / ListWithFilters ---
func (s *AccountRepoSuite) TestList() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc2"})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc2"})
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
@@ -111,53 +116,53 @@ func (s *AccountRepoSuite) TestListWithFilters() {
status string
search string
wantCount int
validate func(accounts []model.Account)
validate func(accounts []service.Account)
}{
{
name: "filter_by_platform",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic})
mustCreateAccount(s.T(), db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI})
mustCreateAccount(s.T(), db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic})
mustCreateAccount(s.T(), db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI})
},
platform: model.PlatformOpenAI,
platform: service.PlatformOpenAI,
wantCount: 1,
validate: func(accounts []model.Account) {
s.Require().Equal(model.PlatformOpenAI, accounts[0].Platform)
validate: func(accounts []service.Account) {
s.Require().Equal(service.PlatformOpenAI, accounts[0].Platform)
},
},
{
name: "filter_by_type",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "t1", Type: model.AccountTypeOAuth})
mustCreateAccount(s.T(), db, &model.Account{Name: "t2", Type: model.AccountTypeApiKey})
mustCreateAccount(s.T(), db, &accountModel{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), db, &accountModel{Name: "t2", Type: service.AccountTypeApiKey})
},
accType: model.AccountTypeApiKey,
accType: service.AccountTypeApiKey,
wantCount: 1,
validate: func(accounts []model.Account) {
s.Require().Equal(model.AccountTypeApiKey, accounts[0].Type)
validate: func(accounts []service.Account) {
s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type)
},
},
{
name: "filter_by_status",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "s1", Status: model.StatusActive})
mustCreateAccount(s.T(), db, &model.Account{Name: "s2", Status: model.StatusDisabled})
mustCreateAccount(s.T(), db, &accountModel{Name: "s1", Status: service.StatusActive})
mustCreateAccount(s.T(), db, &accountModel{Name: "s2", Status: service.StatusDisabled})
},
status: model.StatusDisabled,
status: service.StatusDisabled,
wantCount: 1,
validate: func(accounts []model.Account) {
s.Require().Equal(model.StatusDisabled, accounts[0].Status)
validate: func(accounts []service.Account) {
s.Require().Equal(service.StatusDisabled, accounts[0].Status)
},
},
{
name: "filter_by_search",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "alpha-account"})
mustCreateAccount(s.T(), db, &model.Account{Name: "beta-account"})
mustCreateAccount(s.T(), db, &accountModel{Name: "alpha-account"})
mustCreateAccount(s.T(), db, &accountModel{Name: "beta-account"})
},
search: "alpha",
wantCount: 1,
validate: func(accounts []model.Account) {
validate: func(accounts []service.Account) {
s.Require().Contains(accounts[0].Name, "alpha")
},
},
@@ -185,9 +190,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
// --- ListByGroup / ListActive / ListByPlatform ---
func (s *AccountRepoSuite) TestListByGroup() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
acc1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Status: model.StatusActive})
acc2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Status: model.StatusActive})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
acc1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Status: service.StatusActive})
acc2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Status: service.StatusActive})
mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2)
mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1)
@@ -199,8 +204,8 @@ func (s *AccountRepoSuite) TestListByGroup() {
}
func (s *AccountRepoSuite) TestListActive() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "active1", Status: model.StatusActive})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "inactive1", Status: model.StatusDisabled})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "active1", Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "inactive1", Status: service.StatusDisabled})
accounts, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
@@ -209,22 +214,22 @@ func (s *AccountRepoSuite) TestListActive() {
}
func (s *AccountRepoSuite) TestListByPlatform() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p1", Platform: model.PlatformAnthropic, Status: model.StatusActive})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p2", Platform: model.PlatformOpenAI, Status: model.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
accounts, err := s.repo.ListByPlatform(s.ctx, model.PlatformAnthropic)
accounts, err := s.repo.ListByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListByPlatform")
s.Require().Len(accounts, 1)
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform)
s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
}
// --- Preload and VirtualFields ---
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
account := mustCreateAccount(s.T(), s.db, &model.Account{
account := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc1",
ProxyID: &proxy.ID,
})
@@ -252,9 +257,9 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc"})
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc"})
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
groups, err := s.repo.GetGroups(s.ctx, account.ID)
@@ -274,8 +279,8 @@ func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
}
func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-empty"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-empty"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
@@ -289,13 +294,13 @@ func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
func (s *AccountRepoSuite) TestListSchedulable() {
now := time.Now()
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true})
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
sched, err := s.repo.ListSchedulable(s.ctx)
@@ -307,16 +312,16 @@ func (s *AccountRepoSuite) TestListSchedulable() {
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
now := time.Now()
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true})
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
rateLimited := mustCreateAccount(s.T(), s.db, &model.Account{Name: "rl", Schedulable: true})
rateLimited := mustCreateAccount(s.T(), s.db, &accountModel{Name: "rl", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
@@ -334,30 +339,30 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_Statu
}
func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, model.PlatformAnthropic)
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform)
s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
}
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sp"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sp"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, model.PlatformAnthropic)
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, service.PlatformAnthropic)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(a1.ID, accounts[0].ID)
}
func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-sched", Schedulable: true})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-sched", Schedulable: true})
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
@@ -369,7 +374,7 @@ func (s *AccountRepoSuite) TestSetSchedulable() {
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
func (s *AccountRepoSuite) TestSetOverloaded() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-over"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-over"})
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
@@ -381,7 +386,7 @@ func (s *AccountRepoSuite) TestSetOverloaded() {
}
func (s *AccountRepoSuite) TestSetRateLimited() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-rl"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-rl"})
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
@@ -394,7 +399,7 @@ func (s *AccountRepoSuite) TestSetRateLimited() {
}
func (s *AccountRepoSuite) TestClearRateLimit() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-clear"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-clear"})
until := time.Now().Add(1 * time.Hour)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
@@ -411,7 +416,7 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
// --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-used"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-used"})
s.Require().Nil(account.LastUsedAt)
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
@@ -424,20 +429,20 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() {
// --- SetError ---
func (s *AccountRepoSuite) TestSetError() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-err", Status: model.StatusActive})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-err", Status: service.StatusActive})
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Equal(model.StatusError, got.Status)
s.Require().Equal(service.StatusError, got.Status)
s.Require().Equal("something went wrong", got.ErrorMessage)
}
// --- UpdateSessionWindow ---
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-win"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-win"})
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
@@ -453,9 +458,9 @@ func (s *AccountRepoSuite) TestUpdateSessionWindow() {
// --- UpdateExtra ---
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
account := mustCreateAccount(s.T(), s.db, &model.Account{
account := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc-extra",
Extra: model.JSONB{"a": "1"},
Extra: datatypes.JSONMap{"a": "1"},
})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
@@ -466,12 +471,12 @@ func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
}
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-extra-empty"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-extra-empty"})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
}
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-nil-extra", Extra: nil})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-nil-extra", Extra: nil})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
got, err := s.repo.GetByID(s.ctx, account.ID)
@@ -483,9 +488,9 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
crsID := "crs-12345"
mustCreateAccount(s.T(), s.db, &model.Account{
mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc-crs",
Extra: model.JSONB{"crs_account_id": crsID},
Extra: datatypes.JSONMap{"crs_account_id": crsID},
})
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
@@ -509,8 +514,8 @@ func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
// --- BulkUpdate ---
func (s *AccountRepoSuite) TestBulkUpdate() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk1", Priority: 1})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk2", Priority: 1})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk1", Priority: 1})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk2", Priority: 1})
newPriority := 99
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
@@ -526,13 +531,13 @@ func (s *AccountRepoSuite) TestBulkUpdate() {
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{
a1 := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "bulk-cred",
Credentials: model.JSONB{"existing": "value"},
Credentials: datatypes.JSONMap{"existing": "value"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Credentials: model.JSONB{"new_key": "new_value"},
Credentials: datatypes.JSONMap{"new_key": "new_value"},
})
s.Require().NoError(err)
@@ -542,13 +547,13 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{
a1 := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "bulk-extra",
Extra: model.JSONB{"existing": "val"},
Extra: datatypes.JSONMap{"existing": "val"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Extra: model.JSONB{"new_key": "new_val"},
Extra: datatypes.JSONMap{"new_key": "new_val"},
})
s.Require().NoError(err)
@@ -564,14 +569,14 @@ func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
}
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk-empty"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk-empty"})
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
s.Require().NoError(err)
s.Require().Zero(affected)
}
func idsOfAccounts(accounts []model.Account) []int64 {
func idsOfAccounts(accounts []service.Account) []int64 {
out := make([]int64, 0, len(accounts))
for i := range accounts {
out = append(out, accounts[i].ID)

View File

@@ -2,10 +2,10 @@ package repository
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
@@ -19,42 +19,51 @@ func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
return &apiKeyRepository{db: db}
}
func (r *apiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
err := r.db.WithContext(ctx).Create(key).Error
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyApiKeyModelToService(key, m)
}
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
}
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
var key model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
}
return &key, nil
return apiKeyModelToService(&m), nil
}
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
var apiKey model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&m).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
}
return &apiKey, nil
return apiKeyModelToService(&m), nil
}
func (r *apiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Model(m).Select("name", "group_id", "status", "updated_at").Updates(m).Error
if err == nil {
applyApiKeyModelToService(key, m)
}
return err
}
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
return r.db.WithContext(ctx).Delete(&apiKeyModel{}, id).Error
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []apiKeyModel
var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID)
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
@@ -64,36 +73,31 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
}
return keys, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outKeys, paginationResultFromTotal(total, params), nil
}
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error
return count, err
}
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("key = ?", key).Count(&count).Error
return count > 0, err
}
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []apiKeyModel
var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID)
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
@@ -103,24 +107,19 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
}
return keys, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outKeys, paginationResultFromTotal(total, params), nil
}
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
var keys []model.ApiKey
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
var keys []apiKeyModel
db := r.db.WithContext(ctx).Model(&model.ApiKey{})
db := r.db.WithContext(ctx).Model(&apiKeyModel{})
if userID > 0 {
db = db.Where("user_id = ?", userID)
@@ -135,12 +134,16 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
return nil, err
}
return keys, nil
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
}
return outKeys, nil
}
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.ApiKey{}).
result := r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("group_id = ?", groupID).
Update("group_id", nil)
return result.RowsAffected, result.Error
@@ -149,6 +152,66 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
// CountByGroupID 获取分组的 API Key 数量
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err
}
type apiKeyModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
Key string `gorm:"uniqueIndex;size:128;not null"`
Name string `gorm:"size:100;not null"`
GroupID *int64 `gorm:"index"`
Status string `gorm:"size:20;default:active;not null"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (apiKeyModel) TableName() string { return "api_keys" }
func apiKeyModelToService(m *apiKeyModel) *service.ApiKey {
if m == nil {
return nil
}
return &service.ApiKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
GroupID: m.GroupID,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
}
}
func apiKeyModelFromService(k *service.ApiKey) *apiKeyModel {
if k == nil {
return nil
}
return &apiKeyModel{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
}
}
func applyApiKeyModelToService(key *service.ApiKey, m *apiKeyModel) {
if key == nil || m == nil {
return
}
key.ID = m.ID
key.CreatedAt = m.CreatedAt
key.UpdatedAt = m.UpdatedAt
}

View File

@@ -6,8 +6,8 @@ import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
@@ -32,13 +32,13 @@ func TestApiKeyRepoSuite(t *testing.T) {
// --- Create / GetByID / GetByKey ---
func (s *ApiKeyRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
key := &model.ApiKey{
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-create-test",
Name: "Test Key",
Status: model.StatusActive,
Status: service.StatusActive,
}
err := s.repo.Create(s.ctx, key)
@@ -56,15 +56,15 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
}
func (s *ApiKeyRepoSuite) TestGetByKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbykey@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-key"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbykey@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-key"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-getbykey",
Name: "My Key",
GroupID: &group.ID,
Status: model.StatusActive,
Status: service.StatusActive,
})
got, err := s.repo.GetByKey(s.ctx, key.Key)
@@ -84,16 +84,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
// --- Update ---
func (s *ApiKeyRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-update",
Name: "Original",
Status: model.StatusActive,
})
Status: service.StatusActive,
}))
key.Name = "Renamed"
key.Status = model.StatusDisabled
key.Status = service.StatusDisabled
err := s.repo.Update(s.ctx, key)
s.Require().NoError(err, "Update")
@@ -102,18 +102,18 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
s.Require().Equal("sk-update", got.Key, "Update should not change key")
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got.Name)
s.Require().Equal(model.StatusDisabled, got.Status)
s.Require().Equal(service.StatusDisabled, got.Status)
}
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear"})
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-clear-group",
Name: "Group Key",
GroupID: &group.ID,
})
}))
key.GroupID = nil
err := s.repo.Update(s.ctx, key)
@@ -127,8 +127,8 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
// --- Delete ---
func (s *ApiKeyRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-delete",
Name: "Delete Me",
@@ -144,9 +144,9 @@ func (s *ApiKeyRepoSuite) TestDelete() {
// --- ListByUserID / CountByUserID ---
func (s *ApiKeyRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
@@ -155,9 +155,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
}
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "paging@test.com"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "paging@test.com"})
for i := 0; i < 5; i++ {
mustCreateApiKey(s.T(), s.db, &model.ApiKey{
mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-page-" + string(rune('a'+i)),
Name: "Key",
@@ -172,9 +172,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
}
func (s *ApiKeyRepoSuite) TestCountByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "count@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "count@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID")
@@ -184,12 +184,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
// --- ListByGroupID / CountByGroupID ---
func (s *ApiKeyRepoSuite) TestListByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbygroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbygroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
@@ -200,10 +200,10 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
}
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "countgroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "countgroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
@@ -213,8 +213,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
// --- ExistsByKey ---
func (s *ApiKeyRepoSuite) TestExistsByKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-exists", Name: "K"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-exists", Name: "K"})
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey")
@@ -228,9 +228,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
// --- SearchApiKeys ---
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "search@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "search@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys")
@@ -239,9 +239,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnokw@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnokw@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err)
@@ -249,8 +249,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnouid@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnouid@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err)
@@ -260,12 +260,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
// --- ClearGroupIDByGroupID ---
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear-bulk"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear-bulk"})
k1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
k1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
@@ -283,16 +283,16 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "k@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-k"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "k@example.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-k"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-test-1",
Name: "My Key",
GroupID: &group.ID,
Status: model.StatusActive,
})
Status: service.StatusActive,
}))
got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey")
@@ -303,7 +303,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(group.ID, got.Group.ID)
key.Name = "Renamed"
key.Status = model.StatusDisabled
key.Status = service.StatusDisabled
key.GroupID = nil
s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
@@ -312,7 +312,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got2.Name)
s.Require().Equal(model.StatusDisabled, got2.Status)
s.Require().Equal(service.StatusDisabled, got2.Status)
s.Require().Nil(got2.GroupID)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
@@ -330,7 +330,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-test-2",
Name: "Group Key",

View File

@@ -0,0 +1,20 @@
package repository
import "gorm.io/gorm"
// AutoMigrate runs schema migrations for all repository persistence models.
// Persistence models are defined within individual `*_repo.go` files.
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&userModel{},
&apiKeyModel{},
&groupModel{},
&accountModel{},
&accountGroupModel{},
&proxyModel{},
&redeemCodeModel{},
&usageLogModel{},
&settingModel{},
&userSubscriptionModel{},
)
}

View File

@@ -6,21 +6,25 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"gorm.io/datatypes"
"gorm.io/gorm"
)
func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
func mustCreateUser(t *testing.T, db *gorm.DB, u *userModel) *userModel {
t.Helper()
if u.PasswordHash == "" {
u.PasswordHash = "test-password-hash"
}
if u.Role == "" {
u.Role = model.RoleUser
u.Role = service.RoleUser
}
if u.Status == "" {
u.Status = model.StatusActive
u.Status = service.StatusActive
}
if u.Concurrency == 0 {
u.Concurrency = 5
}
if u.CreatedAt.IsZero() {
u.CreatedAt = time.Now()
@@ -32,16 +36,16 @@ func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
return u
}
func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
func mustCreateGroup(t *testing.T, db *gorm.DB, g *groupModel) *groupModel {
t.Helper()
if g.Platform == "" {
g.Platform = model.PlatformAnthropic
g.Platform = service.PlatformAnthropic
}
if g.Status == "" {
g.Status = model.StatusActive
g.Status = service.StatusActive
}
if g.SubscriptionType == "" {
g.SubscriptionType = model.SubscriptionTypeStandard
g.SubscriptionType = service.SubscriptionTypeStandard
}
if g.CreatedAt.IsZero() {
g.CreatedAt = time.Now()
@@ -53,7 +57,7 @@ func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
return g
}
func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
func mustCreateProxy(t *testing.T, db *gorm.DB, p *proxyModel) *proxyModel {
t.Helper()
if p.Protocol == "" {
p.Protocol = "http"
@@ -65,7 +69,7 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
p.Port = 8080
}
if p.Status == "" {
p.Status = model.StatusActive
p.Status = service.StatusActive
}
if p.CreatedAt.IsZero() {
p.CreatedAt = time.Now()
@@ -77,25 +81,25 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
return p
}
func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Account {
func mustCreateAccount(t *testing.T, db *gorm.DB, a *accountModel) *accountModel {
t.Helper()
if a.Platform == "" {
a.Platform = model.PlatformAnthropic
a.Platform = service.PlatformAnthropic
}
if a.Type == "" {
a.Type = model.AccountTypeOAuth
a.Type = service.AccountTypeOAuth
}
if a.Status == "" {
a.Status = model.StatusActive
a.Status = service.StatusActive
}
if !a.Schedulable {
a.Schedulable = true
}
if a.Credentials == nil {
a.Credentials = model.JSONB{}
a.Credentials = datatypes.JSONMap{}
}
if a.Extra == nil {
a.Extra = model.JSONB{}
a.Extra = datatypes.JSONMap{}
}
if a.CreatedAt.IsZero() {
a.CreatedAt = time.Now()
@@ -107,10 +111,10 @@ func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Accou
return a
}
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey {
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *apiKeyModel) *apiKeyModel {
t.Helper()
if k.Status == "" {
k.Status = model.StatusActive
k.Status = service.StatusActive
}
if k.CreatedAt.IsZero() {
k.CreatedAt = time.Now()
@@ -122,13 +126,13 @@ func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey
return k
}
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model.RedeemCode {
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *redeemCodeModel) *redeemCodeModel {
t.Helper()
if c.Status == "" {
c.Status = model.StatusUnused
c.Status = service.StatusUnused
}
if c.Type == "" {
c.Type = model.RedeemTypeBalance
c.Type = service.RedeemTypeBalance
}
if c.CreatedAt.IsZero() {
c.CreatedAt = time.Now()
@@ -137,10 +141,10 @@ func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model
return c
}
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription) *model.UserSubscription {
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *userSubscriptionModel) *userSubscriptionModel {
t.Helper()
if s.Status == "" {
s.Status = model.SubscriptionStatusActive
s.Status = service.SubscriptionStatusActive
}
now := time.Now()
if s.StartsAt.IsZero() {
@@ -164,9 +168,10 @@ func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription
func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) {
t.Helper()
require.NoError(t, db.Create(&model.AccountGroup{
require.NoError(t, db.Create(&accountGroupModel{
AccountID: accountID,
GroupID: groupID,
Priority: priority,
CreatedAt: time.Now(),
}).Error, "create account_group")
}

View File

@@ -2,10 +2,10 @@ package repository
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
@@ -20,38 +20,50 @@ func NewGroupRepository(db *gorm.DB) service.GroupRepository {
return &groupRepository{db: db}
}
func (r *groupRepository) Create(ctx context.Context, group *model.Group) error {
err := r.db.WithContext(ctx).Create(group).Error
func (r *groupRepository) Create(ctx context.Context, group *service.Group) error {
m := groupModelFromService(group)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyGroupModelToService(group, m)
}
return translatePersistenceError(err, nil, service.ErrGroupExists)
}
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) {
var group model.Group
err := r.db.WithContext(ctx).First(&group, id).Error
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
var m groupModel
err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
return &group, nil
group := groupModelToService(&m)
count, _ := r.GetAccountCount(ctx, group.ID)
group.AccountCount = count
return group, nil
}
func (r *groupRepository) Update(ctx context.Context, group *model.Group) error {
return r.db.WithContext(ctx).Save(group).Error
func (r *groupRepository) Update(ctx context.Context, group *service.Group) error {
m := groupModelFromService(group)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyGroupModelToService(group, m)
}
return err
}
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
return r.db.WithContext(ctx).Delete(&groupModel{}, id).Error
}
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil)
}
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
var groups []model.Group
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
var groups []groupModel
var total int64
db := r.db.WithContext(ctx).Model(&model.Group{})
db := r.db.WithContext(ctx).Model(&groupModel{})
// Apply filters
if platform != "" {
@@ -72,68 +84,71 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err
}
// 获取每个分组的账号数量
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID)
groups[i].AccountCount = count
outGroups = append(outGroups, *groupModelToService(&groups[i]))
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
// 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
}
return groups, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outGroups, paginationResultFromTotal(total, params), nil
}
func (r *groupRepository) ListActive(ctx context.Context) ([]model.Group, error) {
var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error
func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
var groups []groupModel
err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Order("id ASC").Find(&groups).Error
if err != nil {
return nil, err
}
// 获取每个分组的账号数量
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID)
groups[i].AccountCount = count
outGroups = append(outGroups, *groupModelToService(&groups[i]))
}
return groups, nil
// 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
}
return outGroups, nil
}
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
var groups []groupModel
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", service.StatusActive, platform).Order("id ASC").Find(&groups).Error
if err != nil {
return nil, err
}
// 获取每个分组的账号数量
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID)
groups[i].AccountCount = count
outGroups = append(outGroups, *groupModelToService(&groups[i]))
}
return groups, nil
// 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
}
return outGroups, nil
}
func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error
err := r.db.WithContext(ctx).Model(&groupModel{}).Where("name = ?", name).Count(&count).Error
return count > 0, err
}
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error
err := r.db.WithContext(ctx).Table("account_groups").Where("group_id = ?", groupID).Count(&count).Error
return count, err
}
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{})
result := r.db.WithContext(ctx).Exec("DELETE FROM account_groups WHERE group_id = ?", groupID)
return result.RowsAffected, result.Error
}
@@ -145,46 +160,42 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
var affectedUserIDs []int64
if group.IsSubscriptionType() {
var subscriptions []model.UserSubscription
if err := r.db.WithContext(ctx).
Model(&model.UserSubscription{}).
Table("user_subscriptions").
Where("group_id = ?", id).
Select("user_id").
Find(&subscriptions).Error; err != nil {
Pluck("user_id", &affectedUserIDs).Error; err != nil {
return nil, err
}
for _, sub := range subscriptions {
affectedUserIDs = append(affectedUserIDs, sub.UserID)
}
}
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 删除订阅类型分组的订阅记录
if group.IsSubscriptionType() {
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
if err := tx.Exec("DELETE FROM user_subscriptions WHERE group_id = ?", id).Error; err != nil {
return err
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
if err := tx.Exec("UPDATE api_keys SET group_id = NULL WHERE group_id = ?", id).Error; err != nil {
return err
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Model(&model.User{}).
Where("? = ANY(allowed_groups)", id).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
if err := tx.Exec(
"UPDATE users SET allowed_groups = array_remove(allowed_groups, ?) WHERE ? = ANY(allowed_groups)",
id, id,
).Error; err != nil {
return err
}
// 4. 删除 account_groups 中间表的数据
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
if err := tx.Exec("DELETE FROM account_groups WHERE group_id = ?", id).Error; err != nil {
return err
}
// 5. 删除分组本身(带锁,避免并发写)
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&model.Group{}, id).Error; err != nil {
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&groupModel{}, id).Error; err != nil {
return err
}
@@ -196,3 +207,75 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return affectedUserIDs, nil
}
type groupModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"uniqueIndex;size:100;not null"`
Description string `gorm:"type:text"`
Platform string `gorm:"size:50;default:anthropic;not null"`
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null"`
IsExclusive bool `gorm:"default:false;not null"`
Status string `gorm:"size:20;default:active;not null"`
SubscriptionType string `gorm:"size:20;default:standard;not null"`
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (groupModel) TableName() string { return "groups" }
func groupModelToService(m *groupModel) *service.Group {
if m == nil {
return nil
}
return &service.Group{
ID: m.ID,
Name: m.Name,
Description: m.Description,
Platform: m.Platform,
RateMultiplier: m.RateMultiplier,
IsExclusive: m.IsExclusive,
Status: m.Status,
SubscriptionType: m.SubscriptionType,
DailyLimitUSD: m.DailyLimitUSD,
WeeklyLimitUSD: m.WeeklyLimitUSD,
MonthlyLimitUSD: m.MonthlyLimitUSD,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func groupModelFromService(sg *service.Group) *groupModel {
if sg == nil {
return nil
}
return &groupModel{
ID: sg.ID,
Name: sg.Name,
Description: sg.Description,
Platform: sg.Platform,
RateMultiplier: sg.RateMultiplier,
IsExclusive: sg.IsExclusive,
Status: sg.Status,
SubscriptionType: sg.SubscriptionType,
DailyLimitUSD: sg.DailyLimitUSD,
WeeklyLimitUSD: sg.WeeklyLimitUSD,
MonthlyLimitUSD: sg.MonthlyLimitUSD,
CreatedAt: sg.CreatedAt,
UpdatedAt: sg.UpdatedAt,
}
}
func applyGroupModelToService(group *service.Group, m *groupModel) {
if group == nil || m == nil {
return
}
group.ID = m.ID
group.CreatedAt = m.CreatedAt
group.UpdatedAt = m.UpdatedAt
}

View File

@@ -6,8 +6,8 @@ import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
@@ -32,10 +32,10 @@ func TestGroupRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete ---
func (s *GroupRepoSuite) TestCreate() {
group := &model.Group{
group := &service.Group{
Name: "test-create",
Platform: model.PlatformAnthropic,
Status: model.StatusActive,
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
}
err := s.repo.Create(s.ctx, group)
@@ -53,7 +53,7 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
}
func (s *GroupRepoSuite) TestUpdate() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "original"})
group := groupModelToService(mustCreateGroup(s.T(), s.db, &groupModel{Name: "original"}))
group.Name = "updated"
err := s.repo.Update(s.ctx, group)
@@ -65,7 +65,7 @@ func (s *GroupRepoSuite) TestUpdate() {
}
func (s *GroupRepoSuite) TestDelete() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "to-delete"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete")
@@ -77,8 +77,8 @@ func (s *GroupRepoSuite) TestDelete() {
// --- List / ListWithFilters ---
func (s *GroupRepoSuite) TestList() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
@@ -87,28 +87,28 @@ func (s *GroupRepoSuite) TestList() {
}
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI})
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformOpenAI, "", nil)
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().Equal(model.PlatformOpenAI, groups[0].Platform)
s.Require().Equal(service.PlatformOpenAI, groups[0].Platform)
}
func (s *GroupRepoSuite) TestListWithFilters_Status() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Status: model.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Status: model.StatusDisabled})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Status: service.StatusDisabled})
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, nil)
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().Equal(model.StatusDisabled, groups[0].Status)
s.Require().Equal(service.StatusDisabled, groups[0].Status)
}
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", IsExclusive: false})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", IsExclusive: true})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", IsExclusive: false})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", IsExclusive: true})
isExclusive := true
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
@@ -118,24 +118,24 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
}
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := mustCreateGroup(s.T(), s.db, &model.Group{
g1 := mustCreateGroup(s.T(), s.db, &groupModel{
Name: "g1",
Platform: model.PlatformAnthropic,
Status: model.StatusActive,
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{
g2 := mustCreateGroup(s.T(), s.db, &groupModel{
Name: "g2",
Platform: model.PlatformAnthropic,
Status: model.StatusActive,
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
IsExclusive: true,
})
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"})
a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1)
isExclusive := true
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformAnthropic, model.StatusActive, &isExclusive)
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive)
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(groups, 1)
@@ -146,8 +146,8 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
// --- ListActive / ListActiveByPlatform ---
func (s *GroupRepoSuite) TestListActive() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "active1", Status: model.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "inactive1", Status: model.StatusDisabled})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "active1", Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "inactive1", Status: service.StatusDisabled})
groups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
@@ -156,11 +156,11 @@ func (s *GroupRepoSuite) TestListActive() {
}
func (s *GroupRepoSuite) TestListActiveByPlatform() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic, Status: model.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI, Status: model.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g3", Platform: model.PlatformAnthropic, Status: model.StatusDisabled})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g3", Platform: service.PlatformAnthropic, Status: service.StatusDisabled})
groups, err := s.repo.ListActiveByPlatform(s.ctx, model.PlatformAnthropic)
groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListActiveByPlatform")
s.Require().Len(groups, 1)
s.Require().Equal("g1", groups[0].Name)
@@ -169,7 +169,7 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() {
// --- ExistsByName ---
func (s *GroupRepoSuite) TestExistsByName() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "existing-group"})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "existing-group"})
exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
s.Require().NoError(err, "ExistsByName")
@@ -183,9 +183,9 @@ func (s *GroupRepoSuite) TestExistsByName() {
// --- GetAccountCount ---
func (s *GroupRepoSuite) TestGetAccountCount() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
@@ -195,7 +195,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
}
func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err)
@@ -205,8 +205,8 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
// --- DeleteAccountGroupsByGroupID ---
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"})
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"})
g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
@@ -219,10 +219,10 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
}
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-multi"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"})
a3 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"})
g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-multi"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
a3 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2)
mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3)

View File

@@ -15,7 +15,6 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
@@ -94,7 +93,7 @@ func TestMain(m *testing.M) {
log.Printf("failed to open gorm db: %v", err)
os.Exit(1)
}
if err := model.AutoMigrate(integrationDB); err != nil {
if err := AutoMigrate(integrationDB); err != nil {
log.Printf("failed to automigrate db: %v", err)
os.Exit(1)
}

View File

@@ -0,0 +1,16 @@
package repository
import "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
func paginationResultFromTotal(total int64, params pagination.PaginationParams) *pagination.PaginationResult {
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}
}

View File

@@ -2,10 +2,10 @@ package repository
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
@@ -19,37 +19,47 @@ func NewProxyRepository(db *gorm.DB) service.ProxyRepository {
return &proxyRepository{db: db}
}
func (r *proxyRepository) Create(ctx context.Context, proxy *model.Proxy) error {
return r.db.WithContext(ctx).Create(proxy).Error
func (r *proxyRepository) Create(ctx context.Context, proxy *service.Proxy) error {
m := proxyModelFromService(proxy)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyProxyModelToService(proxy, m)
}
return err
}
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
var proxy model.Proxy
err := r.db.WithContext(ctx).First(&proxy, id).Error
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy, error) {
var m proxyModel
err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil)
}
return &proxy, nil
return proxyModelToService(&m), nil
}
func (r *proxyRepository) Update(ctx context.Context, proxy *model.Proxy) error {
return r.db.WithContext(ctx).Save(proxy).Error
func (r *proxyRepository) Update(ctx context.Context, proxy *service.Proxy) error {
m := proxyModelFromService(proxy)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyProxyModelToService(proxy, m)
}
return err
}
func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
return r.db.WithContext(ctx).Delete(&proxyModel{}, id).Error
}
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
var proxies []model.Proxy
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) {
var proxies []proxyModel
var total int64
db := r.db.WithContext(ctx).Model(&model.Proxy{})
db := r.db.WithContext(ctx).Model(&proxyModel{})
// Apply filters
if protocol != "" {
@@ -71,29 +81,31 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
outProxies := make([]service.Proxy, 0, len(proxies))
for i := range proxies {
outProxies = append(outProxies, *proxyModelToService(&proxies[i]))
}
return proxies, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outProxies, paginationResultFromTotal(total, params), nil
}
func (r *proxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) {
var proxies []model.Proxy
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error
return proxies, err
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
var proxies []proxyModel
err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Find(&proxies).Error
if err != nil {
return nil, err
}
outProxies := make([]service.Proxy, 0, len(proxies))
for i := range proxies {
outProxies = append(outProxies, *proxyModelToService(&proxies[i]))
}
return outProxies, nil
}
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.Proxy{}).
err := r.db.WithContext(ctx).Model(&proxyModel{}).
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
Count(&count).Error
if err != nil {
@@ -105,7 +117,7 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.Account{}).
err := r.db.WithContext(ctx).Table("accounts").
Where("proxy_id = ?", proxyID).
Count(&count).Error
return count, err
@@ -119,7 +131,7 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
}
var results []result
err := r.db.WithContext(ctx).
Model(&model.Account{}).
Table("accounts").
Select("proxy_id, COUNT(*) as count").
Where("proxy_id IS NOT NULL").
Group("proxy_id").
@@ -136,10 +148,10 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
}
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
var proxies []model.Proxy
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
var proxies []proxyModel
err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive).
Where("status = ?", service.StatusActive).
Order("created_at DESC").
Find(&proxies).Error
if err != nil {
@@ -153,13 +165,78 @@ func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]mod
}
// Build result with account counts
result := make([]model.ProxyWithAccountCount, len(proxies))
for i, proxy := range proxies {
result[i] = model.ProxyWithAccountCount{
Proxy: proxy,
AccountCount: counts[proxy.ID],
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
proxy := proxyModelToService(&proxies[i])
if proxy == nil {
continue
}
result = append(result, service.ProxyWithAccountCount{
Proxy: *proxy,
AccountCount: counts[proxy.ID],
})
}
return result, nil
}
type proxyModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"size:100;not null"`
Protocol string `gorm:"size:20;not null"`
Host string `gorm:"size:255;not null"`
Port int `gorm:"not null"`
Username string `gorm:"size:100"`
Password string `gorm:"size:100"`
Status string `gorm:"size:20;default:active;not null"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (proxyModel) TableName() string { return "proxies" }
func proxyModelToService(m *proxyModel) *service.Proxy {
if m == nil {
return nil
}
return &service.Proxy{
ID: m.ID,
Name: m.Name,
Protocol: m.Protocol,
Host: m.Host,
Port: m.Port,
Username: m.Username,
Password: m.Password,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func proxyModelFromService(p *service.Proxy) *proxyModel {
if p == nil {
return nil
}
return &proxyModel{
ID: p.ID,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
func applyProxyModelToService(proxy *service.Proxy, m *proxyModel) {
if proxy == nil || m == nil {
return
}
proxy.ID = m.ID
proxy.CreatedAt = m.CreatedAt
proxy.UpdatedAt = m.UpdatedAt
}

View File

@@ -7,8 +7,8 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
@@ -33,12 +33,12 @@ func TestProxyRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete ---
func (s *ProxyRepoSuite) TestCreate() {
proxy := &model.Proxy{
proxy := &service.Proxy{
Name: "test-create",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: model.StatusActive,
Status: service.StatusActive,
}
err := s.repo.Create(s.ctx, proxy)
@@ -56,7 +56,7 @@ func (s *ProxyRepoSuite) TestGetByID_NotFound() {
}
func (s *ProxyRepoSuite) TestUpdate() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "original"})
proxy := proxyModelToService(mustCreateProxy(s.T(), s.db, &proxyModel{Name: "original"}))
proxy.Name = "updated"
err := s.repo.Update(s.ctx, proxy)
@@ -68,7 +68,7 @@ func (s *ProxyRepoSuite) TestUpdate() {
}
func (s *ProxyRepoSuite) TestDelete() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "to-delete"})
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, proxy.ID)
s.Require().NoError(err, "Delete")
@@ -80,8 +80,8 @@ func (s *ProxyRepoSuite) TestDelete() {
// --- List / ListWithFilters ---
func (s *ProxyRepoSuite) TestList() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"})
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
@@ -90,8 +90,8 @@ func (s *ProxyRepoSuite) TestList() {
}
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Protocol: "http"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Protocol: "socks5"})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Protocol: "http"})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Protocol: "socks5"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
s.Require().NoError(err)
@@ -100,18 +100,18 @@ func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
}
func (s *ProxyRepoSuite) TestListWithFilters_Status() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Status: model.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Status: model.StatusDisabled})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Status: service.StatusActive})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Status: service.StatusDisabled})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, "")
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Equal(model.StatusDisabled, proxies[0].Status)
s.Require().Equal(service.StatusDisabled, proxies[0].Status)
}
func (s *ProxyRepoSuite) TestListWithFilters_Search() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "production-proxy"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "dev-proxy"})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "production-proxy"})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "dev-proxy"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
s.Require().NoError(err)
@@ -122,8 +122,8 @@ func (s *ProxyRepoSuite) TestListWithFilters_Search() {
// --- ListActive ---
func (s *ProxyRepoSuite) TestListActive() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "active1", Status: model.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "inactive1", Status: model.StatusDisabled})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "active1", Status: service.StatusActive})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "inactive1", Status: service.StatusDisabled})
proxies, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
@@ -134,7 +134,7 @@ func (s *ProxyRepoSuite) TestListActive() {
// --- ExistsByHostPortAuth ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{
mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1",
Protocol: "http",
Host: "1.2.3.4",
@@ -153,7 +153,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
}
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{
mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p-noauth",
Protocol: "http",
Host: "5.6.7.8",
@@ -170,10 +170,10 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
// --- CountAccountsByProxyID ---
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-count"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) // no proxy
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-count"})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"}) // no proxy
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err, "CountAccountsByProxyID")
@@ -181,7 +181,7 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
}
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-zero"})
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-zero"})
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err)
@@ -191,12 +191,12 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
// --- GetAccountCountsForProxies ---
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
p1 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
p2 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err, "GetAccountCountsForProxies")
@@ -215,24 +215,24 @@ func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
p1 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1",
Status: model.StatusActive,
Status: service.StatusActive,
CreatedAt: base.Add(-1 * time.Hour),
})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
p2 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p2",
Status: model.StatusActive,
Status: service.StatusActive,
CreatedAt: base,
})
mustCreateProxy(s.T(), s.db, &model.Proxy{
mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p3-inactive",
Status: model.StatusDisabled,
Status: service.StatusDisabled,
})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
s.Require().NoError(err, "ListActiveWithAccountCount")
@@ -248,7 +248,7 @@ func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
// --- Combined original test ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
p1 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1",
Protocol: "http",
Host: "1.2.3.4",
@@ -258,7 +258,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
CreatedAt: time.Now().Add(-1 * time.Hour),
UpdatedAt: time.Now().Add(-1 * time.Hour),
})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
p2 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p2",
Protocol: "http",
Host: "5.6.7.8",
@@ -273,9 +273,9 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
s.Require().NoError(err, "ExistsByHostPortAuth")
s.Require().True(exists, "expected proxy to exist")
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
s.Require().NoError(err, "CountAccountsByProxyID")

View File

@@ -4,10 +4,8 @@ import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm"
)
@@ -20,48 +18,61 @@ func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository {
return &redeemCodeRepository{db: db}
}
func (r *redeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error {
return r.db.WithContext(ctx).Create(code).Error
func (r *redeemCodeRepository) Create(ctx context.Context, code *service.RedeemCode) error {
m := redeemCodeModelFromService(code)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyRedeemCodeModelToService(code, m)
}
return err
}
func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error {
return r.db.WithContext(ctx).Create(&codes).Error
func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []service.RedeemCode) error {
if len(codes) == 0 {
return nil
}
models := make([]redeemCodeModel, 0, len(codes))
for i := range codes {
m := redeemCodeModelFromService(&codes[i])
if m != nil {
models = append(models, *m)
}
}
return r.db.WithContext(ctx).Create(&models).Error
}
func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
var code model.RedeemCode
err := r.db.WithContext(ctx).First(&code, id).Error
func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) {
var m redeemCodeModel
err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
}
return &code, nil
return redeemCodeModelToService(&m), nil
}
func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
var redeemCode model.RedeemCode
err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error
func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
var m redeemCodeModel
err := r.db.WithContext(ctx).Where("code = ?", code).First(&m).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
}
return &redeemCode, nil
return redeemCodeModelToService(&m), nil
}
func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
return r.db.WithContext(ctx).Delete(&redeemCodeModel{}, id).Error
}
func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query
func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
var codes []model.RedeemCode
func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
var codes []redeemCodeModel
var total int64
db := r.db.WithContext(ctx).Model(&model.RedeemCode{})
db := r.db.WithContext(ctx).Model(&redeemCodeModel{})
// Apply filters
if codeType != "" {
db = db.Where("type = ?", codeType)
}
@@ -81,29 +92,29 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
outCodes := make([]service.RedeemCode, 0, len(codes))
for i := range codes {
outCodes = append(outCodes, *redeemCodeModelToService(&codes[i]))
}
return codes, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outCodes, paginationResultFromTotal(total, params), nil
}
func (r *redeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error {
return r.db.WithContext(ctx).Save(code).Error
func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
m := redeemCodeModelFromService(code)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyRedeemCodeModelToService(code, m)
}
return err
}
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now()
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
Where("id = ? AND status = ?", id, model.StatusUnused).
result := r.db.WithContext(ctx).Model(&redeemCodeModel{}).
Where("id = ? AND status = ?", id, service.StatusUnused).
Updates(map[string]any{
"status": model.StatusUsed,
"status": service.StatusUsed,
"used_by": userID,
"used_at": now,
})
@@ -116,22 +127,93 @@ func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error
return nil
}
// ListByUser returns all redeem codes used by a specific user
func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
var codes []model.RedeemCode
func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
if limit <= 0 {
limit = 10
}
var codes []redeemCodeModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("used_by = ?", userID).
Order("used_at DESC").
Limit(limit).
Find(&codes).Error
if err != nil {
return nil, err
}
return codes, nil
outCodes := make([]service.RedeemCode, 0, len(codes))
for i := range codes {
outCodes = append(outCodes, *redeemCodeModelToService(&codes[i]))
}
return outCodes, nil
}
type redeemCodeModel struct {
ID int64 `gorm:"primaryKey"`
Code string `gorm:"uniqueIndex;size:32;not null"`
Type string `gorm:"size:20;default:balance;not null"`
Value float64 `gorm:"type:decimal(20,8);not null"`
Status string `gorm:"size:20;default:unused;not null"`
UsedBy *int64 `gorm:"index"`
UsedAt *time.Time
Notes string `gorm:"type:text"`
CreatedAt time.Time `gorm:"not null"`
GroupID *int64 `gorm:"index"`
ValidityDays int `gorm:"default:30"`
User *userModel `gorm:"foreignKey:UsedBy"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (redeemCodeModel) TableName() string { return "redeem_codes" }
func redeemCodeModelToService(m *redeemCodeModel) *service.RedeemCode {
if m == nil {
return nil
}
return &service.RedeemCode{
ID: m.ID,
Code: m.Code,
Type: m.Type,
Value: m.Value,
Status: m.Status,
UsedBy: m.UsedBy,
UsedAt: m.UsedAt,
Notes: m.Notes,
CreatedAt: m.CreatedAt,
GroupID: m.GroupID,
ValidityDays: m.ValidityDays,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
}
}
func redeemCodeModelFromService(r *service.RedeemCode) *redeemCodeModel {
if r == nil {
return nil
}
return &redeemCodeModel{
ID: r.ID,
Code: r.Code,
Type: r.Type,
Value: r.Value,
Status: r.Status,
UsedBy: r.UsedBy,
UsedAt: r.UsedAt,
Notes: r.Notes,
CreatedAt: r.CreatedAt,
GroupID: r.GroupID,
ValidityDays: r.ValidityDays,
}
}
func applyRedeemCodeModelToService(code *service.RedeemCode, m *redeemCodeModel) {
if code == nil || m == nil {
return
}
code.ID = m.ID
code.CreatedAt = m.CreatedAt
}

View File

@@ -7,7 +7,6 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
@@ -34,11 +33,11 @@ func TestRedeemCodeRepoSuite(t *testing.T) {
// --- Create / CreateBatch / GetByID / GetByCode ---
func (s *RedeemCodeRepoSuite) TestCreate() {
code := &model.RedeemCode{
code := &service.RedeemCode{
Code: "TEST-CREATE",
Type: model.RedeemTypeBalance,
Type: service.RedeemTypeBalance,
Value: 100,
Status: model.StatusUnused,
Status: service.StatusUnused,
}
err := s.repo.Create(s.ctx, code)
@@ -51,9 +50,9 @@ func (s *RedeemCodeRepoSuite) TestCreate() {
}
func (s *RedeemCodeRepoSuite) TestCreateBatch() {
codes := []model.RedeemCode{
{Code: "BATCH-1", Type: model.RedeemTypeBalance, Value: 10, Status: model.StatusUnused},
{Code: "BATCH-2", Type: model.RedeemTypeBalance, Value: 20, Status: model.StatusUnused},
codes := []service.RedeemCode{
{Code: "BATCH-1", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused},
{Code: "BATCH-2", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused},
}
err := s.repo.CreateBatch(s.ctx, codes)
@@ -74,7 +73,7 @@ func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() {
}
func (s *RedeemCodeRepoSuite) TestGetByCode() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "GET-BY-CODE", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "GET-BY-CODE", Type: service.RedeemTypeBalance})
got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE")
s.Require().NoError(err, "GetByCode")
@@ -89,7 +88,7 @@ func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() {
// --- Delete ---
func (s *RedeemCodeRepoSuite) TestDelete() {
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TO-DELETE", Type: model.RedeemTypeBalance})
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TO-DELETE", Type: service.RedeemTypeBalance})
err := s.repo.Delete(s.ctx, code.ID)
s.Require().NoError(err, "Delete")
@@ -101,8 +100,8 @@ func (s *RedeemCodeRepoSuite) TestDelete() {
// --- List / ListWithFilters ---
func (s *RedeemCodeRepoSuite) TestList() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-1", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-2", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-1", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-2", Type: service.RedeemTypeBalance})
codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
@@ -111,28 +110,28 @@ func (s *RedeemCodeRepoSuite) TestList() {
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-BAL", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-SUB", Type: model.RedeemTypeSubscription})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-BAL", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-SUB", Type: service.RedeemTypeSubscription})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, "", "")
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, "", "")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().Equal(model.RedeemTypeSubscription, codes[0].Type)
s.Require().Equal(service.RedeemTypeSubscription, codes[0].Type)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-UNUSED", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-UNUSED", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusUsed, "")
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusUsed, "")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().Equal(model.StatusUsed, codes[0].Status)
s.Require().Equal(service.StatusUsed, codes[0].Status)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALPHA-CODE", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "BETA-CODE", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "BETA-CODE", Type: service.RedeemTypeBalance})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha")
s.Require().NoError(err)
@@ -141,10 +140,10 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "WITH-GROUP",
Type: model.RedeemTypeSubscription,
Type: service.RedeemTypeSubscription,
GroupID: &group.ID,
})
@@ -158,7 +157,7 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
// --- Update ---
func (s *RedeemCodeRepoSuite) TestUpdate() {
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "UPDATE-ME", Type: model.RedeemTypeBalance, Value: 10})
code := redeemCodeModelToService(mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "UPDATE-ME", Type: service.RedeemTypeBalance, Value: 10}))
code.Value = 50
err := s.repo.Update(s.ctx, code)
@@ -172,23 +171,23 @@ func (s *RedeemCodeRepoSuite) TestUpdate() {
// --- Use ---
func (s *RedeemCodeRepoSuite) TestUse() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "use@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "USE-ME", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "use@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "USE-ME", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use")
got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err)
s.Require().Equal(model.StatusUsed, got.Status)
s.Require().Equal(service.StatusUsed, got.Status)
s.Require().NotNil(got.UsedBy)
s.Require().Equal(user.ID, *got.UsedBy)
s.Require().NotNil(got.UsedAt)
}
func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "idem@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "IDEM-CODE", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "idem@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "IDEM-CODE", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use first time")
@@ -200,8 +199,8 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
}
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "already@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALREADY-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "already@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALREADY-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed})
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "expected error for already used code")
@@ -211,22 +210,22 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
// --- ListByUser ---
func (s *RedeemCodeRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"})
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
// Create codes with explicit used_at for ordering
c1 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
c1 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "USER-1",
Type: model.RedeemTypeBalance,
Status: model.StatusUsed,
Type: service.RedeemTypeBalance,
Status: service.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c1).Update("used_at", base)
c2 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
c2 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "USER-2",
Type: model.RedeemTypeBalance,
Status: model.StatusUsed,
Type: service.RedeemTypeBalance,
Status: service.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c2).Update("used_at", base.Add(1*time.Hour))
@@ -240,13 +239,13 @@ func (s *RedeemCodeRepoSuite) TestListByUser() {
}
func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listby"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "grp@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listby"})
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "WITH-GRP",
Type: model.RedeemTypeSubscription,
Status: model.StatusUsed,
Type: service.RedeemTypeSubscription,
Status: service.StatusUsed,
UsedBy: &user.ID,
GroupID: &group.ID,
})
@@ -260,11 +259,11 @@ func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
}
func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deflimit@test.com"})
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "deflimit@test.com"})
c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "DEF-LIM",
Type: model.RedeemTypeBalance,
Status: model.StatusUsed,
Type: service.RedeemTypeBalance,
Status: service.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c).Update("used_at", time.Now())
@@ -278,16 +277,16 @@ func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
// --- Combined original test ---
func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "rc@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-rc"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "rc@example.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-rc"})
codes := []model.RedeemCode{
{Code: "CODEA", Type: model.RedeemTypeBalance, Value: 1, Status: model.StatusUnused, CreatedAt: time.Now()},
{Code: "CODEB", Type: model.RedeemTypeSubscription, Value: 0, Status: model.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()},
codes := []service.RedeemCode{
{Code: "CODEA", Type: service.RedeemTypeBalance, Value: 1, Status: service.StatusUnused, CreatedAt: time.Now()},
{Code: "CODEB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()},
}
s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch")
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, model.StatusUnused, "code")
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, service.StatusUnused, "code")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(list, 1)
@@ -305,9 +304,9 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser
s.Require().NoError(err, "GetByCode")
// Use fixed time instead of time.Sleep for deterministic ordering
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC))
s.db.Model(&redeemCodeModel{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC))
s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA")
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC))
s.db.Model(&redeemCodeModel{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC))
used, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err, "ListByUser")

View File

@@ -6,33 +6,27 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// SettingRepository 系统设置数据访问层
type settingRepository struct {
db *gorm.DB
}
// NewSettingRepository 创建系统设置仓库实例
func NewSettingRepository(db *gorm.DB) service.SettingRepository {
return &settingRepository{db: db}
}
// Get 根据Key获取设置值
func (r *settingRepository) Get(ctx context.Context, key string) (*model.Setting, error) {
var setting model.Setting
err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error
func (r *settingRepository) Get(ctx context.Context, key string) (*service.Setting, error) {
var m settingModel
err := r.db.WithContext(ctx).Where("key = ?", key).First(&m).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil)
}
return &setting, nil
return settingModelToService(&m), nil
}
// GetValue 获取设置值字符串
func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) {
setting, err := r.Get(ctx, key)
if err != nil {
@@ -41,9 +35,8 @@ func (r *settingRepository) GetValue(ctx context.Context, key string) (string, e
return setting.Value, nil
}
// Set 设置值(存在则更新,不存在则创建)
func (r *settingRepository) Set(ctx context.Context, key, value string) error {
setting := &model.Setting{
m := &settingModel{
Key: key,
Value: value,
UpdatedAt: time.Now(),
@@ -52,12 +45,11 @@ func (r *settingRepository) Set(ctx context.Context, key, value string) error {
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
}).Create(setting).Error
}).Create(m).Error
}
// GetMultiple 批量获取设置
func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
var settings []model.Setting
var settings []settingModel
err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error
if err != nil {
return nil, err
@@ -70,11 +62,10 @@ func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map
return result, nil
}
// SetMultiple 批量设置值
func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for key, value := range settings {
setting := &model.Setting{
m := &settingModel{
Key: key,
Value: value,
UpdatedAt: time.Now(),
@@ -82,7 +73,7 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string
if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
}).Create(setting).Error; err != nil {
}).Create(m).Error; err != nil {
return err
}
}
@@ -90,9 +81,8 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string
})
}
// GetAll 获取所有设置
func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) {
var settings []model.Setting
var settings []settingModel
err := r.db.WithContext(ctx).Find(&settings).Error
if err != nil {
return nil, err
@@ -105,7 +95,27 @@ func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, erro
return result, nil
}
// Delete 删除设置
func (r *settingRepository) Delete(ctx context.Context, key string) error {
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&settingModel{}).Error
}
type settingModel struct {
ID int64 `gorm:"primaryKey"`
Key string `gorm:"uniqueIndex;size:100;not null"`
Value string `gorm:"type:text;not null"`
UpdatedAt time.Time `gorm:"not null"`
}
func (settingModel) TableName() string { return "settings" }
func settingModelToService(m *settingModel) *service.Setting {
if m == nil {
return nil
}
return &service.Setting{
ID: m.ID,
Key: m.Key,
Value: m.Value,
UpdatedAt: m.UpdatedAt,
}
}

View File

@@ -6,7 +6,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
@@ -30,7 +29,7 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
TokenCount int64 `gorm:"column:token_count"`
}
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
db := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as request_count,
COALESCE(SUM(input_tokens + output_tokens), 0) as token_count
@@ -46,24 +45,29 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
return perfStats.RequestCount / 5, perfStats.TokenCount / 5
}
func (r *usageLogRepository) Create(ctx context.Context, log *model.UsageLog) error {
return r.db.WithContext(ctx).Create(log).Error
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) error {
m := usageLogModelFromService(log)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyUsageLogModelToService(log, m)
}
return err
}
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
var log model.UsageLog
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
var log usageLogModel
err := r.db.WithContext(ctx).First(&log, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrUsageLogNotFound, nil)
}
return &log, nil
return usageLogModelToService(&log), nil
}
func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []usageLogModel
var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("user_id = ?", userID)
db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("user_id = ?", userID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
@@ -73,24 +77,14 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
}
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []usageLogModel
var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("api_key_id = ?", apiKeyID)
db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("api_key_id = ?", apiKeyID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
@@ -100,17 +94,7 @@ func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
}
// UserStats 用户使用统计
@@ -125,7 +109,7 @@ type UserStats struct {
func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) {
var stats UserStats
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
@@ -147,47 +131,47 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
today := timezone.Today()
// 总用户数
r.db.WithContext(ctx).Model(&model.User{}).Count(&stats.TotalUsers)
r.db.WithContext(ctx).Model(&userModel{}).Count(&stats.TotalUsers)
// 今日新增用户数
r.db.WithContext(ctx).Model(&model.User{}).
r.db.WithContext(ctx).Model(&userModel{}).
Where("created_at >= ?", today).
Count(&stats.TodayNewUsers)
// 今日活跃用户数 (今日有请求的用户)
r.db.WithContext(ctx).Model(&model.UsageLog{}).
r.db.WithContext(ctx).Model(&usageLogModel{}).
Distinct("user_id").
Where("created_at >= ?", today).
Count(&stats.ActiveUsers)
// 总 API Key 数
r.db.WithContext(ctx).Model(&model.ApiKey{}).Count(&stats.TotalApiKeys)
r.db.WithContext(ctx).Model(&apiKeyModel{}).Count(&stats.TotalApiKeys)
// 活跃 API Key 数
r.db.WithContext(ctx).Model(&model.ApiKey{}).
Where("status = ?", model.StatusActive).
r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("status = ?", service.StatusActive).
Count(&stats.ActiveApiKeys)
// 总账户数
r.db.WithContext(ctx).Model(&model.Account{}).Count(&stats.TotalAccounts)
r.db.WithContext(ctx).Model(&accountModel{}).Count(&stats.TotalAccounts)
// 正常账户数 (schedulable=true, status=active)
r.db.WithContext(ctx).Model(&model.Account{}).
Where("status = ? AND schedulable = ?", model.StatusActive, true).
r.db.WithContext(ctx).Model(&accountModel{}).
Where("status = ? AND schedulable = ?", service.StatusActive, true).
Count(&stats.NormalAccounts)
// 异常账户数 (status=error)
r.db.WithContext(ctx).Model(&model.Account{}).
Where("status = ?", model.StatusError).
r.db.WithContext(ctx).Model(&accountModel{}).
Where("status = ?", service.StatusError).
Count(&stats.ErrorAccounts)
// 限流账户数
r.db.WithContext(ctx).Model(&model.Account{}).
r.db.WithContext(ctx).Model(&accountModel{}).
Where("rate_limited_at IS NOT NULL AND rate_limit_reset_at > ?", time.Now()).
Count(&stats.RateLimitAccounts)
// 过载账户数
r.db.WithContext(ctx).Model(&model.Account{}).
r.db.WithContext(ctx).Model(&accountModel{}).
Where("overload_until IS NOT NULL AND overload_until > ?", time.Now()).
Count(&stats.OverloadAccounts)
@@ -202,7 +186,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
TotalActualCost float64 `gorm:"column:total_actual_cost"`
AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
}
r.db.WithContext(ctx).Model(&model.UsageLog{}).
r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
@@ -235,7 +219,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
TodayCost float64 `gorm:"column:today_cost"`
TodayActualCost float64 `gorm:"column:today_actual_cost"`
}
r.db.WithContext(ctx).Model(&model.UsageLog{}).
r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
@@ -263,11 +247,11 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
return &stats, nil
}
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []usageLogModel
var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("account_id = ?", accountID)
db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("account_id = ?", accountID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
@@ -277,57 +261,47 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64,
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
}
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []usageLogModel
err := r.db.WithContext(ctx).
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
Order("id DESC").
Find(&logs).Error
return logs, nil, err
return usageLogModelsToService(logs), nil, err
}
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []usageLogModel
err := r.db.WithContext(ctx).
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
Order("id DESC").
Find(&logs).Error
return logs, nil, err
return usageLogModelsToService(logs), nil, err
}
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []usageLogModel
err := r.db.WithContext(ctx).
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
Order("id DESC").
Find(&logs).Error
return logs, nil, err
return usageLogModelsToService(logs), nil, err
}
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []usageLogModel
err := r.db.WithContext(ctx).
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
Order("id DESC").
Find(&logs).Error
return logs, nil, err
return usageLogModelsToService(logs), nil, err
}
func (r *usageLogRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error
return r.db.WithContext(ctx).Delete(&usageLogModel{}, id).Error
}
// GetAccountTodayStats 获取账号今日统计
@@ -340,7 +314,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
Cost float64 `gorm:"column:cost"`
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
@@ -368,7 +342,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
Cost float64 `gorm:"column:cost"`
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
@@ -499,12 +473,12 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
today := timezone.Today()
// API Key 统计
r.db.WithContext(ctx).Model(&model.ApiKey{}).
r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("user_id = ?", userID).
Count(&stats.TotalApiKeys)
r.db.WithContext(ctx).Model(&model.ApiKey{}).
Where("user_id = ? AND status = ?", userID, model.StatusActive).
r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("user_id = ? AND status = ?", userID, service.StatusActive).
Count(&stats.ActiveApiKeys)
// 累计 Token 统计
@@ -518,7 +492,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
TotalActualCost float64 `gorm:"column:total_actual_cost"`
AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
}
r.db.WithContext(ctx).Model(&model.UsageLog{}).
r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
@@ -552,7 +526,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
TodayCost float64 `gorm:"column:today_cost"`
TodayActualCost float64 `gorm:"column:today_actual_cost"`
}
r.db.WithContext(ctx).Model(&model.UsageLog{}).
r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
@@ -591,7 +565,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
dateFormat = "YYYY-MM-DD"
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
TO_CHAR(created_at, ?) as date,
COUNT(*) as requests,
@@ -618,7 +592,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) {
var results []ModelStat
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
model,
COUNT(*) as requests,
@@ -644,11 +618,11 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin)
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []usageLogModel
var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{})
db := r.db.WithContext(ctx).Model(&usageLogModel{})
// Apply filters
if filters.UserID > 0 {
@@ -675,17 +649,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
}
// UsageStats represents usage statistics
@@ -713,7 +677,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
UserID int64 `gorm:"column:user_id"`
TotalCost float64 `gorm:"column:total_cost"`
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("user_id, COALESCE(SUM(actual_cost), 0) as total_cost").
Where("user_id IN ?", userIDs).
Group("user_id").
@@ -733,7 +697,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
UserID int64 `gorm:"column:user_id"`
TodayCost float64 `gorm:"column:today_cost"`
}
err = r.db.WithContext(ctx).Model(&model.UsageLog{}).
err = r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("user_id, COALESCE(SUM(actual_cost), 0) as today_cost").
Where("user_id IN ? AND created_at >= ?", userIDs, today).
Group("user_id").
@@ -773,7 +737,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
ApiKeyID int64 `gorm:"column:api_key_id"`
TotalCost float64 `gorm:"column:total_cost"`
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost").
Where("api_key_id IN ?", apiKeyIDs).
Group("api_key_id").
@@ -793,7 +757,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
ApiKeyID int64 `gorm:"column:api_key_id"`
TodayCost float64 `gorm:"column:today_cost"`
}
err = r.db.WithContext(ctx).Model(&model.UsageLog{}).
err = r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost").
Where("api_key_id IN ? AND created_at >= ?", apiKeyIDs, today).
Group("api_key_id").
@@ -822,7 +786,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
dateFormat = "YYYY-MM-DD"
}
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
db := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
TO_CHAR(created_at, ?) as date,
COUNT(*) as requests,
@@ -854,7 +818,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
var results []ModelStat
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
db := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
model,
COUNT(*) as requests,
@@ -896,7 +860,7 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
@@ -950,7 +914,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
ActualCost float64 `gorm:"column:actual_cost"`
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
TO_CHAR(created_at, 'YYYY-MM-DD') as date,
COUNT(*) as requests,
@@ -1011,7 +975,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
var avgDuration struct {
AvgDurationMs float64 `gorm:"column:avg_duration_ms"`
}
r.db.WithContext(ctx).Model(&model.UsageLog{}).
r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("COALESCE(AVG(duration_ms), 0) as avg_duration_ms").
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
Scan(&avgDuration)
@@ -1090,3 +1054,137 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Models: models,
}, nil
}
type usageLogModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
ApiKeyID int64 `gorm:"index;not null"`
AccountID int64 `gorm:"index;not null"`
RequestID string `gorm:"size:64"`
Model string `gorm:"size:100;index;not null"`
GroupID *int64 `gorm:"index"`
SubscriptionID *int64 `gorm:"index"`
InputTokens int `gorm:"default:0;not null"`
OutputTokens int `gorm:"default:0;not null"`
CacheCreationTokens int `gorm:"default:0;not null"`
CacheReadTokens int `gorm:"default:0;not null"`
CacheCreation5mTokens int `gorm:"default:0;not null"`
CacheCreation1hTokens int `gorm:"default:0;not null"`
InputCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null"`
BillingType int8 `gorm:"type:smallint;default:0;not null"`
Stream bool `gorm:"default:false;not null"`
DurationMs *int
FirstTokenMs *int
CreatedAt time.Time `gorm:"index;not null"`
User *userModel `gorm:"foreignKey:UserID"`
ApiKey *apiKeyModel `gorm:"foreignKey:ApiKeyID"`
Account *accountModel `gorm:"foreignKey:AccountID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
Subscription *userSubscriptionModel `gorm:"foreignKey:SubscriptionID"`
}
func (usageLogModel) TableName() string { return "usage_logs" }
func usageLogModelToService(m *usageLogModel) *service.UsageLog {
if m == nil {
return nil
}
return &service.UsageLog{
ID: m.ID,
UserID: m.UserID,
ApiKeyID: m.ApiKeyID,
AccountID: m.AccountID,
RequestID: m.RequestID,
Model: m.Model,
GroupID: m.GroupID,
SubscriptionID: m.SubscriptionID,
InputTokens: m.InputTokens,
OutputTokens: m.OutputTokens,
CacheCreationTokens: m.CacheCreationTokens,
CacheReadTokens: m.CacheReadTokens,
CacheCreation5mTokens: m.CacheCreation5mTokens,
CacheCreation1hTokens: m.CacheCreation1hTokens,
InputCost: m.InputCost,
OutputCost: m.OutputCost,
CacheCreationCost: m.CacheCreationCost,
CacheReadCost: m.CacheReadCost,
TotalCost: m.TotalCost,
ActualCost: m.ActualCost,
RateMultiplier: m.RateMultiplier,
BillingType: m.BillingType,
Stream: m.Stream,
DurationMs: m.DurationMs,
FirstTokenMs: m.FirstTokenMs,
CreatedAt: m.CreatedAt,
User: userModelToService(m.User),
ApiKey: apiKeyModelToService(m.ApiKey),
Account: accountModelToService(m.Account),
Group: groupModelToService(m.Group),
Subscription: userSubscriptionModelToService(m.Subscription),
}
}
func usageLogModelsToService(models []usageLogModel) []service.UsageLog {
out := make([]service.UsageLog, 0, len(models))
for i := range models {
if s := usageLogModelToService(&models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func usageLogModelFromService(log *service.UsageLog) *usageLogModel {
if log == nil {
return nil
}
return &usageLogModel{
ID: log.ID,
UserID: log.UserID,
ApiKeyID: log.ApiKeyID,
AccountID: log.AccountID,
RequestID: log.RequestID,
Model: log.Model,
GroupID: log.GroupID,
SubscriptionID: log.SubscriptionID,
InputTokens: log.InputTokens,
OutputTokens: log.OutputTokens,
CacheCreationTokens: log.CacheCreationTokens,
CacheReadTokens: log.CacheReadTokens,
CacheCreation5mTokens: log.CacheCreation5mTokens,
CacheCreation1hTokens: log.CacheCreation1hTokens,
InputCost: log.InputCost,
OutputCost: log.OutputCost,
CacheCreationCost: log.CacheCreationCost,
CacheReadCost: log.CacheReadCost,
TotalCost: log.TotalCost,
ActualCost: log.ActualCost,
RateMultiplier: log.RateMultiplier,
BillingType: log.BillingType,
Stream: log.Stream,
DurationMs: log.DurationMs,
FirstTokenMs: log.FirstTokenMs,
CreatedAt: log.CreatedAt,
}
}
func applyUsageLogModelToService(log *service.UsageLog, m *usageLogModel) {
if log == nil || m == nil {
return
}
log.ID = m.ID
log.CreatedAt = m.CreatedAt
}

View File

@@ -7,10 +7,10 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
@@ -32,8 +32,8 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite))
}
func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKey, account *model.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *model.UsageLog {
log := &model.UsageLog{
func (s *UsageLogRepoSuite) createUsageLog(user *userModel, apiKey *apiKeyModel, account *accountModel, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -51,11 +51,11 @@ func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKe
// --- Create / GetByID ---
func (s *UsageLogRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-create"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-create"})
log := &model.UsageLog{
log := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -72,9 +72,9 @@ func (s *UsageLogRepoSuite) TestCreate() {
}
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-getbyid"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-getbyid"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -92,9 +92,9 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
// --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-delete"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-delete"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -108,9 +108,9 @@ func (s *UsageLogRepoSuite) TestDelete() {
// --- ListByUser ---
func (s *UsageLogRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyuser"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyuser"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
@@ -124,9 +124,9 @@ func (s *UsageLogRepoSuite) TestListByUser() {
// --- ListByApiKey ---
func (s *UsageLogRepoSuite) TestListByApiKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyapikey"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyapikey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
@@ -140,9 +140,9 @@ func (s *UsageLogRepoSuite) TestListByApiKey() {
// --- ListByAccount ---
func (s *UsageLogRepoSuite) TestListByAccount() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyaccount"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyaccount"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -155,9 +155,9 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
// --- GetUserStats ---
func (s *UsageLogRepoSuite) TestGetUserStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userstats"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -175,9 +175,9 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
// --- ListWithFilters ---
func (s *UsageLogRepoSuite) TestListWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filters"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filters"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -194,29 +194,29 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
now := time.Now()
todayStart := timezone.Today()
userToday := mustCreateUser(s.T(), s.db, &model.User{
userToday := mustCreateUser(s.T(), s.db, &userModel{
Email: "today@example.com",
CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
UpdatedAt: now,
})
userOld := mustCreateUser(s.T(), s.db, &model.User{
userOld := mustCreateUser(s.T(), s.db, &userModel{
Email: "old@example.com",
CreatedAt: todayStart.Add(-24 * time.Hour),
UpdatedAt: todayStart.Add(-24 * time.Hour),
})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: model.StatusDisabled})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
resetAt := now.Add(10 * time.Minute)
accNormal := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-normal", Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-error", Status: model.StatusError, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
accNormal := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-normal", Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-error", Status: service.StatusError, Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
d1, d2, d3 := 100, 200, 300
logToday := &model.UsageLog{
logToday := &service.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
@@ -233,7 +233,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
}
s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday")
logOld := &model.UsageLog{
logOld := &service.UsageLog{
UserID: userOld.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
@@ -247,7 +247,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
}
s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld")
logPerf := &model.UsageLog{
logPerf := &service.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
@@ -293,9 +293,9 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
// --- GetUserDashboardStats ---
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userdash"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userdash"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -308,9 +308,9 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
// --- GetAccountTodayStats ---
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-today"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -323,11 +323,11 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
// --- GetBatchUserUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batch"})
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batch"})
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
@@ -348,10 +348,10 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
// --- GetBatchApiKeyUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batchkey"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batchkey"})
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
@@ -370,9 +370,9 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
// --- GetGlobalStats ---
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-global"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-global"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -395,9 +395,9 @@ func maxTime(a, b time.Time) time.Time {
// --- ListByUserAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-timerange"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-timerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -414,9 +414,9 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
// --- ListByApiKeyAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytimerange"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -433,9 +433,9 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
// --- ListByAccountAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-acctimerange"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-acctimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -452,14 +452,14 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
// --- ListByModelAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modeltimerange"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modeltimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
// Create logs with different models
log1 := &model.UsageLog{
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -472,7 +472,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -485,7 +485,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
log3 := &model.UsageLog{
log3 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -508,9 +508,9 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
// --- GetAccountWindowStats ---
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-windowstats"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-windowstats"})
now := time.Now()
windowStart := now.Add(-10 * time.Minute)
@@ -528,9 +528,9 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
// --- GetUserUsageTrendByUserID ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrend"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrend"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -545,9 +545,9 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
}
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrendhourly"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrendhourly"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -564,14 +564,14 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
// --- GetUserModelStats ---
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelstats"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
// Create logs with different models
log1 := &model.UsageLog{
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -584,7 +584,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -611,9 +611,9 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
// --- GetUsageTrendWithFilters ---
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -639,9 +639,9 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
}
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters-h"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters-h"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -658,13 +658,13 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
// --- GetModelStatsWithFilters ---
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelfilters"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
log1 := &model.UsageLog{
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -677,7 +677,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -712,14 +712,14 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
// --- GetAccountUsageStats ---
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-accstats"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-accstats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
// Create logs on different days
log1 := &model.UsageLog{
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -732,7 +732,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -758,7 +758,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
}
func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-emptystats"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-emptystats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
startTime := base
@@ -774,11 +774,11 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
// --- GetUserUsageTrend ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrends"})
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base)
@@ -796,10 +796,10 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
// --- GetApiKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrends"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base)
@@ -815,9 +815,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
}
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrendh"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base)
@@ -834,9 +834,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
// --- ListWithFilters (additional filter tests) ---
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterskey"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterskey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -848,9 +848,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
}
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterstime"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterstime"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -867,9 +867,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
}
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterscombined"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterscombined"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)

View File

@@ -2,12 +2,13 @@ package repository
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/lib/pq"
"gorm.io/gorm"
)
@@ -19,48 +20,56 @@ func NewUserRepository(db *gorm.DB) service.UserRepository {
return &userRepository{db: db}
}
func (r *userRepository) Create(ctx context.Context, user *model.User) error {
err := r.db.WithContext(ctx).Create(user).Error
func (r *userRepository) Create(ctx context.Context, user *service.User) error {
m := userModelFromService(user)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyUserModelToService(user, m)
}
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
func (r *userRepository) GetByID(ctx context.Context, id int64) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).First(&user, id).Error
func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
var m userModel
err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return &user, nil
return userModelToService(&m), nil
}
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
var m userModel
err := r.db.WithContext(ctx).Where("email = ?", email).First(&m).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return &user, nil
return userModelToService(&m), nil
}
func (r *userRepository) Update(ctx context.Context, user *model.User) error {
err := r.db.WithContext(ctx).Save(user).Error
func (r *userRepository) Update(ctx context.Context, user *service.User) error {
m := userModelFromService(user)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyUserModelToService(user, m)
}
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
func (r *userRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
return r.db.WithContext(ctx).Delete(&userModel{}, id).Error
}
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists users with optional filtering by status, role, and search query
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
var users []model.User
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) {
var users []userModel
var total int64
db := r.db.WithContext(ctx).Model(&model.User{})
db := r.db.WithContext(ctx).Model(&userModel{})
// Apply filters
if status != "" {
@@ -89,17 +98,20 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
// Batch load subscriptions for all users (avoid N+1)
if len(users) > 0 {
userIDs := make([]int64, len(users))
userMap := make(map[int64]*model.User, len(users))
userMap := make(map[int64]*service.User, len(users))
outUsers := make([]service.User, 0, len(users))
for i := range users {
userIDs[i] = users[i].ID
userMap[users[i].ID] = &users[i]
u := userModelToService(&users[i])
outUsers = append(outUsers, *u)
userMap[u.ID] = &outUsers[len(outUsers)-1]
}
// Query active subscriptions with groups in one query
var subscriptions []model.UserSubscription
var subscriptions []userSubscriptionModel
if err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id IN ? AND status = ?", userIDs, model.SubscriptionStatusActive).
Where("user_id IN ? AND status = ?", userIDs, service.SubscriptionStatusActive).
Find(&subscriptions).Error; err != nil {
return nil, nil, err
}
@@ -107,32 +119,29 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
// Associate subscriptions with users
for i := range subscriptions {
if user, ok := userMap[subscriptions[i].UserID]; ok {
user.Subscriptions = append(user.Subscriptions, subscriptions[i])
user.Subscriptions = append(user.Subscriptions, *userSubscriptionModelToService(&subscriptions[i]))
}
}
return outUsers, paginationResultFromTotal(total, params), nil
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
outUsers := make([]service.User, 0, len(users))
for i := range users {
outUsers = append(outUsers, *userModelToService(&users[i]))
}
return users, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outUsers, paginationResultFromTotal(total, params), nil
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
Update("balance", gorm.Expr("balance + ?", amount)).Error
}
// DeductBalance 扣减用户余额,仅当余额充足时执行
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
result := r.db.WithContext(ctx).Model(&model.User{}).
result := r.db.WithContext(ctx).Model(&userModel{}).
Where("id = ? AND balance >= ?", id, amount).
Update("balance", gorm.Expr("balance - ?", amount))
if result.Error != nil {
@@ -145,34 +154,104 @@ func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount flo
}
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error
err := r.db.WithContext(ctx).Model(&userModel{}).Where("email = ?", email).Count(&count).Error
return count > 0, err
}
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
// 使用 PostgreSQL 的 array_remove 函数
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.User{}).
result := r.db.WithContext(ctx).Model(&userModel{}).
Where("? = ANY(allowed_groups)", groupID).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
return result.RowsAffected, result.Error
}
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
var user model.User
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
var m userModel
err := r.db.WithContext(ctx).
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive).
Where("role = ? AND status = ?", service.RoleAdmin, service.StatusActive).
Order("id ASC").
First(&user).Error
First(&m).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return &user, nil
return userModelToService(&m), nil
}
type userModel struct {
ID int64 `gorm:"primaryKey"`
Email string `gorm:"uniqueIndex;size:255;not null"`
Username string `gorm:"size:100;default:''"`
Wechat string `gorm:"size:100;default:''"`
Notes string `gorm:"type:text;default:''"`
PasswordHash string `gorm:"size:255;not null"`
Role string `gorm:"size:20;default:user;not null"`
Balance float64 `gorm:"type:decimal(20,8);default:0;not null"`
Concurrency int `gorm:"default:5;not null"`
Status string `gorm:"size:20;default:active;not null"`
AllowedGroups pq.Int64Array `gorm:"type:bigint[]"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (userModel) TableName() string { return "users" }
func userModelToService(m *userModel) *service.User {
if m == nil {
return nil
}
return &service.User{
ID: m.ID,
Email: m.Email,
Username: m.Username,
Wechat: m.Wechat,
Notes: m.Notes,
PasswordHash: m.PasswordHash,
Role: m.Role,
Balance: m.Balance,
Concurrency: m.Concurrency,
Status: m.Status,
AllowedGroups: []int64(m.AllowedGroups),
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func userModelFromService(u *service.User) *userModel {
if u == nil {
return nil
}
return &userModel{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: pq.Int64Array(u.AllowedGroups),
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
func applyUserModelToService(dst *service.User, src *userModel) {
if dst == nil || src == nil {
return
}
dst.ID = src.ID
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}

View File

@@ -7,7 +7,6 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
@@ -35,11 +34,12 @@ func TestUserRepoSuite(t *testing.T) {
// --- Create / GetByID / GetByEmail / Update / Delete ---
func (s *UserRepoSuite) TestCreate() {
user := &model.User{
Email: "create@test.com",
Username: "testuser",
Role: model.RoleUser,
Status: model.StatusActive,
user := &service.User{
Email: "create@test.com",
Username: "testuser",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
}
err := s.repo.Create(s.ctx, user)
@@ -57,7 +57,7 @@ func (s *UserRepoSuite) TestGetByID_NotFound() {
}
func (s *UserRepoSuite) TestGetByEmail() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byemail@test.com"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "byemail@test.com"})
got, err := s.repo.GetByEmail(s.ctx, user.Email)
s.Require().NoError(err, "GetByEmail")
@@ -70,7 +70,7 @@ func (s *UserRepoSuite) TestGetByEmail_NotFound() {
}
func (s *UserRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com", Username: "original"})
user := userModelToService(mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com", Username: "original"}))
user.Username = "updated"
err := s.repo.Update(s.ctx, user)
@@ -82,7 +82,7 @@ func (s *UserRepoSuite) TestUpdate() {
}
func (s *UserRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
err := s.repo.Delete(s.ctx, user.ID)
s.Require().NoError(err, "Delete")
@@ -94,8 +94,8 @@ func (s *UserRepoSuite) TestDelete() {
// --- List / ListWithFilters ---
func (s *UserRepoSuite) TestList() {
mustCreateUser(s.T(), s.db, &model.User{Email: "list1@test.com"})
mustCreateUser(s.T(), s.db, &model.User{Email: "list2@test.com"})
mustCreateUser(s.T(), s.db, &userModel{Email: "list1@test.com"})
mustCreateUser(s.T(), s.db, &userModel{Email: "list2@test.com"})
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
@@ -104,28 +104,28 @@ func (s *UserRepoSuite) TestList() {
}
func (s *UserRepoSuite) TestListWithFilters_Status() {
mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com", Status: model.StatusActive})
mustCreateUser(s.T(), s.db, &model.User{Email: "disabled@test.com", Status: model.StatusDisabled})
mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com", Status: service.StatusActive})
mustCreateUser(s.T(), s.db, &userModel{Email: "disabled@test.com", Status: service.StatusDisabled})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, "", "")
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, "", "")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(model.StatusActive, users[0].Status)
s.Require().Equal(service.StatusActive, users[0].Status)
}
func (s *UserRepoSuite) TestListWithFilters_Role() {
mustCreateUser(s.T(), s.db, &model.User{Email: "user@test.com", Role: model.RoleUser})
mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin})
mustCreateUser(s.T(), s.db, &userModel{Email: "user@test.com", Role: service.RoleUser})
mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.RoleAdmin, "")
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.RoleAdmin, "")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(model.RoleAdmin, users[0].Role)
s.Require().Equal(service.RoleAdmin, users[0].Role)
}
func (s *UserRepoSuite) TestListWithFilters_Search() {
mustCreateUser(s.T(), s.db, &model.User{Email: "alice@test.com", Username: "Alice"})
mustCreateUser(s.T(), s.db, &model.User{Email: "bob@test.com", Username: "Bob"})
mustCreateUser(s.T(), s.db, &userModel{Email: "alice@test.com", Username: "Alice"})
mustCreateUser(s.T(), s.db, &userModel{Email: "bob@test.com", Username: "Bob"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice")
s.Require().NoError(err)
@@ -134,8 +134,8 @@ func (s *UserRepoSuite) TestListWithFilters_Search() {
}
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com", Username: "JohnDoe"})
mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com", Username: "JaneSmith"})
mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com", Username: "JohnDoe"})
mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com", Username: "JaneSmith"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john")
s.Require().NoError(err)
@@ -144,8 +144,8 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
}
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
mustCreateUser(s.T(), s.db, &model.User{Email: "w1@test.com", Wechat: "wx_hello"})
mustCreateUser(s.T(), s.db, &model.User{Email: "w2@test.com", Wechat: "wx_world"})
mustCreateUser(s.T(), s.db, &userModel{Email: "w1@test.com", Wechat: "wx_hello"})
mustCreateUser(s.T(), s.db, &userModel{Email: "w2@test.com", Wechat: "wx_world"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
s.Require().NoError(err)
@@ -154,19 +154,19 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
}
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub@test.com", Status: model.StatusActive})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sub"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub@test.com", Status: service.StatusActive})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sub"})
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
_ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(1 * time.Hour),
})
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
_ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-1 * time.Hour),
})
@@ -179,29 +179,29 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
}
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
mustCreateUser(s.T(), s.db, &model.User{
mustCreateUser(s.T(), s.db, &userModel{
Email: "a@example.com",
Username: "Alice",
Wechat: "wx_a",
Role: model.RoleUser,
Status: model.StatusActive,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
})
target := mustCreateUser(s.T(), s.db, &model.User{
target := mustCreateUser(s.T(), s.db, &userModel{
Email: "b@example.com",
Username: "Bob",
Wechat: "wx_b",
Role: model.RoleAdmin,
Status: model.StatusActive,
Role: service.RoleAdmin,
Status: service.StatusActive,
Balance: 1,
})
mustCreateUser(s.T(), s.db, &model.User{
mustCreateUser(s.T(), s.db, &userModel{
Email: "c@example.com",
Role: model.RoleAdmin,
Status: model.StatusDisabled,
Role: service.RoleAdmin,
Status: service.StatusDisabled,
})
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, model.RoleAdmin, "b@")
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, service.RoleAdmin, "b@")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")
@@ -211,7 +211,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
// --- Balance operations ---
func (s *UserRepoSuite) TestUpdateBalance() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "bal@test.com", Balance: 10})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "bal@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
s.Require().NoError(err, "UpdateBalance")
@@ -222,7 +222,7 @@ func (s *UserRepoSuite) TestUpdateBalance() {
}
func (s *UserRepoSuite) TestUpdateBalance_Negative() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "balneg@test.com", Balance: 10})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "balneg@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
s.Require().NoError(err, "UpdateBalance with negative")
@@ -233,7 +233,7 @@ func (s *UserRepoSuite) TestUpdateBalance_Negative() {
}
func (s *UserRepoSuite) TestDeductBalance() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deduct@test.com", Balance: 10})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "deduct@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 5)
s.Require().NoError(err, "DeductBalance")
@@ -244,7 +244,7 @@ func (s *UserRepoSuite) TestDeductBalance() {
}
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "insuf@test.com", Balance: 5})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "insuf@test.com", Balance: 5})
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance")
@@ -252,7 +252,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
}
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exact@test.com", Balance: 10})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exact@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 10)
s.Require().NoError(err, "DeductBalance exact amount")
@@ -265,7 +265,7 @@ func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
// --- Concurrency ---
func (s *UserRepoSuite) TestUpdateConcurrency() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "conc@test.com", Concurrency: 5})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "conc@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
s.Require().NoError(err, "UpdateConcurrency")
@@ -276,7 +276,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency() {
}
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "concneg@test.com", Concurrency: 5})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "concneg@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
s.Require().NoError(err, "UpdateConcurrency negative")
@@ -289,7 +289,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
// --- ExistsByEmail ---
func (s *UserRepoSuite) TestExistsByEmail() {
mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
s.Require().NoError(err, "ExistsByEmail")
@@ -304,11 +304,11 @@ func (s *UserRepoSuite) TestExistsByEmail() {
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
groupID := int64(42)
userA := mustCreateUser(s.T(), s.db, &model.User{
userA := mustCreateUser(s.T(), s.db, &userModel{
Email: "a1@example.com",
AllowedGroups: pq.Int64Array{groupID, 7},
})
mustCreateUser(s.T(), s.db, &model.User{
mustCreateUser(s.T(), s.db, &userModel{
Email: "a2@example.com",
AllowedGroups: pq.Int64Array{7},
})
@@ -325,7 +325,7 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
}
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
mustCreateUser(s.T(), s.db, &model.User{
mustCreateUser(s.T(), s.db, &userModel{
Email: "nomatch@test.com",
AllowedGroups: pq.Int64Array{1, 2, 3},
})
@@ -338,15 +338,15 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
// --- GetFirstAdmin ---
func (s *UserRepoSuite) TestGetFirstAdmin() {
admin1 := mustCreateUser(s.T(), s.db, &model.User{
admin1 := mustCreateUser(s.T(), s.db, &userModel{
Email: "admin1@example.com",
Role: model.RoleAdmin,
Status: model.StatusActive,
Role: service.RoleAdmin,
Status: service.StatusActive,
})
mustCreateUser(s.T(), s.db, &model.User{
mustCreateUser(s.T(), s.db, &userModel{
Email: "admin2@example.com",
Role: model.RoleAdmin,
Status: model.StatusActive,
Role: service.RoleAdmin,
Status: service.StatusActive,
})
got, err := s.repo.GetFirstAdmin(s.ctx)
@@ -355,10 +355,10 @@ func (s *UserRepoSuite) TestGetFirstAdmin() {
}
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
mustCreateUser(s.T(), s.db, &model.User{
mustCreateUser(s.T(), s.db, &userModel{
Email: "user@example.com",
Role: model.RoleUser,
Status: model.StatusActive,
Role: service.RoleUser,
Status: service.StatusActive,
})
_, err := s.repo.GetFirstAdmin(s.ctx)
@@ -366,15 +366,15 @@ func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
}
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
mustCreateUser(s.T(), s.db, &model.User{
mustCreateUser(s.T(), s.db, &userModel{
Email: "disabled@example.com",
Role: model.RoleAdmin,
Status: model.StatusDisabled,
Role: service.RoleAdmin,
Status: service.StatusDisabled,
})
activeAdmin := mustCreateUser(s.T(), s.db, &model.User{
activeAdmin := mustCreateUser(s.T(), s.db, &userModel{
Email: "active@example.com",
Role: model.RoleAdmin,
Status: model.StatusActive,
Role: service.RoleAdmin,
Status: service.StatusActive,
})
got, err := s.repo.GetFirstAdmin(s.ctx)
@@ -385,26 +385,26 @@ func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
// --- Combined original test ---
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user1 := mustCreateUser(s.T(), s.db, &model.User{
user1 := mustCreateUser(s.T(), s.db, &userModel{
Email: "a@example.com",
Username: "Alice",
Wechat: "wx_a",
Role: model.RoleUser,
Status: model.StatusActive,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
})
user2 := mustCreateUser(s.T(), s.db, &model.User{
user2 := mustCreateUser(s.T(), s.db, &userModel{
Email: "b@example.com",
Username: "Bob",
Wechat: "wx_b",
Role: model.RoleAdmin,
Status: model.StatusActive,
Role: service.RoleAdmin,
Status: service.StatusActive,
Balance: 1,
})
_ = mustCreateUser(s.T(), s.db, &model.User{
_ = mustCreateUser(s.T(), s.db, &userModel{
Email: "c@example.com",
Role: model.RoleAdmin,
Status: model.StatusDisabled,
Role: service.RoleAdmin,
Status: service.StatusDisabled,
})
got, err := s.repo.GetByID(s.ctx, user1.ID)
@@ -441,7 +441,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch")
params := pagination.PaginationParams{Page: 1, PageSize: 10}
users, page, err := s.repo.ListWithFilters(s.ctx, params, model.StatusActive, model.RoleAdmin, "b@")
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.StatusActive, service.RoleAdmin, "b@")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")

View File

@@ -4,111 +4,113 @@ import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm"
)
// UserSubscriptionRepository 用户订阅仓库
type userSubscriptionRepository struct {
db *gorm.DB
}
// NewUserSubscriptionRepository 创建用户订阅仓库
func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository {
return &userSubscriptionRepository{db: db}
}
// Create 创建订阅
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error {
err := r.db.WithContext(ctx).Create(sub).Error
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error {
m := userSubscriptionModelFromService(sub)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyUserSubscriptionModelToService(sub, m)
}
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
}
// GetByID 根据ID获取订阅
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
var sub model.UserSubscription
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
var m userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("User").
Preload("Group").
Preload("AssignedByUser").
First(&sub, id).Error
First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return &sub, nil
return userSubscriptionModelToService(&m), nil
}
// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
var sub model.UserSubscription
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
var m userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND group_id = ?", userID, groupID).
First(&sub).Error
First(&m).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return &sub, nil
return userSubscriptionModelToService(&m), nil
}
// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
var sub model.UserSubscription
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
var m userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND group_id = ? AND status = ? AND expires_at > ?",
userID, groupID, model.SubscriptionStatusActive, time.Now()).
First(&sub).Error
userID, groupID, service.SubscriptionStatusActive, time.Now()).
First(&m).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return &sub, nil
return userSubscriptionModelToService(&m), nil
}
// Update 更新订阅
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error {
sub.UpdatedAt = time.Now()
return r.db.WithContext(ctx).Save(sub).Error
m := userSubscriptionModelFromService(sub)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyUserSubscriptionModelToService(sub, m)
}
return err
}
// Delete 删除订阅
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error
return r.db.WithContext(ctx).Delete(&userSubscriptionModel{}, id).Error
}
// ListByUserID 获取用户的所有订阅
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
var subs []userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ?", userID).
Order("created_at DESC").
Find(&subs).Error
return subs, err
if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
}
// ListActiveByUserID 获取用户的所有有效订阅
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
var subs []userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND status = ? AND expires_at > ?",
userID, model.SubscriptionStatusActive, time.Now()).
userID, service.SubscriptionStatusActive, time.Now()).
Order("created_at DESC").
Find(&subs).Error
return subs, err
if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
}
// ListByGroupID 获取分组的所有订阅(分页)
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
var subs []userSubscriptionModel
var total int64
query := r.db.WithContext(ctx).Model(&model.UserSubscription{}).Where("group_id = ?", groupID)
query := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).Where("group_id = ?", groupID)
if err := query.Count(&total).Error; err != nil {
return nil, nil, err
}
@@ -124,26 +126,14 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return subs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
}
// List 获取所有订阅(分页,支持筛选)
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
var subs []userSubscriptionModel
var total int64
query := r.db.WithContext(ctx).Model(&model.UserSubscription{})
query := r.db.WithContext(ctx).Model(&userSubscriptionModel{})
if userID != nil {
query = query.Where("user_id = ?", *userID)
}
@@ -170,22 +160,87 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return subs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
}
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("user_id = ? AND group_id = ?", userID, groupID).
Count(&count).Error
return count > 0, err
}
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", subscriptionID).
Updates(map[string]any{
"expires_at": newExpiresAt,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", subscriptionID).
Updates(map[string]any{
"status": status,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", subscriptionID).
Updates(map[string]any{
"notes": notes,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_window_start": start,
"weekly_window_start": start,
"monthly_window_start": start,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_usage_usd": 0,
"daily_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"weekly_usage_usd": 0,
"weekly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"monthly_usage_usd": 0,
"monthly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
// IncrementUsage 增加使用量
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
@@ -195,131 +250,150 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
}).Error
}
// ResetDailyUsage 重置日使用量
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_usage_usd": 0,
"daily_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
// ResetWeeklyUsage 重置周使用量
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"weekly_usage_usd": 0,
"weekly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
// ResetMonthlyUsage 重置月使用量
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"monthly_usage_usd": 0,
"monthly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
// ActivateWindows 激活所有窗口(首次使用时)
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_window_start": activateTime,
"weekly_window_start": activateTime,
"monthly_window_start": activateTime,
"updated_at": time.Now(),
}).Error
}
// UpdateStatus 更新订阅状态
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"status": status,
"updated_at": time.Now(),
}).Error
}
// ExtendExpiry 延长订阅过期时间
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"expires_at": newExpiresAt,
"updated_at": time.Now(),
}).Error
}
// UpdateNotes 更新订阅备注
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"notes": notes,
"updated_at": time.Now(),
}).Error
}
// ListExpired 获取所有已过期但状态仍为active的订阅
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
err := r.db.WithContext(ctx).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
Find(&subs).Error
return subs, err
}
// BatchUpdateExpiredStatus 批量更新过期订阅状态
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
result := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
Updates(map[string]any{
"status": model.SubscriptionStatusExpired,
"status": service.SubscriptionStatusExpired,
"updated_at": time.Now(),
})
return result.RowsAffected, result.Error
}
// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("user_id = ? AND group_id = ?", userID, groupID).
Count(&count).Error
return count > 0, err
// Extra repository helpers (currently used only by integration tests).
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) {
var subs []userSubscriptionModel
err := r.db.WithContext(ctx).
Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
Find(&subs).Error
if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
}
// CountByGroupID 获取分组的订阅数量
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("group_id = ?", groupID).
Count(&count).Error
return count, err
}
// CountActiveByGroupID 获取分组的有效订阅数量
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("group_id = ? AND status = ? AND expires_at > ?",
groupID, model.SubscriptionStatusActive, time.Now()).
groupID, service.SubscriptionStatusActive, time.Now()).
Count(&count).Error
return count, err
}
// DeleteByGroupID 删除分组相关的所有订阅记录
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{})
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&userSubscriptionModel{})
return result.RowsAffected, result.Error
}
type userSubscriptionModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
GroupID int64 `gorm:"index;not null"`
StartsAt time.Time `gorm:"not null"`
ExpiresAt time.Time `gorm:"not null"`
Status string `gorm:"size:20;default:active;not null"`
DailyWindowStart *time.Time
WeeklyWindowStart *time.Time
MonthlyWindowStart *time.Time
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
AssignedBy *int64 `gorm:"index"`
AssignedAt time.Time `gorm:"not null"`
Notes string `gorm:"type:text"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
AssignedByUser *userModel `gorm:"foreignKey:AssignedBy"`
}
func (userSubscriptionModel) TableName() string { return "user_subscriptions" }
func userSubscriptionModelToService(m *userSubscriptionModel) *service.UserSubscription {
if m == nil {
return nil
}
return &service.UserSubscription{
ID: m.ID,
UserID: m.UserID,
GroupID: m.GroupID,
StartsAt: m.StartsAt,
ExpiresAt: m.ExpiresAt,
Status: m.Status,
DailyWindowStart: m.DailyWindowStart,
WeeklyWindowStart: m.WeeklyWindowStart,
MonthlyWindowStart: m.MonthlyWindowStart,
DailyUsageUSD: m.DailyUsageUSD,
WeeklyUsageUSD: m.WeeklyUsageUSD,
MonthlyUsageUSD: m.MonthlyUsageUSD,
AssignedBy: m.AssignedBy,
AssignedAt: m.AssignedAt,
Notes: m.Notes,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
AssignedByUser: userModelToService(m.AssignedByUser),
}
}
func userSubscriptionModelsToService(models []userSubscriptionModel) []service.UserSubscription {
out := make([]service.UserSubscription, 0, len(models))
for i := range models {
if s := userSubscriptionModelToService(&models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func userSubscriptionModelFromService(s *service.UserSubscription) *userSubscriptionModel {
if s == nil {
return nil
}
return &userSubscriptionModel{
ID: s.ID,
UserID: s.UserID,
GroupID: s.GroupID,
StartsAt: s.StartsAt,
ExpiresAt: s.ExpiresAt,
Status: s.Status,
DailyWindowStart: s.DailyWindowStart,
WeeklyWindowStart: s.WeeklyWindowStart,
MonthlyWindowStart: s.MonthlyWindowStart,
DailyUsageUSD: s.DailyUsageUSD,
WeeklyUsageUSD: s.WeeklyUsageUSD,
MonthlyUsageUSD: s.MonthlyUsageUSD,
AssignedBy: s.AssignedBy,
AssignedAt: s.AssignedAt,
Notes: s.Notes,
CreatedAt: s.CreatedAt,
UpdatedAt: s.UpdatedAt,
}
}
func applyUserSubscriptionModelToService(sub *service.UserSubscription, m *userSubscriptionModel) {
if sub == nil || m == nil {
return
}
sub.ID = m.ID
sub.CreatedAt = m.CreatedAt
sub.UpdatedAt = m.UpdatedAt
}

View File

@@ -7,8 +7,8 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
@@ -33,13 +33,13 @@ func TestUserSubscriptionRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete ---
func (s *UserSubscriptionRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub-create@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-create"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub-create@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-create"})
sub := &model.UserSubscription{
sub := &service.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
@@ -54,14 +54,14 @@ func (s *UserSubscriptionRepoSuite) TestCreate() {
}
func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "preload@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"})
admin := mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "preload@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"})
admin := mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
AssignedBy: &admin.ID,
})
@@ -82,14 +82,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() {
}
func (s *UserSubscriptionRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-update"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-update"})
sub := userSubscriptionModelToService(mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
}))
sub.Notes = "updated notes"
err := s.repo.Update(s.ctx, sub)
@@ -101,12 +101,12 @@ func (s *UserSubscriptionRepoSuite) TestUpdate() {
}
func (s *UserSubscriptionRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delete"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delete"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -120,12 +120,12 @@ func (s *UserSubscriptionRepoSuite) TestDelete() {
// --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byuser@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-byuser"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "byuser@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-byuser"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -141,14 +141,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() {
}
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-active"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-active"})
// Create active subscription (future expiry)
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour),
})
@@ -158,14 +158,14 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
}
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "expired@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-expired"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "expired@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-expired"})
// Create expired subscription (past expiry but active status)
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour),
})
@@ -176,20 +176,20 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnor
// --- ListByUserID / ListActiveByUserID ---
func (s *UserSubscriptionRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list2"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list1"})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g1.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g2.ID,
Status: model.SubscriptionStatusExpired,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
@@ -202,46 +202,46 @@ func (s *UserSubscriptionRepoSuite) TestListByUserID() {
}
func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listactive@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act2"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listactive@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act1"})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g1.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g2.ID,
Status: model.SubscriptionStatusExpired,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID)
s.Require().NoError(err, "ListActiveByUserID")
s.Require().Len(subs, 1)
s.Require().Equal(model.SubscriptionStatusActive, subs[0].Status)
s.Require().Equal(service.SubscriptionStatusActive, subs[0].Status)
}
// --- ListByGroupID ---
func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listgrp"})
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listgrp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -258,13 +258,13 @@ func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
// --- List with filters ---
func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "list@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "list@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -275,20 +275,20 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-filter"})
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter2@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-filter"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -299,20 +299,20 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grpfilter@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f2"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "grpfilter@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f1"})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g1.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g2.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -323,37 +323,37 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "statfilter@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-stat"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "statfilter@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-stat"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, model.SubscriptionStatusExpired)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(model.SubscriptionStatusExpired, subs[0].Status)
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
}
// --- Usage tracking ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usage@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-usage"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "usage@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-usage"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -368,12 +368,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accum@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-accum"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "accum@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-accum"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -386,12 +386,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
}
func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "activate@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-activate"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "activate@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-activate"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -408,12 +408,12 @@ func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
}
func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetd@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetd"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetd@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetd"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
DailyUsageUSD: 10.0,
WeeklyUsageUSD: 20.0,
@@ -431,12 +431,12 @@ func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
}
func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetw@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetw"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetw@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetw"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
WeeklyUsageUSD: 15.0,
MonthlyUsageUSD: 30.0,
@@ -454,12 +454,12 @@ func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
}
func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetm@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetm"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetm@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetm"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
MonthlyUsageUSD: 100.0,
})
@@ -477,30 +477,30 @@ func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
// --- UpdateStatus / ExtendExpiry / UpdateNotes ---
func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "status@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-status"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "status@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-status"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
err := s.repo.UpdateStatus(s.ctx, sub.ID, model.SubscriptionStatusExpired)
err := s.repo.UpdateStatus(s.ctx, sub.ID, service.SubscriptionStatusExpired)
s.Require().NoError(err, "UpdateStatus")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal(model.SubscriptionStatusExpired, got.Status)
s.Require().Equal(service.SubscriptionStatusExpired, got.Status)
}
func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "extend@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-extend"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "extend@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-extend"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -514,12 +514,12 @@ func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
}
func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "notes@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-notes"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "notes@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-notes"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -534,19 +534,19 @@ func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
// --- ListExpired / BatchUpdateExpiredStatus ---
func (s *UserSubscriptionRepoSuite) TestListExpired() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listexp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listexp"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listexp@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listexp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
@@ -556,19 +556,19 @@ func (s *UserSubscriptionRepoSuite) TestListExpired() {
}
func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batch@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-batch"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "batch@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-batch"})
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
@@ -577,22 +577,22 @@ func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
s.Require().Equal(int64(1), affected)
gotActive, _ := s.repo.GetByID(s.ctx, active.ID)
s.Require().Equal(model.SubscriptionStatusActive, gotActive.Status)
s.Require().Equal(service.SubscriptionStatusActive, gotActive.Status)
gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().Equal(model.SubscriptionStatusExpired, gotExpired.Status)
s.Require().Equal(service.SubscriptionStatusExpired, gotExpired.Status)
}
// --- ExistsByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-exists"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-exists"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -608,20 +608,20 @@ func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
// --- CountByGroupID / CountActiveByGroupID ---
func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt2@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
@@ -631,20 +631,20 @@ func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
}
func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-cntact"})
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact2@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-cntact"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time
})
@@ -656,19 +656,19 @@ func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
// --- DeleteByGroupID ---
func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delgrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delgrp"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delgrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delgrp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
@@ -683,19 +683,19 @@ func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
// --- Combined original test ---
func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "subr@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-subr"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "subr@example.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-subr"})
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour),
})
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour),
})
@@ -729,5 +729,5 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
s.Require().Equal(int64(1), affected, "expected 1 affected row")
updated, err := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().NoError(err, "GetByID expired")
s.Require().Equal(model.SubscriptionStatusExpired, updated.Status, "expected status expired")
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
}