merge upstream main
This commit is contained in:
@@ -39,9 +39,15 @@ import (
|
||||
// 设计说明:
|
||||
// - client: Ent 客户端,用于类型安全的 ORM 操作
|
||||
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
|
||||
// - schedulerCache: 调度器缓存,用于在账号状态变更时同步快照
|
||||
type accountRepository struct {
|
||||
client *dbent.Client // Ent ORM 客户端
|
||||
sql sqlExecutor // 原生 SQL 执行接口
|
||||
// schedulerCache 用于在账号状态变更时主动同步快照到缓存,
|
||||
// 确保粘性会话能及时感知账号不可用状态。
|
||||
// Used to proactively sync account snapshot to cache when status changes,
|
||||
// ensuring sticky sessions can promptly detect unavailable accounts.
|
||||
schedulerCache service.SchedulerCache
|
||||
}
|
||||
|
||||
type tempUnschedSnapshot struct {
|
||||
@@ -51,14 +57,14 @@ type tempUnschedSnapshot struct {
|
||||
|
||||
// NewAccountRepository 创建账户仓储实例。
|
||||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
|
||||
return newAccountRepositoryWithSQL(client, sqlDB)
|
||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
||||
return newAccountRepositoryWithSQL(client, sqlDB, schedulerCache)
|
||||
}
|
||||
|
||||
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
|
||||
// 这种设计便于单元测试时注入 mock 对象。
|
||||
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
|
||||
return &accountRepository{client: client, sql: sqlq}
|
||||
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor, schedulerCache service.SchedulerCache) *accountRepository {
|
||||
return &accountRepository{client: client, sql: sqlq, schedulerCache: schedulerCache}
|
||||
}
|
||||
|
||||
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
|
||||
@@ -356,6 +362,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -473,37 +482,6 @@ func (r *accountRepository) ListByPlatform(ctx context.Context, platform string)
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListByPlatformAndCredentialEmails(
|
||||
ctx context.Context,
|
||||
platform string,
|
||||
emails []string,
|
||||
) ([]service.Account, error) {
|
||||
if len(emails) == 0 {
|
||||
return []service.Account{}, nil
|
||||
}
|
||||
args := make([]any, 0, len(emails))
|
||||
for _, email := range emails {
|
||||
if email == "" {
|
||||
continue
|
||||
}
|
||||
args = append(args, email)
|
||||
}
|
||||
if len(args) == 0 {
|
||||
return []service.Account{}, nil
|
||||
}
|
||||
|
||||
accounts, err := r.client.Account.Query().
|
||||
Where(dbaccount.PlatformEQ(platform)).
|
||||
Where(func(s *entsql.Selector) {
|
||||
s.Where(sqljson.ValueIn(dbaccount.FieldCredentials, args, sqljson.Path("email")))
|
||||
}).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
now := time.Now()
|
||||
_, err := r.client.Account.Update().
|
||||
@@ -571,9 +549,32 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncSchedulerAccountSnapshot 在账号状态变更时主动同步快照到调度器缓存。
|
||||
// 当账号被设置为错误、禁用、不可调度或临时不可调度时调用,
|
||||
// 确保调度器和粘性会话逻辑能及时感知账号的最新状态,避免继续使用不可用账号。
|
||||
//
|
||||
// syncSchedulerAccountSnapshot proactively syncs account snapshot to scheduler cache
|
||||
// when account status changes. Called when account is set to error, disabled,
|
||||
// unschedulable, or temporarily unschedulable, ensuring scheduler and sticky session
|
||||
// logic can promptly detect the latest account state and avoid using unavailable accounts.
|
||||
func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, accountID int64) {
|
||||
if r == nil || r.schedulerCache == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
account, err := r.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
|
||||
return
|
||||
}
|
||||
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
_, err := r.client.Account.Update().
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
@@ -787,7 +788,6 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
SetRateLimitedAt(now).
|
||||
SetRateLimitResetAt(resetAt).
|
||||
SetLastUsedAt(now).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -922,6 +922,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1018,7 +1019,16 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
|
||||
builder.SetSessionWindowEnd(*end)
|
||||
}
|
||||
_, err := builder.Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 触发调度器缓存更新(仅当窗口时间有变化时)
|
||||
if start != nil || end != nil {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
@@ -1032,6 +1042,9 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||
}
|
||||
if !schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1186,6 +1199,18 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||
}
|
||||
shouldSync := false
|
||||
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
|
||||
shouldSync = true
|
||||
}
|
||||
if updates.Schedulable != nil && !*updates.Schedulable {
|
||||
shouldSync = true
|
||||
}
|
||||
if shouldSync {
|
||||
for _, id := range ids {
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
@@ -21,11 +21,56 @@ type AccountRepoSuite struct {
|
||||
repo *accountRepository
|
||||
}
|
||||
|
||||
type schedulerCacheRecorder struct {
|
||||
setAccounts []*service.Account
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
|
||||
s.setAccounts = append(s.setAccounts, account)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) DeleteAccount(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetOutboxWatermark(ctx context.Context) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetOutboxWatermark(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = newAccountRepositoryWithSQL(s.client, tx)
|
||||
s.repo = newAccountRepositoryWithSQL(s.client, tx, nil)
|
||||
}
|
||||
|
||||
func TestAccountRepoSuite(t *testing.T) {
|
||||
@@ -73,6 +118,20 @@ func (s *AccountRepoSuite) TestUpdate() {
|
||||
s.Require().Equal("updated", got.Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "sync-update", Status: service.StatusActive, Schedulable: true})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
account.Status = service.StatusDisabled
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||
|
||||
@@ -174,7 +233,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
tx := testEntTx(s.T())
|
||||
client := tx.Client()
|
||||
repo := newAccountRepositoryWithSQL(client, tx)
|
||||
repo := newAccountRepositoryWithSQL(client, tx, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.setup(client)
|
||||
@@ -365,12 +424,38 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
|
||||
|
||||
func (s *AccountRepoSuite) TestSetSchedulable() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(got.Schedulable)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
account1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-1", Status: service.StatusActive, Schedulable: true})
|
||||
account2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-2", Status: service.StatusActive, Schedulable: true})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
disabled := service.StatusDisabled
|
||||
rows, err := s.repo.BulkUpdate(s.ctx, []int64{account1.ID, account2.ID}, service.AccountBulkUpdate{
|
||||
Status: &disabled,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(2), rows)
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 2)
|
||||
ids := map[int64]struct{}{}
|
||||
for _, acc := range cacheRecorder.setAccounts {
|
||||
ids[acc.ID] = struct{}{}
|
||||
}
|
||||
s.Require().Contains(ids, account1.ID)
|
||||
s.Require().Contains(ids, account2.ID)
|
||||
}
|
||||
|
||||
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
|
||||
|
||||
95
backend/internal/repository/aes_encryptor.go
Normal file
95
backend/internal/repository/aes_encryptor.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// AESEncryptor implements SecretEncryptor using AES-256-GCM
|
||||
type AESEncryptor struct {
|
||||
key []byte
|
||||
}
|
||||
|
||||
// NewAESEncryptor creates a new AES encryptor
|
||||
func NewAESEncryptor(cfg *config.Config) (service.SecretEncryptor, error) {
|
||||
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid totp encryption key: %w", err)
|
||||
}
|
||||
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("totp encryption key must be 32 bytes (64 hex chars), got %d bytes", len(key))
|
||||
}
|
||||
|
||||
return &AESEncryptor{key: key}, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext using AES-256-GCM
|
||||
// Output format: base64(nonce + ciphertext + tag)
|
||||
func (e *AESEncryptor) Encrypt(plaintext string) (string, error) {
|
||||
block, err := aes.NewCipher(e.key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create gcm: %w", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the plaintext
|
||||
// Seal appends the ciphertext and tag to the nonce
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
|
||||
// Encode as base64
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts ciphertext using AES-256-GCM
|
||||
func (e *AESEncryptor) Decrypt(ciphertext string) (string, error) {
|
||||
// Decode from base64
|
||||
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode base64: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(e.key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create gcm: %w", err)
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
// Extract nonce and ciphertext
|
||||
nonce, ciphertextData := data[:nonceSize], data[nonceSize:]
|
||||
|
||||
// Decrypt
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
83
backend/internal/repository/announcement_read_repo.go
Normal file
83
backend/internal/repository/announcement_read_repo.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcementread"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type announcementReadRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementReadRepository {
|
||||
return &announcementReadRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
return client.AnnouncementRead.Create().
|
||||
SetAnnouncementID(announcementID).
|
||||
SetUserID(userID).
|
||||
SetReadAt(readAt).
|
||||
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
|
||||
DoNothing().
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
|
||||
if len(announcementIDs) == 0 {
|
||||
return map[int64]time.Time{}, nil
|
||||
}
|
||||
|
||||
rows, err := r.client.AnnouncementRead.Query().
|
||||
Where(
|
||||
announcementread.UserIDEQ(userID),
|
||||
announcementread.AnnouncementIDIn(announcementIDs...),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(map[int64]time.Time, len(rows))
|
||||
for i := range rows {
|
||||
out[rows[i].AnnouncementID] = rows[i].ReadAt
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return map[int64]time.Time{}, nil
|
||||
}
|
||||
|
||||
rows, err := r.client.AnnouncementRead.Query().
|
||||
Where(
|
||||
announcementread.AnnouncementIDEQ(announcementID),
|
||||
announcementread.UserIDIn(userIDs...),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(map[int64]time.Time, len(rows))
|
||||
for i := range rows {
|
||||
out[rows[i].UserID] = rows[i].ReadAt
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error) {
|
||||
count, err := r.client.AnnouncementRead.Query().
|
||||
Where(announcementread.AnnouncementIDEQ(announcementID)).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(count), nil
|
||||
}
|
||||
194
backend/internal/repository/announcement_repo.go
Normal file
194
backend/internal/repository/announcement_repo.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcement"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type announcementRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewAnnouncementRepository(client *dbent.Client) service.AnnouncementRepository {
|
||||
return &announcementRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *announcementRepository) Create(ctx context.Context, a *service.Announcement) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.Announcement.Create().
|
||||
SetTitle(a.Title).
|
||||
SetContent(a.Content).
|
||||
SetStatus(a.Status).
|
||||
SetTargeting(a.Targeting)
|
||||
|
||||
if a.StartsAt != nil {
|
||||
builder.SetStartsAt(*a.StartsAt)
|
||||
}
|
||||
if a.EndsAt != nil {
|
||||
builder.SetEndsAt(*a.EndsAt)
|
||||
}
|
||||
if a.CreatedBy != nil {
|
||||
builder.SetCreatedBy(*a.CreatedBy)
|
||||
}
|
||||
if a.UpdatedBy != nil {
|
||||
builder.SetUpdatedBy(*a.UpdatedBy)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
applyAnnouncementEntityToService(a, created)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) GetByID(ctx context.Context, id int64) (*service.Announcement, error) {
|
||||
m, err := r.client.Announcement.Query().
|
||||
Where(announcement.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
|
||||
}
|
||||
return announcementEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) Update(ctx context.Context, a *service.Announcement) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.Announcement.UpdateOneID(a.ID).
|
||||
SetTitle(a.Title).
|
||||
SetContent(a.Content).
|
||||
SetStatus(a.Status).
|
||||
SetTargeting(a.Targeting)
|
||||
|
||||
if a.StartsAt != nil {
|
||||
builder.SetStartsAt(*a.StartsAt)
|
||||
} else {
|
||||
builder.ClearStartsAt()
|
||||
}
|
||||
if a.EndsAt != nil {
|
||||
builder.SetEndsAt(*a.EndsAt)
|
||||
} else {
|
||||
builder.ClearEndsAt()
|
||||
}
|
||||
if a.CreatedBy != nil {
|
||||
builder.SetCreatedBy(*a.CreatedBy)
|
||||
} else {
|
||||
builder.ClearCreatedBy()
|
||||
}
|
||||
if a.UpdatedBy != nil {
|
||||
builder.SetUpdatedBy(*a.UpdatedBy)
|
||||
} else {
|
||||
builder.ClearUpdatedBy()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
|
||||
}
|
||||
|
||||
a.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) Delete(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.Announcement.Delete().Where(announcement.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *announcementRepository) List(
|
||||
ctx context.Context,
|
||||
params pagination.PaginationParams,
|
||||
filters service.AnnouncementListFilters,
|
||||
) ([]service.Announcement, *pagination.PaginationResult, error) {
|
||||
q := r.client.Announcement.Query()
|
||||
|
||||
if filters.Status != "" {
|
||||
q = q.Where(announcement.StatusEQ(filters.Status))
|
||||
}
|
||||
if filters.Search != "" {
|
||||
q = q.Where(
|
||||
announcement.Or(
|
||||
announcement.TitleContainsFold(filters.Search),
|
||||
announcement.ContentContainsFold(filters.Search),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
items, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(announcement.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
out := announcementEntitiesToService(items)
|
||||
return out, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) ListActive(ctx context.Context, now time.Time) ([]service.Announcement, error) {
|
||||
q := r.client.Announcement.Query().
|
||||
Where(
|
||||
announcement.StatusEQ(service.AnnouncementStatusActive),
|
||||
announcement.Or(announcement.StartsAtIsNil(), announcement.StartsAtLTE(now)),
|
||||
announcement.Or(announcement.EndsAtIsNil(), announcement.EndsAtGT(now)),
|
||||
).
|
||||
Order(dbent.Desc(announcement.FieldID))
|
||||
|
||||
items, err := q.All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return announcementEntitiesToService(items), nil
|
||||
}
|
||||
|
||||
func applyAnnouncementEntityToService(dst *service.Announcement, src *dbent.Announcement) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
dst.ID = src.ID
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
|
||||
func announcementEntityToService(m *dbent.Announcement) *service.Announcement {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Announcement{
|
||||
ID: m.ID,
|
||||
Title: m.Title,
|
||||
Content: m.Content,
|
||||
Status: m.Status,
|
||||
Targeting: m.Targeting,
|
||||
StartsAt: m.StartsAt,
|
||||
EndsAt: m.EndsAt,
|
||||
CreatedBy: m.CreatedBy,
|
||||
UpdatedBy: m.UpdatedBy,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func announcementEntitiesToService(models []*dbent.Announcement) []service.Announcement {
|
||||
out := make([]service.Announcement, 0, len(models))
|
||||
for i := range models {
|
||||
if s := announcementEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -12,9 +13,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
apiKeyAuthCachePrefix = "apikey:auth:"
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
apiKeyAuthCachePrefix = "apikey:auth:"
|
||||
authCacheInvalidateChannel = "auth:cache:invalidate"
|
||||
)
|
||||
|
||||
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
||||
@@ -91,3 +93,45 @@ func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *servi
|
||||
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
|
||||
}
|
||||
|
||||
// PublishAuthCacheInvalidation publishes a cache invalidation message to all instances
|
||||
func (c *apiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
|
||||
return c.rdb.Publish(ctx, authCacheInvalidateChannel, cacheKey).Err()
|
||||
}
|
||||
|
||||
// SubscribeAuthCacheInvalidation subscribes to cache invalidation messages
|
||||
func (c *apiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
|
||||
pubsub := c.rdb.Subscribe(ctx, authCacheInvalidateChannel)
|
||||
|
||||
// Verify subscription is working
|
||||
_, err := pubsub.Receive(ctx)
|
||||
if err != nil {
|
||||
_ = pubsub.Close()
|
||||
return fmt.Errorf("subscribe to auth cache invalidation: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := pubsub.Close(); err != nil {
|
||||
log.Printf("Warning: failed to close auth cache invalidation pubsub: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ch := pubsub.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if msg != nil {
|
||||
handler(msg.Payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -389,17 +389,20 @@ func userEntityToService(u *dbent.User) *service.User {
|
||||
return nil
|
||||
}
|
||||
return &service.User{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
||||
TotpEnabled: u.TotpEnabled,
|
||||
TotpEnabledAt: u.TotpEnabledAt,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -35,7 +35,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
var orgs []struct {
|
||||
UUID string `json:"uuid"`
|
||||
UUID string `json:"uuid"`
|
||||
Name string `json:"name"`
|
||||
RavenType *string `json:"raven_type"` // nil for personal, "team" for team organization
|
||||
}
|
||||
|
||||
targetURL := s.baseURL + "/api/organizations"
|
||||
@@ -65,7 +67,23 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
||||
return "", fmt.Errorf("no organizations found")
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
|
||||
// 如果只有一个组织,直接使用
|
||||
if len(orgs) == 1 {
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||
return orgs[0].UUID, nil
|
||||
}
|
||||
|
||||
// 如果有多个组织,优先选择 raven_type 为 "team" 的组织
|
||||
for _, org := range orgs {
|
||||
if org.RavenType != nil && *org.RavenType == "team" {
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
|
||||
org.UUID, org.Name, *org.RavenType)
|
||||
return org.UUID, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有 team 类型的组织,使用第一个
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||
return orgs[0].UUID, nil
|
||||
}
|
||||
|
||||
@@ -182,7 +200,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Accept", "application/json, text/plain, */*").
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("User-Agent", "axios/1.8.4").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
@@ -205,8 +225,6 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
// 使用 JSON 格式(与 ExchangeCodeForToken 保持一致)
|
||||
// Anthropic OAuth API 期望 JSON 格式的请求体
|
||||
reqBody := map[string]any{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refreshToken,
|
||||
@@ -217,7 +235,9 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Accept", "application/json, text/plain, */*").
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("User-Agent", "axios/1.8.4").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
@@ -171,7 +171,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
|
||||
s.client.baseURL = "http://in-process"
|
||||
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
||||
|
||||
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
|
||||
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
|
||||
@@ -14,37 +14,82 @@ import (
|
||||
|
||||
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||
|
||||
// 默认 User-Agent,与用户抓包的请求一致
|
||||
const defaultUsageUserAgent = "claude-code/2.1.7"
|
||||
|
||||
type claudeUsageService struct {
|
||||
usageURL string
|
||||
allowPrivateHosts bool
|
||||
httpUpstream service.HTTPUpstream
|
||||
}
|
||||
|
||||
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
|
||||
// NewClaudeUsageFetcher 创建 Claude 用量获取服务
|
||||
// httpUpstream: 可选,如果提供则支持 TLS 指纹伪装
|
||||
func NewClaudeUsageFetcher(httpUpstream service.HTTPUpstream) service.ClaudeUsageFetcher {
|
||||
return &claudeUsageService{
|
||||
usageURL: defaultClaudeUsageURL,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
// FetchUsage 简单版本,不支持 TLS 指纹(向后兼容)
|
||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: s.allowPrivateHosts,
|
||||
return s.FetchUsageWithOptions(ctx, &service.ClaudeUsageFetchOptions{
|
||||
AccessToken: accessToken,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
// FetchUsageWithOptions 完整版本,支持 TLS 指纹和自定义 User-Agent
|
||||
func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *service.ClaudeUsageFetchOptions) (*service.ClaudeUsageResponse, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("options is nil")
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
// 设置请求头(与抓包一致,但不设置 Accept-Encoding,让 Go 自动处理压缩)
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+opts.AccessToken)
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
// 设置 User-Agent(优先使用缓存的 Fingerprint,否则使用默认值)
|
||||
userAgent := defaultUsageUserAgent
|
||||
if opts.Fingerprint != nil && opts.Fingerprint.UserAgent != "" {
|
||||
userAgent = opts.Fingerprint.UserAgent
|
||||
}
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
var resp *http.Response
|
||||
|
||||
// 如果启用 TLS 指纹且有 HTTPUpstream,使用 DoWithTLS
|
||||
if opts.EnableTLSFingerprint && s.httpUpstream != nil {
|
||||
// accountConcurrency 传 0 使用默认连接池配置,usage 请求不需要特殊的并发设置
|
||||
resp, err = s.httpUpstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, 0, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request with TLS fingerprint failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
// 不启用 TLS 指纹,使用普通 HTTP 客户端
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: opts.ProxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: s.allowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
|
||||
@@ -77,6 +77,75 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
if r == nil || r.sql == nil {
|
||||
return nil
|
||||
}
|
||||
loc := timezone.Location()
|
||||
startLocal := start.In(loc)
|
||||
endLocal := end.In(loc)
|
||||
if !endLocal.After(startLocal) {
|
||||
return nil
|
||||
}
|
||||
|
||||
hourStart := startLocal.Truncate(time.Hour)
|
||||
hourEnd := endLocal.Truncate(time.Hour)
|
||||
if endLocal.After(hourEnd) {
|
||||
hourEnd = hourEnd.Add(time.Hour)
|
||||
}
|
||||
|
||||
dayStart := truncateToDay(startLocal)
|
||||
dayEnd := truncateToDay(endLocal)
|
||||
if endLocal.After(dayEnd) {
|
||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
// 尽量使用事务保证范围内的一致性(允许在非 *sql.DB 的情况下退化为非事务执行)。
|
||||
if db, ok := r.sql.(*sql.DB); ok {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
|
||||
if err := txRepo.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
return r.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) recomputeRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
|
||||
// 先清空范围内桶,再重建(避免仅增量插入导致活跃用户等指标无法回退)。
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
var ts time.Time
|
||||
query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1"
|
||||
|
||||
@@ -9,13 +9,27 @@ import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const verifyCodeKeyPrefix = "verify_code:"
|
||||
const (
|
||||
verifyCodeKeyPrefix = "verify_code:"
|
||||
passwordResetKeyPrefix = "password_reset:"
|
||||
passwordResetSentAtKeyPrefix = "password_reset_sent:"
|
||||
)
|
||||
|
||||
// verifyCodeKey generates the Redis key for email verification code.
|
||||
func verifyCodeKey(email string) string {
|
||||
return verifyCodeKeyPrefix + email
|
||||
}
|
||||
|
||||
// passwordResetKey generates the Redis key for password reset token.
|
||||
func passwordResetKey(email string) string {
|
||||
return passwordResetKeyPrefix + email
|
||||
}
|
||||
|
||||
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
|
||||
func passwordResetSentAtKey(email string) string {
|
||||
return passwordResetSentAtKeyPrefix + email
|
||||
}
|
||||
|
||||
type emailCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
@@ -50,3 +64,45 @@ func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) e
|
||||
key := verifyCodeKey(email)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// Password reset token methods
|
||||
|
||||
func (c *emailCache) GetPasswordResetToken(ctx context.Context, email string) (*service.PasswordResetTokenData, error) {
|
||||
key := passwordResetKey(email)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data service.PasswordResetTokenData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *emailCache) SetPasswordResetToken(ctx context.Context, email string, data *service.PasswordResetTokenData, ttl time.Duration) error {
|
||||
key := passwordResetKey(email)
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *emailCache) DeletePasswordResetToken(ctx context.Context, email string) error {
|
||||
key := passwordResetKey(email)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// Password reset email cooldown methods
|
||||
|
||||
func (c *emailCache) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
|
||||
key := passwordResetSentAtKey(email)
|
||||
exists, err := c.rdb.Exists(ctx, key).Result()
|
||||
return err == nil && exists > 0
|
||||
}
|
||||
|
||||
func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
|
||||
key := passwordResetSentAtKey(email)
|
||||
return c.rdb.Set(ctx, key, "1", ttl).Err()
|
||||
}
|
||||
|
||||
@@ -65,5 +65,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
|
||||
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
|
||||
client := ent.NewClient(ent.Driver(drv))
|
||||
|
||||
// SIMPLE 模式:启动时补齐各平台默认分组。
|
||||
// - anthropic/openai/gemini: 确保存在 <platform>-default
|
||||
// - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景)
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
seedCtx, seedCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer seedCancel()
|
||||
if err := ensureSimpleModeDefaultGroups(seedCtx, client); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return client, drv.DB(), nil
|
||||
}
|
||||
|
||||
@@ -39,3 +39,15 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Expire(ctx, key, ttl).Err()
|
||||
}
|
||||
|
||||
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
|
||||
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
|
||||
// 以便下次请求能够重新选择可用账号。
|
||||
//
|
||||
// DeleteSessionAccountID removes the sticky session binding for the given session.
|
||||
// Called when the bound account becomes unavailable (e.g., error status, disabled,
|
||||
// or unschedulable), allowing subsequent requests to select a new available account.
|
||||
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
@@ -78,6 +78,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
|
||||
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestDeleteSessionAccountID() {
|
||||
sessionID := "openai:s4"
|
||||
accountID := int64(102)
|
||||
groupID := int64(1)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
require.NoError(s.T(), s.cache.DeleteSessionAccountID(s.ctx, groupID, sessionID), "DeleteSessionAccountID")
|
||||
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||
sessionID := "corrupted"
|
||||
groupID := int64(1)
|
||||
|
||||
@@ -24,7 +24,7 @@ func (s *GatewayRoutingSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx)
|
||||
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx, nil)
|
||||
}
|
||||
|
||||
func TestGatewayRoutingSuite(t *testing.T) {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
@@ -150,6 +152,172 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
|
||||
// 根据 enableTLSFingerprint 参数决定是否使用 TLS 指纹
|
||||
//
|
||||
// 参数:
|
||||
// - req: HTTP 请求对象
|
||||
// - proxyURL: 代理地址,空字符串表示直连
|
||||
// - accountID: 账户 ID,用于账户级隔离和 TLS 指纹模板选择
|
||||
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
|
||||
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
|
||||
//
|
||||
// TLS 指纹说明:
|
||||
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
|
||||
// - 指纹模板根据 accountID % len(profiles) 自动选择
|
||||
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
|
||||
func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
// 如果未启用 TLS 指纹,直接使用标准请求路径
|
||||
if !enableTLSFingerprint {
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
// TLS 指纹已启用,记录调试日志
|
||||
targetHost := ""
|
||||
if req != nil && req.URL != nil {
|
||||
targetHost = req.URL.Host
|
||||
}
|
||||
proxyInfo := "direct"
|
||||
if proxyURL != "" {
|
||||
proxyInfo = proxyURL
|
||||
}
|
||||
slog.Debug("tls_fingerprint_enabled", "account_id", accountID, "target", targetHost, "proxy", proxyInfo)
|
||||
|
||||
if err := s.validateRequestHost(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取 TLS 指纹 Profile
|
||||
registry := tlsfingerprint.GlobalRegistry()
|
||||
profile := registry.GetProfileByAccountID(accountID)
|
||||
if profile == nil {
|
||||
// 如果获取不到 profile,回退到普通请求
|
||||
slog.Debug("tls_fingerprint_no_profile", "account_id", accountID, "fallback", "standard_request")
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
slog.Debug("tls_fingerprint_using_profile", "account_id", accountID, "profile", profile.Name, "grease", profile.EnableGREASE)
|
||||
|
||||
// 获取或创建带 TLS 指纹的客户端
|
||||
entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_acquire_client_failed", "account_id", accountID, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
resp, err := entry.client.Do(req)
|
||||
if err != nil {
|
||||
// 请求失败,立即减少计数
|
||||
atomic.AddInt64(&entry.inFlight, -1)
|
||||
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
|
||||
slog.Debug("tls_fingerprint_request_failed", "account_id", accountID, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Debug("tls_fingerprint_request_success", "account_id", accountID, "status", resp.StatusCode)
|
||||
|
||||
// 包装响应体,在关闭时自动减少计数并更新时间戳
|
||||
resp.Body = wrapTrackedBody(resp.Body, func() {
|
||||
atomic.AddInt64(&entry.inFlight, -1)
|
||||
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
|
||||
})
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// acquireClientWithTLS 获取或创建带 TLS 指纹的客户端
|
||||
func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*upstreamClientEntry, error) {
|
||||
return s.getClientEntryWithTLS(proxyURL, accountID, accountConcurrency, profile, true, true)
|
||||
}
|
||||
|
||||
// getClientEntryWithTLS 获取或创建带 TLS 指纹的客户端条目
|
||||
// TLS 指纹客户端使用独立的缓存键,与普通客户端隔离
|
||||
func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
|
||||
isolation := s.getIsolationMode()
|
||||
proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
|
||||
// TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀
|
||||
cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID)
|
||||
poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls"
|
||||
|
||||
now := time.Now()
|
||||
nowUnix := now.UnixNano()
|
||||
|
||||
// 读锁快速路径
|
||||
s.mu.RLock()
|
||||
if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.AddInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey)
|
||||
return entry, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
// 写锁慢路径
|
||||
s.mu.Lock()
|
||||
if entry, ok := s.clients[cacheKey]; ok {
|
||||
if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.AddInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey)
|
||||
return entry, nil
|
||||
}
|
||||
slog.Debug("tls_fingerprint_evicting_stale_client",
|
||||
"account_id", accountID,
|
||||
"cache_key", cacheKey,
|
||||
"proxy_changed", entry.proxyKey != proxyKey,
|
||||
"pool_changed", entry.poolKey != poolKey)
|
||||
s.removeClientLocked(cacheKey, entry)
|
||||
}
|
||||
|
||||
// 超出缓存上限时尝试淘汰
|
||||
if enforceLimit && s.maxUpstreamClients() > 0 {
|
||||
s.evictIdleLocked(now)
|
||||
if len(s.clients) >= s.maxUpstreamClients() {
|
||||
if !s.evictOldestIdleLocked() {
|
||||
s.mu.Unlock()
|
||||
return nil, errUpstreamClientLimitReached
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建带 TLS 指纹的 Transport
|
||||
slog.Debug("tls_fingerprint_creating_new_client", "account_id", accountID, "cache_key", cacheKey, "proxy", proxyKey)
|
||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||
transport, err := buildUpstreamTransportWithTLSFingerprint(settings, parsedProxy, profile)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return nil, fmt.Errorf("build TLS fingerprint transport: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
if s.shouldValidateResolvedIP() {
|
||||
client.CheckRedirect = s.redirectChecker
|
||||
}
|
||||
|
||||
entry := &upstreamClientEntry{
|
||||
client: client,
|
||||
proxyKey: proxyKey,
|
||||
poolKey: poolKey,
|
||||
}
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.StoreInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.clients[cacheKey] = entry
|
||||
|
||||
s.evictIdleLocked(now)
|
||||
s.evictOverLimitLocked()
|
||||
s.mu.Unlock()
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
|
||||
if s.cfg == nil {
|
||||
return false
|
||||
@@ -618,6 +786,64 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
// buildUpstreamTransportWithTLSFingerprint 构建带 TLS 指纹伪装的 Transport
|
||||
// 使用 utls 库模拟 Claude CLI 的 TLS 指纹
|
||||
//
|
||||
// 参数:
|
||||
// - settings: 连接池配置
|
||||
// - proxyURL: 代理 URL(nil 表示直连)
|
||||
// - profile: TLS 指纹配置
|
||||
//
|
||||
// 返回:
|
||||
// - *http.Transport: 配置好的 Transport 实例
|
||||
// - error: 配置错误
|
||||
//
|
||||
// 代理类型处理:
|
||||
// - nil/空: 直连,使用 TLSFingerprintDialer
|
||||
// - http/https: HTTP 代理,使用 HTTPProxyDialer(CONNECT 隧道 + utls 握手)
|
||||
// - socks5: SOCKS5 代理,使用 SOCKS5ProxyDialer(SOCKS5 隧道 + utls 握手)
|
||||
func buildUpstreamTransportWithTLSFingerprint(settings poolSettings, proxyURL *url.URL, profile *tlsfingerprint.Profile) (*http.Transport, error) {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: settings.maxIdleConns,
|
||||
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: settings.maxConnsPerHost,
|
||||
IdleConnTimeout: settings.idleConnTimeout,
|
||||
ResponseHeaderTimeout: settings.responseHeaderTimeout,
|
||||
// 禁用默认的 TLS,我们使用自定义的 DialTLSContext
|
||||
ForceAttemptHTTP2: false,
|
||||
}
|
||||
|
||||
// 根据代理类型选择合适的 TLS 指纹 Dialer
|
||||
if proxyURL == nil {
|
||||
// 直连:使用 TLSFingerprintDialer
|
||||
slog.Debug("tls_fingerprint_transport_direct")
|
||||
dialer := tlsfingerprint.NewDialer(profile, nil)
|
||||
transport.DialTLSContext = dialer.DialTLSContext
|
||||
} else {
|
||||
scheme := strings.ToLower(proxyURL.Scheme)
|
||||
switch scheme {
|
||||
case "socks5", "socks5h":
|
||||
// SOCKS5 代理:使用 SOCKS5ProxyDialer
|
||||
slog.Debug("tls_fingerprint_transport_socks5", "proxy", proxyURL.Host)
|
||||
socks5Dialer := tlsfingerprint.NewSOCKS5ProxyDialer(profile, proxyURL)
|
||||
transport.DialTLSContext = socks5Dialer.DialTLSContext
|
||||
case "http", "https":
|
||||
// HTTP/HTTPS 代理:使用 HTTPProxyDialer(CONNECT 隧道)
|
||||
slog.Debug("tls_fingerprint_transport_http_connect", "proxy", proxyURL.Host)
|
||||
httpDialer := tlsfingerprint.NewHTTPProxyDialer(profile, proxyURL)
|
||||
transport.DialTLSContext = httpDialer.DialTLSContext
|
||||
default:
|
||||
// 未知代理类型,回退到普通代理配置(无 TLS 指纹)
|
||||
slog.Debug("tls_fingerprint_transport_unknown_scheme_fallback", "scheme", scheme)
|
||||
if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
// trackedBody 带跟踪功能的响应体包装器
|
||||
// 在 Close 时执行回调,用于更新请求计数
|
||||
type trackedBody struct {
|
||||
|
||||
@@ -11,8 +11,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
fingerprintKeyPrefix = "fingerprint:"
|
||||
fingerprintTTL = 24 * time.Hour
|
||||
fingerprintKeyPrefix = "fingerprint:"
|
||||
fingerprintTTL = 24 * time.Hour
|
||||
maskedSessionKeyPrefix = "masked_session:"
|
||||
maskedSessionTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
// fingerprintKey generates the Redis key for account fingerprint cache.
|
||||
@@ -20,6 +22,11 @@ func fingerprintKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
// maskedSessionKey generates the Redis key for masked session ID cache.
|
||||
func maskedSessionKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", maskedSessionKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
type identityCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
@@ -49,3 +56,20 @@ func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
|
||||
}
|
||||
|
||||
func (c *identityCache) GetMaskedSessionID(ctx context.Context, accountID int64) (string, error) {
|
||||
key := maskedSessionKey(accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *identityCache) SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error {
|
||||
key := maskedSessionKey(accountID)
|
||||
return c.rdb.Set(ctx, key, sessionID, maskedSessionTTL).Err()
|
||||
}
|
||||
|
||||
@@ -2,10 +2,11 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/imroc/req/v3"
|
||||
@@ -38,16 +39,17 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("User-Agent", "codex-cli/0.91.0").
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_EXCHANGE_FAILED", "token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
@@ -66,16 +68,17 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("User-Agent", "codex-cli/0.91.0").
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
@@ -84,6 +87,6 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
func createOpenAIReqClient(proxyURL string) *req.Client {
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 60 * time.Second,
|
||||
Timeout: 120 * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
|
||||
require.ErrorContains(s.T(), err, "status 401")
|
||||
}
|
||||
|
||||
func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) {
|
||||
client := NewOpenAIOAuthClient()
|
||||
svc, ok := client.(*openaiOAuthService)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, openai.TokenURL, svc.tokenURL)
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(OpenAIOAuthServiceSuite))
|
||||
}
|
||||
|
||||
@@ -992,7 +992,8 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
}
|
||||
|
||||
// View filter: errors vs excluded vs all.
|
||||
// Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors.
|
||||
// Excluded = business-limited errors (quota/concurrency/billing).
|
||||
// Upstream 429/529 are included in errors view to match SLA calculation.
|
||||
view := ""
|
||||
if filter != nil {
|
||||
view = strings.ToLower(strings.TrimSpace(filter.View))
|
||||
@@ -1000,15 +1001,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
switch view {
|
||||
case "", "errors":
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
||||
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
|
||||
case "excluded":
|
||||
clauses = append(clauses, "(COALESCE(is_business_limited,false) = true OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))")
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = true")
|
||||
case "all":
|
||||
// no-op
|
||||
default:
|
||||
// treat unknown as default 'errors'
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
||||
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
|
||||
}
|
||||
if len(filter.StatusCodes) > 0 {
|
||||
args = append(args, pq.Array(filter.StatusCodes))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -26,7 +27,7 @@ func InitRedis(cfg *config.Config) *redis.Client {
|
||||
// buildRedisOptions 构建 Redis 连接选项
|
||||
// 从配置文件读取连接池和超时参数,支持生产环境调优
|
||||
func buildRedisOptions(cfg *config.Config) *redis.Options {
|
||||
return &redis.Options{
|
||||
opts := &redis.Options{
|
||||
Addr: cfg.Redis.Address(),
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
@@ -36,4 +37,13 @@ func buildRedisOptions(cfg *config.Config) *redis.Options {
|
||||
PoolSize: cfg.Redis.PoolSize, // 连接池大小
|
||||
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
|
||||
}
|
||||
|
||||
if cfg.Redis.EnableTLS {
|
||||
opts.TLSConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: cfg.Redis.Host,
|
||||
}
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
@@ -32,4 +32,16 @@ func TestBuildRedisOptions(t *testing.T) {
|
||||
require.Equal(t, 4*time.Second, opts.WriteTimeout)
|
||||
require.Equal(t, 100, opts.PoolSize)
|
||||
require.Equal(t, 10, opts.MinIdleConns)
|
||||
require.Nil(t, opts.TLSConfig)
|
||||
|
||||
// Test case with TLS enabled
|
||||
cfgTLS := &config.Config{
|
||||
Redis: config.RedisConfig{
|
||||
Host: "localhost",
|
||||
EnableTLS: true,
|
||||
},
|
||||
}
|
||||
optsTLS := buildRedisOptions(cfgTLS)
|
||||
require.NotNil(t, optsTLS.TLSConfig)
|
||||
require.Equal(t, "localhost", optsTLS.TLSConfig.ServerName)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ type reqClientOptions struct {
|
||||
ProxyURL string // 代理 URL(支持 http/https/socks5)
|
||||
Timeout time.Duration // 请求超时时间
|
||||
Impersonate bool // 是否模拟 Chrome 浏览器指纹
|
||||
ForceHTTP2 bool // 是否强制使用 HTTP/2
|
||||
}
|
||||
|
||||
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
|
||||
@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
|
||||
}
|
||||
|
||||
client := req.C().SetTimeout(opts.Timeout)
|
||||
if opts.ForceHTTP2 {
|
||||
client = client.EnableForceHTTP2()
|
||||
}
|
||||
if opts.Impersonate {
|
||||
client = client.ImpersonateChrome()
|
||||
}
|
||||
@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
|
||||
}
|
||||
|
||||
func buildReqClientKey(opts reqClientOptions) string {
|
||||
return fmt.Sprintf("%s|%s|%t",
|
||||
return fmt.Sprintf("%s|%s|%t|%t",
|
||||
strings.TrimSpace(opts.ProxyURL),
|
||||
opts.Timeout.String(),
|
||||
opts.Impersonate,
|
||||
opts.ForceHTTP2,
|
||||
)
|
||||
}
|
||||
|
||||
90
backend/internal/repository/req_client_pool_test.go
Normal file
90
backend/internal/repository/req_client_pool_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func forceHTTPVersion(t *testing.T, client *req.Client) string {
|
||||
t.Helper()
|
||||
transport := client.GetTransport()
|
||||
field := reflect.ValueOf(transport).Elem().FieldByName("forceHttpVersion")
|
||||
require.True(t, field.IsValid(), "forceHttpVersion field not found")
|
||||
require.True(t, field.CanAddr(), "forceHttpVersion field not addressable")
|
||||
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().String()
|
||||
}
|
||||
|
||||
func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
base := reqClientOptions{
|
||||
ProxyURL: "http://proxy.local:8080",
|
||||
Timeout: time.Second,
|
||||
}
|
||||
clientDefault := getSharedReqClient(base)
|
||||
|
||||
force := base
|
||||
force.ForceHTTP2 = true
|
||||
clientForce := getSharedReqClient(force)
|
||||
|
||||
require.NotSame(t, clientDefault, clientForce)
|
||||
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
|
||||
}
|
||||
|
||||
func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
opts := reqClientOptions{
|
||||
ProxyURL: "http://proxy.local:8080",
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
first := getSharedReqClient(opts)
|
||||
second := getSharedReqClient(opts)
|
||||
require.Same(t, first, second)
|
||||
}
|
||||
|
||||
func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
opts := reqClientOptions{
|
||||
ProxyURL: " http://proxy.local:8080 ",
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
key := buildReqClientKey(opts)
|
||||
sharedReqClients.Store(key, "invalid")
|
||||
|
||||
client := getSharedReqClient(opts)
|
||||
|
||||
require.NotNil(t, client)
|
||||
loaded, ok := sharedReqClients.Load(key)
|
||||
require.True(t, ok)
|
||||
require.IsType(t, "invalid", loaded)
|
||||
}
|
||||
|
||||
func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
opts := reqClientOptions{
|
||||
ProxyURL: " http://proxy.local:8080 ",
|
||||
Timeout: 4 * time.Second,
|
||||
Impersonate: true,
|
||||
}
|
||||
client := getSharedReqClient(opts)
|
||||
|
||||
require.NotNil(t, client)
|
||||
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
|
||||
}
|
||||
|
||||
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
client := createOpenAIReqClient("http://proxy.local:8080")
|
||||
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
|
||||
}
|
||||
|
||||
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
|
||||
sharedReqClients = sync.Map{}
|
||||
client := createGeminiReqClient("http://proxy.local:8080")
|
||||
require.Equal(t, "", forceHTTPVersion(t, client))
|
||||
}
|
||||
@@ -58,7 +58,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
|
||||
return nil, false, err
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return []*service.Account{}, true, nil
|
||||
// 空快照视为缓存未命中,触发数据库回退查询
|
||||
// 这解决了新分组创建后立即绑定账号时的竞态条件问题
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(ids))
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
|
||||
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
|
||||
|
||||
accountRepo := newAccountRepositoryWithSQL(client, integrationDB)
|
||||
accountRepo := newAccountRepositoryWithSQL(client, integrationDB, nil)
|
||||
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
|
||||
cache := NewSchedulerCache(rdb)
|
||||
|
||||
|
||||
@@ -217,7 +217,7 @@ func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
|
||||
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return make(map[int64]int), nil
|
||||
}
|
||||
@@ -226,11 +226,18 @@ func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, acco
|
||||
|
||||
// 使用 pipeline 批量执行
|
||||
pipe := c.rdb.Pipeline()
|
||||
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
|
||||
|
||||
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
key := sessionLimitKey(accountID)
|
||||
// 使用各账号自己的 idleTimeout,如果没有则用默认值
|
||||
idleTimeout := c.defaultIdleTimeout
|
||||
if idleTimeouts != nil {
|
||||
if t, ok := idleTimeouts[accountID]; ok && t > 0 {
|
||||
idleTimeout = t
|
||||
}
|
||||
}
|
||||
idleTimeoutSeconds := int(idleTimeout.Seconds())
|
||||
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
|
||||
}
|
||||
|
||||
|
||||
82
backend/internal/repository/simple_mode_default_groups.go
Normal file
82
backend/internal/repository/simple_mode_default_groups.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) error {
|
||||
if client == nil {
|
||||
return fmt.Errorf("nil ent client")
|
||||
}
|
||||
|
||||
requiredByPlatform := map[string]int{
|
||||
service.PlatformAnthropic: 1,
|
||||
service.PlatformOpenAI: 1,
|
||||
service.PlatformGemini: 1,
|
||||
service.PlatformAntigravity: 2,
|
||||
}
|
||||
|
||||
for platform, minCount := range requiredByPlatform {
|
||||
count, err := client.Group.Query().
|
||||
Where(group.PlatformEQ(platform), group.DeletedAtIsNil()).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("count groups for platform %s: %w", platform, err)
|
||||
}
|
||||
|
||||
if platform == service.PlatformAntigravity {
|
||||
if count < minCount {
|
||||
for i := count; i < minCount; i++ {
|
||||
name := fmt.Sprintf("%s-default-%d", platform, i+1)
|
||||
if err := createGroupIfNotExists(ctx, client, name, platform); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Non-antigravity platforms: ensure <platform>-default exists.
|
||||
name := platform + "-default"
|
||||
if err := createGroupIfNotExists(ctx, client, name, platform); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createGroupIfNotExists(ctx context.Context, client *dbent.Client, name, platform string) error {
|
||||
exists, err := client.Group.Query().
|
||||
Where(group.NameEQ(name), group.DeletedAtIsNil()).
|
||||
Exist(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check group exists %s: %w", name, err)
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = client.Group.Create().
|
||||
SetName(name).
|
||||
SetDescription("Auto-created default group").
|
||||
SetPlatform(platform).
|
||||
SetStatus(service.StatusActive).
|
||||
SetSubscriptionType(service.SubscriptionTypeStandard).
|
||||
SetRateMultiplier(1.0).
|
||||
SetIsExclusive(false).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsConstraintError(err) {
|
||||
// Concurrent server startups may race on creation; treat as success.
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("create default group %s: %w", name, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEnsureSimpleModeDefaultGroups_CreatesMissingDefaults(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
client := tx.Client()
|
||||
|
||||
seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
|
||||
|
||||
assertGroupExists := func(name string) {
|
||||
exists, err := client.Group.Query().Where(group.NameEQ(name), group.DeletedAtIsNil()).Exist(seedCtx)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists, "expected group %s to exist", name)
|
||||
}
|
||||
|
||||
assertGroupExists(service.PlatformAnthropic + "-default")
|
||||
assertGroupExists(service.PlatformOpenAI + "-default")
|
||||
assertGroupExists(service.PlatformGemini + "-default")
|
||||
assertGroupExists(service.PlatformAntigravity + "-default-1")
|
||||
assertGroupExists(service.PlatformAntigravity + "-default-2")
|
||||
}
|
||||
|
||||
func TestEnsureSimpleModeDefaultGroups_IgnoresSoftDeletedGroups(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
client := tx.Client()
|
||||
|
||||
seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Create and then soft-delete an anthropic default group.
|
||||
g, err := client.Group.Create().
|
||||
SetName(service.PlatformAnthropic + "-default").
|
||||
SetPlatform(service.PlatformAnthropic).
|
||||
SetStatus(service.StatusActive).
|
||||
SetSubscriptionType(service.SubscriptionTypeStandard).
|
||||
SetRateMultiplier(1.0).
|
||||
SetIsExclusive(false).
|
||||
Save(seedCtx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.Group.Delete().Where(group.IDEQ(g.ID)).Exec(seedCtx)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
|
||||
|
||||
// New active one should exist.
|
||||
count, err := client.Group.Query().Where(group.NameEQ(service.PlatformAnthropic+"-default"), group.DeletedAtIsNil()).Count(seedCtx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestEnsureSimpleModeDefaultGroups_AntigravityNeedsTwoGroupsOnlyByCount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
client := tx.Client()
|
||||
|
||||
seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mustCreateGroup(t, client, &service.Group{Name: "ag-custom-1-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity})
|
||||
mustCreateGroup(t, client, &service.Group{Name: "ag-custom-2-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity})
|
||||
|
||||
require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
|
||||
|
||||
count, err := client.Group.Query().Where(group.PlatformEQ(service.PlatformAntigravity), group.DeletedAtIsNil()).Count(seedCtx)
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, count, 2)
|
||||
}
|
||||
149
backend/internal/repository/totp_cache.go
Normal file
149
backend/internal/repository/totp_cache.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const (
|
||||
totpSetupKeyPrefix = "totp:setup:"
|
||||
totpLoginKeyPrefix = "totp:login:"
|
||||
totpAttemptsKeyPrefix = "totp:attempts:"
|
||||
totpAttemptsTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
// TotpCache implements service.TotpCache using Redis
|
||||
type TotpCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewTotpCache creates a new TOTP cache
|
||||
func NewTotpCache(rdb *redis.Client) service.TotpCache {
|
||||
return &TotpCache{rdb: rdb}
|
||||
}
|
||||
|
||||
// GetSetupSession retrieves a TOTP setup session
|
||||
func (c *TotpCache) GetSetupSession(ctx context.Context, userID int64) (*service.TotpSetupSession, error) {
|
||||
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
|
||||
data, err := c.rdb.Get(ctx, key).Bytes()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get setup session: %w", err)
|
||||
}
|
||||
|
||||
var session service.TotpSetupSession
|
||||
if err := json.Unmarshal(data, &session); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal setup session: %w", err)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// SetSetupSession stores a TOTP setup session
|
||||
func (c *TotpCache) SetSetupSession(ctx context.Context, userID int64, session *service.TotpSetupSession, ttl time.Duration) error {
|
||||
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
|
||||
data, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal setup session: %w", err)
|
||||
}
|
||||
|
||||
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
|
||||
return fmt.Errorf("set setup session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSetupSession deletes a TOTP setup session
|
||||
func (c *TotpCache) DeleteSetupSession(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// GetLoginSession retrieves a TOTP login session
|
||||
func (c *TotpCache) GetLoginSession(ctx context.Context, tempToken string) (*service.TotpLoginSession, error) {
|
||||
key := totpLoginKeyPrefix + tempToken
|
||||
data, err := c.rdb.Get(ctx, key).Bytes()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get login session: %w", err)
|
||||
}
|
||||
|
||||
var session service.TotpLoginSession
|
||||
if err := json.Unmarshal(data, &session); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal login session: %w", err)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// SetLoginSession stores a TOTP login session
|
||||
func (c *TotpCache) SetLoginSession(ctx context.Context, tempToken string, session *service.TotpLoginSession, ttl time.Duration) error {
|
||||
key := totpLoginKeyPrefix + tempToken
|
||||
data, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal login session: %w", err)
|
||||
}
|
||||
|
||||
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
|
||||
return fmt.Errorf("set login session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteLoginSession deletes a TOTP login session
|
||||
func (c *TotpCache) DeleteLoginSession(ctx context.Context, tempToken string) error {
|
||||
key := totpLoginKeyPrefix + tempToken
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// IncrementVerifyAttempts increments the verify attempt counter
|
||||
func (c *TotpCache) IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
|
||||
|
||||
// Use pipeline for atomic increment and set TTL
|
||||
pipe := c.rdb.Pipeline()
|
||||
incrCmd := pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, totpAttemptsTTL)
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return 0, fmt.Errorf("increment verify attempts: %w", err)
|
||||
}
|
||||
|
||||
count, err := incrCmd.Result()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get increment result: %w", err)
|
||||
}
|
||||
|
||||
return int(count), nil
|
||||
}
|
||||
|
||||
// GetVerifyAttempts gets the current verify attempt count
|
||||
func (c *TotpCache) GetVerifyAttempts(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
|
||||
count, err := c.rdb.Get(ctx, key).Int()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, fmt.Errorf("get verify attempts: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ClearVerifyAttempts clears the verify attempt counter
|
||||
func (c *TotpCache) ClearVerifyAttempts(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
551
backend/internal/repository/usage_cleanup_repo.go
Normal file
551
backend/internal/repository/usage_cleanup_repo.go
Normal file
@@ -0,0 +1,551 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
dbusagecleanuptask "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type usageCleanupRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
func NewUsageCleanupRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageCleanupRepository {
|
||||
return newUsageCleanupRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func newUsageCleanupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageCleanupRepository {
|
||||
return &usageCleanupRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
if r.client != nil {
|
||||
return r.createTaskWithEnt(ctx, task)
|
||||
}
|
||||
return r.createTaskWithSQL(ctx, task)
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
|
||||
if r.client != nil {
|
||||
return r.listTasksWithEnt(ctx, params)
|
||||
}
|
||||
var total int64
|
||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM usage_cleanup_tasks", nil, &total); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if total == 0 {
|
||||
return []service.UsageCleanupTask{}, paginationResultFromTotal(0, params), nil
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, status, filters, created_by, deleted_rows, error_message,
|
||||
canceled_by, canceled_at,
|
||||
started_at, finished_at, created_at, updated_at
|
||||
FROM usage_cleanup_tasks
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, params.Limit(), params.Offset())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
tasks := make([]service.UsageCleanupTask, 0)
|
||||
for rows.Next() {
|
||||
var task service.UsageCleanupTask
|
||||
var filtersJSON []byte
|
||||
var errMsg sql.NullString
|
||||
var canceledBy sql.NullInt64
|
||||
var canceledAt sql.NullTime
|
||||
var startedAt sql.NullTime
|
||||
var finishedAt sql.NullTime
|
||||
if err := rows.Scan(
|
||||
&task.ID,
|
||||
&task.Status,
|
||||
&filtersJSON,
|
||||
&task.CreatedBy,
|
||||
&task.DeletedRows,
|
||||
&errMsg,
|
||||
&canceledBy,
|
||||
&canceledAt,
|
||||
&startedAt,
|
||||
&finishedAt,
|
||||
&task.CreatedAt,
|
||||
&task.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil {
|
||||
return nil, nil, fmt.Errorf("parse cleanup filters: %w", err)
|
||||
}
|
||||
if errMsg.Valid {
|
||||
task.ErrorMsg = &errMsg.String
|
||||
}
|
||||
if canceledBy.Valid {
|
||||
v := canceledBy.Int64
|
||||
task.CanceledBy = &v
|
||||
}
|
||||
if canceledAt.Valid {
|
||||
task.CanceledAt = &canceledAt.Time
|
||||
}
|
||||
if startedAt.Valid {
|
||||
task.StartedAt = &startedAt.Time
|
||||
}
|
||||
if finishedAt.Valid {
|
||||
task.FinishedAt = &finishedAt.Time
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return tasks, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
|
||||
if staleRunningAfterSeconds <= 0 {
|
||||
staleRunningAfterSeconds = 1800
|
||||
}
|
||||
query := `
|
||||
WITH next AS (
|
||||
SELECT id
|
||||
FROM usage_cleanup_tasks
|
||||
WHERE status = $1
|
||||
OR (
|
||||
status = $2
|
||||
AND started_at IS NOT NULL
|
||||
AND started_at < NOW() - ($3 * interval '1 second')
|
||||
)
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
FOR UPDATE SKIP LOCKED
|
||||
)
|
||||
UPDATE usage_cleanup_tasks AS tasks
|
||||
SET status = $4,
|
||||
started_at = NOW(),
|
||||
finished_at = NULL,
|
||||
error_message = NULL,
|
||||
updated_at = NOW()
|
||||
FROM next
|
||||
WHERE tasks.id = next.id
|
||||
RETURNING tasks.id, tasks.status, tasks.filters, tasks.created_by, tasks.deleted_rows, tasks.error_message,
|
||||
tasks.started_at, tasks.finished_at, tasks.created_at, tasks.updated_at
|
||||
`
|
||||
var task service.UsageCleanupTask
|
||||
var filtersJSON []byte
|
||||
var errMsg sql.NullString
|
||||
var startedAt sql.NullTime
|
||||
var finishedAt sql.NullTime
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{
|
||||
service.UsageCleanupStatusPending,
|
||||
service.UsageCleanupStatusRunning,
|
||||
staleRunningAfterSeconds,
|
||||
service.UsageCleanupStatusRunning,
|
||||
},
|
||||
&task.ID,
|
||||
&task.Status,
|
||||
&filtersJSON,
|
||||
&task.CreatedBy,
|
||||
&task.DeletedRows,
|
||||
&errMsg,
|
||||
&startedAt,
|
||||
&finishedAt,
|
||||
&task.CreatedAt,
|
||||
&task.UpdatedAt,
|
||||
); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil {
|
||||
return nil, fmt.Errorf("parse cleanup filters: %w", err)
|
||||
}
|
||||
if errMsg.Valid {
|
||||
task.ErrorMsg = &errMsg.String
|
||||
}
|
||||
if startedAt.Valid {
|
||||
task.StartedAt = &startedAt.Time
|
||||
}
|
||||
if finishedAt.Valid {
|
||||
task.FinishedAt = &finishedAt.Time
|
||||
}
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
|
||||
if r.client != nil {
|
||||
return r.getTaskStatusWithEnt(ctx, taskID)
|
||||
}
|
||||
var status string
|
||||
if err := scanSingleRow(ctx, r.sql, "SELECT status FROM usage_cleanup_tasks WHERE id = $1", []any{taskID}, &status); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
if r.client != nil {
|
||||
return r.updateTaskProgressWithEnt(ctx, taskID, deletedRows)
|
||||
}
|
||||
query := `
|
||||
UPDATE usage_cleanup_tasks
|
||||
SET deleted_rows = $1,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, deletedRows, taskID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
|
||||
if r.client != nil {
|
||||
return r.cancelTaskWithEnt(ctx, taskID, canceledBy)
|
||||
}
|
||||
query := `
|
||||
UPDATE usage_cleanup_tasks
|
||||
SET status = $1,
|
||||
canceled_by = $3,
|
||||
canceled_at = NOW(),
|
||||
finished_at = NOW(),
|
||||
error_message = NULL,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2
|
||||
AND status IN ($4, $5)
|
||||
RETURNING id
|
||||
`
|
||||
var id int64
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{
|
||||
service.UsageCleanupStatusCanceled,
|
||||
taskID,
|
||||
canceledBy,
|
||||
service.UsageCleanupStatusPending,
|
||||
service.UsageCleanupStatusRunning,
|
||||
}, &id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
if r.client != nil {
|
||||
return r.markTaskSucceededWithEnt(ctx, taskID, deletedRows)
|
||||
}
|
||||
query := `
|
||||
UPDATE usage_cleanup_tasks
|
||||
SET status = $1,
|
||||
deleted_rows = $2,
|
||||
finished_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $3
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusSucceeded, deletedRows, taskID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
|
||||
if r.client != nil {
|
||||
return r.markTaskFailedWithEnt(ctx, taskID, deletedRows, errorMsg)
|
||||
}
|
||||
query := `
|
||||
UPDATE usage_cleanup_tasks
|
||||
SET status = $1,
|
||||
deleted_rows = $2,
|
||||
error_message = $3,
|
||||
finished_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $4
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusFailed, deletedRows, errorMsg, taskID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
|
||||
if filters.StartTime.IsZero() || filters.EndTime.IsZero() {
|
||||
return 0, fmt.Errorf("cleanup filters missing time range")
|
||||
}
|
||||
whereClause, args := buildUsageCleanupWhere(filters)
|
||||
if whereClause == "" {
|
||||
return 0, fmt.Errorf("cleanup filters missing time range")
|
||||
}
|
||||
args = append(args, limit)
|
||||
query := fmt.Sprintf(`
|
||||
WITH target AS (
|
||||
SELECT id
|
||||
FROM usage_logs
|
||||
WHERE %s
|
||||
ORDER BY created_at ASC, id ASC
|
||||
LIMIT $%d
|
||||
)
|
||||
DELETE FROM usage_logs
|
||||
WHERE id IN (SELECT id FROM target)
|
||||
RETURNING id
|
||||
`, whereClause, len(args))
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var deleted int64
|
||||
for rows.Next() {
|
||||
deleted++
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any) {
|
||||
conditions := make([]string, 0, 8)
|
||||
args := make([]any, 0, 8)
|
||||
idx := 1
|
||||
if !filters.StartTime.IsZero() {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", idx))
|
||||
args = append(args, filters.StartTime)
|
||||
idx++
|
||||
}
|
||||
if !filters.EndTime.IsZero() {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", idx))
|
||||
args = append(args, filters.EndTime)
|
||||
idx++
|
||||
}
|
||||
if filters.UserID != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("user_id = $%d", idx))
|
||||
args = append(args, *filters.UserID)
|
||||
idx++
|
||||
}
|
||||
if filters.APIKeyID != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", idx))
|
||||
args = append(args, *filters.APIKeyID)
|
||||
idx++
|
||||
}
|
||||
if filters.AccountID != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("account_id = $%d", idx))
|
||||
args = append(args, *filters.AccountID)
|
||||
idx++
|
||||
}
|
||||
if filters.GroupID != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("group_id = $%d", idx))
|
||||
args = append(args, *filters.GroupID)
|
||||
idx++
|
||||
}
|
||||
if filters.Model != nil {
|
||||
model := strings.TrimSpace(*filters.Model)
|
||||
if model != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("model = $%d", idx))
|
||||
args = append(args, model)
|
||||
idx++
|
||||
}
|
||||
}
|
||||
if filters.Stream != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("stream = $%d", idx))
|
||||
args = append(args, *filters.Stream)
|
||||
idx++
|
||||
}
|
||||
if filters.BillingType != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", idx))
|
||||
args = append(args, *filters.BillingType)
|
||||
}
|
||||
return strings.Join(conditions, " AND "), args
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) createTaskWithEnt(ctx context.Context, task *service.UsageCleanupTask) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
filtersJSON, err := json.Marshal(task.Filters)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal cleanup filters: %w", err)
|
||||
}
|
||||
created, err := client.UsageCleanupTask.
|
||||
Create().
|
||||
SetStatus(task.Status).
|
||||
SetFilters(json.RawMessage(filtersJSON)).
|
||||
SetCreatedBy(task.CreatedBy).
|
||||
SetDeletedRows(task.DeletedRows).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
task.ID = created.ID
|
||||
task.CreatedAt = created.CreatedAt
|
||||
task.UpdatedAt = created.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) createTaskWithSQL(ctx context.Context, task *service.UsageCleanupTask) error {
|
||||
filtersJSON, err := json.Marshal(task.Filters)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal cleanup filters: %w", err)
|
||||
}
|
||||
query := `
|
||||
INSERT INTO usage_cleanup_tasks (
|
||||
status,
|
||||
filters,
|
||||
created_by,
|
||||
deleted_rows
|
||||
) VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
if err := scanSingleRow(ctx, r.sql, query, []any{task.Status, filtersJSON, task.CreatedBy, task.DeletedRows}, &task.ID, &task.CreatedAt, &task.UpdatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) listTasksWithEnt(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
query := client.UsageCleanupTask.Query()
|
||||
total, err := query.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if total == 0 {
|
||||
return []service.UsageCleanupTask{}, paginationResultFromTotal(0, params), nil
|
||||
}
|
||||
rows, err := query.
|
||||
Order(dbent.Desc(dbusagecleanuptask.FieldCreatedAt), dbent.Desc(dbusagecleanuptask.FieldID)).
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
tasks := make([]service.UsageCleanupTask, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
task, err := usageCleanupTaskFromEnt(row)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
return tasks, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) getTaskStatusWithEnt(ctx context.Context, taskID int64) (string, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
task, err := client.UsageCleanupTask.Query().
|
||||
Where(dbusagecleanuptask.IDEQ(taskID)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return "", sql.ErrNoRows
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return task.Status, nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) updateTaskProgressWithEnt(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
now := time.Now()
|
||||
_, err := client.UsageCleanupTask.Update().
|
||||
Where(dbusagecleanuptask.IDEQ(taskID)).
|
||||
SetDeletedRows(deletedRows).
|
||||
SetUpdatedAt(now).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) cancelTaskWithEnt(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
now := time.Now()
|
||||
affected, err := client.UsageCleanupTask.Update().
|
||||
Where(
|
||||
dbusagecleanuptask.IDEQ(taskID),
|
||||
dbusagecleanuptask.StatusIn(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning),
|
||||
).
|
||||
SetStatus(service.UsageCleanupStatusCanceled).
|
||||
SetCanceledBy(canceledBy).
|
||||
SetCanceledAt(now).
|
||||
SetFinishedAt(now).
|
||||
ClearErrorMessage().
|
||||
SetUpdatedAt(now).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) markTaskSucceededWithEnt(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
now := time.Now()
|
||||
_, err := client.UsageCleanupTask.Update().
|
||||
Where(dbusagecleanuptask.IDEQ(taskID)).
|
||||
SetStatus(service.UsageCleanupStatusSucceeded).
|
||||
SetDeletedRows(deletedRows).
|
||||
SetFinishedAt(now).
|
||||
SetUpdatedAt(now).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *usageCleanupRepository) markTaskFailedWithEnt(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
now := time.Now()
|
||||
_, err := client.UsageCleanupTask.Update().
|
||||
Where(dbusagecleanuptask.IDEQ(taskID)).
|
||||
SetStatus(service.UsageCleanupStatusFailed).
|
||||
SetDeletedRows(deletedRows).
|
||||
SetErrorMessage(errorMsg).
|
||||
SetFinishedAt(now).
|
||||
SetUpdatedAt(now).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func usageCleanupTaskFromEnt(row *dbent.UsageCleanupTask) (service.UsageCleanupTask, error) {
|
||||
task := service.UsageCleanupTask{
|
||||
ID: row.ID,
|
||||
Status: row.Status,
|
||||
CreatedBy: row.CreatedBy,
|
||||
DeletedRows: row.DeletedRows,
|
||||
CreatedAt: row.CreatedAt,
|
||||
UpdatedAt: row.UpdatedAt,
|
||||
}
|
||||
if len(row.Filters) > 0 {
|
||||
if err := json.Unmarshal(row.Filters, &task.Filters); err != nil {
|
||||
return service.UsageCleanupTask{}, fmt.Errorf("parse cleanup filters: %w", err)
|
||||
}
|
||||
}
|
||||
if row.ErrorMessage != nil {
|
||||
task.ErrorMsg = row.ErrorMessage
|
||||
}
|
||||
if row.CanceledBy != nil {
|
||||
task.CanceledBy = row.CanceledBy
|
||||
}
|
||||
if row.CanceledAt != nil {
|
||||
task.CanceledAt = row.CanceledAt
|
||||
}
|
||||
if row.StartedAt != nil {
|
||||
task.StartedAt = row.StartedAt
|
||||
}
|
||||
if row.FinishedAt != nil {
|
||||
task.FinishedAt = row.FinishedAt
|
||||
}
|
||||
return task, nil
|
||||
}
|
||||
251
backend/internal/repository/usage_cleanup_repo_ent_test.go
Normal file
251
backend/internal/repository/usage_cleanup_repo_ent_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
dbusagecleanuptask "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func newUsageCleanupEntRepo(t *testing.T) (*usageCleanupRepository, *dbent.Client) {
|
||||
t.Helper()
|
||||
db, err := sql.Open("sqlite", "file:usage_cleanup?mode=memory&cache=shared")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
repo := &usageCleanupRepository{client: client, sql: db}
|
||||
return repo, client
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryEntCreateAndList(t *testing.T) {
|
||||
repo, _ := newUsageCleanupEntRepo(t)
|
||||
|
||||
start := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
task := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusPending,
|
||||
Filters: service.UsageCleanupFilters{StartTime: start, EndTime: end},
|
||||
CreatedBy: 9,
|
||||
}
|
||||
require.NoError(t, repo.CreateTask(context.Background(), task))
|
||||
require.NotZero(t, task.ID)
|
||||
|
||||
task2 := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusRunning,
|
||||
Filters: service.UsageCleanupFilters{StartTime: start.Add(-24 * time.Hour), EndTime: end.Add(-24 * time.Hour)},
|
||||
CreatedBy: 10,
|
||||
}
|
||||
require.NoError(t, repo.CreateTask(context.Background(), task2))
|
||||
|
||||
tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tasks, 2)
|
||||
require.Equal(t, int64(2), result.Total)
|
||||
require.Greater(t, tasks[0].ID, tasks[1].ID)
|
||||
require.Equal(t, start, tasks[1].Filters.StartTime)
|
||||
require.Equal(t, end, tasks[1].Filters.EndTime)
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryEntListEmpty(t *testing.T) {
|
||||
repo, _ := newUsageCleanupEntRepo(t)
|
||||
|
||||
tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, tasks)
|
||||
require.Equal(t, int64(0), result.Total)
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryEntGetStatusAndProgress(t *testing.T) {
|
||||
repo, client := newUsageCleanupEntRepo(t)
|
||||
|
||||
task := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusPending,
|
||||
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
|
||||
CreatedBy: 3,
|
||||
}
|
||||
require.NoError(t, repo.CreateTask(context.Background(), task))
|
||||
|
||||
status, err := repo.GetTaskStatus(context.Background(), task.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, service.UsageCleanupStatusPending, status)
|
||||
|
||||
_, err = repo.GetTaskStatus(context.Background(), task.ID+99)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
|
||||
require.NoError(t, repo.UpdateTaskProgress(context.Background(), task.ID, 42))
|
||||
loaded, err := client.UsageCleanupTask.Get(context.Background(), task.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(42), loaded.DeletedRows)
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryEntCancelAndFinish(t *testing.T) {
|
||||
repo, client := newUsageCleanupEntRepo(t)
|
||||
|
||||
task := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusPending,
|
||||
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
|
||||
CreatedBy: 5,
|
||||
}
|
||||
require.NoError(t, repo.CreateTask(context.Background(), task))
|
||||
|
||||
ok, err := repo.CancelTask(context.Background(), task.ID, 7)
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
|
||||
loaded, err := client.UsageCleanupTask.Get(context.Background(), task.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, service.UsageCleanupStatusCanceled, loaded.Status)
|
||||
require.NotNil(t, loaded.CanceledBy)
|
||||
require.NotNil(t, loaded.CanceledAt)
|
||||
require.NotNil(t, loaded.FinishedAt)
|
||||
|
||||
loaded.Status = service.UsageCleanupStatusSucceeded
|
||||
_, err = client.UsageCleanupTask.Update().Where(dbusagecleanuptask.IDEQ(task.ID)).SetStatus(loaded.Status).Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ok, err = repo.CancelTask(context.Background(), task.ID, 7)
|
||||
require.NoError(t, err)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryEntCancelError(t *testing.T) {
|
||||
repo, client := newUsageCleanupEntRepo(t)
|
||||
|
||||
task := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusPending,
|
||||
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
|
||||
CreatedBy: 5,
|
||||
}
|
||||
require.NoError(t, repo.CreateTask(context.Background(), task))
|
||||
|
||||
require.NoError(t, client.Close())
|
||||
_, err := repo.CancelTask(context.Background(), task.ID, 7)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryEntMarkResults(t *testing.T) {
|
||||
repo, client := newUsageCleanupEntRepo(t)
|
||||
|
||||
task := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusRunning,
|
||||
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
|
||||
CreatedBy: 12,
|
||||
}
|
||||
require.NoError(t, repo.CreateTask(context.Background(), task))
|
||||
|
||||
require.NoError(t, repo.MarkTaskSucceeded(context.Background(), task.ID, 6))
|
||||
loaded, err := client.UsageCleanupTask.Get(context.Background(), task.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, service.UsageCleanupStatusSucceeded, loaded.Status)
|
||||
require.Equal(t, int64(6), loaded.DeletedRows)
|
||||
require.NotNil(t, loaded.FinishedAt)
|
||||
|
||||
task2 := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusRunning,
|
||||
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
|
||||
CreatedBy: 12,
|
||||
}
|
||||
require.NoError(t, repo.CreateTask(context.Background(), task2))
|
||||
|
||||
require.NoError(t, repo.MarkTaskFailed(context.Background(), task2.ID, 4, "boom"))
|
||||
loaded2, err := client.UsageCleanupTask.Get(context.Background(), task2.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, service.UsageCleanupStatusFailed, loaded2.Status)
|
||||
require.Equal(t, "boom", *loaded2.ErrorMessage)
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryEntInvalidStatus(t *testing.T) {
|
||||
repo, _ := newUsageCleanupEntRepo(t)
|
||||
|
||||
task := &service.UsageCleanupTask{
|
||||
Status: "invalid",
|
||||
Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
|
||||
CreatedBy: 1,
|
||||
}
|
||||
require.Error(t, repo.CreateTask(context.Background(), task))
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryEntListInvalidFilters(t *testing.T) {
|
||||
repo, client := newUsageCleanupEntRepo(t)
|
||||
|
||||
now := time.Now().UTC()
|
||||
driver, ok := client.Driver().(*entsql.Driver)
|
||||
require.True(t, ok)
|
||||
_, err := driver.DB().ExecContext(
|
||||
context.Background(),
|
||||
`INSERT INTO usage_cleanup_tasks (status, filters, created_by, deleted_rows, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
service.UsageCleanupStatusPending,
|
||||
[]byte("invalid-json"),
|
||||
int64(1),
|
||||
int64(0),
|
||||
now,
|
||||
now,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUsageCleanupTaskFromEntFull(t *testing.T) {
|
||||
start := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
errMsg := "failed"
|
||||
canceledBy := int64(2)
|
||||
canceledAt := start.Add(time.Minute)
|
||||
startedAt := start.Add(2 * time.Minute)
|
||||
finishedAt := start.Add(3 * time.Minute)
|
||||
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
|
||||
filtersJSON, err := json.Marshal(filters)
|
||||
require.NoError(t, err)
|
||||
|
||||
task, err := usageCleanupTaskFromEnt(&dbent.UsageCleanupTask{
|
||||
ID: 10,
|
||||
Status: service.UsageCleanupStatusFailed,
|
||||
Filters: filtersJSON,
|
||||
CreatedBy: 11,
|
||||
DeletedRows: 7,
|
||||
ErrorMessage: &errMsg,
|
||||
CanceledBy: &canceledBy,
|
||||
CanceledAt: &canceledAt,
|
||||
StartedAt: &startedAt,
|
||||
FinishedAt: &finishedAt,
|
||||
CreatedAt: start,
|
||||
UpdatedAt: end,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), task.ID)
|
||||
require.Equal(t, service.UsageCleanupStatusFailed, task.Status)
|
||||
require.NotNil(t, task.ErrorMsg)
|
||||
require.NotNil(t, task.CanceledBy)
|
||||
require.NotNil(t, task.CanceledAt)
|
||||
require.NotNil(t, task.StartedAt)
|
||||
require.NotNil(t, task.FinishedAt)
|
||||
}
|
||||
|
||||
func TestUsageCleanupTaskFromEntInvalidFilters(t *testing.T) {
|
||||
task, err := usageCleanupTaskFromEnt(&dbent.UsageCleanupTask{
|
||||
Filters: json.RawMessage("invalid-json"),
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Empty(t, task)
|
||||
}
|
||||
482
backend/internal/repository/usage_cleanup_repo_test.go
Normal file
482
backend/internal/repository/usage_cleanup_repo_test.go
Normal file
@@ -0,0 +1,482 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newSQLMock(t *testing.T) (*sql.DB, sqlmock.Sqlmock) {
|
||||
t.Helper()
|
||||
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherRegexp))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
return db, mock
|
||||
}
|
||||
|
||||
func TestNewUsageCleanupRepository(t *testing.T) {
|
||||
db, _ := newSQLMock(t)
|
||||
repo := NewUsageCleanupRepository(nil, db)
|
||||
require.NotNil(t, repo)
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryCreateTask(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
task := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusPending,
|
||||
Filters: service.UsageCleanupFilters{StartTime: start, EndTime: end},
|
||||
CreatedBy: 12,
|
||||
}
|
||||
now := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
mock.ExpectQuery("INSERT INTO usage_cleanup_tasks").
|
||||
WithArgs(task.Status, sqlmock.AnyArg(), task.CreatedBy, task.DeletedRows).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at", "updated_at"}).AddRow(int64(1), now, now))
|
||||
|
||||
err := repo.CreateTask(context.Background(), task)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), task.ID)
|
||||
require.Equal(t, now, task.CreatedAt)
|
||||
require.Equal(t, now, task.UpdatedAt)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryCreateTaskNil(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
err := repo.CreateTask(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryCreateTaskQueryError(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
task := &service.UsageCleanupTask{
|
||||
Status: service.UsageCleanupStatusPending,
|
||||
Filters: service.UsageCleanupFilters{StartTime: time.Now(), EndTime: time.Now().Add(time.Hour)},
|
||||
CreatedBy: 1,
|
||||
}
|
||||
|
||||
mock.ExpectQuery("INSERT INTO usage_cleanup_tasks").
|
||||
WithArgs(task.Status, sqlmock.AnyArg(), task.CreatedBy, task.DeletedRows).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
err := repo.CreateTask(context.Background(), task)
|
||||
require.Error(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryListTasksEmpty(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0)))
|
||||
|
||||
tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, tasks)
|
||||
require.Equal(t, int64(0), result.Total)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryListTasks(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(2 * time.Hour)
|
||||
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
|
||||
filtersJSON, err := json.Marshal(filters)
|
||||
require.NoError(t, err)
|
||||
|
||||
createdAt := time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC)
|
||||
updatedAt := createdAt.Add(time.Minute)
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
|
||||
"canceled_by", "canceled_at",
|
||||
"started_at", "finished_at", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
int64(1),
|
||||
service.UsageCleanupStatusSucceeded,
|
||||
filtersJSON,
|
||||
int64(2),
|
||||
int64(9),
|
||||
"error",
|
||||
nil,
|
||||
nil,
|
||||
start,
|
||||
end,
|
||||
createdAt,
|
||||
updatedAt,
|
||||
)
|
||||
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1)))
|
||||
mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message").
|
||||
WithArgs(20, 0).
|
||||
WillReturnRows(rows)
|
||||
|
||||
tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tasks, 1)
|
||||
require.Equal(t, int64(1), tasks[0].ID)
|
||||
require.Equal(t, service.UsageCleanupStatusSucceeded, tasks[0].Status)
|
||||
require.Equal(t, int64(2), tasks[0].CreatedBy)
|
||||
require.Equal(t, int64(9), tasks[0].DeletedRows)
|
||||
require.NotNil(t, tasks[0].ErrorMsg)
|
||||
require.Equal(t, "error", *tasks[0].ErrorMsg)
|
||||
require.NotNil(t, tasks[0].StartedAt)
|
||||
require.NotNil(t, tasks[0].FinishedAt)
|
||||
require.Equal(t, int64(1), result.Total)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryListTasksQueryError(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(2)))
|
||||
mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message").
|
||||
WithArgs(20, 0).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
_, _, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
|
||||
require.Error(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryListTasksInvalidFilters(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
|
||||
"canceled_by", "canceled_at",
|
||||
"started_at", "finished_at", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
int64(1),
|
||||
service.UsageCleanupStatusSucceeded,
|
||||
[]byte("not-json"),
|
||||
int64(2),
|
||||
int64(9),
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
time.Now().UTC(),
|
||||
time.Now().UTC(),
|
||||
)
|
||||
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1)))
|
||||
mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message").
|
||||
WithArgs(20, 0).
|
||||
WillReturnRows(rows)
|
||||
|
||||
_, _, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
|
||||
require.Error(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryClaimNextPendingTaskNone(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
|
||||
"started_at", "finished_at", "created_at", "updated_at",
|
||||
}))
|
||||
|
||||
task, err := repo.ClaimNextPendingTask(context.Background(), 1800)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, task)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryClaimNextPendingTask(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
|
||||
filtersJSON, err := json.Marshal(filters)
|
||||
require.NoError(t, err)
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
|
||||
"started_at", "finished_at", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
int64(4),
|
||||
service.UsageCleanupStatusRunning,
|
||||
filtersJSON,
|
||||
int64(7),
|
||||
int64(0),
|
||||
nil,
|
||||
start,
|
||||
nil,
|
||||
start,
|
||||
start,
|
||||
)
|
||||
|
||||
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
|
||||
WillReturnRows(rows)
|
||||
|
||||
task, err := repo.ClaimNextPendingTask(context.Background(), 1800)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, task)
|
||||
require.Equal(t, int64(4), task.ID)
|
||||
require.Equal(t, service.UsageCleanupStatusRunning, task.Status)
|
||||
require.Equal(t, int64(7), task.CreatedBy)
|
||||
require.NotNil(t, task.StartedAt)
|
||||
require.Nil(t, task.ErrorMsg)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryClaimNextPendingTaskError(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
_, err := repo.ClaimNextPendingTask(context.Background(), 1800)
|
||||
require.Error(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryClaimNextPendingTaskInvalidFilters(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
|
||||
"started_at", "finished_at", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
int64(4),
|
||||
service.UsageCleanupStatusRunning,
|
||||
[]byte("invalid"),
|
||||
int64(7),
|
||||
int64(0),
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
time.Now().UTC(),
|
||||
time.Now().UTC(),
|
||||
)
|
||||
|
||||
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
|
||||
WillReturnRows(rows)
|
||||
|
||||
_, err := repo.ClaimNextPendingTask(context.Background(), 1800)
|
||||
require.Error(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryMarkTaskSucceeded(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectExec("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(service.UsageCleanupStatusSucceeded, int64(12), int64(9)).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
err := repo.MarkTaskSucceeded(context.Background(), 9, 12)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryMarkTaskFailed(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectExec("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(service.UsageCleanupStatusFailed, int64(4), "boom", int64(2)).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
err := repo.MarkTaskFailed(context.Background(), 2, 4, "boom")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryGetTaskStatus(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectQuery("SELECT status FROM usage_cleanup_tasks").
|
||||
WithArgs(int64(9)).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow(service.UsageCleanupStatusPending))
|
||||
|
||||
status, err := repo.GetTaskStatus(context.Background(), 9)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, service.UsageCleanupStatusPending, status)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryGetTaskStatusQueryError(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectQuery("SELECT status FROM usage_cleanup_tasks").
|
||||
WithArgs(int64(9)).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
_, err := repo.GetTaskStatus(context.Background(), 9)
|
||||
require.Error(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryUpdateTaskProgress(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectExec("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(int64(123), int64(8)).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
err := repo.UpdateTaskProgress(context.Background(), 8, 123)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryCancelTask(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(service.UsageCleanupStatusCanceled, int64(6), int64(9), service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(6)))
|
||||
|
||||
ok, err := repo.CancelTask(context.Background(), 6, 9)
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryCancelTaskNoRows(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
|
||||
WithArgs(service.UsageCleanupStatusCanceled, int64(6), int64(9), service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||
|
||||
ok, err := repo.CancelTask(context.Background(), 6, 9)
|
||||
require.NoError(t, err)
|
||||
require.False(t, ok)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryDeleteUsageLogsBatchMissingRange(t *testing.T) {
|
||||
db, _ := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
_, err := repo.DeleteUsageLogsBatch(context.Background(), service.UsageCleanupFilters{}, 10)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryDeleteUsageLogsBatch(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
userID := int64(3)
|
||||
model := " gpt-4 "
|
||||
filters := service.UsageCleanupFilters{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
UserID: &userID,
|
||||
Model: &model,
|
||||
}
|
||||
|
||||
mock.ExpectQuery("DELETE FROM usage_logs").
|
||||
WithArgs(start, end, userID, "gpt-4", 2).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(1)).AddRow(int64(2)))
|
||||
|
||||
deleted, err := repo.DeleteUsageLogsBatch(context.Background(), filters, 2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), deleted)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageCleanupRepositoryDeleteUsageLogsBatchQueryError(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageCleanupRepository{sql: db}
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
|
||||
|
||||
mock.ExpectQuery("DELETE FROM usage_logs").
|
||||
WithArgs(start, end, 5).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
_, err := repo.DeleteUsageLogsBatch(context.Background(), filters, 5)
|
||||
require.Error(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestBuildUsageCleanupWhere(t *testing.T) {
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
userID := int64(1)
|
||||
apiKeyID := int64(2)
|
||||
accountID := int64(3)
|
||||
groupID := int64(4)
|
||||
model := " gpt-4 "
|
||||
stream := true
|
||||
billingType := int8(2)
|
||||
|
||||
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
UserID: &userID,
|
||||
APIKeyID: &apiKeyID,
|
||||
AccountID: &accountID,
|
||||
GroupID: &groupID,
|
||||
Model: &model,
|
||||
Stream: &stream,
|
||||
BillingType: &billingType,
|
||||
})
|
||||
|
||||
require.Equal(t, "created_at >= $1 AND created_at <= $2 AND user_id = $3 AND api_key_id = $4 AND account_id = $5 AND group_id = $6 AND model = $7 AND stream = $8 AND billing_type = $9", where)
|
||||
require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args)
|
||||
}
|
||||
|
||||
func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) {
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
model := " "
|
||||
|
||||
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Model: &model,
|
||||
})
|
||||
|
||||
require.Equal(t, "created_at >= $1 AND created_at <= $2", where)
|
||||
require.Equal(t, []any{start, end}, args)
|
||||
}
|
||||
@@ -1411,7 +1411,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
||||
}
|
||||
|
||||
// GetUsageTrendWithFilters returns usage trend data with optional filters
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) (results []TrendDataPoint, err error) {
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
@@ -1456,6 +1456,10 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||||
args = append(args, *stream)
|
||||
}
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
args = append(args, int16(*billingType))
|
||||
}
|
||||
query += " GROUP BY date ORDER BY date ASC"
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
@@ -1479,7 +1483,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
}
|
||||
|
||||
// GetModelStatsWithFilters returns model statistics with optional filters
|
||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) (results []ModelStat, err error) {
|
||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) {
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
@@ -1520,6 +1524,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
||||
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||||
args = append(args, *stream)
|
||||
}
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
args = append(args, int16(*billingType))
|
||||
}
|
||||
query += " GROUP BY model ORDER BY total_tokens DESC"
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
@@ -1825,7 +1833,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
}
|
||||
}
|
||||
|
||||
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil)
|
||||
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil)
|
||||
if err != nil {
|
||||
models = []ModelStat{}
|
||||
}
|
||||
|
||||
@@ -944,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
||||
endTime := base.Add(48 * time.Hour)
|
||||
|
||||
// Test with user filter
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil)
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
|
||||
s.Require().Len(trend, 2)
|
||||
|
||||
// Test with apiKey filter
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil)
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
|
||||
s.Require().Len(trend, 2)
|
||||
|
||||
// Test with both filters
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil)
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
|
||||
s.Require().Len(trend, 2)
|
||||
}
|
||||
@@ -971,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(3 * time.Hour)
|
||||
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil)
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
|
||||
s.Require().Len(trend, 2)
|
||||
}
|
||||
@@ -1017,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
|
||||
// Test with user filter
|
||||
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil)
|
||||
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil)
|
||||
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
|
||||
s.Require().Len(stats, 2)
|
||||
|
||||
// Test with apiKey filter
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil)
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil)
|
||||
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
|
||||
s.Require().Len(stats, 2)
|
||||
|
||||
// Test with account filter
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil)
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil)
|
||||
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
|
||||
s.Require().Len(stats, 2)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
@@ -189,6 +190,7 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
dbuser.Or(
|
||||
dbuser.EmailContainsFold(filters.Search),
|
||||
dbuser.UsernameContainsFold(filters.Search),
|
||||
dbuser.NotesContainsFold(filters.Search),
|
||||
),
|
||||
)
|
||||
}
|
||||
@@ -466,3 +468,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
|
||||
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
|
||||
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
update := client.User.UpdateOneID(userID)
|
||||
if encryptedSecret == nil {
|
||||
update = update.ClearTotpSecretEncrypted()
|
||||
} else {
|
||||
update = update.SetTotpSecretEncrypted(*encryptedSecret)
|
||||
}
|
||||
_, err := update.Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnableTotp 启用用户的 TOTP 双因素认证
|
||||
func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.User.UpdateOneID(userID).
|
||||
SetTotpEnabled(true).
|
||||
SetTotpEnabledAt(time.Now()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisableTotp 禁用用户的 TOTP 双因素认证
|
||||
func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.User.UpdateOneID(userID).
|
||||
SetTotpEnabled(false).
|
||||
ClearTotpEnabledAt().
|
||||
ClearTotpSecretEncrypted().
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
|
||||
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
q := client.UserSubscription.Query()
|
||||
if userID != nil {
|
||||
@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
|
||||
if groupID != nil {
|
||||
q = q.Where(usersubscription.GroupIDEQ(*groupID))
|
||||
}
|
||||
if status != "" {
|
||||
|
||||
// Status filtering with real-time expiration check
|
||||
now := time.Now()
|
||||
switch status {
|
||||
case service.SubscriptionStatusActive:
|
||||
// Active: status is active AND not yet expired
|
||||
q = q.Where(
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
usersubscription.ExpiresAtGT(now),
|
||||
)
|
||||
case service.SubscriptionStatusExpired:
|
||||
// Expired: status is expired OR (status is active but already expired)
|
||||
q = q.Where(
|
||||
usersubscription.Or(
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusExpired),
|
||||
usersubscription.And(
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
usersubscription.ExpiresAtLTE(now),
|
||||
),
|
||||
),
|
||||
)
|
||||
case "":
|
||||
// No filter
|
||||
default:
|
||||
// Other status (e.g., revoked)
|
||||
q = q.Where(usersubscription.StatusEQ(status))
|
||||
}
|
||||
|
||||
@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Apply sorting
|
||||
q = q.WithUser().WithGroup().WithAssignedByUser()
|
||||
|
||||
// Determine sort field
|
||||
var field string
|
||||
switch sortBy {
|
||||
case "expires_at":
|
||||
field = usersubscription.FieldExpiresAt
|
||||
case "status":
|
||||
field = usersubscription.FieldStatus
|
||||
default:
|
||||
field = usersubscription.FieldCreatedAt
|
||||
}
|
||||
|
||||
// Determine sort order (default: desc)
|
||||
if sortOrder == "asc" && sortBy != "" {
|
||||
q = q.Order(dbent.Asc(field))
|
||||
} else {
|
||||
q = q.Order(dbent.Desc(field))
|
||||
}
|
||||
|
||||
subs, err := q.
|
||||
WithUser().
|
||||
WithGroup().
|
||||
WithAssignedByUser().
|
||||
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
All(ctx)
|
||||
|
||||
@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
|
||||
group := s.mustCreateGroup("g-list")
|
||||
s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
|
||||
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
|
||||
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
||||
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(user1.ID, subs[0].UserID)
|
||||
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
|
||||
s.mustCreateSubscription(user.ID, g1.ID, nil)
|
||||
s.mustCreateSubscription(user.ID, g2.ID, nil)
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(g1.ID, subs[0].GroupID)
|
||||
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
|
||||
|
||||
@@ -56,7 +56,10 @@ var ProviderSet = wire.NewSet(
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewPromoCodeRepository,
|
||||
NewAnnouncementRepository,
|
||||
NewAnnouncementReadRepository,
|
||||
NewUsageLogRepository,
|
||||
NewUsageCleanupRepository,
|
||||
NewDashboardAggregationRepository,
|
||||
NewSettingRepository,
|
||||
NewOpsRepository,
|
||||
@@ -81,6 +84,10 @@ var ProviderSet = wire.NewSet(
|
||||
NewSchedulerCache,
|
||||
NewSchedulerOutboxRepository,
|
||||
NewProxyLatencyCache,
|
||||
NewTotpCache,
|
||||
|
||||
// Encryptors
|
||||
NewAESEncryptor,
|
||||
|
||||
// HTTP service ports (DI Strategy A: return interface directly)
|
||||
NewTurnstileVerifier,
|
||||
|
||||
Reference in New Issue
Block a user