Files
auto_cursor/core/database.py
2025-04-01 15:43:27 +08:00

274 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
import aiomysql
from loguru import logger
# 使用条件导入替代直接导入
REDIS_AVAILABLE = False
try:
# 尝试导入redis.asyncio (Redis-py 4.2.0+)
import redis.asyncio as redis_asyncio
REDIS_AVAILABLE = True
REDIS_TYPE = "redis-py"
except ImportError:
try:
# 尝试导入aioredis (旧版本)
import aioredis
REDIS_AVAILABLE = True
REDIS_TYPE = "aioredis"
except (ImportError, TypeError):
REDIS_AVAILABLE = False
REDIS_TYPE = None
from core.config import Config
class DatabaseManager:
def __init__(self, config: Config):
# 数据库配置
self.db_config = config.database_config
self._pool_size = self.db_config.pool_size
self._pool = None # 连接池
self._pool_lock = asyncio.Lock()
# Redis配置
self.use_redis = self.db_config.use_redis
self.redis_config = config.redis_config if hasattr(config, 'redis_config') else None
self.redis = None
async def initialize(self):
"""初始化数据库连接池"""
logger.info("初始化数据库连接池")
# 创建MySQL连接池
try:
logger.info(f"连接MySQL: {self.db_config.host}:{self.db_config.port}, 用户: {self.db_config.username}, 数据库: {self.db_config.database}")
self._pool = await aiomysql.create_pool(
host=self.db_config.host,
port=self.db_config.port,
user=self.db_config.username,
password=self.db_config.password,
db=self.db_config.database,
maxsize=self._pool_size,
autocommit=True,
charset='utf8mb4'
)
logger.info("MySQL连接池创建成功")
except Exception as e:
logger.error(f"MySQL连接池创建失败: {str(e)}")
logger.error("请检查MySQL配置是否正确以及MySQL服务是否已启动")
logger.info(f"您可能需要创建MySQL用户和数据库")
logger.info(f" CREATE USER '{self.db_config.username}'@'localhost' IDENTIFIED BY '{self.db_config.password}';")
logger.info(f" CREATE DATABASE {self.db_config.database} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;")
logger.info(f" GRANT ALL PRIVILEGES ON {self.db_config.database}.* TO '{self.db_config.username}'@'localhost';")
logger.info(f" FLUSH PRIVILEGES;")
raise
# 初始化表结构
async with self.get_connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute('''
CREATE TABLE IF NOT EXISTS email_accounts (
id INT AUTO_INCREMENT PRIMARY KEY,
email VARCHAR(255) UNIQUE NOT NULL,
password VARCHAR(255) NOT NULL,
client_id VARCHAR(255) NOT NULL,
refresh_token TEXT NOT NULL,
in_use BOOLEAN DEFAULT 0,
cursor_password VARCHAR(255),
cursor_cookie TEXT,
cursor_token TEXT,
sold BOOLEAN DEFAULT 0,
status VARCHAR(20) DEFAULT 'pending',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_status_inuse_sold (status, in_use, sold)
)
''')
# 初始化Redis连接如果配置了
if self.use_redis and REDIS_AVAILABLE and self.redis_config:
try:
# 根据检测到的Redis库类型创建连接
if REDIS_TYPE == "redis-py":
# 使用redis.asyncio创建连接
logger.info(f"使用redis-py连接Redis: {self.redis_config.host}:{self.redis_config.port}")
self.redis = redis_asyncio.Redis(
host=self.redis_config.host,
port=self.redis_config.port,
password=self.redis_config.password or None,
db=self.redis_config.db,
decode_responses=True
)
# 测试连接
await self.redis.ping()
elif REDIS_TYPE == "aioredis":
# 使用旧版aioredis创建连接
logger.info(f"使用aioredis连接Redis: {self.redis_config.host}:{self.redis_config.port}")
self.redis = await aioredis.from_url(
f"redis://{self.redis_config.host}:{self.redis_config.port}",
password=self.redis_config.password or None,
db=self.redis_config.db,
encoding="utf-8",
decode_responses=True
)
logger.info("Redis连接初始化成功")
except Exception as e:
logger.error(f"Redis连接初始化失败: {e}")
logger.info("Redis缓存将被禁用")
self.redis = None
logger.info(f"数据库连接池初始化完成,大小: {self._pool_size}")
async def cleanup(self):
"""清理数据库连接"""
if self._pool:
self._pool.close()
await self._pool.wait_closed()
if self.redis:
if REDIS_TYPE == "redis-py":
await self.redis.close()
else:
await self.redis.close()
logger.info("数据库连接已清理")
@asynccontextmanager
async def get_connection(self):
"""获取数据库连接"""
if self._pool is None:
raise Exception("数据库连接池未初始化")
async with self._pool.acquire() as conn:
try:
yield conn
finally:
pass # 连接会自动返回池中
async def execute(self, query: str, params: tuple = ()) -> Any:
"""执行SQL语句"""
logger.debug(f"执行SQL: {query}, 参数: {params}")
try:
async with self.get_connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, params)
# 对于INSERT语句返回最后插入的ID
if query.strip().upper().startswith("INSERT"):
return cursor.lastrowid
# 对于UPDATE/DELETE语句返回影响的行数
return cursor.rowcount
except Exception as e:
logger.error(f"SQL执行失败: {query}, 参数: {params}, 错误: {str(e)}")
raise
async def fetch_one(self, query: str, params: tuple = ()) -> Optional[Dict]:
"""查询单条记录"""
logger.debug(f"查询单条: {query}, 参数: {params}")
# 尝试从Redis获取缓存
cache_key = f"db:{self._make_cache_key(query, params)}"
cached_result = await self._get_from_cache(cache_key)
if cached_result is not None:
return cached_result
try:
async with self.get_connection() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(query, params)
result = await cursor.fetchone()
# 缓存结果
if result and self.redis:
await self._store_in_cache(cache_key, result)
logger.debug(f"查询结果: {result}")
return result
except Exception as e:
logger.error(f"查询单条失败: {query}, 参数: {params}, 错误: {str(e)}")
raise
async def fetch_all(self, query: str, params: tuple = ()) -> List[Dict]:
"""查询多条记录"""
logger.debug(f"查询多条: {query}, 参数: {params}")
# 尝试从Redis获取缓存
cache_key = f"db:{self._make_cache_key(query, params)}"
cached_result = await self._get_from_cache(cache_key)
if cached_result is not None:
return cached_result
try:
async with self.get_connection() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(query, params)
results = await cursor.fetchall()
# 缓存结果
if results and self.redis:
await self._store_in_cache(cache_key, results)
logger.debug(f"查询结果数量: {len(results)}")
return results
except Exception as e:
logger.error(f"查询多条失败: {query}, 参数: {params}, 错误: {str(e)}")
raise
async def _get_from_cache(self, key: str) -> Optional[Union[Dict, List[Dict]]]:
"""从Redis缓存获取数据"""
if not self.redis:
return None
try:
cached_data = await self.redis.get(key)
if cached_data:
return json.loads(cached_data)
except Exception as e:
logger.error(f"从缓存获取数据失败: {e}")
return None
async def _store_in_cache(self, key: str, data: Union[Dict, List[Dict]], ttl: int = 300) -> bool:
"""存储数据到Redis缓存"""
if not self.redis:
return False
try:
json_data = json.dumps(data)
if REDIS_TYPE == "redis-py":
await self.redis.setex(key, ttl, json_data)
else:
await self.redis.setex(key, ttl, json_data)
return True
except Exception as e:
logger.error(f"存储数据到缓存失败: {e}")
return False
async def clear_cache(self, pattern: str = "db:*") -> int:
"""清除缓存"""
if not self.redis:
return 0
try:
if REDIS_TYPE == "redis-py":
keys = await self.redis.keys(pattern)
if not keys:
return 0
return await self.redis.delete(*keys)
else:
keys = await self.redis.keys(pattern)
if not keys:
return 0
return await self.redis.delete(*keys)
except Exception as e:
logger.error(f"清除缓存失败: {e}")
return 0
def _make_cache_key(self, query: str, params: tuple) -> str:
"""生成缓存键"""
return f"{query}:{hash(params)}"