package service import ( "context" "fmt" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) var ( ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found") ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil") ) type AccountRepository interface { Create(ctx context.Context, account *Account) error GetByID(ctx context.Context, id int64) (*Account, error) // GetByIDs fetches accounts by IDs in a single query. // It should return all accounts found (missing IDs are ignored). GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) // ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查 ExistsByID(ctx context.Context, id int64) (bool, error) // GetByCRSAccountID finds an account previously synced from CRS. // Returns (nil, nil) if not found. GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) // ListCRSAccountIDs returns a map of crs_account_id -> local account ID // for all accounts that have been synced from CRS. ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) Update(ctx context.Context, account *Account) error Delete(ctx context.Context, id int64) error List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) ListActive(ctx context.Context) ([]Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error) UpdateLastUsed(ctx context.Context, id int64) error BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error SetError(ctx context.Context, id int64, errorMsg string) error ClearError(ctx context.Context, id int64) error SetSchedulable(ctx context.Context, id int64, schedulable bool) error AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error ListSchedulable(ctx context.Context) ([]Account, error) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error ClearTempUnschedulable(ctx context.Context, id int64) error ClearRateLimit(ctx context.Context, id int64) error ClearAntigravityQuotaScopes(ctx context.Context, id int64) error ClearModelRateLimits(ctx context.Context, id int64) error UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error UpdateExtra(ctx context.Context, id int64, updates map[string]any) error BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) } // AccountBulkUpdate describes the fields that can be updated in a bulk operation. // Nil pointers mean "do not change". type AccountBulkUpdate struct { Name *string ProxyID *int64 Concurrency *int Priority *int RateMultiplier *float64 Status *string Schedulable *bool Credentials map[string]any Extra map[string]any } // CreateAccountRequest 创建账号请求 type CreateAccountRequest struct { Name string `json:"name"` Notes *string `json:"notes"` Platform string `json:"platform"` Type string `json:"type"` Credentials map[string]any `json:"credentials"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` Concurrency int `json:"concurrency"` Priority int `json:"priority"` GroupIDs []int64 `json:"group_ids"` ExpiresAt *time.Time `json:"expires_at"` AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` } // UpdateAccountRequest 更新账号请求 type UpdateAccountRequest struct { Name *string `json:"name"` Notes *string `json:"notes"` Credentials *map[string]any `json:"credentials"` Extra *map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` Concurrency *int `json:"concurrency"` Priority *int `json:"priority"` Status *string `json:"status"` GroupIDs *[]int64 `json:"group_ids"` ExpiresAt *time.Time `json:"expires_at"` AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` } // AccountService 账号管理服务 type AccountService struct { accountRepo AccountRepository groupRepo GroupRepository } // NewAccountService 创建账号服务实例 func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService { return &AccountService{ accountRepo: accountRepo, groupRepo: groupRepo, } } // Create 创建账号 func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) { // 验证分组是否存在(如果指定了分组) if len(req.GroupIDs) > 0 { for _, groupID := range req.GroupIDs { _, err := s.groupRepo.GetByID(ctx, groupID) if err != nil { return nil, fmt.Errorf("get group: %w", err) } } } // 创建账号 account := &Account{ Name: req.Name, Notes: normalizeAccountNotes(req.Notes), Platform: req.Platform, Type: req.Type, Credentials: req.Credentials, Extra: req.Extra, ProxyID: req.ProxyID, Concurrency: req.Concurrency, Priority: req.Priority, Status: StatusActive, ExpiresAt: req.ExpiresAt, } if req.AutoPauseOnExpired != nil { account.AutoPauseOnExpired = *req.AutoPauseOnExpired } else { account.AutoPauseOnExpired = true } if err := s.accountRepo.Create(ctx, account); err != nil { return nil, fmt.Errorf("create account: %w", err) } // 绑定分组 if len(req.GroupIDs) > 0 { if err := s.accountRepo.BindGroups(ctx, account.ID, req.GroupIDs); err != nil { return nil, fmt.Errorf("bind groups: %w", err) } } return account, nil } // GetByID 根据ID获取账号 func (s *AccountService) GetByID(ctx context.Context, id int64) (*Account, error) { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get account: %w", err) } return account, nil } // List 获取账号列表 func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { accounts, pagination, err := s.accountRepo.List(ctx, params) if err != nil { return nil, nil, fmt.Errorf("list accounts: %w", err) } return accounts, pagination, nil } // ListByPlatform 根据平台获取账号列表 func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]Account, error) { accounts, err := s.accountRepo.ListByPlatform(ctx, platform) if err != nil { return nil, fmt.Errorf("list accounts by platform: %w", err) } return accounts, nil } // ListByGroup 根据分组获取账号列表 func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { accounts, err := s.accountRepo.ListByGroup(ctx, groupID) if err != nil { return nil, fmt.Errorf("list accounts by group: %w", err) } return accounts, nil } // Update 更新账号 func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*Account, error) { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get account: %w", err) } // 更新字段 if req.Name != nil { account.Name = *req.Name } if req.Notes != nil { account.Notes = normalizeAccountNotes(req.Notes) } if req.Credentials != nil { account.Credentials = *req.Credentials } if req.Extra != nil { account.Extra = *req.Extra } if req.ProxyID != nil { account.ProxyID = req.ProxyID } if req.Concurrency != nil { account.Concurrency = *req.Concurrency } if req.Priority != nil { account.Priority = *req.Priority } if req.Status != nil { account.Status = *req.Status } if req.ExpiresAt != nil { account.ExpiresAt = req.ExpiresAt } if req.AutoPauseOnExpired != nil { account.AutoPauseOnExpired = *req.AutoPauseOnExpired } // 先验证分组是否存在(在任何写操作之前) if req.GroupIDs != nil { for _, groupID := range *req.GroupIDs { _, err := s.groupRepo.GetByID(ctx, groupID) if err != nil { return nil, fmt.Errorf("get group: %w", err) } } } // 执行更新 if err := s.accountRepo.Update(ctx, account); err != nil { return nil, fmt.Errorf("update account: %w", err) } // 绑定分组 if req.GroupIDs != nil { if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil { return nil, fmt.Errorf("bind groups: %w", err) } } return account, nil } // Delete 删除账号 // 优化:使用 ExistsByID 替代 GetByID 进行存在性检查, // 避免加载完整账号对象及其关联数据,提升删除操作的性能 func (s *AccountService) Delete(ctx context.Context, id int64) error { // 使用轻量级的存在性检查,而非加载完整账号对象 exists, err := s.accountRepo.ExistsByID(ctx, id) if err != nil { return fmt.Errorf("check account: %w", err) } // 明确返回账号不存在错误,便于调用方区分错误类型 if !exists { return ErrAccountNotFound } if err := s.accountRepo.Delete(ctx, id); err != nil { return fmt.Errorf("delete account: %w", err) } return nil } // UpdateStatus 更新账号状态 func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { return fmt.Errorf("get account: %w", err) } account.Status = status account.ErrorMessage = errorMessage if err := s.accountRepo.Update(ctx, account); err != nil { return fmt.Errorf("update account: %w", err) } return nil } // UpdateLastUsed 更新最后使用时间 func (s *AccountService) UpdateLastUsed(ctx context.Context, id int64) error { if err := s.accountRepo.UpdateLastUsed(ctx, id); err != nil { return fmt.Errorf("update last used: %w", err) } return nil } // GetCredential 获取账号凭证(安全访问) func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { return "", fmt.Errorf("get account: %w", err) } return account.GetCredential(key), nil } // TestCredentials 测试账号凭证是否有效(需要实现具体平台的测试逻辑) func (s *AccountService) TestCredentials(ctx context.Context, id int64) error { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { return fmt.Errorf("get account: %w", err) } // 根据平台执行不同的测试逻辑 switch account.Platform { case PlatformAnthropic: // TODO: 测试Anthropic API凭证 return nil case PlatformOpenAI: // TODO: 测试OpenAI API凭证 return nil case PlatformGemini: // TODO: 测试Gemini API凭证 return nil default: return fmt.Errorf("unsupported platform: %s", account.Platform) } }