Files
yinghuoapi/backend/internal/repository/account_repo.go
墨颜 fb99ceacc7 feat(计费): 支持账号计费倍率快照与统计展示
- 新增 accounts.rate_multiplier(默认 1.0,允许 0)
- 使用 usage_logs.account_rate_multiplier 记录倍率快照,避免历史回算
- 统计/导出/管理端展示账号口径费用(total_cost * account_rate_multiplier)
2026-01-14 16:12:08 +08:00

1426 lines
41 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package repository 实现数据访问层Repository Pattern
//
// 该包提供了与数据库交互的所有操作,包括 CRUD、复杂查询和批量操作。
// 采用 Repository 模式将数据访问逻辑与业务逻辑分离,便于测试和维护。
//
// 主要特性:
// - 使用 Ent ORM 进行类型安全的数据库操作
// - 对于复杂查询(如批量更新、聚合统计)使用原生 SQL
// - 提供统一的错误翻译机制,将数据库错误转换为业务错误
// - 支持软删除,所有查询自动过滤已删除记录
package repository
import (
"context"
"database/sql"
"encoding/json"
"errors"
"log"
"strconv"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbaccount "github.com/Wei-Shaw/sub2api/ent/account"
dbaccountgroup "github.com/Wei-Shaw/sub2api/ent/accountgroup"
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
entsql "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqljson"
)
// accountRepository 实现 service.AccountRepository 接口。
// 提供 AI API 账户的完整数据访问功能。
//
// 设计说明:
// - client: Ent 客户端,用于类型安全的 ORM 操作
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
type accountRepository struct {
client *dbent.Client // Ent ORM 客户端
sql sqlExecutor // 原生 SQL 执行接口
}
type tempUnschedSnapshot struct {
until *time.Time
reason string
}
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
return newAccountRepositoryWithSQL(client, sqlDB)
}
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
// 这种设计便于单元测试时注入 mock 对象。
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
return &accountRepository{client: client, sql: sqlq}
}
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
if account == nil {
return service.ErrAccountNilInput
}
builder := r.client.Account.Create().
SetName(account.Name).
SetNillableNotes(account.Notes).
SetPlatform(account.Platform).
SetType(account.Type).
SetCredentials(normalizeJSONMap(account.Credentials)).
SetExtra(normalizeJSONMap(account.Extra)).
SetConcurrency(account.Concurrency).
SetPriority(account.Priority).
SetStatus(account.Status).
SetErrorMessage(account.ErrorMessage).
SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
}
if account.LastUsedAt != nil {
builder.SetLastUsedAt(*account.LastUsedAt)
}
if account.ExpiresAt != nil {
builder.SetExpiresAt(*account.ExpiresAt)
}
if account.RateLimitedAt != nil {
builder.SetRateLimitedAt(*account.RateLimitedAt)
}
if account.RateLimitResetAt != nil {
builder.SetRateLimitResetAt(*account.RateLimitResetAt)
}
if account.OverloadUntil != nil {
builder.SetOverloadUntil(*account.OverloadUntil)
}
if account.SessionWindowStart != nil {
builder.SetSessionWindowStart(*account.SessionWindowStart)
}
if account.SessionWindowEnd != nil {
builder.SetSessionWindowEnd(*account.SessionWindowEnd)
}
if account.SessionWindowStatus != "" {
builder.SetSessionWindowStatus(account.SessionWindowStatus)
}
created, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
account.ID = created.ID
account.CreatedAt = created.CreatedAt
account.UpdatedAt = created.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
}
return nil
}
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Account, error) {
m, err := r.client.Account.Query().Where(dbaccount.IDEQ(id)).Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
accounts, err := r.accountsToService(ctx, []*dbent.Account{m})
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, service.ErrAccountNotFound
}
return &accounts[0], nil
}
func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
if len(ids) == 0 {
return []*service.Account{}, nil
}
// De-duplicate while preserving order of first occurrence.
uniqueIDs := make([]int64, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
}
if len(uniqueIDs) == 0 {
return []*service.Account{}, nil
}
entAccounts, err := r.client.Account.
Query().
Where(dbaccount.IDIn(uniqueIDs...)).
WithProxy().
All(ctx)
if err != nil {
return nil, err
}
if len(entAccounts) == 0 {
return []*service.Account{}, nil
}
accountIDs := make([]int64, 0, len(entAccounts))
entByID := make(map[int64]*dbent.Account, len(entAccounts))
for _, acc := range entAccounts {
entByID[acc.ID] = acc
accountIDs = append(accountIDs, acc.ID)
}
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
}
outByID := make(map[int64]*service.Account, len(entAccounts))
for _, entAcc := range entAccounts {
out := accountEntityToService(entAcc)
if out == nil {
continue
}
// Prefer the preloaded proxy edge when available.
if entAcc.Edges.Proxy != nil {
out.Proxy = proxyEntityToService(entAcc.Edges.Proxy)
}
if groups, ok := groupsByAccount[entAcc.ID]; ok {
out.Groups = groups
}
if groupIDs, ok := groupIDsByAccount[entAcc.ID]; ok {
out.GroupIDs = groupIDs
}
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
out.AccountGroups = ags
}
if snap, ok := tempUnschedMap[entAcc.ID]; ok {
out.TempUnschedulableUntil = snap.until
out.TempUnschedulableReason = snap.reason
}
outByID[entAcc.ID] = out
}
// Preserve input order (first occurrence), and ignore missing IDs.
out := make([]*service.Account, 0, len(uniqueIDs))
for _, id := range uniqueIDs {
if _, ok := entByID[id]; !ok {
continue
}
if acc, ok := outByID[id]; ok && acc != nil {
out = append(out, acc)
}
}
return out, nil
}
// ExistsByID 检查指定 ID 的账号是否存在。
// 相比 GetByID此方法性能更优因为
// - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值
// - 不加载完整的账号实体及其关联数据Groups、Proxy 等)
// - 适用于删除前的存在性检查等只需判断有无的场景
func (r *accountRepository) ExistsByID(ctx context.Context, id int64) (bool, error) {
exists, err := r.client.Account.Query().Where(dbaccount.IDEQ(id)).Exist(ctx)
if err != nil {
return false, err
}
return exists, nil
}
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
if crsAccountID == "" {
return nil, nil
}
// 使用 sqljson.ValueEQ 生成 JSON 路径过滤,避免手写 SQL 片段导致语法兼容问题。
m, err := r.client.Account.Query().
Where(func(s *entsql.Selector) {
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, crsAccountID, sqljson.Path("crs_account_id")))
}).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
accounts, err := r.accountsToService(ctx, []*dbent.Account{m})
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, nil
}
return &accounts[0], nil
}
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
if account == nil {
return nil
}
builder := r.client.Account.UpdateOneID(account.ID).
SetName(account.Name).
SetNillableNotes(account.Notes).
SetPlatform(account.Platform).
SetType(account.Type).
SetCredentials(normalizeJSONMap(account.Credentials)).
SetExtra(normalizeJSONMap(account.Extra)).
SetConcurrency(account.Concurrency).
SetPriority(account.Priority).
SetStatus(account.Status).
SetErrorMessage(account.ErrorMessage).
SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
} else {
builder.ClearProxyID()
}
if account.LastUsedAt != nil {
builder.SetLastUsedAt(*account.LastUsedAt)
} else {
builder.ClearLastUsedAt()
}
if account.ExpiresAt != nil {
builder.SetExpiresAt(*account.ExpiresAt)
} else {
builder.ClearExpiresAt()
}
if account.RateLimitedAt != nil {
builder.SetRateLimitedAt(*account.RateLimitedAt)
} else {
builder.ClearRateLimitedAt()
}
if account.RateLimitResetAt != nil {
builder.SetRateLimitResetAt(*account.RateLimitResetAt)
} else {
builder.ClearRateLimitResetAt()
}
if account.OverloadUntil != nil {
builder.SetOverloadUntil(*account.OverloadUntil)
} else {
builder.ClearOverloadUntil()
}
if account.SessionWindowStart != nil {
builder.SetSessionWindowStart(*account.SessionWindowStart)
} else {
builder.ClearSessionWindowStart()
}
if account.SessionWindowEnd != nil {
builder.SetSessionWindowEnd(*account.SessionWindowEnd)
} else {
builder.ClearSessionWindowEnd()
}
if account.SessionWindowStatus != "" {
builder.SetSessionWindowStatus(account.SessionWindowStatus)
} else {
builder.ClearSessionWindowStatus()
}
if account.Notes == nil {
builder.ClearNotes()
}
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
account.UpdatedAt = updated.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
}
return nil
}
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
groupIDs, err := r.loadAccountGroupIDs(ctx, id)
if err != nil {
return err
}
// 使用事务保证账号与关联分组的删除原子性
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中ErrTxStarted复用当前 client
txClient = r.client
}
if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
return err
}
if _, err := txClient.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx); err != nil {
return err
}
if tx != nil {
if err := tx.Commit(); err != nil {
return err
}
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "")
}
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
q := r.client.Account.Query()
if platform != "" {
q = q.Where(dbaccount.PlatformEQ(platform))
}
if accountType != "" {
q = q.Where(dbaccount.TypeEQ(accountType))
}
if status != "" {
q = q.Where(dbaccount.StatusEQ(status))
}
if search != "" {
q = q.Where(dbaccount.NameContainsFold(search))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
accounts, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(dbaccount.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outAccounts, err := r.accountsToService(ctx, accounts)
if err != nil {
return nil, nil, err
}
return outAccounts, paginationResultFromTotal(int64(total), params), nil
}
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
accounts, err := r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
status: service.StatusActive,
})
if err != nil {
return nil, err
}
return accounts, nil
}
func (r *accountRepository) ListActive(ctx context.Context) ([]service.Account, error) {
accounts, err := r.client.Account.Query().
Where(dbaccount.StatusEQ(service.StatusActive)).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformEQ(platform),
dbaccount.StatusEQ(service.StatusActive),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now()
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetLastUsedAt(now).
Save(ctx)
if err != nil {
return err
}
payload := map[string]any{
"last_used": map[string]int64{
strconv.FormatInt(id, 10): now.Unix(),
},
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
if len(updates) == 0 {
return nil
}
ids := make([]int64, 0, len(updates))
args := make([]any, 0, len(updates)*2+1)
caseSQL := "UPDATE accounts SET last_used_at = CASE id"
idx := 1
for id, ts := range updates {
caseSQL += " WHEN $" + itoa(idx) + " THEN $" + itoa(idx+1) + "::timestamptz"
args = append(args, id, ts)
ids = append(ids, id)
idx += 2
}
caseSQL += " END, updated_at = NOW() WHERE id = ANY($" + itoa(idx) + ") AND deleted_at IS NULL"
args = append(args, pq.Array(ids))
_, err := r.sql.ExecContext(ctx, caseSQL, args...)
if err != nil {
return err
}
lastUsedPayload := make(map[string]int64, len(updates))
for id, ts := range updates {
lastUsedPayload[strconv.FormatInt(id, 10)] = ts.Unix()
}
payload := map[string]any{"last_used": lastUsedPayload}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
}
return nil
}
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetStatus(service.StatusError).
SetErrorMessage(errorMsg).
Save(ctx)
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
_, err := r.client.AccountGroup.Create().
SetAccountID(accountID).
SetGroupID(groupID).
SetPriority(priority).
Save(ctx)
if err != nil {
return err
}
payload := buildSchedulerGroupPayload([]int64{groupID})
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
}
return nil
}
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
_, err := r.client.AccountGroup.Delete().
Where(
dbaccountgroup.AccountIDEQ(accountID),
dbaccountgroup.GroupIDEQ(groupID),
).
Exec(ctx)
if err != nil {
return err
}
payload := buildSchedulerGroupPayload([]int64{groupID})
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
}
return nil
}
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
groups, err := r.client.Group.Query().
Where(
dbgroup.HasAccountsWith(dbaccount.IDEQ(accountID)),
).
All(ctx)
if err != nil {
return nil, err
}
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *groupEntityToService(groups[i]))
}
return outGroups, nil
}
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
existingGroupIDs, err := r.loadAccountGroupIDs(ctx, accountID)
if err != nil {
return err
}
// 使用事务保证删除旧绑定与创建新绑定的原子性
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中ErrTxStarted复用当前 client
txClient = r.client
}
if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil {
return err
}
if len(groupIDs) == 0 {
if tx != nil {
return tx.Commit()
}
return nil
}
builders := make([]*dbent.AccountGroupCreate, 0, len(groupIDs))
for i, groupID := range groupIDs {
builders = append(builders, txClient.AccountGroup.Create().
SetAccountID(accountID).
SetGroupID(groupID).
SetPriority(i+1),
)
}
if _, err := txClient.AccountGroup.CreateBulk(builders...).Save(ctx); err != nil {
return err
}
if tx != nil {
if err := tx.Commit(); err != nil {
return err
}
}
payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs))
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
}
return nil
}
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) {
now := time.Now()
accounts, err := r.client.Account.Query().
Where(
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
status: service.StatusActive,
schedulable: true,
})
}
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
now := time.Now()
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformEQ(platform),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
// 单平台查询复用多平台逻辑,保持过滤条件与排序策略一致。
return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
status: service.StatusActive,
schedulable: true,
platforms: []string{platform},
})
}
func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 {
return nil, nil
}
// 仅返回可调度的活跃账号,并过滤处于过载/限流窗口的账号。
// 代理与分组信息统一在 accountsToService 中批量加载,避免 N+1 查询。
now := time.Now()
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformIn(platforms...),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 {
return nil, nil
}
// 复用按分组查询逻辑,保证分组优先级 + 账号优先级的排序与筛选一致。
return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
status: service.StatusActive,
schedulable: true,
platforms: platforms,
})
}
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now()
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetRateLimitedAt(now).
SetRateLimitResetAt(resetAt).
Save(ctx)
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
now := time.Now().UTC()
payload := map[string]string{
"rate_limited_at": now.Format(time.RFC3339),
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
}
raw, err := json.Marshal(payload)
if err != nil {
return err
}
path := "{antigravity_quota_scopes," + string(scope) + "}"
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
path,
raw,
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetOverloadUntil(until).
Save(ctx)
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
_, err := r.sql.ExecContext(ctx, `
UPDATE accounts
SET temp_unschedulable_until = $1,
temp_unschedulable_reason = $2,
updated_at = NOW()
WHERE id = $3
AND deleted_at IS NULL
AND (temp_unschedulable_until IS NULL OR temp_unschedulable_until < $1)
`, until, reason, id)
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx, `
UPDATE accounts
SET temp_unschedulable_until = NULL,
temp_unschedulable_reason = NULL,
updated_at = NOW()
WHERE id = $1
AND deleted_at IS NULL
`, id)
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
ClearRateLimitedAt().
ClearRateLimitResetAt().
ClearOverloadUntil().
Save(ctx)
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'antigravity_quota_scopes', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL",
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
builder := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetSessionWindowStatus(status)
if start != nil {
builder.SetSessionWindowStart(*start)
}
if end != nil {
builder.SetSessionWindowEnd(*end)
}
_, err := builder.Save(ctx)
return err
}
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetSchedulable(schedulable).
Save(ctx)
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE accounts
SET schedulable = FALSE,
updated_at = NOW()
WHERE deleted_at IS NULL
AND schedulable = TRUE
AND auto_pause_on_expired = TRUE
AND expires_at IS NOT NULL
AND expires_at <= $1
`, now)
if err != nil {
return 0, err
}
rows, err := result.RowsAffected()
if err != nil {
return 0, err
}
if rows > 0 {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
}
}
return rows, nil
}
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 {
return nil
}
// 使用 JSONB 合并操作实现原子更新,避免读-改-写的并发丢失更新问题
payload, err := json.Marshal(updates)
if err != nil {
return err
}
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) || $1::jsonb, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL",
payload, id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
setClauses := make([]string, 0, 8)
args := make([]any, 0, 8)
idx := 1
if updates.Name != nil {
setClauses = append(setClauses, "name = $"+itoa(idx))
args = append(args, *updates.Name)
idx++
}
if updates.ProxyID != nil {
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
if *updates.ProxyID == 0 {
setClauses = append(setClauses, "proxy_id = NULL")
} else {
setClauses = append(setClauses, "proxy_id = $"+itoa(idx))
args = append(args, *updates.ProxyID)
idx++
}
}
if updates.Concurrency != nil {
setClauses = append(setClauses, "concurrency = $"+itoa(idx))
args = append(args, *updates.Concurrency)
idx++
}
if updates.Priority != nil {
setClauses = append(setClauses, "priority = $"+itoa(idx))
args = append(args, *updates.Priority)
idx++
}
if updates.RateMultiplier != nil {
setClauses = append(setClauses, "rate_multiplier = $"+itoa(idx))
args = append(args, *updates.RateMultiplier)
idx++
}
if updates.Status != nil {
setClauses = append(setClauses, "status = $"+itoa(idx))
args = append(args, *updates.Status)
idx++
}
if updates.Schedulable != nil {
setClauses = append(setClauses, "schedulable = $"+itoa(idx))
args = append(args, *updates.Schedulable)
idx++
}
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
if len(updates.Credentials) > 0 {
payload, err := json.Marshal(updates.Credentials)
if err != nil {
return 0, err
}
setClauses = append(setClauses, "credentials = COALESCE(credentials, '{}'::jsonb) || $"+itoa(idx)+"::jsonb")
args = append(args, payload)
idx++
}
if len(updates.Extra) > 0 {
payload, err := json.Marshal(updates.Extra)
if err != nil {
return 0, err
}
setClauses = append(setClauses, "extra = COALESCE(extra, '{}'::jsonb) || $"+itoa(idx)+"::jsonb")
args = append(args, payload)
idx++
}
if len(setClauses) == 0 {
return 0, nil
}
setClauses = append(setClauses, "updated_at = NOW()")
query := "UPDATE accounts SET " + joinClauses(setClauses, ", ") + " WHERE id = ANY($" + itoa(idx) + ") AND deleted_at IS NULL"
args = append(args, pq.Array(ids))
result, err := r.sql.ExecContext(ctx, query, args...)
if err != nil {
return 0, err
}
rows, err := result.RowsAffected()
if err != nil {
return 0, err
}
if rows > 0 {
payload := map[string]any{"account_ids": ids}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
}
}
return rows, nil
}
type accountGroupQueryOptions struct {
status string
schedulable bool
platforms []string // 允许的多个平台,空切片表示不进行平台过滤
}
func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID int64, opts accountGroupQueryOptions) ([]service.Account, error) {
q := r.client.AccountGroup.Query().
Where(dbaccountgroup.GroupIDEQ(groupID))
// 通过 account_groups 中间表查询账号,并按需叠加状态/平台/调度能力过滤。
preds := make([]dbpredicate.Account, 0, 6)
preds = append(preds, dbaccount.DeletedAtIsNil())
if opts.status != "" {
preds = append(preds, dbaccount.StatusEQ(opts.status))
}
if len(opts.platforms) > 0 {
preds = append(preds, dbaccount.PlatformIn(opts.platforms...))
}
if opts.schedulable {
now := time.Now()
preds = append(preds,
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
)
}
if len(preds) > 0 {
q = q.Where(dbaccountgroup.HasAccountWith(preds...))
}
groups, err := q.
Order(
dbaccountgroup.ByPriority(),
dbaccountgroup.ByAccountField(dbaccount.FieldPriority),
).
WithAccount().
All(ctx)
if err != nil {
return nil, err
}
orderedIDs := make([]int64, 0, len(groups))
accountMap := make(map[int64]*dbent.Account, len(groups))
for _, ag := range groups {
if ag.Edges.Account == nil {
continue
}
if _, exists := accountMap[ag.AccountID]; exists {
continue
}
accountMap[ag.AccountID] = ag.Edges.Account
orderedIDs = append(orderedIDs, ag.AccountID)
}
accounts := make([]*dbent.Account, 0, len(orderedIDs))
for _, id := range orderedIDs {
if acc, ok := accountMap[id]; ok {
accounts = append(accounts, acc)
}
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) accountsToService(ctx context.Context, accounts []*dbent.Account) ([]service.Account, error) {
if len(accounts) == 0 {
return []service.Account{}, nil
}
accountIDs := make([]int64, 0, len(accounts))
proxyIDs := make([]int64, 0, len(accounts))
for _, acc := range accounts {
accountIDs = append(accountIDs, acc.ID)
if acc.ProxyID != nil {
proxyIDs = append(proxyIDs, *acc.ProxyID)
}
}
proxyMap, err := r.loadProxies(ctx, proxyIDs)
if err != nil {
return nil, err
}
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for _, acc := range accounts {
out := accountEntityToService(acc)
if out == nil {
continue
}
if acc.ProxyID != nil {
if proxy, ok := proxyMap[*acc.ProxyID]; ok {
out.Proxy = proxy
}
}
if groups, ok := groupsByAccount[acc.ID]; ok {
out.Groups = groups
}
if groupIDs, ok := groupIDsByAccount[acc.ID]; ok {
out.GroupIDs = groupIDs
}
if ags, ok := accountGroupsByAccount[acc.ID]; ok {
out.AccountGroups = ags
}
if snap, ok := tempUnschedMap[acc.ID]; ok {
out.TempUnschedulableUntil = snap.until
out.TempUnschedulableReason = snap.reason
}
outAccounts = append(outAccounts, *out)
}
return outAccounts, nil
}
func tempUnschedulablePredicate() dbpredicate.Account {
return dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.LTE(col, entsql.Expr("NOW()")),
))
})
}
func notExpiredPredicate(now time.Time) dbpredicate.Account {
return dbaccount.Or(
dbaccount.ExpiresAtIsNil(),
dbaccount.ExpiresAtGT(now),
dbaccount.AutoPauseOnExpiredEQ(false),
)
}
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
out := make(map[int64]tempUnschedSnapshot)
if len(accountIDs) == 0 {
return out, nil
}
rows, err := r.sql.QueryContext(ctx, `
SELECT id, temp_unschedulable_until, temp_unschedulable_reason
FROM accounts
WHERE id = ANY($1)
`, pq.Array(accountIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var id int64
var until sql.NullTime
var reason sql.NullString
if err := rows.Scan(&id, &until, &reason); err != nil {
return nil, err
}
var untilPtr *time.Time
if until.Valid {
tmp := until.Time
untilPtr = &tmp
}
if reason.Valid {
out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String}
} else {
out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""}
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
proxyMap := make(map[int64]*service.Proxy)
if len(proxyIDs) == 0 {
return proxyMap, nil
}
proxies, err := r.client.Proxy.Query().Where(dbproxy.IDIn(proxyIDs...)).All(ctx)
if err != nil {
return nil, err
}
for _, p := range proxies {
proxyMap[p.ID] = proxyEntityToService(p)
}
return proxyMap, nil
}
func (r *accountRepository) loadAccountGroups(ctx context.Context, accountIDs []int64) (map[int64][]*service.Group, map[int64][]int64, map[int64][]service.AccountGroup, error) {
groupsByAccount := make(map[int64][]*service.Group)
groupIDsByAccount := make(map[int64][]int64)
accountGroupsByAccount := make(map[int64][]service.AccountGroup)
if len(accountIDs) == 0 {
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
}
entries, err := r.client.AccountGroup.Query().
Where(dbaccountgroup.AccountIDIn(accountIDs...)).
WithGroup().
Order(dbaccountgroup.ByAccountID(), dbaccountgroup.ByPriority()).
All(ctx)
if err != nil {
return nil, nil, nil, err
}
for _, ag := range entries {
groupSvc := groupEntityToService(ag.Edges.Group)
agSvc := service.AccountGroup{
AccountID: ag.AccountID,
GroupID: ag.GroupID,
Priority: ag.Priority,
CreatedAt: ag.CreatedAt,
Group: groupSvc,
}
accountGroupsByAccount[ag.AccountID] = append(accountGroupsByAccount[ag.AccountID], agSvc)
groupIDsByAccount[ag.AccountID] = append(groupIDsByAccount[ag.AccountID], ag.GroupID)
if groupSvc != nil {
groupsByAccount[ag.AccountID] = append(groupsByAccount[ag.AccountID], groupSvc)
}
}
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
}
func (r *accountRepository) loadAccountGroupIDs(ctx context.Context, accountID int64) ([]int64, error) {
entries, err := r.client.AccountGroup.
Query().
Where(dbaccountgroup.AccountIDEQ(accountID)).
All(ctx)
if err != nil {
return nil, err
}
ids := make([]int64, 0, len(entries))
for _, entry := range entries {
ids = append(ids, entry.GroupID)
}
return ids, nil
}
func mergeGroupIDs(a []int64, b []int64) []int64 {
seen := make(map[int64]struct{}, len(a)+len(b))
out := make([]int64, 0, len(a)+len(b))
for _, id := range a {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
for _, id := range b {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
return out
}
func buildSchedulerGroupPayload(groupIDs []int64) map[string]any {
if len(groupIDs) == 0 {
return nil
}
return map[string]any{"group_ids": groupIDs}
}
func accountEntityToService(m *dbent.Account) *service.Account {
if m == nil {
return nil
}
rateMultiplier := m.RateMultiplier
return &service.Account{
ID: m.ID,
Name: m.Name,
Notes: m.Notes,
Platform: m.Platform,
Type: m.Type,
Credentials: copyJSONMap(m.Credentials),
Extra: copyJSONMap(m.Extra),
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
RateMultiplier: &rateMultiplier,
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,
ExpiresAt: m.ExpiresAt,
AutoPauseOnExpired: m.AutoPauseOnExpired,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable,
RateLimitedAt: m.RateLimitedAt,
RateLimitResetAt: m.RateLimitResetAt,
OverloadUntil: m.OverloadUntil,
SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: derefString(m.SessionWindowStatus),
}
}
func normalizeJSONMap(in map[string]any) map[string]any {
if in == nil {
return map[string]any{}
}
return in
}
func copyJSONMap(in map[string]any) map[string]any {
if in == nil {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func joinClauses(clauses []string, sep string) string {
if len(clauses) == 0 {
return ""
}
out := clauses[0]
for i := 1; i < len(clauses); i++ {
out += sep + clauses[i]
}
return out
}
func itoa(v int) string {
return strconv.Itoa(v)
}