Files
auto_cursor/services/email_manager.py
huangzhenpc d16f6bdc62 ssss
2025-04-02 09:28:25 +08:00

538 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import email
from dataclasses import dataclass
from email.header import decode_header, make_header
from typing import Dict, List, Optional
import aiohttp
from loguru import logger
import aiomysql
from core.config import Config
from core.database import DatabaseManager
from core.exceptions import EmailError
@dataclass
class EmailAccount:
id: int
email: str
password: str # 这里实际上是 refresh_token
client_id: str
refresh_token: str
in_use: bool = False
cursor_password: Optional[str] = None
cursor_cookie: Optional[str] = None
sold: bool = False
status: str = 'pending' # 新增状态字段: pending, unavailable, success
class EmailManager:
def __init__(self, config: Config, db_manager: DatabaseManager):
self.config = config
self.db = db_manager
self.verification_subjects = [
"Verify your email address",
"Complete code challenge",
]
# Redis相关配置
self.use_redis = False
if hasattr(self.db, 'redis') and self.db.redis:
self.use_redis = True
logger.info("Redis可用将使用Redis进行邮箱状态管理")
else:
logger.warning("Redis不可用将使用MySQL进行邮箱状态管理")
# Redis键前缀
self.redis_prefix = "emailmanager:"
# 锁和黑名单的过期时间(秒)
self.lock_timeout = 600 # 10分钟
self.blacklist_timeout = 86400 * 30 # 30天
# 黑名单初始化标记
self.blacklist_initialized = False
async def initialize(self):
"""初始化EmailManager确保在使用前完成必要的设置"""
if self.use_redis:
await self._ensure_blacklist_initialized()
self.blacklist_initialized = True
async def _ensure_blacklist_initialized(self):
"""确保黑名单已经初始化"""
if not self.use_redis:
return False
blacklist_key = f"{self.redis_prefix}blacklist:initialized"
initialized = await self.db.redis.exists(blacklist_key)
if not initialized:
logger.info("初始化Redis邮箱黑名单...")
# 查询所有已成功或不可用的邮箱
query = """
SELECT DISTINCT email
FROM email_accounts
WHERE status = 'success' OR status = 'unavailable'
"""
results = await self.db.fetch_all(query)
if results:
# 批量添加到黑名单
blacklist_key = f"{self.redis_prefix}blacklist:emails"
emails = [row['email'] for row in results]
if emails:
pipeline = self.db.redis.pipeline()
for email in emails:
pipeline.sadd(blacklist_key, email)
pipeline.expire(blacklist_key, self.blacklist_timeout)
await pipeline.execute()
logger.info(f"已将 {len(emails)} 个邮箱添加到黑名单")
# 标记为已初始化
await self.db.redis.setex(f"{self.redis_prefix}blacklist:initialized", self.blacklist_timeout, "1")
return True
return True
async def is_email_blacklisted(self, email: str) -> bool:
"""检查邮箱是否在黑名单中"""
if not self.use_redis:
# 回退到数据库查询
query = """
SELECT 1 FROM email_accounts
WHERE email = %s AND (status = 'success' OR status = 'unavailable')
LIMIT 1
"""
result = await self.db.fetch_one(query, (email,))
return result is not None
# 使用Redis SET存储黑名单
blacklist_key = f"{self.redis_prefix}blacklist:emails"
return await self.db.redis.sismember(blacklist_key, email)
async def add_email_to_blacklist(self, email: str):
"""将邮箱添加到黑名单"""
if not self.use_redis:
return False
blacklist_key = f"{self.redis_prefix}blacklist:emails"
await self.db.redis.sadd(blacklist_key, email)
await self.db.redis.expire(blacklist_key, self.blacklist_timeout)
logger.debug(f"邮箱 {email} 已添加到黑名单")
return True
async def lock_account(self, account_id: int) -> bool:
"""锁定账号"""
if not self.use_redis:
# 如果不使用Redis通过数据库更新来锁定
try:
query = """
UPDATE email_accounts
SET in_use = 1, updated_at = CURRENT_TIMESTAMP
WHERE id = %s AND in_use = 0
"""
affected = await self.db.execute(query, (account_id,))
return affected > 0
except Exception as e:
logger.error(f"通过数据库锁定账号 {account_id} 失败: {e}")
return False
# 使用Redis实现分布式锁
lock_key = f"{self.redis_prefix}lock:account:{account_id}"
locked = await self.db.redis.setnx(lock_key, "1")
if locked:
# 锁定成功,设置过期时间
await self.db.redis.expire(lock_key, self.lock_timeout)
# 同时更新数据库状态
try:
update_query = """
UPDATE email_accounts
SET in_use = 1, updated_at = CURRENT_TIMESTAMP
WHERE id = %s
"""
await self.db.execute(update_query, (account_id,))
except Exception as e:
# 如果数据库更新失败释放Redis锁
logger.error(f"锁定账号 {account_id} 后更新数据库失败: {e}")
await self.db.redis.delete(lock_key)
return False
logger.debug(f"账号 {account_id} 已锁定")
return locked
async def unlock_account(self, account_id: int) -> bool:
"""解锁账号"""
# 无论是否使用Redis都更新数据库
try:
update_query = """
UPDATE email_accounts
SET in_use = 0, updated_at = CURRENT_TIMESTAMP
WHERE id = %s
"""
await self.db.execute(update_query, (account_id,))
except Exception as e:
logger.error(f"解锁账号 {account_id} 更新数据库失败: {e}")
if not self.use_redis:
return True
# 删除Redis锁
lock_key = f"{self.redis_prefix}lock:account:{account_id}"
deleted = await self.db.redis.delete(lock_key)
logger.debug(f"账号 {account_id} 锁已释放")
return deleted > 0
async def batch_get_accounts(self, num: int) -> List[EmailAccount]:
"""批量获取未使用的邮箱账号"""
logger.info(f"尝试获取 {num} 个未使用的邮箱账号")
# 如果使用Redis确保黑名单已初始化
if self.use_redis:
await self._ensure_blacklist_initialized()
# 1. 先从数据库中获取候选账号
select_query = """
SELECT id, email, password, client_id, refresh_token
FROM email_accounts
WHERE in_use = 0 AND sold = 0 AND status = 'pending'
LIMIT %s
"""
# 多获取一些候选账号,防止有些被排除
candidate_accounts = await self.db.fetch_all(select_query, (num * 2,))
if not candidate_accounts:
logger.debug("没有找到符合条件的候选账号")
return []
logger.debug(f"找到 {len(candidate_accounts)} 个候选账号")
# 2. 筛选并锁定账号
result_accounts = []
for account in candidate_accounts:
# 检查邮箱是否在黑名单中
if await self.is_email_blacklisted(account['email']):
logger.debug(f"邮箱 {account['email']} 在黑名单中,跳过")
continue
# 尝试锁定账号
if await self.lock_account(account['id']):
# 添加到结果列表
result_accounts.append(EmailAccount(
id=account['id'],
email=account['email'],
password=account['password'],
client_id=account['client_id'],
refresh_token=account['refresh_token'],
in_use=True
))
# 如果已经获取足够的账号,退出循环
if len(result_accounts) >= num:
break
else:
logger.debug(f"账号 {account['id']} 锁定失败,可能被其他进程使用")
logger.info(f"实际获取到 {len(result_accounts)} 个可用账号")
# 如果账号数量不足,尝试清理长时间锁定但未更新的账号
if len(result_accounts) < num and len(result_accounts) < len(candidate_accounts):
logger.warning("可用账号不足,尝试清理长时间锁定的账号")
await self._cleanup_stuck_accounts()
return result_accounts
async def _cleanup_stuck_accounts(self):
"""清理长时间锁定但未更新的账号"""
try:
# 清理超过30分钟未更新且仍标记为in_use=1的账号
cleanup_query = """
UPDATE email_accounts
SET in_use = 0
WHERE in_use = 1
AND updated_at < DATE_SUB(NOW(), INTERVAL 30 MINUTE)
"""
affected = await self.db.execute(cleanup_query)
if affected > 0:
logger.info(f"已清理 {affected} 个长时间锁定的账号")
# 如果使用Redis同时清理对应的锁
if self.use_redis:
# 清理可能存在的Redis锁但这需要知道具体的account_id
# 这里简化处理,依赖锁的自动过期机制
pass
except Exception as e:
logger.error(f"清理账号时出错: {e}")
async def update_account_status(self, account_id: int, status: str):
"""更新账号状态"""
try:
# 获取账号邮箱信息(用于黑名单)
account_query = "SELECT email, status as current_status FROM email_accounts WHERE id = %s"
account_info = await self.db.fetch_one(account_query, (account_id,))
if not account_info:
logger.error(f"账号 {account_id} 不存在")
return
# 检查状态变更是否合理
current_status = account_info.get('current_status')
if current_status == 'success' and status != 'success':
logger.warning(f"警告: 尝试将成功账号 {account_id} 状态改为 {status},这可能是不正确的操作")
# 如果已经是success状态不允许降级为其他状态
# return
# 更新数据库状态
query = '''
UPDATE email_accounts
SET
status = %s,
in_use = 0,
updated_at = CURRENT_TIMESTAMP
WHERE id = %s
'''
await self.db.execute(query, (status, account_id))
logger.info(f"账号 {account_id} 状态已更新为 {status}")
# 如果是success或unavailable状态添加到黑名单
if account_info and (status == 'success' or status == 'unavailable'):
email = account_info['email']
logger.debug(f"将邮箱 {email} 添加到黑名单 (状态: {status})")
await self.add_email_to_blacklist(email)
# 解锁账号
await self.unlock_account(account_id)
# 清除数据库缓存
if self.db.redis:
await self.db.clear_cache("db:*email_accounts*")
except Exception as e:
logger.error(f"更新账号 {account_id} 状态为 {status} 时出错: {e}")
# 确保无论如何都解锁账号
try:
await self.unlock_account(account_id)
except Exception as unlock_error:
logger.error(f"尝试解锁账号 {account_id} 时出错: {unlock_error}")
raise
async def update_account(self, account_id: int, cursor_password: str, cursor_cookie: str, cursor_token: str):
"""更新账号信息(注册成功)"""
try:
# 获取账号邮箱信息(用于黑名单)
account_query = "SELECT email FROM email_accounts WHERE id = %s"
account_info = await self.db.fetch_one(account_query, (account_id,))
# 更新数据库
query = '''
UPDATE email_accounts
SET
cursor_password = %s,
cursor_cookie = %s,
cursor_token = %s,
in_use = 0,
sold = 1,
status = 'success',
updated_at = CURRENT_TIMESTAMP
WHERE id = %s
'''
await self.db.execute(query, (cursor_password, cursor_cookie, cursor_token, account_id))
logger.info(f"账号 {account_id} 更新为注册成功")
# 添加到黑名单
if account_info:
email = account_info['email']
logger.debug(f"将邮箱 {email} 添加到黑名单 (注册成功)")
await self.add_email_to_blacklist(email)
# 解锁账号
await self.unlock_account(account_id)
# 清除数据库缓存
if self.db.redis:
await self.db.clear_cache("db:*email_accounts*")
except Exception as e:
logger.error(f"更新账号 {account_id} 信息时出错: {e}")
# 确保无论如何都解锁账号
try:
await self.unlock_account(account_id)
except:
pass
raise
async def release_account(self, account_id: int):
"""释放账号"""
try:
await self.unlock_account(account_id)
logger.info(f"账号 {account_id} 已释放")
# 清除数据库缓存
if self.db.redis:
await self.db.clear_cache("db:*email_accounts*")
except Exception as e:
logger.error(f"释放账号 {account_id} 时出错: {e}")
raise
async def count_pending_accounts(self) -> int:
"""统计可用的pending状态账号数量"""
if self.use_redis:
await self._ensure_blacklist_initialized()
query = """
SELECT COUNT(*)
FROM email_accounts
WHERE status = 'pending' AND in_use = 0 AND sold = 0
"""
# 注:这里不使用黑名单过滤,因为数据量可能很大,
# 但实际获取账号时会应用黑名单过滤
result = await self.db.fetch_one(query)
if result:
return result.get("COUNT(*)", 0)
return 0
async def _get_access_token(self, client_id: str, refresh_token: str) -> str:
"""获取微软 access token"""
logger.debug(f"开始获取 access token - client_id: {client_id}")
url = 'https://login.microsoftonline.com/common/oauth2/v2.0/token'
data = {
'client_id': client_id,
'grant_type': 'refresh_token',
'refresh_token': refresh_token,
}
async with aiohttp.ClientSession() as session:
async with session.post(url, data=data) as response:
result = await response.json()
if 'error' in result:
error = result.get('error')
logger.error(f"获取 access token 失败: {error}")
raise EmailError(f"Failed to get access token: {error}")
access_token = result['access_token']
logger.debug("成功获取 access token")
return access_token
async def get_verification_code(self, email: str, refresh_token: str, client_id: str) -> str:
"""获取验证码"""
logger.info(f"开始获取邮箱验证码 - {email}")
try:
# 1. 获取 access token
access_token = await self._get_access_token(client_id, refresh_token)
logger.debug(f"[{email}] 获取 access token 成功")
# 2. 构建认证字符串
auth_string = f"user={email}\1auth=Bearer {access_token}\1\1"
logger.debug(f"[{email}] 认证字符串构建完成")
# 3. 连接邮箱
import imaplib
mail = imaplib.IMAP4_SSL('outlook.live.com')
mail.authenticate('XOAUTH2', lambda x: auth_string)
mail.select('inbox')
logger.debug(f"[{email}] 邮箱连接成功")
# 4. 等待并获取验证码邮件
for i in range(15):
logger.debug(f"[{email}] 第 {i + 1} 次尝试获取验证码")
# 搜索来自 no-reply@cursor.sh 的最新邮件
result, data = mail.search(None, '(FROM "no-reply@cursor.sh")')
if result != "OK" or not data[0]:
logger.debug(f"[{email}] 未找到来自 cursor 的邮件等待1秒后重试")
await asyncio.sleep(1)
continue
mail_ids = data[0].split()
if not mail_ids:
logger.debug(f"[{email}] 邮件ID列表为空等待1秒后重试")
await asyncio.sleep(1)
continue
# 获取最新的3封邮件
last_mail_ids = sorted(mail_ids, reverse=True)[:3]
for mail_id in last_mail_ids:
result, msg_data = mail.fetch(mail_id, "(RFC822)")
if result != 'OK':
logger.warning(f"[{email}] 获取邮件内容失败: {result}")
continue
# 确保 msg_data 不为空且格式正确
if not msg_data or not msg_data[0] or len(msg_data[0]) < 2:
logger.warning(f"[{email}] 邮件数据格式不正确")
continue
# 正确导入 email 模块
from email import message_from_bytes
email_message = message_from_bytes(msg_data[0][1])
# 检查发件人
from_addr = str(make_header(decode_header(email_message['From'])))
if 'no-reply@cursor.sh' not in from_addr:
logger.debug(f"[{email}] 跳过非 Cursor 邮件,发件人: {from_addr}")
continue
# 检查主题
subject = str(make_header(decode_header(email_message['SUBJECT'])))
if not any(verify_subject in subject for verify_subject in self.verification_subjects):
logger.debug(f"[{email}] 跳过非验证码邮件,主题: {subject}")
continue
code = self._extract_code_from_email(email_message)
if code:
logger.debug(f"[{email}] 成功获取验证码: {code}")
mail.close()
mail.logout()
return code
await asyncio.sleep(1)
logger.error(f"[{email}] 验证码邮件未收到")
raise EmailError("Verification code not received")
except Exception as e:
logger.error(f"[{email}] 获取验证码失败: {str(e)}")
raise EmailError(f"Failed to get verification code: {str(e)}")
def _extract_code_from_email(self, email_message) -> Optional[str]:
"""从邮件内容中提取验证码"""
try:
# 获取邮件内容
if email_message.is_multipart():
for part in email_message.walk():
if part.get_content_type() == "text/html":
body = part.get_payload(decode=True).decode('utf-8', errors='ignore')
break
else:
body = email_message.get_payload(decode=True).decode('utf-8', errors='ignore')
# 提取6位数字验证码
import re
# 在HTML中查找包含6位数字的div
match = re.search(r'<div[^>]*>(\d{6})</div>', body)
if match:
code = match.group(1)
logger.debug(f"从HTML中提取到验证码: {code}")
return code
# 备用方案搜索任何6位数字
match = re.search(r'\b\d{6}\b', body)
if match:
code = match.group(0)
logger.debug(f"从文本中提取到验证码: {code}")
return code
logger.warning(f"未能从邮件中提取到验证码")
logger.debug(f"邮件内容预览: " + body[:200])
return None
except Exception as e:
logger.error(f"提取验证码失败: {str(e)}")
return None