Merge branch 'main' of https://github.com/james-6-23/sub2api
This commit is contained in:
@@ -15,7 +15,6 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -25,6 +24,7 @@ import (
|
||||
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
|
||||
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
@@ -50,11 +50,6 @@ type accountRepository struct {
|
||||
schedulerCache service.SchedulerCache
|
||||
}
|
||||
|
||||
type tempUnschedSnapshot struct {
|
||||
until *time.Time
|
||||
reason string
|
||||
}
|
||||
|
||||
// NewAccountRepository 创建账户仓储实例。
|
||||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
||||
@@ -127,7 +122,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
||||
account.CreatedAt = created.CreatedAt
|
||||
account.UpdatedAt = created.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -189,11 +184,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
|
||||
accountIDs = append(accountIDs, acc.ID)
|
||||
}
|
||||
|
||||
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -220,10 +210,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
|
||||
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
|
||||
out.AccountGroups = ags
|
||||
}
|
||||
if snap, ok := tempUnschedMap[entAcc.ID]; ok {
|
||||
out.TempUnschedulableUntil = snap.until
|
||||
out.TempUnschedulableReason = snap.reason
|
||||
}
|
||||
outByID[entAcc.ID] = out
|
||||
}
|
||||
|
||||
@@ -388,7 +374,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
}
|
||||
account.UpdatedAt = updated.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
@@ -429,7 +415,7 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -541,7 +527,7 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error
|
||||
},
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -576,7 +562,7 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
|
||||
}
|
||||
payload := map[string]any{"last_used": lastUsedPayload}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -591,7 +577,7 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
@@ -611,11 +597,48 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
|
||||
}
|
||||
account, err := r.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
|
||||
logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
|
||||
return
|
||||
}
|
||||
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
|
||||
logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) {
|
||||
if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
uniqueIDs := make([]int64, 0, len(accountIDs))
|
||||
seen := make(map[int64]struct{}, len(accountIDs))
|
||||
for _, id := range accountIDs {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[id]; exists {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
uniqueIDs = append(uniqueIDs, id)
|
||||
}
|
||||
if len(uniqueIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
accounts, err := r.GetByIDs(ctx, uniqueIDs)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot read failed: count=%d err=%v", len(uniqueIDs), err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, account := range accounts {
|
||||
if account == nil {
|
||||
continue
|
||||
}
|
||||
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot write failed: id=%d err=%v", account.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -649,7 +672,7 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
|
||||
}
|
||||
payload := buildSchedulerGroupPayload([]int64{groupID})
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -666,7 +689,7 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou
|
||||
}
|
||||
payload := buildSchedulerGroupPayload([]int64{groupID})
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -739,7 +762,7 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
|
||||
}
|
||||
payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs))
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -824,6 +847,51 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
now := time.Now()
|
||||
accounts, err := r.client.Account.Query().
|
||||
Where(
|
||||
dbaccount.PlatformEQ(platform),
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
dbaccount.Not(dbaccount.HasAccountGroups()),
|
||||
tempUnschedulablePredicate(),
|
||||
notExpiredPredicate(now),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
).
|
||||
Order(dbent.Asc(dbaccount.FieldPriority)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||
if len(platforms) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
now := time.Now()
|
||||
accounts, err := r.client.Account.Query().
|
||||
Where(
|
||||
dbaccount.PlatformIn(platforms...),
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
dbaccount.Not(dbaccount.HasAccountGroups()),
|
||||
tempUnschedulablePredicate(),
|
||||
notExpiredPredicate(now),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
).
|
||||
Order(dbent.Asc(dbaccount.FieldPriority)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
|
||||
if len(platforms) == 0 {
|
||||
return nil, nil
|
||||
@@ -847,7 +915,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -894,7 +962,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -908,7 +976,7 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -927,7 +995,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
@@ -946,7 +1014,7 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -962,7 +1030,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -986,7 +1054,7 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1010,7 +1078,7 @@ func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64)
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1032,7 +1100,7 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
|
||||
// 触发调度器缓存更新(仅当窗口时间有变化时)
|
||||
if start != nil || end != nil {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -1047,7 +1115,7 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||
}
|
||||
if !schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
@@ -1075,7 +1143,7 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti
|
||||
}
|
||||
if rows > 0 {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
@@ -1111,7 +1179,7 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1205,7 +1273,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
if rows > 0 {
|
||||
payload := map[string]any{"account_ids": ids}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||
}
|
||||
shouldSync := false
|
||||
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
|
||||
@@ -1215,9 +1283,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
shouldSync = true
|
||||
}
|
||||
if shouldSync {
|
||||
for _, id := range ids {
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshots(ctx, ids)
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
@@ -1309,10 +1375,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1338,10 +1400,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
|
||||
if ags, ok := accountGroupsByAccount[acc.ID]; ok {
|
||||
out.AccountGroups = ags
|
||||
}
|
||||
if snap, ok := tempUnschedMap[acc.ID]; ok {
|
||||
out.TempUnschedulableUntil = snap.until
|
||||
out.TempUnschedulableReason = snap.reason
|
||||
}
|
||||
outAccounts = append(outAccounts, *out)
|
||||
}
|
||||
|
||||
@@ -1366,48 +1424,6 @@ func notExpiredPredicate(now time.Time) dbpredicate.Account {
|
||||
)
|
||||
}
|
||||
|
||||
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
|
||||
out := make(map[int64]tempUnschedSnapshot)
|
||||
if len(accountIDs) == 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT id, temp_unschedulable_until, temp_unschedulable_reason
|
||||
FROM accounts
|
||||
WHERE id = ANY($1)
|
||||
`, pq.Array(accountIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var until sql.NullTime
|
||||
var reason sql.NullString
|
||||
if err := rows.Scan(&id, &until, &reason); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var untilPtr *time.Time
|
||||
if until.Valid {
|
||||
tmp := until.Time
|
||||
untilPtr = &tmp
|
||||
}
|
||||
if reason.Valid {
|
||||
out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String}
|
||||
} else {
|
||||
out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
|
||||
proxyMap := make(map[int64]*service.Proxy)
|
||||
if len(proxyIDs) == 0 {
|
||||
@@ -1518,31 +1534,33 @@ func accountEntityToService(m *dbent.Account) *service.Account {
|
||||
rateMultiplier := m.RateMultiplier
|
||||
|
||||
return &service.Account{
|
||||
ID: m.ID,
|
||||
Name: m.Name,
|
||||
Notes: m.Notes,
|
||||
Platform: m.Platform,
|
||||
Type: m.Type,
|
||||
Credentials: copyJSONMap(m.Credentials),
|
||||
Extra: copyJSONMap(m.Extra),
|
||||
ProxyID: m.ProxyID,
|
||||
Concurrency: m.Concurrency,
|
||||
Priority: m.Priority,
|
||||
RateMultiplier: &rateMultiplier,
|
||||
Status: m.Status,
|
||||
ErrorMessage: derefString(m.ErrorMessage),
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
AutoPauseOnExpired: m.AutoPauseOnExpired,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
Schedulable: m.Schedulable,
|
||||
RateLimitedAt: m.RateLimitedAt,
|
||||
RateLimitResetAt: m.RateLimitResetAt,
|
||||
OverloadUntil: m.OverloadUntil,
|
||||
SessionWindowStart: m.SessionWindowStart,
|
||||
SessionWindowEnd: m.SessionWindowEnd,
|
||||
SessionWindowStatus: derefString(m.SessionWindowStatus),
|
||||
ID: m.ID,
|
||||
Name: m.Name,
|
||||
Notes: m.Notes,
|
||||
Platform: m.Platform,
|
||||
Type: m.Type,
|
||||
Credentials: copyJSONMap(m.Credentials),
|
||||
Extra: copyJSONMap(m.Extra),
|
||||
ProxyID: m.ProxyID,
|
||||
Concurrency: m.Concurrency,
|
||||
Priority: m.Priority,
|
||||
RateMultiplier: &rateMultiplier,
|
||||
Status: m.Status,
|
||||
ErrorMessage: derefString(m.ErrorMessage),
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
AutoPauseOnExpired: m.AutoPauseOnExpired,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
Schedulable: m.Schedulable,
|
||||
RateLimitedAt: m.RateLimitedAt,
|
||||
RateLimitResetAt: m.RateLimitResetAt,
|
||||
OverloadUntil: m.OverloadUntil,
|
||||
TempUnschedulableUntil: m.TempUnschedulableUntil,
|
||||
TempUnschedulableReason: derefString(m.TempUnschedulableReason),
|
||||
SessionWindowStart: m.SessionWindowStart,
|
||||
SessionWindowEnd: m.SessionWindowEnd,
|
||||
SessionWindowStatus: derefString(m.SessionWindowStatus),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1578,3 +1596,64 @@ func joinClauses(clauses []string, sep string) string {
|
||||
func itoa(v int) string {
|
||||
return strconv.Itoa(v)
|
||||
}
|
||||
|
||||
// FindByExtraField 根据 extra 字段中的键值对查找账号。
|
||||
// 该方法限定 platform='sora',避免误查询其他平台的账号。
|
||||
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
|
||||
//
|
||||
// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。
|
||||
//
|
||||
// FindByExtraField finds accounts by key-value pairs in the extra field.
|
||||
// Limited to platform='sora' to avoid querying accounts from other platforms.
|
||||
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
|
||||
//
|
||||
// Use case: Finding Sora accounts linked via linked_openai_account_id.
|
||||
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
|
||||
accounts, err := r.client.Account.Query().
|
||||
Where(
|
||||
dbaccount.PlatformEQ("sora"), // 限定平台为 sora
|
||||
dbaccount.DeletedAtIsNil(),
|
||||
func(s *entsql.Selector) {
|
||||
path := sqljson.Path(key)
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
preds := []*entsql.Predicate{sqljson.ValueEQ(dbaccount.FieldExtra, v, path)}
|
||||
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
preds = append(preds, sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path))
|
||||
}
|
||||
if len(preds) == 1 {
|
||||
s.Where(preds[0])
|
||||
} else {
|
||||
s.Where(entsql.Or(preds...))
|
||||
}
|
||||
case int:
|
||||
s.Where(entsql.Or(
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, v, path),
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, strconv.Itoa(v), path),
|
||||
))
|
||||
case int64:
|
||||
s.Where(entsql.Or(
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, v, path),
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, strconv.FormatInt(v, 10), path),
|
||||
))
|
||||
case json.Number:
|
||||
if parsed, err := v.Int64(); err == nil {
|
||||
s.Where(entsql.Or(
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path),
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path),
|
||||
))
|
||||
} else {
|
||||
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path))
|
||||
}
|
||||
default:
|
||||
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, value, path))
|
||||
}
|
||||
},
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
||||
}
|
||||
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user