// Package repository 实现数据访问层(Repository Pattern)。 // // 该包提供了与数据库交互的所有操作,包括 CRUD、复杂查询和批量操作。 // 采用 Repository 模式将数据访问逻辑与业务逻辑分离,便于测试和维护。 // // 主要特性: // - 使用 Ent ORM 进行类型安全的数据库操作 // - 对于复杂查询(如批量更新、聚合统计)使用原生 SQL // - 提供统一的错误翻译机制,将数据库错误转换为业务错误 // - 支持软删除,所有查询自动过滤已删除记录 package repository import ( "context" "database/sql" "encoding/json" "errors" "log" "strconv" "time" dbent "github.com/Wei-Shaw/sub2api/ent" dbaccount "github.com/Wei-Shaw/sub2api/ent/account" dbaccountgroup "github.com/Wei-Shaw/sub2api/ent/accountgroup" dbgroup "github.com/Wei-Shaw/sub2api/ent/group" dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate" dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" entsql "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqljson" ) // accountRepository 实现 service.AccountRepository 接口。 // 提供 AI API 账户的完整数据访问功能。 // // 设计说明: // - client: Ent 客户端,用于类型安全的 ORM 操作 // - sql: 原生 SQL 执行器,用于复杂查询和批量操作 // - schedulerCache: 调度器缓存,用于在账号状态变更时同步快照 type accountRepository struct { client *dbent.Client // Ent ORM 客户端 sql sqlExecutor // 原生 SQL 执行接口 // schedulerCache 用于在账号状态变更时主动同步快照到缓存, // 确保粘性会话能及时感知账号不可用状态。 // Used to proactively sync account snapshot to cache when status changes, // ensuring sticky sessions can promptly detect unavailable accounts. schedulerCache service.SchedulerCache } type tempUnschedSnapshot struct { until *time.Time reason string } // NewAccountRepository 创建账户仓储实例。 // 这是对外暴露的构造函数,返回接口类型以便于依赖注入。 func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository { return newAccountRepositoryWithSQL(client, sqlDB, schedulerCache) } // newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。 // 这种设计便于单元测试时注入 mock 对象。 func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor, schedulerCache service.SchedulerCache) *accountRepository { return &accountRepository{client: client, sql: sqlq, schedulerCache: schedulerCache} } func (r *accountRepository) Create(ctx context.Context, account *service.Account) error { if account == nil { return service.ErrAccountNilInput } builder := r.client.Account.Create(). SetName(account.Name). SetNillableNotes(account.Notes). SetPlatform(account.Platform). SetType(account.Type). SetCredentials(normalizeJSONMap(account.Credentials)). SetExtra(normalizeJSONMap(account.Extra)). SetConcurrency(account.Concurrency). SetPriority(account.Priority). SetStatus(account.Status). SetErrorMessage(account.ErrorMessage). SetSchedulable(account.Schedulable). SetAutoPauseOnExpired(account.AutoPauseOnExpired) if account.RateMultiplier != nil { builder.SetRateMultiplier(*account.RateMultiplier) } if account.ProxyID != nil { builder.SetProxyID(*account.ProxyID) } if account.LastUsedAt != nil { builder.SetLastUsedAt(*account.LastUsedAt) } if account.ExpiresAt != nil { builder.SetExpiresAt(*account.ExpiresAt) } if account.RateLimitedAt != nil { builder.SetRateLimitedAt(*account.RateLimitedAt) } if account.RateLimitResetAt != nil { builder.SetRateLimitResetAt(*account.RateLimitResetAt) } if account.OverloadUntil != nil { builder.SetOverloadUntil(*account.OverloadUntil) } if account.SessionWindowStart != nil { builder.SetSessionWindowStart(*account.SessionWindowStart) } if account.SessionWindowEnd != nil { builder.SetSessionWindowEnd(*account.SessionWindowEnd) } if account.SessionWindowStatus != "" { builder.SetSessionWindowStatus(account.SessionWindowStatus) } created, err := builder.Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrAccountNotFound, nil) } account.ID = created.ID account.CreatedAt = created.CreatedAt account.UpdatedAt = created.UpdatedAt if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil { log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err) } return nil } func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Account, error) { m, err := r.client.Account.Query().Where(dbaccount.IDEQ(id)).Only(ctx) if err != nil { return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil) } accounts, err := r.accountsToService(ctx, []*dbent.Account{m}) if err != nil { return nil, err } if len(accounts) == 0 { return nil, service.ErrAccountNotFound } return &accounts[0], nil } func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) { if len(ids) == 0 { return []*service.Account{}, nil } // De-duplicate while preserving order of first occurrence. uniqueIDs := make([]int64, 0, len(ids)) seen := make(map[int64]struct{}, len(ids)) for _, id := range ids { if id <= 0 { continue } if _, ok := seen[id]; ok { continue } seen[id] = struct{}{} uniqueIDs = append(uniqueIDs, id) } if len(uniqueIDs) == 0 { return []*service.Account{}, nil } entAccounts, err := r.client.Account. Query(). Where(dbaccount.IDIn(uniqueIDs...)). WithProxy(). All(ctx) if err != nil { return nil, err } if len(entAccounts) == 0 { return []*service.Account{}, nil } accountIDs := make([]int64, 0, len(entAccounts)) entByID := make(map[int64]*dbent.Account, len(entAccounts)) for _, acc := range entAccounts { entByID[acc.ID] = acc accountIDs = append(accountIDs, acc.ID) } tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs) if err != nil { return nil, err } groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs) if err != nil { return nil, err } outByID := make(map[int64]*service.Account, len(entAccounts)) for _, entAcc := range entAccounts { out := accountEntityToService(entAcc) if out == nil { continue } // Prefer the preloaded proxy edge when available. if entAcc.Edges.Proxy != nil { out.Proxy = proxyEntityToService(entAcc.Edges.Proxy) } if groups, ok := groupsByAccount[entAcc.ID]; ok { out.Groups = groups } if groupIDs, ok := groupIDsByAccount[entAcc.ID]; ok { out.GroupIDs = groupIDs } if ags, ok := accountGroupsByAccount[entAcc.ID]; ok { out.AccountGroups = ags } if snap, ok := tempUnschedMap[entAcc.ID]; ok { out.TempUnschedulableUntil = snap.until out.TempUnschedulableReason = snap.reason } outByID[entAcc.ID] = out } // Preserve input order (first occurrence), and ignore missing IDs. out := make([]*service.Account, 0, len(uniqueIDs)) for _, id := range uniqueIDs { if _, ok := entByID[id]; !ok { continue } if acc, ok := outByID[id]; ok && acc != nil { out = append(out, acc) } } return out, nil } // ExistsByID 检查指定 ID 的账号是否存在。 // 相比 GetByID,此方法性能更优,因为: // - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值 // - 不加载完整的账号实体及其关联数据(Groups、Proxy 等) // - 适用于删除前的存在性检查等只需判断有无的场景 func (r *accountRepository) ExistsByID(ctx context.Context, id int64) (bool, error) { exists, err := r.client.Account.Query().Where(dbaccount.IDEQ(id)).Exist(ctx) if err != nil { return false, err } return exists, nil } func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) { if crsAccountID == "" { return nil, nil } // 使用 sqljson.ValueEQ 生成 JSON 路径过滤,避免手写 SQL 片段导致语法兼容问题。 m, err := r.client.Account.Query(). Where(func(s *entsql.Selector) { s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, crsAccountID, sqljson.Path("crs_account_id"))) }). Only(ctx) if err != nil { if dbent.IsNotFound(err) { return nil, nil } return nil, err } accounts, err := r.accountsToService(ctx, []*dbent.Account{m}) if err != nil { return nil, err } if len(accounts) == 0 { return nil, nil } return &accounts[0], nil } func (r *accountRepository) Update(ctx context.Context, account *service.Account) error { if account == nil { return nil } builder := r.client.Account.UpdateOneID(account.ID). SetName(account.Name). SetNillableNotes(account.Notes). SetPlatform(account.Platform). SetType(account.Type). SetCredentials(normalizeJSONMap(account.Credentials)). SetExtra(normalizeJSONMap(account.Extra)). SetConcurrency(account.Concurrency). SetPriority(account.Priority). SetStatus(account.Status). SetErrorMessage(account.ErrorMessage). SetSchedulable(account.Schedulable). SetAutoPauseOnExpired(account.AutoPauseOnExpired) if account.RateMultiplier != nil { builder.SetRateMultiplier(*account.RateMultiplier) } if account.ProxyID != nil { builder.SetProxyID(*account.ProxyID) } else { builder.ClearProxyID() } if account.LastUsedAt != nil { builder.SetLastUsedAt(*account.LastUsedAt) } else { builder.ClearLastUsedAt() } if account.ExpiresAt != nil { builder.SetExpiresAt(*account.ExpiresAt) } else { builder.ClearExpiresAt() } if account.RateLimitedAt != nil { builder.SetRateLimitedAt(*account.RateLimitedAt) } else { builder.ClearRateLimitedAt() } if account.RateLimitResetAt != nil { builder.SetRateLimitResetAt(*account.RateLimitResetAt) } else { builder.ClearRateLimitResetAt() } if account.OverloadUntil != nil { builder.SetOverloadUntil(*account.OverloadUntil) } else { builder.ClearOverloadUntil() } if account.SessionWindowStart != nil { builder.SetSessionWindowStart(*account.SessionWindowStart) } else { builder.ClearSessionWindowStart() } if account.SessionWindowEnd != nil { builder.SetSessionWindowEnd(*account.SessionWindowEnd) } else { builder.ClearSessionWindowEnd() } if account.SessionWindowStatus != "" { builder.SetSessionWindowStatus(account.SessionWindowStatus) } else { builder.ClearSessionWindowStatus() } if account.Notes == nil { builder.ClearNotes() } updated, err := builder.Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrAccountNotFound, nil) } account.UpdatedAt = updated.UpdatedAt if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil { log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err) } if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable { r.syncSchedulerAccountSnapshot(ctx, account.ID) } return nil } func (r *accountRepository) Delete(ctx context.Context, id int64) error { groupIDs, err := r.loadAccountGroupIDs(ctx, id) if err != nil { return err } // 使用事务保证账号与关联分组的删除原子性 tx, err := r.client.Tx(ctx) if err != nil && !errors.Is(err, dbent.ErrTxStarted) { return err } var txClient *dbent.Client if err == nil { defer func() { _ = tx.Rollback() }() txClient = tx.Client() } else { // 已处于外部事务中(ErrTxStarted),复用当前 client txClient = r.client } if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil { return err } if _, err := txClient.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx); err != nil { return err } if tx != nil { if err := tx.Commit(); err != nil { return err } } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil { log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err) } return nil } func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { return r.ListWithFilters(ctx, params, "", "", "", "") } func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { q := r.client.Account.Query() if platform != "" { q = q.Where(dbaccount.PlatformEQ(platform)) } if accountType != "" { q = q.Where(dbaccount.TypeEQ(accountType)) } if status != "" { q = q.Where(dbaccount.StatusEQ(status)) } if search != "" { q = q.Where(dbaccount.NameContainsFold(search)) } total, err := q.Count(ctx) if err != nil { return nil, nil, err } accounts, err := q. Offset(params.Offset()). Limit(params.Limit()). Order(dbent.Desc(dbaccount.FieldID)). All(ctx) if err != nil { return nil, nil, err } outAccounts, err := r.accountsToService(ctx, accounts) if err != nil { return nil, nil, err } return outAccounts, paginationResultFromTotal(int64(total), params), nil } func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { accounts, err := r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{ status: service.StatusActive, }) if err != nil { return nil, err } return accounts, nil } func (r *accountRepository) ListActive(ctx context.Context) ([]service.Account, error) { accounts, err := r.client.Account.Query(). Where(dbaccount.StatusEQ(service.StatusActive)). Order(dbent.Asc(dbaccount.FieldPriority)). All(ctx) if err != nil { return nil, err } return r.accountsToService(ctx, accounts) } func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) { accounts, err := r.client.Account.Query(). Where( dbaccount.PlatformEQ(platform), dbaccount.StatusEQ(service.StatusActive), ). Order(dbent.Asc(dbaccount.FieldPriority)). All(ctx) if err != nil { return nil, err } return r.accountsToService(ctx, accounts) } func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error { now := time.Now() _, err := r.client.Account.Update(). Where(dbaccount.IDEQ(id)). SetLastUsedAt(now). Save(ctx) if err != nil { return err } payload := map[string]any{ "last_used": map[string]int64{ strconv.FormatInt(id, 10): now.Unix(), }, } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil { log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err) } return nil } func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { if len(updates) == 0 { return nil } ids := make([]int64, 0, len(updates)) args := make([]any, 0, len(updates)*2+1) caseSQL := "UPDATE accounts SET last_used_at = CASE id" idx := 1 for id, ts := range updates { caseSQL += " WHEN $" + itoa(idx) + " THEN $" + itoa(idx+1) + "::timestamptz" args = append(args, id, ts) ids = append(ids, id) idx += 2 } caseSQL += " END, updated_at = NOW() WHERE id = ANY($" + itoa(idx) + ") AND deleted_at IS NULL" args = append(args, pq.Array(ids)) _, err := r.sql.ExecContext(ctx, caseSQL, args...) if err != nil { return err } lastUsedPayload := make(map[string]int64, len(updates)) for id, ts := range updates { lastUsedPayload[strconv.FormatInt(id, 10)] = ts.Unix() } payload := map[string]any{"last_used": lastUsedPayload} if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil { log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err) } return nil } func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error { _, err := r.client.Account.Update(). Where(dbaccount.IDEQ(id)). SetStatus(service.StatusError). SetErrorMessage(errorMsg). Save(ctx) if err != nil { return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err) } r.syncSchedulerAccountSnapshot(ctx, id) return nil } // syncSchedulerAccountSnapshot 在账号状态变更时主动同步快照到调度器缓存。 // 当账号被设置为错误、禁用、不可调度或临时不可调度时调用, // 确保调度器和粘性会话逻辑能及时感知账号的最新状态,避免继续使用不可用账号。 // // syncSchedulerAccountSnapshot proactively syncs account snapshot to scheduler cache // when account status changes. Called when account is set to error, disabled, // unschedulable, or temporarily unschedulable, ensuring scheduler and sticky session // logic can promptly detect the latest account state and avoid using unavailable accounts. func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, accountID int64) { if r == nil || r.schedulerCache == nil || accountID <= 0 { return } account, err := r.GetByID(ctx, accountID) if err != nil { log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err) return } if err := r.schedulerCache.SetAccount(ctx, account); err != nil { log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err) } } func (r *accountRepository) ClearError(ctx context.Context, id int64) error { _, err := r.client.Account.Update(). Where(dbaccount.IDEQ(id)). SetStatus(service.StatusActive). SetErrorMessage(""). Save(ctx) return err } func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { _, err := r.client.AccountGroup.Create(). SetAccountID(accountID). SetGroupID(groupID). SetPriority(priority). Save(ctx) if err != nil { return err } payload := buildSchedulerGroupPayload([]int64{groupID}) if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil { log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err) } return nil } func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error { _, err := r.client.AccountGroup.Delete(). Where( dbaccountgroup.AccountIDEQ(accountID), dbaccountgroup.GroupIDEQ(groupID), ). Exec(ctx) if err != nil { return err } payload := buildSchedulerGroupPayload([]int64{groupID}) if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil { log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err) } return nil } func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) { groups, err := r.client.Group.Query(). Where( dbgroup.HasAccountsWith(dbaccount.IDEQ(accountID)), ). All(ctx) if err != nil { return nil, err } outGroups := make([]service.Group, 0, len(groups)) for i := range groups { outGroups = append(outGroups, *groupEntityToService(groups[i])) } return outGroups, nil } func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { existingGroupIDs, err := r.loadAccountGroupIDs(ctx, accountID) if err != nil { return err } // 使用事务保证删除旧绑定与创建新绑定的原子性 tx, err := r.client.Tx(ctx) if err != nil && !errors.Is(err, dbent.ErrTxStarted) { return err } var txClient *dbent.Client if err == nil { defer func() { _ = tx.Rollback() }() txClient = tx.Client() } else { // 已处于外部事务中(ErrTxStarted),复用当前 client txClient = r.client } if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil { return err } if len(groupIDs) == 0 { if tx != nil { return tx.Commit() } return nil } builders := make([]*dbent.AccountGroupCreate, 0, len(groupIDs)) for i, groupID := range groupIDs { builders = append(builders, txClient.AccountGroup.Create(). SetAccountID(accountID). SetGroupID(groupID). SetPriority(i+1), ) } if _, err := txClient.AccountGroup.CreateBulk(builders...).Save(ctx); err != nil { return err } if tx != nil { if err := tx.Commit(); err != nil { return err } } payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs)) if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil { log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err) } return nil } func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) { now := time.Now() accounts, err := r.client.Account.Query(). Where( dbaccount.StatusEQ(service.StatusActive), dbaccount.SchedulableEQ(true), tempUnschedulablePredicate(), notExpiredPredicate(now), dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), ). Order(dbent.Asc(dbaccount.FieldPriority)). All(ctx) if err != nil { return nil, err } return r.accountsToService(ctx, accounts) } func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) { return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{ status: service.StatusActive, schedulable: true, }) } func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) { now := time.Now() accounts, err := r.client.Account.Query(). Where( dbaccount.PlatformEQ(platform), dbaccount.StatusEQ(service.StatusActive), dbaccount.SchedulableEQ(true), tempUnschedulablePredicate(), notExpiredPredicate(now), dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), ). Order(dbent.Asc(dbaccount.FieldPriority)). All(ctx) if err != nil { return nil, err } return r.accountsToService(ctx, accounts) } func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) { // 单平台查询复用多平台逻辑,保持过滤条件与排序策略一致。 return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{ status: service.StatusActive, schedulable: true, platforms: []string{platform}, }) } func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { if len(platforms) == 0 { return nil, nil } // 仅返回可调度的活跃账号,并过滤处于过载/限流窗口的账号。 // 代理与分组信息统一在 accountsToService 中批量加载,避免 N+1 查询。 now := time.Now() accounts, err := r.client.Account.Query(). Where( dbaccount.PlatformIn(platforms...), dbaccount.StatusEQ(service.StatusActive), dbaccount.SchedulableEQ(true), tempUnschedulablePredicate(), notExpiredPredicate(now), dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), ). Order(dbent.Asc(dbaccount.FieldPriority)). All(ctx) if err != nil { return nil, err } return r.accountsToService(ctx, accounts) } func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { if len(platforms) == 0 { return nil, nil } // 复用按分组查询逻辑,保证分组优先级 + 账号优先级的排序与筛选一致。 return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{ status: service.StatusActive, schedulable: true, platforms: platforms, }) } func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { now := time.Now() _, err := r.client.Account.Update(). Where(dbaccount.IDEQ(id)). SetRateLimitedAt(now). SetRateLimitResetAt(resetAt). Save(ctx) if err != nil { return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err) } return nil } func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error { now := time.Now().UTC() payload := map[string]string{ "rate_limited_at": now.Format(time.RFC3339), "rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339), } raw, err := json.Marshal(payload) if err != nil { return err } scopeKey := string(scope) client := clientFromContext(ctx, r.client) result, err := client.ExecContext( ctx, `UPDATE accounts SET extra = jsonb_set( jsonb_set(COALESCE(extra, '{}'::jsonb), '{antigravity_quota_scopes}'::text[], COALESCE(extra->'antigravity_quota_scopes', '{}'::jsonb), true), ARRAY['antigravity_quota_scopes', $1]::text[], $2::jsonb, true ), updated_at = NOW(), last_used_at = NOW() WHERE id = $3 AND deleted_at IS NULL`, scopeKey, raw, id, ) if err != nil { return err } affected, err := result.RowsAffected() if err != nil { return err } if affected == 0 { return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err) } return nil } func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { if scope == "" { return nil } now := time.Now().UTC() payload := map[string]string{ "rate_limited_at": now.Format(time.RFC3339), "rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339), } raw, err := json.Marshal(payload) if err != nil { return err } client := clientFromContext(ctx, r.client) result, err := client.ExecContext( ctx, `UPDATE accounts SET extra = jsonb_set( jsonb_set(COALESCE(extra, '{}'::jsonb), '{model_rate_limits}'::text[], COALESCE(extra->'model_rate_limits', '{}'::jsonb), true), ARRAY['model_rate_limits', $1]::text[], $2::jsonb, true ), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL`, scope, raw, id, ) if err != nil { return err } affected, err := result.RowsAffected() if err != nil { return err } if affected == 0 { return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err) } return nil } func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error { _, err := r.client.Account.Update(). Where(dbaccount.IDEQ(id)). SetOverloadUntil(until). Save(ctx) if err != nil { return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err) } return nil } func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { _, err := r.sql.ExecContext(ctx, ` UPDATE accounts SET temp_unschedulable_until = $1, temp_unschedulable_reason = $2, updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL AND (temp_unschedulable_until IS NULL OR temp_unschedulable_until < $1) `, until, reason, id) if err != nil { return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err) } r.syncSchedulerAccountSnapshot(ctx, id) return nil } func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64) error { _, err := r.sql.ExecContext(ctx, ` UPDATE accounts SET temp_unschedulable_until = NULL, temp_unschedulable_reason = NULL, updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL `, id) if err != nil { return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err) } return nil } func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error { _, err := r.client.Account.Update(). Where(dbaccount.IDEQ(id)). ClearRateLimitedAt(). ClearRateLimitResetAt(). ClearOverloadUntil(). Save(ctx) if err != nil { return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err) } return nil } func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { client := clientFromContext(ctx, r.client) result, err := client.ExecContext( ctx, "UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'antigravity_quota_scopes', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL", id, ) if err != nil { return err } affected, err := result.RowsAffected() if err != nil { return err } if affected == 0 { return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err) } return nil } func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64) error { client := clientFromContext(ctx, r.client) result, err := client.ExecContext( ctx, "UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'model_rate_limits', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL", id, ) if err != nil { return err } affected, err := result.RowsAffected() if err != nil { return err } if affected == 0 { return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err) } return nil } func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { builder := r.client.Account.Update(). Where(dbaccount.IDEQ(id)). SetSessionWindowStatus(status) if start != nil { builder.SetSessionWindowStart(*start) } if end != nil { builder.SetSessionWindowEnd(*end) } _, err := builder.Save(ctx) if err != nil { return err } // 触发调度器缓存更新(仅当窗口时间有变化时) if start != nil || end != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err) } } return nil } func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { _, err := r.client.Account.Update(). Where(dbaccount.IDEQ(id)). SetSchedulable(schedulable). Save(ctx) if err != nil { return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err) } if !schedulable { r.syncSchedulerAccountSnapshot(ctx, id) } return nil } func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) { result, err := r.sql.ExecContext(ctx, ` UPDATE accounts SET schedulable = FALSE, updated_at = NOW() WHERE deleted_at IS NULL AND schedulable = TRUE AND auto_pause_on_expired = TRUE AND expires_at IS NOT NULL AND expires_at <= $1 `, now) if err != nil { return 0, err } rows, err := result.RowsAffected() if err != nil { return 0, err } if rows > 0 { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err) } } return rows, nil } func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { if len(updates) == 0 { return nil } // 使用 JSONB 合并操作实现原子更新,避免读-改-写的并发丢失更新问题 payload, err := json.Marshal(updates) if err != nil { return err } client := clientFromContext(ctx, r.client) result, err := client.ExecContext( ctx, "UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) || $1::jsonb, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL", string(payload), id, ) if err != nil { return err } affected, err := result.RowsAffected() if err != nil { return err } if affected == 0 { return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) } return nil } func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { if len(ids) == 0 { return 0, nil } setClauses := make([]string, 0, 8) args := make([]any, 0, 8) idx := 1 if updates.Name != nil { setClauses = append(setClauses, "name = $"+itoa(idx)) args = append(args, *updates.Name) idx++ } if updates.ProxyID != nil { // 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图) if *updates.ProxyID == 0 { setClauses = append(setClauses, "proxy_id = NULL") } else { setClauses = append(setClauses, "proxy_id = $"+itoa(idx)) args = append(args, *updates.ProxyID) idx++ } } if updates.Concurrency != nil { setClauses = append(setClauses, "concurrency = $"+itoa(idx)) args = append(args, *updates.Concurrency) idx++ } if updates.Priority != nil { setClauses = append(setClauses, "priority = $"+itoa(idx)) args = append(args, *updates.Priority) idx++ } if updates.RateMultiplier != nil { setClauses = append(setClauses, "rate_multiplier = $"+itoa(idx)) args = append(args, *updates.RateMultiplier) idx++ } if updates.Status != nil { setClauses = append(setClauses, "status = $"+itoa(idx)) args = append(args, *updates.Status) idx++ } if updates.Schedulable != nil { setClauses = append(setClauses, "schedulable = $"+itoa(idx)) args = append(args, *updates.Schedulable) idx++ } // JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。 if len(updates.Credentials) > 0 { payload, err := json.Marshal(updates.Credentials) if err != nil { return 0, err } setClauses = append(setClauses, "credentials = COALESCE(credentials, '{}'::jsonb) || $"+itoa(idx)+"::jsonb") args = append(args, payload) idx++ } if len(updates.Extra) > 0 { payload, err := json.Marshal(updates.Extra) if err != nil { return 0, err } setClauses = append(setClauses, "extra = COALESCE(extra, '{}'::jsonb) || $"+itoa(idx)+"::jsonb") args = append(args, payload) idx++ } if len(setClauses) == 0 { return 0, nil } setClauses = append(setClauses, "updated_at = NOW()") query := "UPDATE accounts SET " + joinClauses(setClauses, ", ") + " WHERE id = ANY($" + itoa(idx) + ") AND deleted_at IS NULL" args = append(args, pq.Array(ids)) result, err := r.sql.ExecContext(ctx, query, args...) if err != nil { return 0, err } rows, err := result.RowsAffected() if err != nil { return 0, err } if rows > 0 { payload := map[string]any{"account_ids": ids} if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil { log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err) } shouldSync := false if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) { shouldSync = true } if updates.Schedulable != nil && !*updates.Schedulable { shouldSync = true } if shouldSync { for _, id := range ids { r.syncSchedulerAccountSnapshot(ctx, id) } } } return rows, nil } type accountGroupQueryOptions struct { status string schedulable bool platforms []string // 允许的多个平台,空切片表示不进行平台过滤 } func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID int64, opts accountGroupQueryOptions) ([]service.Account, error) { q := r.client.AccountGroup.Query(). Where(dbaccountgroup.GroupIDEQ(groupID)) // 通过 account_groups 中间表查询账号,并按需叠加状态/平台/调度能力过滤。 preds := make([]dbpredicate.Account, 0, 6) preds = append(preds, dbaccount.DeletedAtIsNil()) if opts.status != "" { preds = append(preds, dbaccount.StatusEQ(opts.status)) } if len(opts.platforms) > 0 { preds = append(preds, dbaccount.PlatformIn(opts.platforms...)) } if opts.schedulable { now := time.Now() preds = append(preds, dbaccount.SchedulableEQ(true), tempUnschedulablePredicate(), notExpiredPredicate(now), dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), ) } if len(preds) > 0 { q = q.Where(dbaccountgroup.HasAccountWith(preds...)) } groups, err := q. Order( dbaccountgroup.ByPriority(), dbaccountgroup.ByAccountField(dbaccount.FieldPriority), ). WithAccount(). All(ctx) if err != nil { return nil, err } orderedIDs := make([]int64, 0, len(groups)) accountMap := make(map[int64]*dbent.Account, len(groups)) for _, ag := range groups { if ag.Edges.Account == nil { continue } if _, exists := accountMap[ag.AccountID]; exists { continue } accountMap[ag.AccountID] = ag.Edges.Account orderedIDs = append(orderedIDs, ag.AccountID) } accounts := make([]*dbent.Account, 0, len(orderedIDs)) for _, id := range orderedIDs { if acc, ok := accountMap[id]; ok { accounts = append(accounts, acc) } } return r.accountsToService(ctx, accounts) } func (r *accountRepository) accountsToService(ctx context.Context, accounts []*dbent.Account) ([]service.Account, error) { if len(accounts) == 0 { return []service.Account{}, nil } accountIDs := make([]int64, 0, len(accounts)) proxyIDs := make([]int64, 0, len(accounts)) for _, acc := range accounts { accountIDs = append(accountIDs, acc.ID) if acc.ProxyID != nil { proxyIDs = append(proxyIDs, *acc.ProxyID) } } proxyMap, err := r.loadProxies(ctx, proxyIDs) if err != nil { return nil, err } tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs) if err != nil { return nil, err } groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs) if err != nil { return nil, err } outAccounts := make([]service.Account, 0, len(accounts)) for _, acc := range accounts { out := accountEntityToService(acc) if out == nil { continue } if acc.ProxyID != nil { if proxy, ok := proxyMap[*acc.ProxyID]; ok { out.Proxy = proxy } } if groups, ok := groupsByAccount[acc.ID]; ok { out.Groups = groups } if groupIDs, ok := groupIDsByAccount[acc.ID]; ok { out.GroupIDs = groupIDs } if ags, ok := accountGroupsByAccount[acc.ID]; ok { out.AccountGroups = ags } if snap, ok := tempUnschedMap[acc.ID]; ok { out.TempUnschedulableUntil = snap.until out.TempUnschedulableReason = snap.reason } outAccounts = append(outAccounts, *out) } return outAccounts, nil } func tempUnschedulablePredicate() dbpredicate.Account { return dbpredicate.Account(func(s *entsql.Selector) { col := s.C("temp_unschedulable_until") s.Where(entsql.Or( entsql.IsNull(col), entsql.LTE(col, entsql.Expr("NOW()")), )) }) } func notExpiredPredicate(now time.Time) dbpredicate.Account { return dbaccount.Or( dbaccount.ExpiresAtIsNil(), dbaccount.ExpiresAtGT(now), dbaccount.AutoPauseOnExpiredEQ(false), ) } func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) { out := make(map[int64]tempUnschedSnapshot) if len(accountIDs) == 0 { return out, nil } rows, err := r.sql.QueryContext(ctx, ` SELECT id, temp_unschedulable_until, temp_unschedulable_reason FROM accounts WHERE id = ANY($1) `, pq.Array(accountIDs)) if err != nil { return nil, err } defer func() { _ = rows.Close() }() for rows.Next() { var id int64 var until sql.NullTime var reason sql.NullString if err := rows.Scan(&id, &until, &reason); err != nil { return nil, err } var untilPtr *time.Time if until.Valid { tmp := until.Time untilPtr = &tmp } if reason.Valid { out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String} } else { out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""} } } if err := rows.Err(); err != nil { return nil, err } return out, nil } func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) { proxyMap := make(map[int64]*service.Proxy) if len(proxyIDs) == 0 { return proxyMap, nil } proxies, err := r.client.Proxy.Query().Where(dbproxy.IDIn(proxyIDs...)).All(ctx) if err != nil { return nil, err } for _, p := range proxies { proxyMap[p.ID] = proxyEntityToService(p) } return proxyMap, nil } func (r *accountRepository) loadAccountGroups(ctx context.Context, accountIDs []int64) (map[int64][]*service.Group, map[int64][]int64, map[int64][]service.AccountGroup, error) { groupsByAccount := make(map[int64][]*service.Group) groupIDsByAccount := make(map[int64][]int64) accountGroupsByAccount := make(map[int64][]service.AccountGroup) if len(accountIDs) == 0 { return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil } entries, err := r.client.AccountGroup.Query(). Where(dbaccountgroup.AccountIDIn(accountIDs...)). WithGroup(). Order(dbaccountgroup.ByAccountID(), dbaccountgroup.ByPriority()). All(ctx) if err != nil { return nil, nil, nil, err } for _, ag := range entries { groupSvc := groupEntityToService(ag.Edges.Group) agSvc := service.AccountGroup{ AccountID: ag.AccountID, GroupID: ag.GroupID, Priority: ag.Priority, CreatedAt: ag.CreatedAt, Group: groupSvc, } accountGroupsByAccount[ag.AccountID] = append(accountGroupsByAccount[ag.AccountID], agSvc) groupIDsByAccount[ag.AccountID] = append(groupIDsByAccount[ag.AccountID], ag.GroupID) if groupSvc != nil { groupsByAccount[ag.AccountID] = append(groupsByAccount[ag.AccountID], groupSvc) } } return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil } func (r *accountRepository) loadAccountGroupIDs(ctx context.Context, accountID int64) ([]int64, error) { entries, err := r.client.AccountGroup. Query(). Where(dbaccountgroup.AccountIDEQ(accountID)). All(ctx) if err != nil { return nil, err } ids := make([]int64, 0, len(entries)) for _, entry := range entries { ids = append(ids, entry.GroupID) } return ids, nil } func mergeGroupIDs(a []int64, b []int64) []int64 { seen := make(map[int64]struct{}, len(a)+len(b)) out := make([]int64, 0, len(a)+len(b)) for _, id := range a { if id <= 0 { continue } if _, ok := seen[id]; ok { continue } seen[id] = struct{}{} out = append(out, id) } for _, id := range b { if id <= 0 { continue } if _, ok := seen[id]; ok { continue } seen[id] = struct{}{} out = append(out, id) } return out } func buildSchedulerGroupPayload(groupIDs []int64) map[string]any { if len(groupIDs) == 0 { return nil } return map[string]any{"group_ids": groupIDs} } func accountEntityToService(m *dbent.Account) *service.Account { if m == nil { return nil } rateMultiplier := m.RateMultiplier return &service.Account{ ID: m.ID, Name: m.Name, Notes: m.Notes, Platform: m.Platform, Type: m.Type, Credentials: copyJSONMap(m.Credentials), Extra: copyJSONMap(m.Extra), ProxyID: m.ProxyID, Concurrency: m.Concurrency, Priority: m.Priority, RateMultiplier: &rateMultiplier, Status: m.Status, ErrorMessage: derefString(m.ErrorMessage), LastUsedAt: m.LastUsedAt, ExpiresAt: m.ExpiresAt, AutoPauseOnExpired: m.AutoPauseOnExpired, CreatedAt: m.CreatedAt, UpdatedAt: m.UpdatedAt, Schedulable: m.Schedulable, RateLimitedAt: m.RateLimitedAt, RateLimitResetAt: m.RateLimitResetAt, OverloadUntil: m.OverloadUntil, SessionWindowStart: m.SessionWindowStart, SessionWindowEnd: m.SessionWindowEnd, SessionWindowStatus: derefString(m.SessionWindowStatus), } } func normalizeJSONMap(in map[string]any) map[string]any { if in == nil { return map[string]any{} } return in } func copyJSONMap(in map[string]any) map[string]any { if in == nil { return nil } out := make(map[string]any, len(in)) for k, v := range in { out[k] = v } return out } func joinClauses(clauses []string, sep string) string { if len(clauses) == 0 { return "" } out := clauses[0] for i := 1; i < len(clauses); i++ { out += sep + clauses[i] } return out } func itoa(v int) string { return strconv.Itoa(v) }