- 新增 ExistsByID 方法用于账号存在性检查,避免加载完整对象 - 新增 GetOwnerID 方法用于 API Key 所有权验证,仅查询 user_id 字段 - 优化 AccountService.Delete 使用轻量级存在性检查 - 优化 ApiKeyService.Delete 使用轻量级权限验证 - 改进前端删除错误提示,显示后端返回的具体错误消息 - 添加详细的中文注释说明优化原因 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
850 lines
24 KiB
Go
850 lines
24 KiB
Go
// Package repository 实现数据访问层(Repository Pattern)。
|
||
//
|
||
// 该包提供了与数据库交互的所有操作,包括 CRUD、复杂查询和批量操作。
|
||
// 采用 Repository 模式将数据访问逻辑与业务逻辑分离,便于测试和维护。
|
||
//
|
||
// 主要特性:
|
||
// - 使用 Ent ORM 进行类型安全的数据库操作
|
||
// - 对于复杂查询(如批量更新、聚合统计)使用原生 SQL
|
||
// - 提供统一的错误翻译机制,将数据库错误转换为业务错误
|
||
// - 支持软删除,所有查询自动过滤已删除记录
|
||
package repository
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"encoding/json"
|
||
"strconv"
|
||
"time"
|
||
|
||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||
dbaccount "github.com/Wei-Shaw/sub2api/ent/account"
|
||
dbaccountgroup "github.com/Wei-Shaw/sub2api/ent/accountgroup"
|
||
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
|
||
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
|
||
dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||
"github.com/lib/pq"
|
||
|
||
entsql "entgo.io/ent/dialect/sql"
|
||
)
|
||
|
||
// accountRepository 实现 service.AccountRepository 接口。
|
||
// 提供 AI API 账户的完整数据访问功能。
|
||
//
|
||
// 设计说明:
|
||
// - client: Ent 客户端,用于类型安全的 ORM 操作
|
||
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
|
||
// - begin: SQL 事务开启器,用于需要事务的操作
|
||
type accountRepository struct {
|
||
client *dbent.Client // Ent ORM 客户端
|
||
sql sqlExecutor // 原生 SQL 执行接口
|
||
begin sqlBeginner // 事务开启接口
|
||
}
|
||
|
||
// NewAccountRepository 创建账户仓储实例。
|
||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
|
||
return newAccountRepositoryWithSQL(client, sqlDB)
|
||
}
|
||
|
||
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
|
||
// 这种设计便于单元测试时注入 mock 对象。
|
||
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
|
||
var beginner sqlBeginner
|
||
if b, ok := sqlq.(sqlBeginner); ok {
|
||
beginner = b
|
||
}
|
||
return &accountRepository{client: client, sql: sqlq, begin: beginner}
|
||
}
|
||
|
||
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
|
||
if account == nil {
|
||
return nil
|
||
}
|
||
|
||
builder := r.client.Account.Create().
|
||
SetName(account.Name).
|
||
SetPlatform(account.Platform).
|
||
SetType(account.Type).
|
||
SetCredentials(normalizeJSONMap(account.Credentials)).
|
||
SetExtra(normalizeJSONMap(account.Extra)).
|
||
SetConcurrency(account.Concurrency).
|
||
SetPriority(account.Priority).
|
||
SetStatus(account.Status).
|
||
SetErrorMessage(account.ErrorMessage).
|
||
SetSchedulable(account.Schedulable)
|
||
|
||
if account.ProxyID != nil {
|
||
builder.SetProxyID(*account.ProxyID)
|
||
}
|
||
if account.LastUsedAt != nil {
|
||
builder.SetLastUsedAt(*account.LastUsedAt)
|
||
}
|
||
if account.RateLimitedAt != nil {
|
||
builder.SetRateLimitedAt(*account.RateLimitedAt)
|
||
}
|
||
if account.RateLimitResetAt != nil {
|
||
builder.SetRateLimitResetAt(*account.RateLimitResetAt)
|
||
}
|
||
if account.OverloadUntil != nil {
|
||
builder.SetOverloadUntil(*account.OverloadUntil)
|
||
}
|
||
if account.SessionWindowStart != nil {
|
||
builder.SetSessionWindowStart(*account.SessionWindowStart)
|
||
}
|
||
if account.SessionWindowEnd != nil {
|
||
builder.SetSessionWindowEnd(*account.SessionWindowEnd)
|
||
}
|
||
if account.SessionWindowStatus != "" {
|
||
builder.SetSessionWindowStatus(account.SessionWindowStatus)
|
||
}
|
||
|
||
created, err := builder.Save(ctx)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
account.ID = created.ID
|
||
account.CreatedAt = created.CreatedAt
|
||
account.UpdatedAt = created.UpdatedAt
|
||
return nil
|
||
}
|
||
|
||
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||
m, err := r.client.Account.Query().Where(dbaccount.IDEQ(id)).Only(ctx)
|
||
if err != nil {
|
||
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
||
}
|
||
|
||
accounts, err := r.accountsToService(ctx, []*dbent.Account{m})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if len(accounts) == 0 {
|
||
return nil, service.ErrAccountNotFound
|
||
}
|
||
return &accounts[0], nil
|
||
}
|
||
|
||
// ExistsByID 检查指定 ID 的账号是否存在。
|
||
// 相比 GetByID,此方法性能更优,因为:
|
||
// - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值
|
||
// - 不加载完整的账号实体及其关联数据(Groups、Proxy 等)
|
||
// - 适用于删除前的存在性检查等只需判断有无的场景
|
||
func (r *accountRepository) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
||
exists, err := r.client.Account.Query().Where(dbaccount.IDEQ(id)).Exist(ctx)
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return exists, nil
|
||
}
|
||
|
||
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
|
||
if crsAccountID == "" {
|
||
return nil, nil
|
||
}
|
||
|
||
m, err := r.client.Account.Query().
|
||
Where(func(s *entsql.Selector) {
|
||
s.Where(entsql.ExprP("extra->>'crs_account_id' = ?", crsAccountID))
|
||
}).
|
||
Only(ctx)
|
||
if err != nil {
|
||
if dbent.IsNotFound(err) {
|
||
return nil, nil
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
accounts, err := r.accountsToService(ctx, []*dbent.Account{m})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if len(accounts) == 0 {
|
||
return nil, nil
|
||
}
|
||
return &accounts[0], nil
|
||
}
|
||
|
||
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
|
||
if account == nil {
|
||
return nil
|
||
}
|
||
|
||
builder := r.client.Account.UpdateOneID(account.ID).
|
||
SetName(account.Name).
|
||
SetPlatform(account.Platform).
|
||
SetType(account.Type).
|
||
SetCredentials(normalizeJSONMap(account.Credentials)).
|
||
SetExtra(normalizeJSONMap(account.Extra)).
|
||
SetConcurrency(account.Concurrency).
|
||
SetPriority(account.Priority).
|
||
SetStatus(account.Status).
|
||
SetErrorMessage(account.ErrorMessage).
|
||
SetSchedulable(account.Schedulable)
|
||
|
||
if account.ProxyID != nil {
|
||
builder.SetProxyID(*account.ProxyID)
|
||
} else {
|
||
builder.ClearProxyID()
|
||
}
|
||
if account.LastUsedAt != nil {
|
||
builder.SetLastUsedAt(*account.LastUsedAt)
|
||
} else {
|
||
builder.ClearLastUsedAt()
|
||
}
|
||
if account.RateLimitedAt != nil {
|
||
builder.SetRateLimitedAt(*account.RateLimitedAt)
|
||
} else {
|
||
builder.ClearRateLimitedAt()
|
||
}
|
||
if account.RateLimitResetAt != nil {
|
||
builder.SetRateLimitResetAt(*account.RateLimitResetAt)
|
||
} else {
|
||
builder.ClearRateLimitResetAt()
|
||
}
|
||
if account.OverloadUntil != nil {
|
||
builder.SetOverloadUntil(*account.OverloadUntil)
|
||
} else {
|
||
builder.ClearOverloadUntil()
|
||
}
|
||
if account.SessionWindowStart != nil {
|
||
builder.SetSessionWindowStart(*account.SessionWindowStart)
|
||
} else {
|
||
builder.ClearSessionWindowStart()
|
||
}
|
||
if account.SessionWindowEnd != nil {
|
||
builder.SetSessionWindowEnd(*account.SessionWindowEnd)
|
||
} else {
|
||
builder.ClearSessionWindowEnd()
|
||
}
|
||
if account.SessionWindowStatus != "" {
|
||
builder.SetSessionWindowStatus(account.SessionWindowStatus)
|
||
} else {
|
||
builder.ClearSessionWindowStatus()
|
||
}
|
||
|
||
updated, err := builder.Save(ctx)
|
||
if err != nil {
|
||
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
||
}
|
||
account.UpdatedAt = updated.UpdatedAt
|
||
return nil
|
||
}
|
||
|
||
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||
if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
|
||
return err
|
||
}
|
||
_, err := r.client.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx)
|
||
return err
|
||
}
|
||
|
||
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||
return r.ListWithFilters(ctx, params, "", "", "", "")
|
||
}
|
||
|
||
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
|
||
q := r.client.Account.Query()
|
||
|
||
if platform != "" {
|
||
q = q.Where(dbaccount.PlatformEQ(platform))
|
||
}
|
||
if accountType != "" {
|
||
q = q.Where(dbaccount.TypeEQ(accountType))
|
||
}
|
||
if status != "" {
|
||
q = q.Where(dbaccount.StatusEQ(status))
|
||
}
|
||
if search != "" {
|
||
q = q.Where(dbaccount.NameContainsFold(search))
|
||
}
|
||
|
||
total, err := q.Count(ctx)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
accounts, err := q.
|
||
Offset(params.Offset()).
|
||
Limit(params.Limit()).
|
||
Order(dbent.Desc(dbaccount.FieldID)).
|
||
All(ctx)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
outAccounts, err := r.accountsToService(ctx, accounts)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
return outAccounts, paginationResultFromTotal(int64(total), params), nil
|
||
}
|
||
|
||
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||
accounts, err := r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
|
||
status: service.StatusActive,
|
||
})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return accounts, nil
|
||
}
|
||
|
||
func (r *accountRepository) ListActive(ctx context.Context) ([]service.Account, error) {
|
||
accounts, err := r.client.Account.Query().
|
||
Where(dbaccount.StatusEQ(service.StatusActive)).
|
||
Order(dbent.Asc(dbaccount.FieldPriority)).
|
||
All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return r.accountsToService(ctx, accounts)
|
||
}
|
||
|
||
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||
accounts, err := r.client.Account.Query().
|
||
Where(
|
||
dbaccount.PlatformEQ(platform),
|
||
dbaccount.StatusEQ(service.StatusActive),
|
||
).
|
||
Order(dbent.Asc(dbaccount.FieldPriority)).
|
||
All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return r.accountsToService(ctx, accounts)
|
||
}
|
||
|
||
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
|
||
now := time.Now()
|
||
_, err := r.client.Account.Update().
|
||
Where(dbaccount.IDEQ(id)).
|
||
SetLastUsedAt(now).
|
||
Save(ctx)
|
||
return err
|
||
}
|
||
|
||
func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||
if len(updates) == 0 {
|
||
return nil
|
||
}
|
||
|
||
ids := make([]int64, 0, len(updates))
|
||
args := make([]any, 0, len(updates)*2+1)
|
||
caseSQL := "UPDATE accounts SET last_used_at = CASE id"
|
||
|
||
idx := 1
|
||
for id, ts := range updates {
|
||
caseSQL += " WHEN $" + itoa(idx) + " THEN $" + itoa(idx+1)
|
||
args = append(args, id, ts)
|
||
ids = append(ids, id)
|
||
idx += 2
|
||
}
|
||
|
||
caseSQL += " END, updated_at = NOW() WHERE id = ANY($" + itoa(idx) + ") AND deleted_at IS NULL"
|
||
args = append(args, pq.Array(ids))
|
||
|
||
_, err := r.sql.ExecContext(ctx, caseSQL, args...)
|
||
return err
|
||
}
|
||
|
||
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||
_, err := r.client.Account.Update().
|
||
Where(dbaccount.IDEQ(id)).
|
||
SetStatus(service.StatusError).
|
||
SetErrorMessage(errorMsg).
|
||
Save(ctx)
|
||
return err
|
||
}
|
||
|
||
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||
_, err := r.client.AccountGroup.Create().
|
||
SetAccountID(accountID).
|
||
SetGroupID(groupID).
|
||
SetPriority(priority).
|
||
Save(ctx)
|
||
return err
|
||
}
|
||
|
||
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
|
||
_, err := r.client.AccountGroup.Delete().
|
||
Where(
|
||
dbaccountgroup.AccountIDEQ(accountID),
|
||
dbaccountgroup.GroupIDEQ(groupID),
|
||
).
|
||
Exec(ctx)
|
||
return err
|
||
}
|
||
|
||
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
|
||
groups, err := r.client.Group.Query().
|
||
Where(
|
||
dbgroup.HasAccountsWith(dbaccount.IDEQ(accountID)),
|
||
).
|
||
All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
outGroups := make([]service.Group, 0, len(groups))
|
||
for i := range groups {
|
||
outGroups = append(outGroups, *groupEntityToService(groups[i]))
|
||
}
|
||
return outGroups, nil
|
||
}
|
||
|
||
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||
if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil {
|
||
return err
|
||
}
|
||
|
||
if len(groupIDs) == 0 {
|
||
return nil
|
||
}
|
||
|
||
builders := make([]*dbent.AccountGroupCreate, 0, len(groupIDs))
|
||
for i, groupID := range groupIDs {
|
||
builders = append(builders, r.client.AccountGroup.Create().
|
||
SetAccountID(accountID).
|
||
SetGroupID(groupID).
|
||
SetPriority(i+1),
|
||
)
|
||
}
|
||
|
||
_, err := r.client.AccountGroup.CreateBulk(builders...).Save(ctx)
|
||
return err
|
||
}
|
||
|
||
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) {
|
||
now := time.Now()
|
||
accounts, err := r.client.Account.Query().
|
||
Where(
|
||
dbaccount.StatusEQ(service.StatusActive),
|
||
dbaccount.SchedulableEQ(true),
|
||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||
).
|
||
Order(dbent.Asc(dbaccount.FieldPriority)).
|
||
All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return r.accountsToService(ctx, accounts)
|
||
}
|
||
|
||
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||
return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
|
||
status: service.StatusActive,
|
||
schedulable: true,
|
||
})
|
||
}
|
||
|
||
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||
now := time.Now()
|
||
accounts, err := r.client.Account.Query().
|
||
Where(
|
||
dbaccount.PlatformEQ(platform),
|
||
dbaccount.StatusEQ(service.StatusActive),
|
||
dbaccount.SchedulableEQ(true),
|
||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||
).
|
||
Order(dbent.Asc(dbaccount.FieldPriority)).
|
||
All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return r.accountsToService(ctx, accounts)
|
||
}
|
||
|
||
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||
return r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
|
||
status: service.StatusActive,
|
||
schedulable: true,
|
||
platform: platform,
|
||
})
|
||
}
|
||
|
||
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||
now := time.Now()
|
||
_, err := r.client.Account.Update().
|
||
Where(dbaccount.IDEQ(id)).
|
||
SetRateLimitedAt(now).
|
||
SetRateLimitResetAt(resetAt).
|
||
Save(ctx)
|
||
return err
|
||
}
|
||
|
||
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||
_, err := r.client.Account.Update().
|
||
Where(dbaccount.IDEQ(id)).
|
||
SetOverloadUntil(until).
|
||
Save(ctx)
|
||
return err
|
||
}
|
||
|
||
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||
_, err := r.client.Account.Update().
|
||
Where(dbaccount.IDEQ(id)).
|
||
ClearRateLimitedAt().
|
||
ClearRateLimitResetAt().
|
||
ClearOverloadUntil().
|
||
Save(ctx)
|
||
return err
|
||
}
|
||
|
||
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||
builder := r.client.Account.Update().
|
||
Where(dbaccount.IDEQ(id)).
|
||
SetSessionWindowStatus(status)
|
||
if start != nil {
|
||
builder.SetSessionWindowStart(*start)
|
||
}
|
||
if end != nil {
|
||
builder.SetSessionWindowEnd(*end)
|
||
}
|
||
_, err := builder.Save(ctx)
|
||
return err
|
||
}
|
||
|
||
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||
_, err := r.client.Account.Update().
|
||
Where(dbaccount.IDEQ(id)).
|
||
SetSchedulable(schedulable).
|
||
Save(ctx)
|
||
return err
|
||
}
|
||
|
||
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||
if len(updates) == 0 {
|
||
return nil
|
||
}
|
||
|
||
accountExtra, err := r.client.Account.Query().
|
||
Where(dbaccount.IDEQ(id)).
|
||
Select(dbaccount.FieldExtra).
|
||
Only(ctx)
|
||
if err != nil {
|
||
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
||
}
|
||
|
||
extra := normalizeJSONMap(accountExtra.Extra)
|
||
for k, v := range updates {
|
||
extra[k] = v
|
||
}
|
||
|
||
_, err = r.client.Account.Update().
|
||
Where(dbaccount.IDEQ(id)).
|
||
SetExtra(extra).
|
||
Save(ctx)
|
||
return err
|
||
}
|
||
|
||
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||
if len(ids) == 0 {
|
||
return 0, nil
|
||
}
|
||
|
||
setClauses := make([]string, 0, 8)
|
||
args := make([]any, 0, 8)
|
||
|
||
idx := 1
|
||
if updates.Name != nil {
|
||
setClauses = append(setClauses, "name = $"+itoa(idx))
|
||
args = append(args, *updates.Name)
|
||
idx++
|
||
}
|
||
if updates.ProxyID != nil {
|
||
setClauses = append(setClauses, "proxy_id = $"+itoa(idx))
|
||
args = append(args, *updates.ProxyID)
|
||
idx++
|
||
}
|
||
if updates.Concurrency != nil {
|
||
setClauses = append(setClauses, "concurrency = $"+itoa(idx))
|
||
args = append(args, *updates.Concurrency)
|
||
idx++
|
||
}
|
||
if updates.Priority != nil {
|
||
setClauses = append(setClauses, "priority = $"+itoa(idx))
|
||
args = append(args, *updates.Priority)
|
||
idx++
|
||
}
|
||
if updates.Status != nil {
|
||
setClauses = append(setClauses, "status = $"+itoa(idx))
|
||
args = append(args, *updates.Status)
|
||
idx++
|
||
}
|
||
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
|
||
if len(updates.Credentials) > 0 {
|
||
payload, err := json.Marshal(updates.Credentials)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
setClauses = append(setClauses, "credentials = COALESCE(credentials, '{}'::jsonb) || $"+itoa(idx)+"::jsonb")
|
||
args = append(args, payload)
|
||
idx++
|
||
}
|
||
if len(updates.Extra) > 0 {
|
||
payload, err := json.Marshal(updates.Extra)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
setClauses = append(setClauses, "extra = COALESCE(extra, '{}'::jsonb) || $"+itoa(idx)+"::jsonb")
|
||
args = append(args, payload)
|
||
idx++
|
||
}
|
||
|
||
if len(setClauses) == 0 {
|
||
return 0, nil
|
||
}
|
||
|
||
setClauses = append(setClauses, "updated_at = NOW()")
|
||
|
||
query := "UPDATE accounts SET " + joinClauses(setClauses, ", ") + " WHERE id = ANY($" + itoa(idx) + ") AND deleted_at IS NULL"
|
||
args = append(args, pq.Array(ids))
|
||
|
||
result, err := r.sql.ExecContext(ctx, query, args...)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
rows, err := result.RowsAffected()
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return rows, nil
|
||
}
|
||
|
||
type accountGroupQueryOptions struct {
|
||
status string
|
||
schedulable bool
|
||
platform string
|
||
}
|
||
|
||
func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID int64, opts accountGroupQueryOptions) ([]service.Account, error) {
|
||
q := r.client.AccountGroup.Query().
|
||
Where(dbaccountgroup.GroupIDEQ(groupID))
|
||
|
||
preds := make([]dbpredicate.Account, 0, 6)
|
||
preds = append(preds, dbaccount.DeletedAtIsNil())
|
||
if opts.status != "" {
|
||
preds = append(preds, dbaccount.StatusEQ(opts.status))
|
||
}
|
||
if opts.platform != "" {
|
||
preds = append(preds, dbaccount.PlatformEQ(opts.platform))
|
||
}
|
||
if opts.schedulable {
|
||
now := time.Now()
|
||
preds = append(preds,
|
||
dbaccount.SchedulableEQ(true),
|
||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||
)
|
||
}
|
||
|
||
if len(preds) > 0 {
|
||
q = q.Where(dbaccountgroup.HasAccountWith(preds...))
|
||
}
|
||
|
||
groups, err := q.
|
||
Order(
|
||
dbaccountgroup.ByPriority(),
|
||
dbaccountgroup.ByAccountField(dbaccount.FieldPriority),
|
||
).
|
||
WithAccount().
|
||
All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
orderedIDs := make([]int64, 0, len(groups))
|
||
accountMap := make(map[int64]*dbent.Account, len(groups))
|
||
for _, ag := range groups {
|
||
if ag.Edges.Account == nil {
|
||
continue
|
||
}
|
||
if _, exists := accountMap[ag.AccountID]; exists {
|
||
continue
|
||
}
|
||
accountMap[ag.AccountID] = ag.Edges.Account
|
||
orderedIDs = append(orderedIDs, ag.AccountID)
|
||
}
|
||
|
||
accounts := make([]*dbent.Account, 0, len(orderedIDs))
|
||
for _, id := range orderedIDs {
|
||
if acc, ok := accountMap[id]; ok {
|
||
accounts = append(accounts, acc)
|
||
}
|
||
}
|
||
|
||
return r.accountsToService(ctx, accounts)
|
||
}
|
||
|
||
func (r *accountRepository) accountsToService(ctx context.Context, accounts []*dbent.Account) ([]service.Account, error) {
|
||
if len(accounts) == 0 {
|
||
return []service.Account{}, nil
|
||
}
|
||
|
||
accountIDs := make([]int64, 0, len(accounts))
|
||
proxyIDs := make([]int64, 0, len(accounts))
|
||
for _, acc := range accounts {
|
||
accountIDs = append(accountIDs, acc.ID)
|
||
if acc.ProxyID != nil {
|
||
proxyIDs = append(proxyIDs, *acc.ProxyID)
|
||
}
|
||
}
|
||
|
||
proxyMap, err := r.loadProxies(ctx, proxyIDs)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
outAccounts := make([]service.Account, 0, len(accounts))
|
||
for _, acc := range accounts {
|
||
out := accountEntityToService(acc)
|
||
if out == nil {
|
||
continue
|
||
}
|
||
if acc.ProxyID != nil {
|
||
if proxy, ok := proxyMap[*acc.ProxyID]; ok {
|
||
out.Proxy = proxy
|
||
}
|
||
}
|
||
if groups, ok := groupsByAccount[acc.ID]; ok {
|
||
out.Groups = groups
|
||
}
|
||
if groupIDs, ok := groupIDsByAccount[acc.ID]; ok {
|
||
out.GroupIDs = groupIDs
|
||
}
|
||
if ags, ok := accountGroupsByAccount[acc.ID]; ok {
|
||
out.AccountGroups = ags
|
||
}
|
||
outAccounts = append(outAccounts, *out)
|
||
}
|
||
|
||
return outAccounts, nil
|
||
}
|
||
|
||
func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
|
||
proxyMap := make(map[int64]*service.Proxy)
|
||
if len(proxyIDs) == 0 {
|
||
return proxyMap, nil
|
||
}
|
||
|
||
proxies, err := r.client.Proxy.Query().Where(dbproxy.IDIn(proxyIDs...)).All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
for _, p := range proxies {
|
||
proxyMap[p.ID] = proxyEntityToService(p)
|
||
}
|
||
return proxyMap, nil
|
||
}
|
||
|
||
func (r *accountRepository) loadAccountGroups(ctx context.Context, accountIDs []int64) (map[int64][]*service.Group, map[int64][]int64, map[int64][]service.AccountGroup, error) {
|
||
groupsByAccount := make(map[int64][]*service.Group)
|
||
groupIDsByAccount := make(map[int64][]int64)
|
||
accountGroupsByAccount := make(map[int64][]service.AccountGroup)
|
||
|
||
if len(accountIDs) == 0 {
|
||
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
|
||
}
|
||
|
||
entries, err := r.client.AccountGroup.Query().
|
||
Where(dbaccountgroup.AccountIDIn(accountIDs...)).
|
||
WithGroup().
|
||
Order(dbaccountgroup.ByAccountID(), dbaccountgroup.ByPriority()).
|
||
All(ctx)
|
||
if err != nil {
|
||
return nil, nil, nil, err
|
||
}
|
||
|
||
for _, ag := range entries {
|
||
groupSvc := groupEntityToService(ag.Edges.Group)
|
||
agSvc := service.AccountGroup{
|
||
AccountID: ag.AccountID,
|
||
GroupID: ag.GroupID,
|
||
Priority: ag.Priority,
|
||
CreatedAt: ag.CreatedAt,
|
||
Group: groupSvc,
|
||
}
|
||
accountGroupsByAccount[ag.AccountID] = append(accountGroupsByAccount[ag.AccountID], agSvc)
|
||
groupIDsByAccount[ag.AccountID] = append(groupIDsByAccount[ag.AccountID], ag.GroupID)
|
||
if groupSvc != nil {
|
||
groupsByAccount[ag.AccountID] = append(groupsByAccount[ag.AccountID], groupSvc)
|
||
}
|
||
}
|
||
|
||
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
|
||
}
|
||
|
||
func accountEntityToService(m *dbent.Account) *service.Account {
|
||
if m == nil {
|
||
return nil
|
||
}
|
||
|
||
return &service.Account{
|
||
ID: m.ID,
|
||
Name: m.Name,
|
||
Platform: m.Platform,
|
||
Type: m.Type,
|
||
Credentials: copyJSONMap(m.Credentials),
|
||
Extra: copyJSONMap(m.Extra),
|
||
ProxyID: m.ProxyID,
|
||
Concurrency: m.Concurrency,
|
||
Priority: m.Priority,
|
||
Status: m.Status,
|
||
ErrorMessage: derefString(m.ErrorMessage),
|
||
LastUsedAt: m.LastUsedAt,
|
||
CreatedAt: m.CreatedAt,
|
||
UpdatedAt: m.UpdatedAt,
|
||
Schedulable: m.Schedulable,
|
||
RateLimitedAt: m.RateLimitedAt,
|
||
RateLimitResetAt: m.RateLimitResetAt,
|
||
OverloadUntil: m.OverloadUntil,
|
||
SessionWindowStart: m.SessionWindowStart,
|
||
SessionWindowEnd: m.SessionWindowEnd,
|
||
SessionWindowStatus: derefString(m.SessionWindowStatus),
|
||
}
|
||
}
|
||
|
||
func normalizeJSONMap(in map[string]any) map[string]any {
|
||
if in == nil {
|
||
return map[string]any{}
|
||
}
|
||
return in
|
||
}
|
||
|
||
func copyJSONMap(in map[string]any) map[string]any {
|
||
if in == nil {
|
||
return nil
|
||
}
|
||
out := make(map[string]any, len(in))
|
||
for k, v := range in {
|
||
out[k] = v
|
||
}
|
||
return out
|
||
}
|
||
|
||
func joinClauses(clauses []string, sep string) string {
|
||
if len(clauses) == 0 {
|
||
return ""
|
||
}
|
||
out := clauses[0]
|
||
for i := 1; i < len(clauses); i++ {
|
||
out += sep + clauses[i]
|
||
}
|
||
return out
|
||
}
|
||
|
||
func itoa(v int) string {
|
||
return strconv.Itoa(v)
|
||
}
|