CursorPro 后台管理系统 v1.0
功能: - 激活码管理 (Pro/Auto 两种类型) - 账号池管理 - 设备绑定记录 - 使用日志 - 搜索/筛选功能 - 禁用/启用功能 (支持退款参考) - 全局设置 (换号间隔、额度消耗等) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2
backend/app/services/__init__.py
Normal file
2
backend/app/services/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from app.services.account_service import AccountService, KeyService, LogService, GlobalSettingsService, BatchService
|
||||
from app.services.auth_service import authenticate_admin, create_access_token, get_current_user
|
||||
547
backend/app/services/account_service.py
Normal file
547
backend/app/services/account_service.py
Normal file
@@ -0,0 +1,547 @@
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, and_, or_
|
||||
from app.models import CursorAccount, ActivationKey, KeyDevice, UsageLog, MembershipType, AccountStatus, KeyStatus, GlobalSettings
|
||||
from app.schemas import AccountCreate, KeyCreate
|
||||
|
||||
|
||||
def generate_key(length: int = 32) -> str:
|
||||
"""生成随机激活码"""
|
||||
chars = string.ascii_uppercase + string.digits
|
||||
return ''.join(secrets.choice(chars) for _ in range(length))
|
||||
|
||||
|
||||
class AccountService:
|
||||
"""账号管理服务"""
|
||||
|
||||
@staticmethod
|
||||
def create(db: Session, account: AccountCreate) -> CursorAccount:
|
||||
"""创建账号"""
|
||||
db_account = CursorAccount(
|
||||
email=account.email,
|
||||
access_token=account.access_token,
|
||||
refresh_token=account.refresh_token,
|
||||
workos_session_token=account.workos_session_token,
|
||||
membership_type=account.membership_type,
|
||||
remark=account.remark
|
||||
)
|
||||
db.add(db_account)
|
||||
db.commit()
|
||||
db.refresh(db_account)
|
||||
return db_account
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(db: Session, account_id: int) -> Optional[CursorAccount]:
|
||||
return db.query(CursorAccount).filter(CursorAccount.id == account_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_by_email(db: Session, email: str) -> Optional[CursorAccount]:
|
||||
return db.query(CursorAccount).filter(CursorAccount.email == email).first()
|
||||
|
||||
@staticmethod
|
||||
def get_all(db: Session, skip: int = 0, limit: int = 100) -> List[CursorAccount]:
|
||||
return db.query(CursorAccount).offset(skip).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def get_available(db: Session, membership_type: MembershipType = None) -> Optional[CursorAccount]:
|
||||
"""获取一个可用账号"""
|
||||
query = db.query(CursorAccount).filter(CursorAccount.status == AccountStatus.ACTIVE)
|
||||
if membership_type:
|
||||
query = query.filter(CursorAccount.membership_type == membership_type)
|
||||
# 优先选择使用次数少的
|
||||
return query.order_by(CursorAccount.usage_count.asc()).first()
|
||||
|
||||
@staticmethod
|
||||
def update(db: Session, account_id: int, **kwargs) -> Optional[CursorAccount]:
|
||||
account = db.query(CursorAccount).filter(CursorAccount.id == account_id).first()
|
||||
if account:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(account, key) and value is not None:
|
||||
setattr(account, key, value)
|
||||
db.commit()
|
||||
db.refresh(account)
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def delete(db: Session, account_id: int) -> bool:
|
||||
account = db.query(CursorAccount).filter(CursorAccount.id == account_id).first()
|
||||
if account:
|
||||
db.delete(account)
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def mark_used(db: Session, account: CursorAccount, key_id: int = None):
|
||||
"""标记账号被使用"""
|
||||
account.usage_count += 1
|
||||
account.last_used_at = datetime.now()
|
||||
account.status = AccountStatus.IN_USE
|
||||
if key_id:
|
||||
account.current_key_id = key_id
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def release(db: Session, account: CursorAccount):
|
||||
"""释放账号"""
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.current_key_id = None
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def count(db: Session) -> dict:
|
||||
"""统计账号数量"""
|
||||
total = db.query(CursorAccount).count()
|
||||
active = db.query(CursorAccount).filter(CursorAccount.status == AccountStatus.ACTIVE).count()
|
||||
pro = db.query(CursorAccount).filter(CursorAccount.membership_type == MembershipType.PRO).count()
|
||||
return {"total": total, "active": active, "pro": pro}
|
||||
|
||||
|
||||
class KeyService:
|
||||
"""激活码管理服务"""
|
||||
|
||||
@staticmethod
|
||||
def create(db: Session, key_data: KeyCreate) -> List[ActivationKey]:
|
||||
"""创建激活码(支持批量)"""
|
||||
keys = []
|
||||
max_retries = 5 # 最大重试次数
|
||||
|
||||
for _ in range(key_data.count):
|
||||
# 生成唯一的key,如果冲突则重试
|
||||
for retry in range(max_retries):
|
||||
key_str = key_data.key if key_data.key and key_data.count == 1 else generate_key()
|
||||
|
||||
# 检查key是否已存在
|
||||
existing = db.query(ActivationKey).filter(ActivationKey.key == key_str).first()
|
||||
if not existing:
|
||||
break
|
||||
if retry == max_retries - 1:
|
||||
raise ValueError(f"无法生成唯一激活码,请重试")
|
||||
|
||||
db_key = ActivationKey(
|
||||
key=key_str,
|
||||
membership_type=key_data.membership_type,
|
||||
quota=key_data.quota if key_data.membership_type == MembershipType.PRO else 0, # Free不需要额度
|
||||
valid_days=key_data.valid_days,
|
||||
max_devices=key_data.max_devices,
|
||||
remark=key_data.remark
|
||||
)
|
||||
db.add(db_key)
|
||||
keys.append(db_key)
|
||||
db.commit()
|
||||
for k in keys:
|
||||
db.refresh(k)
|
||||
return keys
|
||||
|
||||
@staticmethod
|
||||
def get_by_key(db: Session, key: str) -> Optional[ActivationKey]:
|
||||
return db.query(ActivationKey).filter(ActivationKey.key == key).first()
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(db: Session, key_id: int) -> Optional[ActivationKey]:
|
||||
return db.query(ActivationKey).filter(ActivationKey.id == key_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_all(db: Session, skip: int = 0, limit: int = 100) -> List[ActivationKey]:
|
||||
return db.query(ActivationKey).order_by(ActivationKey.id.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def update(db: Session, key_id: int, **kwargs) -> Optional[ActivationKey]:
|
||||
key = db.query(ActivationKey).filter(ActivationKey.id == key_id).first()
|
||||
if key:
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(key, k) and v is not None:
|
||||
setattr(key, k, v)
|
||||
db.commit()
|
||||
db.refresh(key)
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
def delete(db: Session, key_id: int) -> bool:
|
||||
key = db.query(ActivationKey).filter(ActivationKey.id == key_id).first()
|
||||
if key:
|
||||
# 删除关联的设备记录
|
||||
db.query(KeyDevice).filter(KeyDevice.key_id == key_id).delete()
|
||||
db.delete(key)
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def activate(db: Session, key: ActivationKey):
|
||||
"""首次激活:设置激活时间和过期时间"""
|
||||
if key.first_activated_at is None:
|
||||
key.first_activated_at = datetime.now()
|
||||
if key.valid_days > 0:
|
||||
key.expire_at = key.first_activated_at + timedelta(days=key.valid_days)
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def is_valid(key: ActivationKey, db: Session) -> Tuple[bool, str]:
|
||||
"""检查激活码是否有效"""
|
||||
if key.status != KeyStatus.ACTIVE:
|
||||
return False, "激活码已禁用"
|
||||
|
||||
# 检查是否已过期(只有激活后才检查)
|
||||
if key.first_activated_at and key.expire_at and key.expire_at < datetime.now():
|
||||
return False, "激活码已过期"
|
||||
|
||||
# Pro套餐检查额度
|
||||
if key.membership_type == MembershipType.PRO:
|
||||
quota_cost = GlobalSettingsService.get_int(db, "pro_quota_cost")
|
||||
if key.quota_used + quota_cost > key.quota:
|
||||
return False, f"额度不足,需要{quota_cost},剩余{key.quota - key.quota_used}"
|
||||
|
||||
return True, "有效"
|
||||
|
||||
@staticmethod
|
||||
def can_switch(db: Session, key: ActivationKey) -> Tuple[bool, str]:
|
||||
"""检查是否可以换号
|
||||
- Auto: 检查换号间隔 + 每天最大次数(全局设置)
|
||||
- Pro: 无频率限制(只检查额度,在is_valid中)
|
||||
"""
|
||||
# Pro密钥无频率限制
|
||||
if key.membership_type == MembershipType.PRO:
|
||||
return True, "可以换号"
|
||||
|
||||
# === Auto密钥频率检查 ===
|
||||
now = datetime.now()
|
||||
|
||||
# 1. 检查换号间隔
|
||||
interval_minutes = GlobalSettingsService.get_int(db, "auto_switch_interval_minutes")
|
||||
if key.last_switch_at:
|
||||
minutes_since_last = (now - key.last_switch_at).total_seconds() / 60
|
||||
if minutes_since_last < interval_minutes:
|
||||
wait_minutes = int(interval_minutes - minutes_since_last)
|
||||
return False, f"换号太频繁,请等待{wait_minutes}分钟"
|
||||
|
||||
# 2. 检查今日换号次数
|
||||
max_per_day = GlobalSettingsService.get_int(db, "auto_max_switches_per_day")
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
today_count = db.query(UsageLog).filter(
|
||||
UsageLog.key_id == key.id,
|
||||
UsageLog.action == "switch",
|
||||
UsageLog.success == True,
|
||||
UsageLog.created_at >= today_start
|
||||
).count()
|
||||
|
||||
if today_count >= max_per_day:
|
||||
return False, f"今日换号次数已达上限({max_per_day}次)"
|
||||
|
||||
return True, "可以换号"
|
||||
|
||||
@staticmethod
|
||||
def check_device(db: Session, key: ActivationKey, device_id: str) -> Tuple[bool, str]:
|
||||
"""检查设备限制"""
|
||||
if not device_id:
|
||||
return True, "无设备ID"
|
||||
|
||||
# 查找现有设备
|
||||
device = db.query(KeyDevice).filter(
|
||||
KeyDevice.key_id == key.id,
|
||||
KeyDevice.device_id == device_id
|
||||
).first()
|
||||
|
||||
if device:
|
||||
# 更新最后活跃时间
|
||||
device.last_active_at = datetime.now()
|
||||
db.commit()
|
||||
return True, "设备已绑定"
|
||||
|
||||
# 检查设备数量
|
||||
device_count = db.query(KeyDevice).filter(KeyDevice.key_id == key.id).count()
|
||||
if device_count >= key.max_devices:
|
||||
return False, f"设备数量已达上限({key.max_devices}个)"
|
||||
|
||||
# 添加新设备
|
||||
new_device = KeyDevice(key_id=key.id, device_id=device_id, last_active_at=datetime.now())
|
||||
db.add(new_device)
|
||||
db.commit()
|
||||
return True, "新设备已绑定"
|
||||
|
||||
@staticmethod
|
||||
def use_switch(db: Session, key: ActivationKey):
|
||||
"""使用一次换号(Pro扣除额度,Free不扣)"""
|
||||
if key.membership_type == MembershipType.PRO:
|
||||
quota_cost = GlobalSettingsService.get_int(db, "pro_quota_cost")
|
||||
key.quota_used += quota_cost
|
||||
# Free不扣额度
|
||||
key.switch_count += 1
|
||||
key.last_switch_at = datetime.now()
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_quota_cost(db: Session, membership_type: MembershipType) -> int:
|
||||
"""获取换号消耗的额度"""
|
||||
if membership_type == MembershipType.PRO:
|
||||
return GlobalSettingsService.get_int(db, "pro_quota_cost")
|
||||
return 0 # Free不消耗额度
|
||||
|
||||
@staticmethod
|
||||
def add_quota(db: Session, key: ActivationKey, add_quota: int):
|
||||
"""叠加额度(只加额度不加时间,仅Pro有效)"""
|
||||
key.quota += add_quota
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def bind_account(db: Session, key: ActivationKey, account: CursorAccount):
|
||||
"""绑定账号"""
|
||||
key.current_account_id = account.id
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def count(db: Session) -> dict:
|
||||
"""统计激活码数量"""
|
||||
total = db.query(ActivationKey).count()
|
||||
active = db.query(ActivationKey).filter(ActivationKey.status == KeyStatus.ACTIVE).count()
|
||||
return {"total": total, "active": active}
|
||||
|
||||
|
||||
class LogService:
|
||||
"""日志服务"""
|
||||
|
||||
@staticmethod
|
||||
def log(db: Session, key_id: int, action: str, account_id: int = None,
|
||||
ip_address: str = None, user_agent: str = None,
|
||||
success: bool = True, message: str = None):
|
||||
log = UsageLog(
|
||||
key_id=key_id,
|
||||
account_id=account_id,
|
||||
action=action,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
success=success,
|
||||
message=message
|
||||
)
|
||||
db.add(log)
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_today_count(db: Session) -> int:
|
||||
today = datetime.now().date()
|
||||
return db.query(UsageLog).filter(
|
||||
func.date(UsageLog.created_at) == today
|
||||
).count()
|
||||
|
||||
|
||||
class GlobalSettingsService:
|
||||
"""全局设置服务"""
|
||||
|
||||
# 默认设置
|
||||
DEFAULT_SETTINGS = {
|
||||
# Auto密钥设置
|
||||
"auto_switch_interval_minutes": ("20", "Auto换号最小间隔(分钟)"),
|
||||
"auto_max_switches_per_day": ("50", "Auto每天最大换号次数"),
|
||||
# Pro密钥设置
|
||||
"pro_quota_cost": ("50", "Pro每次换号扣除额度"),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def init_settings(db: Session):
|
||||
"""初始化默认设置"""
|
||||
for key, (value, desc) in GlobalSettingsService.DEFAULT_SETTINGS.items():
|
||||
existing = db.query(GlobalSettings).filter(GlobalSettings.key == key).first()
|
||||
if not existing:
|
||||
setting = GlobalSettings(key=key, value=value, description=desc)
|
||||
db.add(setting)
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def get(db: Session, key: str) -> Optional[str]:
|
||||
"""获取单个设置"""
|
||||
setting = db.query(GlobalSettings).filter(GlobalSettings.key == key).first()
|
||||
if setting:
|
||||
return setting.value
|
||||
# 返回默认值
|
||||
if key in GlobalSettingsService.DEFAULT_SETTINGS:
|
||||
return GlobalSettingsService.DEFAULT_SETTINGS[key][0]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_int(db: Session, key: str) -> int:
|
||||
"""获取整数设置"""
|
||||
value = GlobalSettingsService.get(db, key)
|
||||
return int(value) if value else 0
|
||||
|
||||
@staticmethod
|
||||
def set(db: Session, key: str, value: str, description: str = None):
|
||||
"""设置单个配置"""
|
||||
setting = db.query(GlobalSettings).filter(GlobalSettings.key == key).first()
|
||||
if setting:
|
||||
setting.value = value
|
||||
if description:
|
||||
setting.description = description
|
||||
else:
|
||||
setting = GlobalSettings(key=key, value=value, description=description)
|
||||
db.add(setting)
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_all(db: Session) -> dict:
|
||||
"""获取所有设置"""
|
||||
return {
|
||||
"auto_switch_interval_minutes": GlobalSettingsService.get_int(db, "auto_switch_interval_minutes"),
|
||||
"auto_max_switches_per_day": GlobalSettingsService.get_int(db, "auto_max_switches_per_day"),
|
||||
"pro_quota_cost": GlobalSettingsService.get_int(db, "pro_quota_cost"),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def update_all(db: Session, **kwargs):
|
||||
"""批量更新设置"""
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
GlobalSettingsService.set(db, key, str(value))
|
||||
|
||||
|
||||
class BatchService:
|
||||
"""批量操作服务"""
|
||||
|
||||
@staticmethod
|
||||
def extend_keys(db: Session, key_ids: List[int], extend_days: int = 0, add_quota: int = 0) -> dict:
|
||||
"""批量延长密钥
|
||||
- Auto密钥:只能延长到期时间
|
||||
- Pro密钥:可以延长到期时间 + 增加额度
|
||||
"""
|
||||
success = 0
|
||||
failed = 0
|
||||
errors = []
|
||||
|
||||
for key_id in key_ids:
|
||||
try:
|
||||
key = db.query(ActivationKey).filter(ActivationKey.id == key_id).first()
|
||||
if not key:
|
||||
failed += 1
|
||||
errors.append(f"ID {key_id}: 密钥不存在")
|
||||
continue
|
||||
|
||||
# 延长到期时间
|
||||
if extend_days > 0:
|
||||
if key.expire_at:
|
||||
# 已激活:在当前到期时间基础上延长
|
||||
key.expire_at = key.expire_at + timedelta(days=extend_days)
|
||||
else:
|
||||
# 未激活:增加有效天数
|
||||
key.valid_days += extend_days
|
||||
|
||||
# 增加额度(仅Pro有效)
|
||||
if add_quota > 0 and key.membership_type == MembershipType.PRO:
|
||||
key.quota += add_quota
|
||||
|
||||
db.commit()
|
||||
success += 1
|
||||
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
errors.append(f"ID {key_id}: {str(e)}")
|
||||
|
||||
return {"success": success, "failed": failed, "errors": errors[:10]}
|
||||
|
||||
@staticmethod
|
||||
def get_keys_for_compensation(
|
||||
db: Session,
|
||||
membership_type: MembershipType = None,
|
||||
activated_before: datetime = None,
|
||||
not_expired_on: datetime = None,
|
||||
) -> List[ActivationKey]:
|
||||
"""获取符合补偿条件的密钥列表
|
||||
- membership_type: 筛选套餐类型 (pro/free)
|
||||
- activated_before: 在此日期之前激活的 (first_activated_at < activated_before)
|
||||
- not_expired_on: 在此日期时还未过期的 (expire_at > not_expired_on)
|
||||
|
||||
例如:补偿12月4号之前激活、且12月4号还没过期的用户
|
||||
activated_before = 2024-12-05 (12月4号之前,即<12月5号0点)
|
||||
not_expired_on = 2024-12-04 (12月4号还没过期,即expire_at > 12月4号)
|
||||
"""
|
||||
query = db.query(ActivationKey)
|
||||
|
||||
if membership_type:
|
||||
query = query.filter(ActivationKey.membership_type == membership_type)
|
||||
|
||||
# 只选择状态为active的
|
||||
query = query.filter(ActivationKey.status == KeyStatus.ACTIVE)
|
||||
|
||||
# 只选择已激活的(有激活时间的)
|
||||
query = query.filter(ActivationKey.first_activated_at != None)
|
||||
|
||||
if activated_before:
|
||||
# 在指定日期之前激活的
|
||||
query = query.filter(ActivationKey.first_activated_at < activated_before)
|
||||
|
||||
if not_expired_on:
|
||||
# 在指定日期时还未过期的 (expire_at > 指定日期 或 永久卡)
|
||||
query = query.filter(
|
||||
or_(
|
||||
ActivationKey.expire_at == None, # 永久卡
|
||||
ActivationKey.expire_at > not_expired_on # 在那天还没过期
|
||||
)
|
||||
)
|
||||
|
||||
return query.all()
|
||||
|
||||
@staticmethod
|
||||
def batch_compensate(
|
||||
db: Session,
|
||||
membership_type: MembershipType = None,
|
||||
activated_before: datetime = None,
|
||||
not_expired_on: datetime = None,
|
||||
extend_days: int = 0,
|
||||
add_quota: int = 0
|
||||
) -> dict:
|
||||
"""批量补偿 - 根据条件筛选并补偿
|
||||
|
||||
补偿逻辑:
|
||||
- 如果卡当前还没过期:expire_at += extend_days
|
||||
- 如果卡已过期(但符合补偿条件):expire_at = 今天 + extend_days(恢复使用)
|
||||
|
||||
例如: 补偿12月4号之前激活、12月4号还没过期的Auto密钥,延长1天
|
||||
"""
|
||||
keys = BatchService.get_keys_for_compensation(
|
||||
db,
|
||||
membership_type=membership_type,
|
||||
activated_before=activated_before,
|
||||
not_expired_on=not_expired_on
|
||||
)
|
||||
|
||||
if not keys:
|
||||
return {"success": 0, "failed": 0, "total_matched": 0, "recovered": 0, "errors": ["没有符合条件的密钥"]}
|
||||
|
||||
success = 0
|
||||
failed = 0
|
||||
recovered = 0 # 恢复使用的数量
|
||||
errors = []
|
||||
now = datetime.now()
|
||||
|
||||
for key in keys:
|
||||
try:
|
||||
# 延长到期时间
|
||||
if extend_days > 0 and key.expire_at:
|
||||
if key.expire_at > now:
|
||||
# 还没过期:在当前过期时间上加天数
|
||||
key.expire_at = key.expire_at + timedelta(days=extend_days)
|
||||
else:
|
||||
# 已过期:恢复使用,设为今天+补偿天数
|
||||
key.expire_at = now + timedelta(days=extend_days)
|
||||
recovered += 1
|
||||
|
||||
# 增加额度(仅Pro有效)
|
||||
if add_quota > 0 and key.membership_type == MembershipType.PRO:
|
||||
key.quota += add_quota
|
||||
|
||||
db.commit()
|
||||
success += 1
|
||||
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
errors.append(f"ID {key.id}: {str(e)}")
|
||||
|
||||
return {
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"total_matched": len(keys),
|
||||
"recovered": recovered,
|
||||
"errors": errors[:10]
|
||||
}
|
||||
|
||||
59
backend/app/services/auth_service.py
Normal file
59
backend/app/services/auth_service.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from fastapi import HTTPException, Depends, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from app.config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def verify_token(token: str) -> dict:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
return payload
|
||||
except JWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的认证凭据",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def authenticate_admin(username: str, password: str) -> bool:
|
||||
"""验证管理员账号"""
|
||||
return username == settings.ADMIN_USERNAME and password == settings.ADMIN_PASSWORD
|
||||
|
||||
|
||||
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> dict:
|
||||
"""获取当前用户"""
|
||||
token = credentials.credentials
|
||||
payload = verify_token(token)
|
||||
username = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的认证凭据"
|
||||
)
|
||||
return {"username": username}
|
||||
Reference in New Issue
Block a user