Files
auto_cursor/services/email_manager.py
huangzhenpc 31fe73f998 sdsad
2025-04-01 18:38:16 +08:00

283 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import email
from dataclasses import dataclass
from email.header import decode_header, make_header
from typing import Dict, List, Optional
import aiohttp
from loguru import logger
from core.config import Config
from core.database import DatabaseManager
from core.exceptions import EmailError
@dataclass
class EmailAccount:
id: int
email: str
password: str # 这里实际上是 refresh_token
client_id: str
refresh_token: str
in_use: bool = False
cursor_password: Optional[str] = None
cursor_cookie: Optional[str] = None
sold: bool = False
status: str = 'pending' # 新增状态字段: pending, unavailable, success
class EmailManager:
def __init__(self, config: Config, db_manager: DatabaseManager):
self.config = config
self.db = db_manager
self.verification_subjects = [
"Verify your email address",
"Complete code challenge",
]
async def batch_get_accounts(self, num: int) -> List[EmailAccount]:
"""批量获取未使用的邮箱账号"""
logger.info(f"尝试获取 {num} 个未使用的邮箱账号")
# 使用简单条件获取未使用的账号
select_query = """
SELECT id, email, password, client_id, refresh_token
FROM email_accounts
WHERE in_use = 0 AND sold = 0 AND status = 'pending'
AND email NOT IN (
SELECT email FROM email_accounts
WHERE status = 'success' OR status = 'unavailable'
)
LIMIT %s
"""
accounts = await self.db.fetch_all(select_query, (num,))
if not accounts:
logger.debug("没有找到符合条件的账号")
return []
# 2. 提取账号ID列表
account_ids = [account['id'] for account in accounts]
# 3. 更新这些账号的状态,设置注册开始时间戳,避免其他进程获取
if account_ids:
placeholders = ', '.join(['%s' for _ in account_ids])
update_query = f"""
UPDATE email_accounts
SET in_use = 1, updated_at = CURRENT_TIMESTAMP
WHERE id IN ({placeholders})
"""
await self.db.execute(update_query, tuple(account_ids))
# 记录被锁定的账号信息
logger.info(f"已锁定 {len(account_ids)} 个账号用于注册: {account_ids}")
# 4. 返回账号数据
logger.debug(f"实际获取到 {len(accounts)} 个账号")
return [
EmailAccount(
id=row['id'],
email=row['email'],
password=row['password'],
client_id=row['client_id'],
refresh_token=row['refresh_token'],
in_use=True
)
for row in accounts
]
async def update_account_status(self, account_id: int, status: str):
"""更新账号状态"""
query = '''
UPDATE email_accounts
SET
status = %s,
in_use = 0,
updated_at = CURRENT_TIMESTAMP
WHERE id = %s
'''
await self.db.execute(query, (status, account_id))
async def update_account(self, account_id: int, cursor_password: str, cursor_cookie: str, cursor_token: str):
"""更新账号信息"""
query = '''
UPDATE email_accounts
SET
cursor_password = %s,
cursor_cookie = %s,
cursor_token = %s,
in_use = 0,
sold = 1,
status = 'success',
updated_at = CURRENT_TIMESTAMP
WHERE id = %s
'''
await self.db.execute(query, (cursor_password, cursor_cookie, cursor_token, account_id))
async def release_account(self, account_id: int):
"""释放账号"""
query = '''
UPDATE email_accounts
SET in_use = 0, updated_at = CURRENT_TIMESTAMP
WHERE id = %s
'''
await self.db.execute(query, (account_id,))
async def count_pending_accounts(self) -> int:
"""统计可用的pending状态账号数量"""
query = """
SELECT COUNT(*)
FROM email_accounts
WHERE status = 'pending' AND in_use = 0 AND sold = 0
AND email NOT IN (
SELECT email FROM email_accounts
WHERE status = 'success' OR status = 'unavailable'
)
"""
result = await self.db.fetch_one(query)
if result:
return result.get("COUNT(*)", 0)
return 0
async def _get_access_token(self, client_id: str, refresh_token: str) -> str:
"""获取微软 access token"""
logger.debug(f"开始获取 access token - client_id: {client_id}")
url = 'https://login.microsoftonline.com/common/oauth2/v2.0/token'
data = {
'client_id': client_id,
'grant_type': 'refresh_token',
'refresh_token': refresh_token,
}
async with aiohttp.ClientSession() as session:
async with session.post(url, data=data) as response:
result = await response.json()
if 'error' in result:
error = result.get('error')
logger.error(f"获取 access token 失败: {error}")
raise EmailError(f"Failed to get access token: {error}")
access_token = result['access_token']
logger.debug("成功获取 access token")
return access_token
async def get_verification_code(self, email: str, refresh_token: str, client_id: str) -> str:
"""获取验证码"""
logger.info(f"开始获取邮箱验证码 - {email}")
try:
# 1. 获取 access token
access_token = await self._get_access_token(client_id, refresh_token)
logger.debug(f"[{email}] 获取 access token 成功")
# 2. 构建认证字符串
auth_string = f"user={email}\1auth=Bearer {access_token}\1\1"
logger.debug(f"[{email}] 认证字符串构建完成")
# 3. 连接邮箱
import imaplib
mail = imaplib.IMAP4_SSL('outlook.live.com')
mail.authenticate('XOAUTH2', lambda x: auth_string)
mail.select('inbox')
logger.debug(f"[{email}] 邮箱连接成功")
# 4. 等待并获取验证码邮件
for i in range(15):
logger.debug(f"[{email}] 第 {i + 1} 次尝试获取验证码")
# 搜索来自 no-reply@cursor.sh 的最新邮件
result, data = mail.search(None, '(FROM "no-reply@cursor.sh")')
if result != "OK" or not data[0]:
logger.debug(f"[{email}] 未找到来自 cursor 的邮件等待1秒后重试")
await asyncio.sleep(1)
continue
mail_ids = data[0].split()
if not mail_ids:
logger.debug(f"[{email}] 邮件ID列表为空等待1秒后重试")
await asyncio.sleep(1)
continue
# 获取最新的3封邮件
last_mail_ids = sorted(mail_ids, reverse=True)[:3]
for mail_id in last_mail_ids:
result, msg_data = mail.fetch(mail_id, "(RFC822)")
if result != 'OK':
logger.warning(f"[{email}] 获取邮件内容失败: {result}")
continue
# 确保 msg_data 不为空且格式正确
if not msg_data or not msg_data[0] or len(msg_data[0]) < 2:
logger.warning(f"[{email}] 邮件数据格式不正确")
continue
# 正确导入 email 模块
from email import message_from_bytes
email_message = message_from_bytes(msg_data[0][1])
# 检查发件人
from_addr = str(make_header(decode_header(email_message['From'])))
if 'no-reply@cursor.sh' not in from_addr:
logger.debug(f"[{email}] 跳过非 Cursor 邮件,发件人: {from_addr}")
continue
# 检查主题
subject = str(make_header(decode_header(email_message['SUBJECT'])))
if not any(verify_subject in subject for verify_subject in self.verification_subjects):
logger.debug(f"[{email}] 跳过非验证码邮件,主题: {subject}")
continue
code = self._extract_code_from_email(email_message)
if code:
logger.debug(f"[{email}] 成功获取验证码: {code}")
mail.close()
mail.logout()
return code
await asyncio.sleep(1)
logger.error(f"[{email}] 验证码邮件未收到")
raise EmailError("Verification code not received")
except Exception as e:
logger.error(f"[{email}] 获取验证码失败: {str(e)}")
raise EmailError(f"Failed to get verification code: {str(e)}")
def _extract_code_from_email(self, email_message) -> Optional[str]:
"""从邮件内容中提取验证码"""
try:
# 获取邮件内容
if email_message.is_multipart():
for part in email_message.walk():
if part.get_content_type() == "text/html":
body = part.get_payload(decode=True).decode('utf-8', errors='ignore')
break
else:
body = email_message.get_payload(decode=True).decode('utf-8', errors='ignore')
# 提取6位数字验证码
import re
# 在HTML中查找包含6位数字的div
match = re.search(r'<div[^>]*>(\d{6})</div>', body)
if match:
code = match.group(1)
logger.debug(f"从HTML中提取到验证码: {code}")
return code
# 备用方案搜索任何6位数字
match = re.search(r'\b\d{6}\b', body)
if match:
code = match.group(0)
logger.debug(f"从文本中提取到验证码: {code}")
return code
logger.warning(f"未能从邮件中提取到验证码")
logger.debug(f"邮件内容预览: " + body[:200])
return None
except Exception as e:
logger.error(f"提取验证码失败: {str(e)}")
return None