Files
yinghuoapi/backend/internal/repository/account_repo.go
2025-12-25 17:15:01 +08:00

402 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"errors"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type AccountRepository struct {
db *gorm.DB
}
func NewAccountRepository(db *gorm.DB) *AccountRepository {
return &AccountRepository{db: db}
}
func (r *AccountRepository) Create(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Create(account).Error
}
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
var account model.Account
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
if err != nil {
return nil, err
}
// 填充 GroupIDs 和 Groups 虚拟字段
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
account.Groups = make([]*model.Group, 0, len(account.AccountGroups))
for _, ag := range account.AccountGroups {
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
}
}
return &account, nil
}
func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
if crsAccountID == "" {
return nil, nil
}
var account model.Account
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&account).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &account, nil
}
func (r *AccountRepository) Update(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Save(account).Error
}
func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
// 先删除账号与分组的绑定关系
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return err
}
// 再删除账号
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
}
func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "")
}
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
var accounts []model.Account
var total int64
db := r.db.WithContext(ctx).Model(&model.Account{})
// Apply filters
if platform != "" {
db = db.Where("platform = ?", platform)
}
if accountType != "" {
db = db.Where("type = ?", accountType)
}
if status != "" {
db = db.Where("status = ?", status)
}
if search != "" {
searchPattern := "%" + search + "%"
db = db.Where("name ILIKE ?", searchPattern)
}
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
}
if err := db.Preload("Proxy").Preload("AccountGroups.Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
return nil, nil, err
}
// 填充每个 Account 的虚拟字段GroupIDs 和 Groups
for i := range accounts {
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups))
accounts[i].Groups = make([]*model.Group, 0, len(accounts[i].AccountGroups))
for _, ag := range accounts[i].AccountGroups {
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
if ag.Group != nil {
accounts[i].Groups = append(accounts[i].Groups, ag.Group)
}
}
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return accounts, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, model.StatusActive).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
}
func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
}
func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error
}
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]any{
"status": model.StatusError,
"error_message": errorMsg,
}).Error
}
func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &model.AccountGroup{
AccountID: accountID,
GroupID: groupID,
Priority: priority,
}
return r.db.WithContext(ctx).Create(ag).Error
}
func (r *AccountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
Delete(&model.AccountGroup{}).Error
}
func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
var groups []model.Group
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.group_id = groups.id").
Where("account_groups.account_id = ?", accountID).
Find(&groups).Error
return groups, err
}
func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, model.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
}
func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
// 删除现有绑定
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
return err
}
// 添加新绑定
if len(groupIDs) > 0 {
accountGroups := make([]model.AccountGroup, 0, len(groupIDs))
for i, groupID := range groupIDs {
accountGroups = append(accountGroups, model.AccountGroup{
AccountID: accountID,
GroupID: groupID,
Priority: i + 1, // 使用索引作为优先级
})
}
return r.db.WithContext(ctx).Create(&accountGroups).Error
}
return nil
}
// ListSchedulable 获取所有可调度的账号
func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
now := time.Now()
err := r.db.WithContext(ctx).
Where("status = ? AND schedulable = ?", model.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
}
// ListSchedulableByGroupID 按组获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
}
// ListSchedulableByPlatform 按平台获取可调度的账号
func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account
now := time.Now()
err := r.db.WithContext(ctx).
Where("platform = ?", platform).
Where("status = ? AND schedulable = ?", model.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
}
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
var accounts []model.Account
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.platform = ?", platform).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
}
// SetRateLimited 标记账号为限流状态(429)
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]any{
"rate_limited_at": now,
"rate_limit_reset_at": resetAt,
}).Error
}
// SetOverloaded 标记账号为过载状态(529)
func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("overload_until", until).Error
}
// ClearRateLimit 清除账号的限流状态
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]any{
"rate_limited_at": nil,
"rate_limit_reset_at": nil,
"overload_until": nil,
}).Error
}
// UpdateSessionWindow 更新账号的5小时时间窗口信息
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]any{
"session_window_status": status,
}
if start != nil {
updates["session_window_start"] = start
}
if end != nil {
updates["session_window_end"] = end
}
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Updates(updates).Error
}
// SetSchedulable 设置账号的调度开关
func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("schedulable", schedulable).Error
}
// UpdateExtra updates specific fields in account's Extra JSONB field
// It merges the updates into existing Extra data without overwriting other fields
func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 {
return nil
}
// Get current account to preserve existing Extra data
var account model.Account
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil {
return err
}
// Initialize Extra if nil
if account.Extra == nil {
account.Extra = make(model.JSONB)
}
// Merge updates into existing Extra
for k, v := range updates {
account.Extra[k] = v
}
// Save updated Extra
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("extra", account.Extra).Error
}
// BulkUpdate updates multiple accounts with the provided fields.
// It merges credentials/extra JSONB fields instead of overwriting them.
func (r *AccountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
updateMap := map[string]any{}
if updates.Name != nil {
updateMap["name"] = *updates.Name
}
if updates.ProxyID != nil {
updateMap["proxy_id"] = updates.ProxyID
}
if updates.Concurrency != nil {
updateMap["concurrency"] = *updates.Concurrency
}
if updates.Priority != nil {
updateMap["priority"] = *updates.Priority
}
if updates.Status != nil {
updateMap["status"] = *updates.Status
}
if len(updates.Credentials) > 0 {
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", updates.Credentials)
}
if len(updates.Extra) > 0 {
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", updates.Extra)
}
if len(updateMap) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).
Model(&model.Account{}).
Where("id IN ?", ids).
Clauses(clause.Returning{}).
Updates(updateMap)
return result.RowsAffected, result.Error
}