Files
xinghuoapi/backend/internal/repository/account_repo.go
yangjianbo e9c755f428 perf(后端): 优化删除操作的数据库查询性能
- 新增 ExistsByID 方法用于账号存在性检查,避免加载完整对象
- 新增 GetOwnerID 方法用于 API Key 所有权验证,仅查询 user_id 字段
- 优化 AccountService.Delete 使用轻量级存在性检查
- 优化 ApiKeyService.Delete 使用轻量级权限验证
- 改进前端删除错误提示,显示后端返回的具体错误消息
- 添加详细的中文注释说明优化原因

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-29 14:06:38 +08:00

850 lines
24 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package repository 实现数据访问层Repository Pattern
//
// 该包提供了与数据库交互的所有操作,包括 CRUD、复杂查询和批量操作。
// 采用 Repository 模式将数据访问逻辑与业务逻辑分离,便于测试和维护。
//
// 主要特性:
// - 使用 Ent ORM 进行类型安全的数据库操作
// - 对于复杂查询(如批量更新、聚合统计)使用原生 SQL
// - 提供统一的错误翻译机制,将数据库错误转换为业务错误
// - 支持软删除,所有查询自动过滤已删除记录
package repository
import (
"context"
"database/sql"
"encoding/json"
"strconv"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbaccount "github.com/Wei-Shaw/sub2api/ent/account"
dbaccountgroup "github.com/Wei-Shaw/sub2api/ent/accountgroup"
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
entsql "entgo.io/ent/dialect/sql"
)
// accountRepository 实现 service.AccountRepository 接口。
// 提供 AI API 账户的完整数据访问功能。
//
// 设计说明:
// - client: Ent 客户端,用于类型安全的 ORM 操作
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
// - begin: SQL 事务开启器,用于需要事务的操作
type accountRepository struct {
client *dbent.Client // Ent ORM 客户端
sql sqlExecutor // 原生 SQL 执行接口
begin sqlBeginner // 事务开启接口
}
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
return newAccountRepositoryWithSQL(client, sqlDB)
}
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
// 这种设计便于单元测试时注入 mock 对象。
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
var beginner sqlBeginner
if b, ok := sqlq.(sqlBeginner); ok {
beginner = b
}
return &accountRepository{client: client, sql: sqlq, begin: beginner}
}
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
if account == nil {
return nil
}
builder := r.client.Account.Create().
SetName(account.Name).
SetPlatform(account.Platform).
SetType(account.Type).
SetCredentials(normalizeJSONMap(account.Credentials)).
SetExtra(normalizeJSONMap(account.Extra)).
SetConcurrency(account.Concurrency).
SetPriority(account.Priority).
SetStatus(account.Status).
SetErrorMessage(account.ErrorMessage).
SetSchedulable(account.Schedulable)
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
}
if account.LastUsedAt != nil {
builder.SetLastUsedAt(*account.LastUsedAt)
}
if account.RateLimitedAt != nil {
builder.SetRateLimitedAt(*account.RateLimitedAt)
}
if account.RateLimitResetAt != nil {
builder.SetRateLimitResetAt(*account.RateLimitResetAt)
}
if account.OverloadUntil != nil {
builder.SetOverloadUntil(*account.OverloadUntil)
}
if account.SessionWindowStart != nil {
builder.SetSessionWindowStart(*account.SessionWindowStart)
}
if account.SessionWindowEnd != nil {
builder.SetSessionWindowEnd(*account.SessionWindowEnd)
}
if account.SessionWindowStatus != "" {
builder.SetSessionWindowStatus(account.SessionWindowStatus)
}
created, err := builder.Save(ctx)
if err != nil {
return err
}
account.ID = created.ID
account.CreatedAt = created.CreatedAt
account.UpdatedAt = created.UpdatedAt
return nil
}
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Account, error) {
m, err := r.client.Account.Query().Where(dbaccount.IDEQ(id)).Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
accounts, err := r.accountsToService(ctx, []*dbent.Account{m})
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, service.ErrAccountNotFound
}
return &accounts[0], nil
}
// ExistsByID 检查指定 ID 的账号是否存在。
// 相比 GetByID此方法性能更优因为
// - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值
// - 不加载完整的账号实体及其关联数据Groups、Proxy 等)
// - 适用于删除前的存在性检查等只需判断有无的场景
func (r *accountRepository) ExistsByID(ctx context.Context, id int64) (bool, error) {
exists, err := r.client.Account.Query().Where(dbaccount.IDEQ(id)).Exist(ctx)
if err != nil {
return false, err
}
return exists, nil
}
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
if crsAccountID == "" {
return nil, nil
}
m, err := r.client.Account.Query().
Where(func(s *entsql.Selector) {
s.Where(entsql.ExprP("extra->>'crs_account_id' = ?", crsAccountID))
}).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
accounts, err := r.accountsToService(ctx, []*dbent.Account{m})
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, nil
}
return &accounts[0], nil
}
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
if account == nil {
return nil
}
builder := r.client.Account.UpdateOneID(account.ID).
SetName(account.Name).
SetPlatform(account.Platform).
SetType(account.Type).
SetCredentials(normalizeJSONMap(account.Credentials)).
SetExtra(normalizeJSONMap(account.Extra)).
SetConcurrency(account.Concurrency).
SetPriority(account.Priority).
SetStatus(account.Status).
SetErrorMessage(account.ErrorMessage).
SetSchedulable(account.Schedulable)
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
} else {
builder.ClearProxyID()
}
if account.LastUsedAt != nil {
builder.SetLastUsedAt(*account.LastUsedAt)
} else {
builder.ClearLastUsedAt()
}
if account.RateLimitedAt != nil {
builder.SetRateLimitedAt(*account.RateLimitedAt)
} else {
builder.ClearRateLimitedAt()
}
if account.RateLimitResetAt != nil {
builder.SetRateLimitResetAt(*account.RateLimitResetAt)
} else {
builder.ClearRateLimitResetAt()
}
if account.OverloadUntil != nil {
builder.SetOverloadUntil(*account.OverloadUntil)
} else {
builder.ClearOverloadUntil()
}
if account.SessionWindowStart != nil {
builder.SetSessionWindowStart(*account.SessionWindowStart)
} else {
builder.ClearSessionWindowStart()
}
if account.SessionWindowEnd != nil {
builder.SetSessionWindowEnd(*account.SessionWindowEnd)
} else {
builder.ClearSessionWindowEnd()
}
if account.SessionWindowStatus != "" {
builder.SetSessionWindowStatus(account.SessionWindowStatus)
} else {
builder.ClearSessionWindowStatus()
}
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
account.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
return err
}
_, err := r.client.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx)
return err
}
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "")
}
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
q := r.client.Account.Query()
if platform != "" {
q = q.Where(dbaccount.PlatformEQ(platform))
}
if accountType != "" {
q = q.Where(dbaccount.TypeEQ(accountType))
}
if status != "" {
q = q.Where(dbaccount.StatusEQ(status))
}
if search != "" {
q = q.Where(dbaccount.NameContainsFold(search))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
accounts, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(dbaccount.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outAccounts, err := r.accountsToService(ctx, accounts)
if err != nil {
return nil, nil, err
}
return outAccounts, paginationResultFromTotal(int64(total), params), nil
}
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
accounts, err := r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
status: service.StatusActive,
})
if err != nil {
return nil, err
}
return accounts, nil
}
func (r *accountRepository) ListActive(ctx context.Context) ([]service.Account, error) {
accounts, err := r.client.Account.Query().
Where(dbaccount.StatusEQ(service.StatusActive)).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformEQ(platform),
dbaccount.StatusEQ(service.StatusActive),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now()
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetLastUsedAt(now).
Save(ctx)
return err
}
func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
if len(updates) == 0 {
return nil
}
ids := make([]int64, 0, len(updates))
args := make([]any, 0, len(updates)*2+1)
caseSQL := "UPDATE accounts SET last_used_at = CASE id"
idx := 1
for id, ts := range updates {
caseSQL += " WHEN $" + itoa(idx) + " THEN $" + itoa(idx+1)
args = append(args, id, ts)
ids = append(ids, id)
idx += 2
}
caseSQL += " END, updated_at = NOW() WHERE id = ANY($" + itoa(idx) + ") AND deleted_at IS NULL"
args = append(args, pq.Array(ids))
_, err := r.sql.ExecContext(ctx, caseSQL, args...)
return err
}
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetStatus(service.StatusError).
SetErrorMessage(errorMsg).
Save(ctx)
return err
}
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
_, err := r.client.AccountGroup.Create().
SetAccountID(accountID).
SetGroupID(groupID).
SetPriority(priority).
Save(ctx)
return err
}
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
_, err := r.client.AccountGroup.Delete().
Where(
dbaccountgroup.AccountIDEQ(accountID),
dbaccountgroup.GroupIDEQ(groupID),
).
Exec(ctx)
return err
}
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
groups, err := r.client.Group.Query().
Where(
dbgroup.HasAccountsWith(dbaccount.IDEQ(accountID)),
).
All(ctx)
if err != nil {
return nil, err
}
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *groupEntityToService(groups[i]))
}
return outGroups, nil
}
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil {
return err
}
if len(groupIDs) == 0 {
return nil
}
builders := make([]*dbent.AccountGroupCreate, 0, len(groupIDs))
for i, groupID := range groupIDs {
builders = append(builders, r.client.AccountGroup.Create().
SetAccountID(accountID).
SetGroupID(groupID).
SetPriority(i+1),
)
}
_, err := r.client.AccountGroup.CreateBulk(builders...).Save(ctx)
return err
}
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) {
now := time.Now()
accounts, err := r.client.Account.Query().
Where(
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
status: service.StatusActive,
schedulable: true,
})
}
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
now := time.Now()
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformEQ(platform),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
status: service.StatusActive,
schedulable: true,
platform: platform,
})
}
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now()
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetRateLimitedAt(now).
SetRateLimitResetAt(resetAt).
Save(ctx)
return err
}
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetOverloadUntil(until).
Save(ctx)
return err
}
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
ClearRateLimitedAt().
ClearRateLimitResetAt().
ClearOverloadUntil().
Save(ctx)
return err
}
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
builder := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetSessionWindowStatus(status)
if start != nil {
builder.SetSessionWindowStart(*start)
}
if end != nil {
builder.SetSessionWindowEnd(*end)
}
_, err := builder.Save(ctx)
return err
}
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetSchedulable(schedulable).
Save(ctx)
return err
}
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 {
return nil
}
accountExtra, err := r.client.Account.Query().
Where(dbaccount.IDEQ(id)).
Select(dbaccount.FieldExtra).
Only(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
extra := normalizeJSONMap(accountExtra.Extra)
for k, v := range updates {
extra[k] = v
}
_, err = r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetExtra(extra).
Save(ctx)
return err
}
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
setClauses := make([]string, 0, 8)
args := make([]any, 0, 8)
idx := 1
if updates.Name != nil {
setClauses = append(setClauses, "name = $"+itoa(idx))
args = append(args, *updates.Name)
idx++
}
if updates.ProxyID != nil {
setClauses = append(setClauses, "proxy_id = $"+itoa(idx))
args = append(args, *updates.ProxyID)
idx++
}
if updates.Concurrency != nil {
setClauses = append(setClauses, "concurrency = $"+itoa(idx))
args = append(args, *updates.Concurrency)
idx++
}
if updates.Priority != nil {
setClauses = append(setClauses, "priority = $"+itoa(idx))
args = append(args, *updates.Priority)
idx++
}
if updates.Status != nil {
setClauses = append(setClauses, "status = $"+itoa(idx))
args = append(args, *updates.Status)
idx++
}
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
if len(updates.Credentials) > 0 {
payload, err := json.Marshal(updates.Credentials)
if err != nil {
return 0, err
}
setClauses = append(setClauses, "credentials = COALESCE(credentials, '{}'::jsonb) || $"+itoa(idx)+"::jsonb")
args = append(args, payload)
idx++
}
if len(updates.Extra) > 0 {
payload, err := json.Marshal(updates.Extra)
if err != nil {
return 0, err
}
setClauses = append(setClauses, "extra = COALESCE(extra, '{}'::jsonb) || $"+itoa(idx)+"::jsonb")
args = append(args, payload)
idx++
}
if len(setClauses) == 0 {
return 0, nil
}
setClauses = append(setClauses, "updated_at = NOW()")
query := "UPDATE accounts SET " + joinClauses(setClauses, ", ") + " WHERE id = ANY($" + itoa(idx) + ") AND deleted_at IS NULL"
args = append(args, pq.Array(ids))
result, err := r.sql.ExecContext(ctx, query, args...)
if err != nil {
return 0, err
}
rows, err := result.RowsAffected()
if err != nil {
return 0, err
}
return rows, nil
}
type accountGroupQueryOptions struct {
status string
schedulable bool
platform string
}
func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID int64, opts accountGroupQueryOptions) ([]service.Account, error) {
q := r.client.AccountGroup.Query().
Where(dbaccountgroup.GroupIDEQ(groupID))
preds := make([]dbpredicate.Account, 0, 6)
preds = append(preds, dbaccount.DeletedAtIsNil())
if opts.status != "" {
preds = append(preds, dbaccount.StatusEQ(opts.status))
}
if opts.platform != "" {
preds = append(preds, dbaccount.PlatformEQ(opts.platform))
}
if opts.schedulable {
now := time.Now()
preds = append(preds,
dbaccount.SchedulableEQ(true),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
)
}
if len(preds) > 0 {
q = q.Where(dbaccountgroup.HasAccountWith(preds...))
}
groups, err := q.
Order(
dbaccountgroup.ByPriority(),
dbaccountgroup.ByAccountField(dbaccount.FieldPriority),
).
WithAccount().
All(ctx)
if err != nil {
return nil, err
}
orderedIDs := make([]int64, 0, len(groups))
accountMap := make(map[int64]*dbent.Account, len(groups))
for _, ag := range groups {
if ag.Edges.Account == nil {
continue
}
if _, exists := accountMap[ag.AccountID]; exists {
continue
}
accountMap[ag.AccountID] = ag.Edges.Account
orderedIDs = append(orderedIDs, ag.AccountID)
}
accounts := make([]*dbent.Account, 0, len(orderedIDs))
for _, id := range orderedIDs {
if acc, ok := accountMap[id]; ok {
accounts = append(accounts, acc)
}
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) accountsToService(ctx context.Context, accounts []*dbent.Account) ([]service.Account, error) {
if len(accounts) == 0 {
return []service.Account{}, nil
}
accountIDs := make([]int64, 0, len(accounts))
proxyIDs := make([]int64, 0, len(accounts))
for _, acc := range accounts {
accountIDs = append(accountIDs, acc.ID)
if acc.ProxyID != nil {
proxyIDs = append(proxyIDs, *acc.ProxyID)
}
}
proxyMap, err := r.loadProxies(ctx, proxyIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for _, acc := range accounts {
out := accountEntityToService(acc)
if out == nil {
continue
}
if acc.ProxyID != nil {
if proxy, ok := proxyMap[*acc.ProxyID]; ok {
out.Proxy = proxy
}
}
if groups, ok := groupsByAccount[acc.ID]; ok {
out.Groups = groups
}
if groupIDs, ok := groupIDsByAccount[acc.ID]; ok {
out.GroupIDs = groupIDs
}
if ags, ok := accountGroupsByAccount[acc.ID]; ok {
out.AccountGroups = ags
}
outAccounts = append(outAccounts, *out)
}
return outAccounts, nil
}
func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
proxyMap := make(map[int64]*service.Proxy)
if len(proxyIDs) == 0 {
return proxyMap, nil
}
proxies, err := r.client.Proxy.Query().Where(dbproxy.IDIn(proxyIDs...)).All(ctx)
if err != nil {
return nil, err
}
for _, p := range proxies {
proxyMap[p.ID] = proxyEntityToService(p)
}
return proxyMap, nil
}
func (r *accountRepository) loadAccountGroups(ctx context.Context, accountIDs []int64) (map[int64][]*service.Group, map[int64][]int64, map[int64][]service.AccountGroup, error) {
groupsByAccount := make(map[int64][]*service.Group)
groupIDsByAccount := make(map[int64][]int64)
accountGroupsByAccount := make(map[int64][]service.AccountGroup)
if len(accountIDs) == 0 {
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
}
entries, err := r.client.AccountGroup.Query().
Where(dbaccountgroup.AccountIDIn(accountIDs...)).
WithGroup().
Order(dbaccountgroup.ByAccountID(), dbaccountgroup.ByPriority()).
All(ctx)
if err != nil {
return nil, nil, nil, err
}
for _, ag := range entries {
groupSvc := groupEntityToService(ag.Edges.Group)
agSvc := service.AccountGroup{
AccountID: ag.AccountID,
GroupID: ag.GroupID,
Priority: ag.Priority,
CreatedAt: ag.CreatedAt,
Group: groupSvc,
}
accountGroupsByAccount[ag.AccountID] = append(accountGroupsByAccount[ag.AccountID], agSvc)
groupIDsByAccount[ag.AccountID] = append(groupIDsByAccount[ag.AccountID], ag.GroupID)
if groupSvc != nil {
groupsByAccount[ag.AccountID] = append(groupsByAccount[ag.AccountID], groupSvc)
}
}
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
}
func accountEntityToService(m *dbent.Account) *service.Account {
if m == nil {
return nil
}
return &service.Account{
ID: m.ID,
Name: m.Name,
Platform: m.Platform,
Type: m.Type,
Credentials: copyJSONMap(m.Credentials),
Extra: copyJSONMap(m.Extra),
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable,
RateLimitedAt: m.RateLimitedAt,
RateLimitResetAt: m.RateLimitResetAt,
OverloadUntil: m.OverloadUntil,
SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: derefString(m.SessionWindowStatus),
}
}
func normalizeJSONMap(in map[string]any) map[string]any {
if in == nil {
return map[string]any{}
}
return in
}
func copyJSONMap(in map[string]any) map[string]any {
if in == nil {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func joinClauses(clauses []string, sep string) string {
if len(clauses) == 0 {
return ""
}
out := clauses[0]
for i := 1; i < len(clauses); i++ {
out += sep + clauses[i]
}
return out
}
func itoa(v int) string {
return strconv.Itoa(v)
}