538 lines
22 KiB
Python
538 lines
22 KiB
Python
import asyncio
|
||
import email
|
||
from dataclasses import dataclass
|
||
from email.header import decode_header, make_header
|
||
from typing import Dict, List, Optional
|
||
|
||
import aiohttp
|
||
from loguru import logger
|
||
import aiomysql
|
||
|
||
from core.config import Config
|
||
from core.database import DatabaseManager
|
||
from core.exceptions import EmailError
|
||
|
||
|
||
@dataclass
|
||
class EmailAccount:
|
||
id: int
|
||
email: str
|
||
password: str # 这里实际上是 refresh_token
|
||
client_id: str
|
||
refresh_token: str
|
||
in_use: bool = False
|
||
cursor_password: Optional[str] = None
|
||
cursor_cookie: Optional[str] = None
|
||
sold: bool = False
|
||
status: str = 'pending' # 新增状态字段: pending, unavailable, success
|
||
|
||
|
||
class EmailManager:
|
||
def __init__(self, config: Config, db_manager: DatabaseManager):
|
||
self.config = config
|
||
self.db = db_manager
|
||
self.verification_subjects = [
|
||
"Verify your email address",
|
||
"Complete code challenge",
|
||
]
|
||
# Redis相关配置
|
||
self.use_redis = False
|
||
if hasattr(self.db, 'redis') and self.db.redis:
|
||
self.use_redis = True
|
||
logger.info("Redis可用,将使用Redis进行邮箱状态管理")
|
||
else:
|
||
logger.warning("Redis不可用,将使用MySQL进行邮箱状态管理")
|
||
|
||
# Redis键前缀
|
||
self.redis_prefix = "emailmanager:"
|
||
# 锁和黑名单的过期时间(秒)
|
||
self.lock_timeout = 600 # 10分钟
|
||
self.blacklist_timeout = 86400 * 30 # 30天
|
||
# 黑名单初始化标记
|
||
self.blacklist_initialized = False
|
||
|
||
async def initialize(self):
|
||
"""初始化EmailManager,确保在使用前完成必要的设置"""
|
||
if self.use_redis:
|
||
await self._ensure_blacklist_initialized()
|
||
self.blacklist_initialized = True
|
||
|
||
async def _ensure_blacklist_initialized(self):
|
||
"""确保黑名单已经初始化"""
|
||
if not self.use_redis:
|
||
return False
|
||
|
||
blacklist_key = f"{self.redis_prefix}blacklist:initialized"
|
||
initialized = await self.db.redis.exists(blacklist_key)
|
||
|
||
if not initialized:
|
||
logger.info("初始化Redis邮箱黑名单...")
|
||
# 查询所有已成功或不可用的邮箱
|
||
query = """
|
||
SELECT DISTINCT email
|
||
FROM email_accounts
|
||
WHERE status = 'success' OR status = 'unavailable'
|
||
"""
|
||
results = await self.db.fetch_all(query)
|
||
|
||
if results:
|
||
# 批量添加到黑名单
|
||
blacklist_key = f"{self.redis_prefix}blacklist:emails"
|
||
emails = [row['email'] for row in results]
|
||
if emails:
|
||
pipeline = self.db.redis.pipeline()
|
||
for email in emails:
|
||
pipeline.sadd(blacklist_key, email)
|
||
pipeline.expire(blacklist_key, self.blacklist_timeout)
|
||
await pipeline.execute()
|
||
|
||
logger.info(f"已将 {len(emails)} 个邮箱添加到黑名单")
|
||
|
||
# 标记为已初始化
|
||
await self.db.redis.setex(f"{self.redis_prefix}blacklist:initialized", self.blacklist_timeout, "1")
|
||
return True
|
||
|
||
return True
|
||
|
||
async def is_email_blacklisted(self, email: str) -> bool:
|
||
"""检查邮箱是否在黑名单中"""
|
||
if not self.use_redis:
|
||
# 回退到数据库查询
|
||
query = """
|
||
SELECT 1 FROM email_accounts
|
||
WHERE email = %s AND (status = 'success' OR status = 'unavailable')
|
||
LIMIT 1
|
||
"""
|
||
result = await self.db.fetch_one(query, (email,))
|
||
return result is not None
|
||
|
||
# 使用Redis SET存储黑名单
|
||
blacklist_key = f"{self.redis_prefix}blacklist:emails"
|
||
return await self.db.redis.sismember(blacklist_key, email)
|
||
|
||
async def add_email_to_blacklist(self, email: str):
|
||
"""将邮箱添加到黑名单"""
|
||
if not self.use_redis:
|
||
return False
|
||
|
||
blacklist_key = f"{self.redis_prefix}blacklist:emails"
|
||
await self.db.redis.sadd(blacklist_key, email)
|
||
await self.db.redis.expire(blacklist_key, self.blacklist_timeout)
|
||
logger.debug(f"邮箱 {email} 已添加到黑名单")
|
||
return True
|
||
|
||
async def lock_account(self, account_id: int) -> bool:
|
||
"""锁定账号"""
|
||
if not self.use_redis:
|
||
# 如果不使用Redis,通过数据库更新来锁定
|
||
try:
|
||
query = """
|
||
UPDATE email_accounts
|
||
SET in_use = 1, updated_at = CURRENT_TIMESTAMP
|
||
WHERE id = %s AND in_use = 0
|
||
"""
|
||
affected = await self.db.execute(query, (account_id,))
|
||
return affected > 0
|
||
except Exception as e:
|
||
logger.error(f"通过数据库锁定账号 {account_id} 失败: {e}")
|
||
return False
|
||
|
||
# 使用Redis实现分布式锁
|
||
lock_key = f"{self.redis_prefix}lock:account:{account_id}"
|
||
locked = await self.db.redis.setnx(lock_key, "1")
|
||
if locked:
|
||
# 锁定成功,设置过期时间
|
||
await self.db.redis.expire(lock_key, self.lock_timeout)
|
||
|
||
# 同时更新数据库状态
|
||
try:
|
||
update_query = """
|
||
UPDATE email_accounts
|
||
SET in_use = 1, updated_at = CURRENT_TIMESTAMP
|
||
WHERE id = %s
|
||
"""
|
||
await self.db.execute(update_query, (account_id,))
|
||
except Exception as e:
|
||
# 如果数据库更新失败,释放Redis锁
|
||
logger.error(f"锁定账号 {account_id} 后更新数据库失败: {e}")
|
||
await self.db.redis.delete(lock_key)
|
||
return False
|
||
|
||
logger.debug(f"账号 {account_id} 已锁定")
|
||
return locked
|
||
|
||
async def unlock_account(self, account_id: int) -> bool:
|
||
"""解锁账号"""
|
||
# 无论是否使用Redis,都更新数据库
|
||
try:
|
||
update_query = """
|
||
UPDATE email_accounts
|
||
SET in_use = 0, updated_at = CURRENT_TIMESTAMP
|
||
WHERE id = %s
|
||
"""
|
||
await self.db.execute(update_query, (account_id,))
|
||
except Exception as e:
|
||
logger.error(f"解锁账号 {account_id} 更新数据库失败: {e}")
|
||
|
||
if not self.use_redis:
|
||
return True
|
||
|
||
# 删除Redis锁
|
||
lock_key = f"{self.redis_prefix}lock:account:{account_id}"
|
||
deleted = await self.db.redis.delete(lock_key)
|
||
logger.debug(f"账号 {account_id} 锁已释放")
|
||
return deleted > 0
|
||
|
||
async def batch_get_accounts(self, num: int) -> List[EmailAccount]:
|
||
"""批量获取未使用的邮箱账号"""
|
||
logger.info(f"尝试获取 {num} 个未使用的邮箱账号")
|
||
|
||
# 如果使用Redis,确保黑名单已初始化
|
||
if self.use_redis:
|
||
await self._ensure_blacklist_initialized()
|
||
|
||
# 1. 先从数据库中获取候选账号
|
||
select_query = """
|
||
SELECT id, email, password, client_id, refresh_token
|
||
FROM email_accounts
|
||
WHERE in_use = 0 AND sold = 0 AND status = 'pending'
|
||
LIMIT %s
|
||
"""
|
||
# 多获取一些候选账号,防止有些被排除
|
||
candidate_accounts = await self.db.fetch_all(select_query, (num * 2,))
|
||
|
||
if not candidate_accounts:
|
||
logger.debug("没有找到符合条件的候选账号")
|
||
return []
|
||
|
||
logger.debug(f"找到 {len(candidate_accounts)} 个候选账号")
|
||
|
||
# 2. 筛选并锁定账号
|
||
result_accounts = []
|
||
for account in candidate_accounts:
|
||
# 检查邮箱是否在黑名单中
|
||
if await self.is_email_blacklisted(account['email']):
|
||
logger.debug(f"邮箱 {account['email']} 在黑名单中,跳过")
|
||
continue
|
||
|
||
# 尝试锁定账号
|
||
if await self.lock_account(account['id']):
|
||
# 添加到结果列表
|
||
result_accounts.append(EmailAccount(
|
||
id=account['id'],
|
||
email=account['email'],
|
||
password=account['password'],
|
||
client_id=account['client_id'],
|
||
refresh_token=account['refresh_token'],
|
||
in_use=True
|
||
))
|
||
|
||
# 如果已经获取足够的账号,退出循环
|
||
if len(result_accounts) >= num:
|
||
break
|
||
else:
|
||
logger.debug(f"账号 {account['id']} 锁定失败,可能被其他进程使用")
|
||
|
||
logger.info(f"实际获取到 {len(result_accounts)} 个可用账号")
|
||
|
||
# 如果账号数量不足,尝试清理长时间锁定但未更新的账号
|
||
if len(result_accounts) < num and len(result_accounts) < len(candidate_accounts):
|
||
logger.warning("可用账号不足,尝试清理长时间锁定的账号")
|
||
await self._cleanup_stuck_accounts()
|
||
|
||
return result_accounts
|
||
|
||
async def _cleanup_stuck_accounts(self):
|
||
"""清理长时间锁定但未更新的账号"""
|
||
try:
|
||
# 清理超过30分钟未更新且仍标记为in_use=1的账号
|
||
cleanup_query = """
|
||
UPDATE email_accounts
|
||
SET in_use = 0
|
||
WHERE in_use = 1
|
||
AND updated_at < DATE_SUB(NOW(), INTERVAL 30 MINUTE)
|
||
"""
|
||
affected = await self.db.execute(cleanup_query)
|
||
if affected > 0:
|
||
logger.info(f"已清理 {affected} 个长时间锁定的账号")
|
||
|
||
# 如果使用Redis,同时清理对应的锁
|
||
if self.use_redis:
|
||
# 清理可能存在的Redis锁,但这需要知道具体的account_id
|
||
# 这里简化处理,依赖锁的自动过期机制
|
||
pass
|
||
except Exception as e:
|
||
logger.error(f"清理账号时出错: {e}")
|
||
|
||
async def update_account_status(self, account_id: int, status: str):
|
||
"""更新账号状态"""
|
||
try:
|
||
# 获取账号邮箱信息(用于黑名单)
|
||
account_query = "SELECT email, status as current_status FROM email_accounts WHERE id = %s"
|
||
account_info = await self.db.fetch_one(account_query, (account_id,))
|
||
|
||
if not account_info:
|
||
logger.error(f"账号 {account_id} 不存在")
|
||
return
|
||
|
||
# 检查状态变更是否合理
|
||
current_status = account_info.get('current_status')
|
||
if current_status == 'success' and status != 'success':
|
||
logger.warning(f"警告: 尝试将成功账号 {account_id} 状态改为 {status},这可能是不正确的操作")
|
||
# 如果已经是success状态,不允许降级为其他状态
|
||
# return
|
||
|
||
# 更新数据库状态
|
||
query = '''
|
||
UPDATE email_accounts
|
||
SET
|
||
status = %s,
|
||
in_use = 0,
|
||
updated_at = CURRENT_TIMESTAMP
|
||
WHERE id = %s
|
||
'''
|
||
await self.db.execute(query, (status, account_id))
|
||
logger.info(f"账号 {account_id} 状态已更新为 {status}")
|
||
|
||
# 如果是success或unavailable状态,添加到黑名单
|
||
if account_info and (status == 'success' or status == 'unavailable'):
|
||
email = account_info['email']
|
||
logger.debug(f"将邮箱 {email} 添加到黑名单 (状态: {status})")
|
||
await self.add_email_to_blacklist(email)
|
||
|
||
# 解锁账号
|
||
await self.unlock_account(account_id)
|
||
|
||
# 清除数据库缓存
|
||
if self.db.redis:
|
||
await self.db.clear_cache("db:*email_accounts*")
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新账号 {account_id} 状态为 {status} 时出错: {e}")
|
||
# 确保无论如何都解锁账号
|
||
try:
|
||
await self.unlock_account(account_id)
|
||
except Exception as unlock_error:
|
||
logger.error(f"尝试解锁账号 {account_id} 时出错: {unlock_error}")
|
||
raise
|
||
|
||
async def update_account(self, account_id: int, cursor_password: str, cursor_cookie: str, cursor_token: str):
|
||
"""更新账号信息(注册成功)"""
|
||
try:
|
||
# 获取账号邮箱信息(用于黑名单)
|
||
account_query = "SELECT email FROM email_accounts WHERE id = %s"
|
||
account_info = await self.db.fetch_one(account_query, (account_id,))
|
||
|
||
# 更新数据库
|
||
query = '''
|
||
UPDATE email_accounts
|
||
SET
|
||
cursor_password = %s,
|
||
cursor_cookie = %s,
|
||
cursor_token = %s,
|
||
in_use = 0,
|
||
sold = 1,
|
||
status = 'success',
|
||
updated_at = CURRENT_TIMESTAMP
|
||
WHERE id = %s
|
||
'''
|
||
await self.db.execute(query, (cursor_password, cursor_cookie, cursor_token, account_id))
|
||
logger.info(f"账号 {account_id} 更新为注册成功")
|
||
|
||
# 添加到黑名单
|
||
if account_info:
|
||
email = account_info['email']
|
||
logger.debug(f"将邮箱 {email} 添加到黑名单 (注册成功)")
|
||
await self.add_email_to_blacklist(email)
|
||
|
||
# 解锁账号
|
||
await self.unlock_account(account_id)
|
||
|
||
# 清除数据库缓存
|
||
if self.db.redis:
|
||
await self.db.clear_cache("db:*email_accounts*")
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新账号 {account_id} 信息时出错: {e}")
|
||
# 确保无论如何都解锁账号
|
||
try:
|
||
await self.unlock_account(account_id)
|
||
except:
|
||
pass
|
||
raise
|
||
|
||
async def release_account(self, account_id: int):
|
||
"""释放账号"""
|
||
try:
|
||
await self.unlock_account(account_id)
|
||
logger.info(f"账号 {account_id} 已释放")
|
||
|
||
# 清除数据库缓存
|
||
if self.db.redis:
|
||
await self.db.clear_cache("db:*email_accounts*")
|
||
|
||
except Exception as e:
|
||
logger.error(f"释放账号 {account_id} 时出错: {e}")
|
||
raise
|
||
|
||
async def count_pending_accounts(self) -> int:
|
||
"""统计可用的pending状态账号数量"""
|
||
if self.use_redis:
|
||
await self._ensure_blacklist_initialized()
|
||
|
||
query = """
|
||
SELECT COUNT(*)
|
||
FROM email_accounts
|
||
WHERE status = 'pending' AND in_use = 0 AND sold = 0
|
||
"""
|
||
|
||
# 注:这里不使用黑名单过滤,因为数据量可能很大,
|
||
# 但实际获取账号时会应用黑名单过滤
|
||
|
||
result = await self.db.fetch_one(query)
|
||
if result:
|
||
return result.get("COUNT(*)", 0)
|
||
return 0
|
||
|
||
async def _get_access_token(self, client_id: str, refresh_token: str) -> str:
|
||
"""获取微软 access token"""
|
||
logger.debug(f"开始获取 access token - client_id: {client_id}")
|
||
|
||
url = 'https://login.microsoftonline.com/common/oauth2/v2.0/token'
|
||
data = {
|
||
'client_id': client_id,
|
||
'grant_type': 'refresh_token',
|
||
'refresh_token': refresh_token,
|
||
}
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(url, data=data) as response:
|
||
result = await response.json()
|
||
|
||
if 'error' in result:
|
||
error = result.get('error')
|
||
logger.error(f"获取 access token 失败: {error}")
|
||
raise EmailError(f"Failed to get access token: {error}")
|
||
|
||
access_token = result['access_token']
|
||
logger.debug("成功获取 access token")
|
||
return access_token
|
||
|
||
async def get_verification_code(self, email: str, refresh_token: str, client_id: str) -> str:
|
||
"""获取验证码"""
|
||
logger.info(f"开始获取邮箱验证码 - {email}")
|
||
try:
|
||
# 1. 获取 access token
|
||
access_token = await self._get_access_token(client_id, refresh_token)
|
||
logger.debug(f"[{email}] 获取 access token 成功")
|
||
|
||
# 2. 构建认证字符串
|
||
auth_string = f"user={email}\1auth=Bearer {access_token}\1\1"
|
||
logger.debug(f"[{email}] 认证字符串构建完成")
|
||
|
||
# 3. 连接邮箱
|
||
import imaplib
|
||
mail = imaplib.IMAP4_SSL('outlook.live.com')
|
||
mail.authenticate('XOAUTH2', lambda x: auth_string)
|
||
mail.select('inbox')
|
||
logger.debug(f"[{email}] 邮箱连接成功")
|
||
|
||
# 4. 等待并获取验证码邮件
|
||
for i in range(15):
|
||
logger.debug(f"[{email}] 第 {i + 1} 次尝试获取验证码")
|
||
|
||
# 搜索来自 no-reply@cursor.sh 的最新邮件
|
||
result, data = mail.search(None, '(FROM "no-reply@cursor.sh")')
|
||
if result != "OK" or not data[0]:
|
||
logger.debug(f"[{email}] 未找到来自 cursor 的邮件,等待1秒后重试")
|
||
await asyncio.sleep(1)
|
||
continue
|
||
|
||
mail_ids = data[0].split()
|
||
if not mail_ids:
|
||
logger.debug(f"[{email}] 邮件ID列表为空,等待1秒后重试")
|
||
await asyncio.sleep(1)
|
||
continue
|
||
|
||
# 获取最新的3封邮件
|
||
last_mail_ids = sorted(mail_ids, reverse=True)[:3]
|
||
|
||
for mail_id in last_mail_ids:
|
||
result, msg_data = mail.fetch(mail_id, "(RFC822)")
|
||
if result != 'OK':
|
||
logger.warning(f"[{email}] 获取邮件内容失败: {result}")
|
||
continue
|
||
|
||
# 确保 msg_data 不为空且格式正确
|
||
if not msg_data or not msg_data[0] or len(msg_data[0]) < 2:
|
||
logger.warning(f"[{email}] 邮件数据格式不正确")
|
||
continue
|
||
|
||
# 正确导入 email 模块
|
||
from email import message_from_bytes
|
||
email_message = message_from_bytes(msg_data[0][1])
|
||
|
||
# 检查发件人
|
||
from_addr = str(make_header(decode_header(email_message['From'])))
|
||
if 'no-reply@cursor.sh' not in from_addr:
|
||
logger.debug(f"[{email}] 跳过非 Cursor 邮件,发件人: {from_addr}")
|
||
continue
|
||
|
||
# 检查主题
|
||
subject = str(make_header(decode_header(email_message['SUBJECT'])))
|
||
if not any(verify_subject in subject for verify_subject in self.verification_subjects):
|
||
logger.debug(f"[{email}] 跳过非验证码邮件,主题: {subject}")
|
||
continue
|
||
|
||
code = self._extract_code_from_email(email_message)
|
||
if code:
|
||
logger.debug(f"[{email}] 成功获取验证码: {code}")
|
||
mail.close()
|
||
mail.logout()
|
||
return code
|
||
|
||
await asyncio.sleep(1)
|
||
|
||
logger.error(f"[{email}] 验证码邮件未收到")
|
||
raise EmailError("Verification code not received")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[{email}] 获取验证码失败: {str(e)}")
|
||
raise EmailError(f"Failed to get verification code: {str(e)}")
|
||
|
||
def _extract_code_from_email(self, email_message) -> Optional[str]:
|
||
"""从邮件内容中提取验证码"""
|
||
try:
|
||
# 获取邮件内容
|
||
if email_message.is_multipart():
|
||
for part in email_message.walk():
|
||
if part.get_content_type() == "text/html":
|
||
body = part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||
break
|
||
else:
|
||
body = email_message.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||
|
||
# 提取6位数字验证码
|
||
import re
|
||
|
||
# 在HTML中查找包含6位数字的div
|
||
match = re.search(r'<div[^>]*>(\d{6})</div>', body)
|
||
if match:
|
||
code = match.group(1)
|
||
logger.debug(f"从HTML中提取到验证码: {code}")
|
||
return code
|
||
|
||
# 备用方案:搜索任何6位数字
|
||
match = re.search(r'\b\d{6}\b', body)
|
||
if match:
|
||
code = match.group(0)
|
||
logger.debug(f"从文本中提取到验证码: {code}")
|
||
return code
|
||
|
||
logger.warning(f"未能从邮件中提取到验证码")
|
||
logger.debug(f"邮件内容预览: " + body[:200])
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"提取验证码失败: {str(e)}")
|
||
return None |