Files
auto_cursor_online/core/database.py
2025-03-31 09:55:54 +08:00

86 lines
3.1 KiB
Python

import asyncio
from contextlib import asynccontextmanager
from typing import Any, List, Optional
import aiosqlite
from loguru import logger
from core.config import Config
class DatabaseManager:
def __init__(self, config: Config):
self.db_path = config.database_config.path
self._pool_size = config.database_config.pool_size
self._pool: List[aiosqlite.Connection] = []
self._pool_lock = asyncio.Lock()
async def initialize(self):
"""初始化数据库连接池"""
logger.info("初始化数据库连接池")
async with aiosqlite.connect(self.db_path) as db:
await db.execute('''
CREATE TABLE IF NOT EXISTS email_accounts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email TEXT UNIQUE NOT NULL,
password TEXT NOT NULL,
client_id TEXT NOT NULL,
refresh_token TEXT NOT NULL,
in_use BOOLEAN DEFAULT 0,
cursor_password TEXT,
cursor_cookie TEXT,
sold BOOLEAN DEFAULT 0,
status TEXT DEFAULT 'pending',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
await db.commit()
# 初始化连接池
for i in range(self._pool_size):
conn = await aiosqlite.connect(self.db_path)
self._pool.append(conn)
logger.info(f"数据库连接池初始化完成,大小: {self._pool_size}")
async def cleanup(self):
"""清理数据库连接"""
for conn in self._pool:
await conn.close()
self._pool.clear()
@asynccontextmanager
async def get_connection(self):
"""获取数据库连接"""
async with self._pool_lock:
if not self._pool:
conn = await aiosqlite.connect(self.db_path)
else:
conn = self._pool.pop()
try:
yield conn
finally:
if len(self._pool) < self._pool_size:
self._pool.append(conn)
else:
await conn.close()
async def execute(self, query: str, params: tuple = ()) -> Any:
"""执行SQL语句"""
async with self.get_connection() as conn:
cursor = await conn.execute(query, params)
await conn.commit()
return cursor.lastrowid
async def fetch_one(self, query: str, params: tuple = ()) -> Optional[tuple]:
"""查询单条记录"""
async with self.get_connection() as conn:
cursor = await conn.execute(query, params)
return await cursor.fetchone()
async def fetch_all(self, query: str, params: tuple = ()) -> List[tuple]:
"""查询多条记录"""
async with self.get_connection() as conn:
cursor = await conn.execute(query, params)
return await cursor.fetchall()