Files
cursornew2026/backend/app/services/account_service.py
huangzhenpc ac19d029da backend v2.1: 公告管理功能 + 系统重构
- 新增 Announcement 数据模型,支持公告的增删改查
- 后台管理新增"公告管理"Tab(创建/编辑/删除/启用禁用)
- 客户端 /api/announcement 改为从数据库读取
- 账号服务重构,新增无感换号、自动分析等功能
- 新增后台任务调度器、数据库迁移脚本
- Schema/Service/Config 全面升级至 v2.1

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 19:58:05 +08:00

925 lines
33 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
蜂鸟Pro 账号服务 v2.1
基于系统设计文档重构
"""
import secrets
import string
from datetime import datetime, timedelta
from typing import Optional, List, Tuple, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import func, and_, or_
from decimal import Decimal
from app.models.models import (
CursorAccount, ActivationKey, KeyDevice, UsageLog, GlobalSettings,
AccountStatus, AccountType, KeyMembershipType, KeyStatus
)
def generate_key(length: int = 32) -> str:
"""生成随机激活码"""
chars = string.ascii_uppercase + string.digits
return ''.join(secrets.choice(chars) for _ in range(length))
# ==================== 账号服务 ====================
class AccountService:
"""Cursor 账号管理服务"""
@staticmethod
def create(
db: Session,
email: str,
token: Optional[str] = None,
password: str = None,
remark: str = None,
access_token: Optional[str] = None,
refresh_token: Optional[str] = None,
workos_session_token: Optional[str] = None
) -> CursorAccount:
"""
创建账号 (状态为 pending等待后台分析)
"""
resolved_token = workos_session_token or token or access_token
if not resolved_token:
raise ValueError("需要至少提供一个 Token (Workos 或 Access)")
db_account = CursorAccount(
email=email,
token=resolved_token,
access_token=access_token,
refresh_token=refresh_token,
workos_session_token=workos_session_token or token,
password=password,
status=AccountStatus.PENDING,
remark=remark
)
db.add(db_account)
db.commit()
db.refresh(db_account)
return db_account
@staticmethod
def get_by_id(db: Session, account_id: int) -> Optional[CursorAccount]:
return db.query(CursorAccount).filter(CursorAccount.id == account_id).first()
@staticmethod
def get_by_email(db: Session, email: str) -> Optional[CursorAccount]:
return db.query(CursorAccount).filter(CursorAccount.email == email).first()
@staticmethod
def get_all(db: Session, skip: int = 0, limit: int = 100, status: AccountStatus = None) -> List[CursorAccount]:
query = db.query(CursorAccount)
if status:
query = query.filter(CursorAccount.status == status)
return query.order_by(CursorAccount.id.desc()).offset(skip).limit(limit).all()
@staticmethod
def get_available_for_key(db: Session, key: ActivationKey) -> Optional[CursorAccount]:
"""
为密钥分配一个可用账号
根据密钥类型选择合适的账号池
"""
query = db.query(CursorAccount).filter(
CursorAccount.status == AccountStatus.AVAILABLE
)
# 根据密钥类型选择账号
if key.membership_type == KeyMembershipType.AUTO:
# Auto密钥优先选择 free_trial 类型账号
query = query.filter(
CursorAccount.account_type.in_([AccountType.FREE_TRIAL, AccountType.FREE])
)
else:
# Pro密钥优先选择 pro 类型账号
query = query.filter(
CursorAccount.account_type.in_([AccountType.PRO, AccountType.FREE_TRIAL, AccountType.BUSINESS])
)
# 按用量百分比升序排列(选择用量最低的)
return query.order_by(CursorAccount.usage_percent.asc()).first()
@staticmethod
def get_pending_accounts(db: Session, limit: int = 10) -> List[CursorAccount]:
"""获取待分析的账号"""
return db.query(CursorAccount).filter(
CursorAccount.status.in_([AccountStatus.PENDING, AccountStatus.AVAILABLE]),
or_(
CursorAccount.last_analyzed_at == None,
CursorAccount.last_analyzed_at < datetime.now() - timedelta(minutes=30)
)
).limit(limit).all()
@staticmethod
def update_from_analysis(db: Session, account_id: int, analysis_data: Dict[str, Any]) -> Optional[CursorAccount]:
"""
根据 Cursor API 分析结果更新账号信息
"""
account = db.query(CursorAccount).filter(CursorAccount.id == account_id).first()
if not account:
return None
# 更新分析结果
if analysis_data.get("success"):
account.status = AccountStatus[analysis_data["status"].upper()]
account.account_type = AccountType[analysis_data["account_type"].upper()]
account.membership_type = analysis_data.get("membership_type")
account.billing_cycle_start = analysis_data.get("billing_cycle_start")
account.billing_cycle_end = analysis_data.get("billing_cycle_end")
account.trial_days_remaining = analysis_data.get("trial_days_remaining", 0)
account.usage_limit = analysis_data.get("usage_limit", 0)
account.usage_used = analysis_data.get("usage_used", 0)
account.usage_remaining = analysis_data.get("usage_remaining", 0)
account.usage_percent = analysis_data.get("usage_percent", Decimal("0"))
account.total_requests = analysis_data.get("total_requests", 0)
account.total_input_tokens = analysis_data.get("total_input_tokens", 0)
account.total_output_tokens = analysis_data.get("total_output_tokens", 0)
account.total_cost_cents = analysis_data.get("total_cost_cents", Decimal("0"))
account.last_analyzed_at = analysis_data.get("last_analyzed_at", datetime.now())
account.analyze_error = None
else:
# 分析失败
account.status = AccountStatus.INVALID
account.analyze_error = analysis_data.get("error", "Unknown error")
account.last_analyzed_at = datetime.now()
db.commit()
db.refresh(account)
return account
@staticmethod
def lock_account(db: Session, account: CursorAccount, key_id: int) -> CursorAccount:
"""锁定账号给指定密钥使用"""
account.status = AccountStatus.IN_USE
account.locked_by_key_id = key_id
account.locked_at = datetime.now()
db.commit()
db.refresh(account)
return account
@staticmethod
def release_account(db: Session, account: CursorAccount) -> CursorAccount:
"""释放账号"""
# 根据用量决定状态使用配置阈值默认98%
threshold = GlobalSettingsService.get_int(db, "auto_switch_threshold") or 98
if account.usage_percent and float(account.usage_percent) >= threshold:
account.status = AccountStatus.EXHAUSTED
else:
account.status = AccountStatus.AVAILABLE
account.locked_by_key_id = None
account.locked_at = None
db.commit()
db.refresh(account)
return account
@staticmethod
def update(db: Session, account_id: int, **kwargs) -> Optional[CursorAccount]:
account = db.query(CursorAccount).filter(CursorAccount.id == account_id).first()
if account:
for key, value in kwargs.items():
if hasattr(account, key) and value is not None:
setattr(account, key, value)
db.commit()
db.refresh(account)
return account
@staticmethod
def delete(db: Session, account_id: int) -> bool:
account = db.query(CursorAccount).filter(CursorAccount.id == account_id).first()
if account:
db.delete(account)
db.commit()
return True
return False
@staticmethod
def count(db: Session) -> dict:
"""统计账号数量"""
rows = db.query(
CursorAccount.status, func.count(CursorAccount.id)
).group_by(CursorAccount.status).all()
counts = {status.value: cnt for status, cnt in rows}
total = sum(counts.values())
return {
"total": total,
"available": counts.get(AccountStatus.AVAILABLE.value, 0),
"in_use": counts.get(AccountStatus.IN_USE.value, 0),
"exhausted": counts.get(AccountStatus.EXHAUSTED.value, 0),
"pending": counts.get(AccountStatus.PENDING.value, 0),
"invalid": counts.get(AccountStatus.INVALID.value, 0),
}
# ==================== 激活码服务 ====================
class KeyService:
"""激活码管理服务"""
@staticmethod
def _get_bool_setting(db: Session, key: str, default: bool = False) -> bool:
value = GlobalSettingsService.get(db, key)
if value is None:
return default
return str(value).strip().lower() in ("true", "1", "yes", "y", "on")
@staticmethod
def _get_key_max_devices(db: Session) -> int:
limit = GlobalSettingsService.get_int(db, "key_max_devices")
return limit if limit > 0 else 2
@staticmethod
def _is_device_bound(db: Session, key_id: int, device_id: str) -> bool:
if not device_id:
return False
return db.query(KeyDevice).filter(
KeyDevice.key_id == key_id,
KeyDevice.device_id == device_id
).first() is not None
@staticmethod
def _ensure_device_bound(db: Session, key: ActivationKey, device_id: str) -> Tuple[bool, str]:
if not device_id:
return True, ""
existing = db.query(KeyDevice).filter(
KeyDevice.key_id == key.id,
KeyDevice.device_id == device_id
).first()
now = datetime.now()
if existing:
existing.last_active_at = now
return True, ""
limit = KeyService._get_key_max_devices(db)
count = db.query(KeyDevice).filter(KeyDevice.key_id == key.id).count()
if count >= limit:
return False, f"设备数量已达上限({limit}),请先解绑旧设备"
db.add(KeyDevice(
key_id=key.id,
device_id=device_id,
last_active_at=now
))
return True, ""
@staticmethod
def _find_master_for_device(db: Session, membership_type: KeyMembershipType, device_id: str) -> Optional[ActivationKey]:
if not device_id:
return None
# 新逻辑:以 key_devices 为准
master = db.query(ActivationKey).join(
KeyDevice, KeyDevice.key_id == ActivationKey.id
).filter(
KeyDevice.device_id == device_id,
ActivationKey.membership_type == membership_type,
ActivationKey.status != KeyStatus.DISABLED,
ActivationKey.master_key_id == None
).order_by(ActivationKey.id.desc()).first()
if master:
return master
# 兼容旧数据activation_keys.device_id 绑定
return db.query(ActivationKey).filter(
ActivationKey.device_id == device_id,
ActivationKey.membership_type == membership_type,
ActivationKey.status != KeyStatus.DISABLED,
ActivationKey.master_key_id == None
).order_by(ActivationKey.id.desc()).first()
@staticmethod
def create(
db: Session,
count: int = 1,
membership_type: KeyMembershipType = KeyMembershipType.PRO,
duration_days: int = 30,
quota: int = 500,
max_devices: int = 2,
remark: str = None
) -> List[ActivationKey]:
"""创建激活码(支持批量)"""
keys = []
max_retries = 5
for _ in range(count):
for retry in range(max_retries):
key_str = generate_key()
existing = db.query(ActivationKey).filter(ActivationKey.key == key_str).first()
if not existing:
break
if retry == max_retries - 1:
raise ValueError("无法生成唯一激活码,请重试")
db_key = ActivationKey(
key=key_str,
status=KeyStatus.UNUSED,
membership_type=membership_type,
duration_days=duration_days if membership_type == KeyMembershipType.AUTO else 0,
quota_contribution=quota if membership_type == KeyMembershipType.PRO else 0,
quota=quota if membership_type == KeyMembershipType.PRO else 0,
max_devices=max_devices,
remark=remark
)
db.add(db_key)
keys.append(db_key)
db.commit()
for k in keys:
db.refresh(k)
return keys
@staticmethod
def get_by_key(db: Session, key: str) -> Optional[ActivationKey]:
return db.query(ActivationKey).filter(ActivationKey.key == key).first()
@staticmethod
def get_by_id(db: Session, key_id: int) -> Optional[ActivationKey]:
return db.query(ActivationKey).filter(ActivationKey.id == key_id).first()
@staticmethod
def get_all(db: Session, skip: int = 0, limit: int = 100) -> List[ActivationKey]:
return db.query(ActivationKey).order_by(ActivationKey.id.desc()).offset(skip).limit(limit).all()
@staticmethod
def activate(db: Session, key: ActivationKey, device_id: str = None) -> Tuple[bool, str, Optional[ActivationKey]]:
"""
激活密钥 (不分配账号!)
返回: (成功, 消息, 密钥)
"""
now = datetime.now()
# 已合并的子密钥:默认禁止继续使用
if key.master_key_id:
master_key = KeyService.get_by_id(db, key.master_key_id)
if not master_key:
return False, "该密钥已合并,但主密钥不存在,请联系管理员", None
# 仅当当前设备已绑定到主密钥时,允许自动“跳转到主密钥”(用于同设备迁移/兼容)
if device_id and KeyService._is_device_bound(db, master_key.id, device_id):
ok, bind_msg = KeyService._ensure_device_bound(db, master_key, device_id)
if not ok:
return False, bind_msg, None
master_key.device_id = device_id
master_key.last_active_at = now
db.commit()
return True, "该密钥已合并,已为你切换到主密钥", master_key
return False, "该密钥已合并,请使用主密钥", None
# 检查密钥状态
if key.status == KeyStatus.DISABLED:
return False, "该密钥已被禁用", None
if key.status == KeyStatus.EXPIRED:
return False, "该密钥已过期", None
# 已激活的密钥
if key.status == KeyStatus.ACTIVE:
# 检查是否过期
if key.is_expired:
key.status = KeyStatus.EXPIRED
db.commit()
return False, "该密钥已过期", None
# 更新设备ID和活跃时间
if device_id:
ok, bind_msg = KeyService._ensure_device_bound(db, key, device_id)
if not ok:
return False, bind_msg, None
key.device_id = device_id
key.last_active_at = now
db.commit()
return True, "密钥已激活", key
# 查找该设备同类型的主密钥(用于合并)
master_key = None
auto_merge_enabled = KeyService._get_bool_setting(db, "auto_merge_enabled", default=True)
if auto_merge_enabled and device_id:
master_key = KeyService._find_master_for_device(db, key.membership_type, device_id)
if master_key:
# 合并到现有主密钥
key.master_key_id = master_key.id
key.merged_at = now
key.device_id = device_id
key.status = KeyStatus.ACTIVE # 合并密钥标记为已使用(通过 master_key_id 判定不可直接使用)
# 叠加资源
if key.membership_type == KeyMembershipType.PRO:
master_key.quota += key.quota_contribution
else:
base = master_key.expire_at if master_key.expire_at and master_key.expire_at > now else now
master_key.expire_at = base + timedelta(days=key.duration_days)
# 合并后主密钥可能从“过期”恢复为可用
if master_key.status in (KeyStatus.EXPIRED, KeyStatus.UNUSED):
master_key.status = KeyStatus.ACTIVE
master_key.merged_count += 1
if device_id:
ok, bind_msg = KeyService._ensure_device_bound(db, master_key, device_id)
if not ok:
return False, bind_msg, None
master_key.device_id = device_id
master_key.last_active_at = now
db.commit()
return True, f"密钥已合并,{'积分' if key.membership_type == KeyMembershipType.PRO else '时长'}已叠加", master_key
else:
# 成为主密钥
key.status = KeyStatus.ACTIVE
key.device_id = device_id
key.first_activated_at = now
key.last_active_at = now
# 设置初始到期时间Auto
if key.membership_type == KeyMembershipType.AUTO and key.duration_days > 0:
key.expire_at = now + timedelta(days=key.duration_days)
if device_id:
ok, bind_msg = KeyService._ensure_device_bound(db, key, device_id)
if not ok:
return False, bind_msg, None
db.commit()
return True, "激活成功", key
@staticmethod
def enable_seamless(db: Session, key: ActivationKey, device_id: str) -> Tuple[bool, str, Optional[CursorAccount]]:
"""
启用无感换号并分配账号
返回: (成功, 消息, 账号)
"""
# 检查密钥状态
if key.status != KeyStatus.ACTIVE:
return False, "密钥未激活", None
if key.is_expired:
key.status = KeyStatus.EXPIRED
db.commit()
return False, "密钥已过期", None
# 如果已启用且有账号,直接返回
if key.seamless_enabled and key.current_account_id:
account = AccountService.get_by_id(db, key.current_account_id)
if account and account.status == AccountStatus.IN_USE:
return True, "无感换号已启用", account
# 分配账号
account = AccountService.get_available_for_key(db, key)
if not account:
return False, "无可用账号", None
# 锁定账号
AccountService.lock_account(db, account, key.id)
# 更新密钥
key.seamless_enabled = True
key.current_account_id = account.id
key.device_id = device_id
key.last_active_at = datetime.now()
db.commit()
return True, "无感换号已启用", account
@staticmethod
def disable_seamless(db: Session, key: ActivationKey) -> Tuple[bool, str]:
"""禁用无感换号"""
if key.current_account_id:
account = AccountService.get_by_id(db, key.current_account_id)
if account:
AccountService.release_account(db, account)
key.seamless_enabled = False
key.current_account_id = None
db.commit()
return True, "无感换号已禁用"
@staticmethod
def switch_account(db: Session, key: ActivationKey) -> Tuple[bool, str, Optional[CursorAccount]]:
"""
换号
返回: (成功, 消息, 新账号)
"""
# 检查状态
if key.status != KeyStatus.ACTIVE:
return False, "密钥未激活", None
if not key.seamless_enabled:
return False, "未启用无感换号", None
if key.is_expired:
key.status = KeyStatus.EXPIRED
db.commit()
return False, "密钥已过期", None
# Pro密钥检查积分
if key.membership_type == KeyMembershipType.PRO:
quota_cost = GlobalSettingsService.get_int(db, "pro_quota_per_switch")
if key.quota_remaining < quota_cost:
return False, f"积分不足,需要{quota_cost},剩余{key.quota_remaining}", None
# Auto密钥检查换号冷却时间
if key.membership_type == KeyMembershipType.AUTO:
interval_minutes = GlobalSettingsService.get_int(db, "auto_switch_interval") or 0
if interval_minutes > 0 and key.last_switch_at:
elapsed = (datetime.now() - key.last_switch_at).total_seconds()
remaining_seconds = interval_minutes * 60 - elapsed
if remaining_seconds > 0:
remaining_minutes = int(remaining_seconds / 60) + 1
return False, f"换号冷却中,请{remaining_minutes}分钟后再试", None
# 先查找新账号,再释放旧账号(避免竞态条件)
old_account = None
if key.current_account_id:
old_account = AccountService.get_by_id(db, key.current_account_id)
# 分配新账号(排除当前账号)
new_account = AccountService.get_available_for_key(db, key)
if not new_account:
return False, "无可用账号", None
# 确认有新账号后,再释放旧账号
if old_account:
AccountService.release_account(db, old_account)
# 锁定新账号
AccountService.lock_account(db, new_account, key.id)
# 更新密钥
key.current_account_id = new_account.id
# Pro扣除积分
if key.membership_type == KeyMembershipType.PRO:
quota_cost = GlobalSettingsService.get_int(db, "pro_quota_per_switch")
key.quota_used += quota_cost
key.switch_count += 1
key.last_switch_at = datetime.now()
db.commit()
# 记录日志
LogService.log(
db, key.id, "switch",
account_id=new_account.id,
success=True,
message=f"{old_account.email if old_account else 'N/A'} 切换到 {new_account.email}",
usage_snapshot={
"old_account": old_account.to_dict() if old_account else None,
"new_account": new_account.to_dict()
}
)
return True, "换号成功", new_account
@staticmethod
def get_status(db: Session, key: ActivationKey) -> Dict[str, Any]:
"""获取密钥完整状态"""
data = key.to_dict(include_account=True)
# 添加账号信息
if key.seamless_enabled and key.current_account_id:
account = AccountService.get_by_id(db, key.current_account_id)
if account:
data["account_info"] = account.to_dict()
return data
@staticmethod
def update(db: Session, key_id: int, **kwargs) -> Optional[ActivationKey]:
key = db.query(ActivationKey).filter(ActivationKey.id == key_id).first()
if key:
for k, v in kwargs.items():
if hasattr(key, k) and v is not None:
setattr(key, k, v)
db.commit()
db.refresh(key)
return key
@staticmethod
def delete(db: Session, key_id: int) -> bool:
key = db.query(ActivationKey).filter(ActivationKey.id == key_id).first()
if key:
# 释放关联账号
if key.current_account_id:
account = AccountService.get_by_id(db, key.current_account_id)
if account:
AccountService.release_account(db, account)
# 删除关联设备
db.query(KeyDevice).filter(KeyDevice.key_id == key_id).delete()
db.delete(key)
db.commit()
return True
return False
@staticmethod
def count(db: Session) -> dict:
"""统计激活码数量"""
status_rows = db.query(
ActivationKey.status, func.count(ActivationKey.id)
).group_by(ActivationKey.status).all()
sc = {s.value: c for s, c in status_rows}
type_rows = db.query(
ActivationKey.membership_type, func.count(ActivationKey.id)
).group_by(ActivationKey.membership_type).all()
tc = {t.value: c for t, c in type_rows}
total = sum(sc.values())
return {
"total": total,
"unused": sc.get(KeyStatus.UNUSED.value, 0),
"active": sc.get(KeyStatus.ACTIVE.value, 0),
"expired": sc.get(KeyStatus.EXPIRED.value, 0),
"auto": tc.get(KeyMembershipType.AUTO.value, 0),
"pro": tc.get(KeyMembershipType.PRO.value, 0),
}
# ==================== 日志服务 ====================
class LogService:
"""日志服务"""
@staticmethod
def log(
db: Session,
key_id: int,
action: str,
account_id: int = None,
ip_address: str = None,
user_agent: str = None,
device_id: str = None,
success: bool = True,
message: str = None,
usage_snapshot: Dict = None
):
log = UsageLog(
key_id=key_id,
account_id=account_id,
action=action,
ip_address=ip_address,
user_agent=user_agent,
device_id=device_id,
success=success,
message=message,
usage_snapshot=usage_snapshot
)
db.add(log)
db.commit()
@staticmethod
def get_by_key(db: Session, key_id: int, limit: int = 50) -> List[UsageLog]:
return db.query(UsageLog).filter(
UsageLog.key_id == key_id
).order_by(UsageLog.created_at.desc()).limit(limit).all()
@staticmethod
def get_today_count(db: Session) -> int:
today = datetime.now().date()
return db.query(UsageLog).filter(
func.date(UsageLog.created_at) == today
).count()
# ==================== 全局设置服务 ====================
class GlobalSettingsService:
"""全局设置服务"""
@staticmethod
def init_settings(db: Session):
"""初始化默认设置"""
defaults = GlobalSettings.get_default_settings()
for setting in defaults:
existing = db.query(GlobalSettings).filter(GlobalSettings.key == setting["key"]).first()
if not existing:
db.add(GlobalSettings(**setting))
db.commit()
@staticmethod
def get(db: Session, key: str) -> Optional[str]:
setting = db.query(GlobalSettings).filter(GlobalSettings.key == key).first()
if setting:
return setting.value
# 返回默认值
defaults = {s["key"]: s["value"] for s in GlobalSettings.get_default_settings()}
return defaults.get(key)
@staticmethod
def get_int(db: Session, key: str) -> int:
value = GlobalSettingsService.get(db, key)
return int(value) if value else 0
@staticmethod
def set(db: Session, key: str, value: str, description: str = None):
setting = db.query(GlobalSettings).filter(GlobalSettings.key == key).first()
if setting:
setting.value = value
if description:
setting.description = description
else:
setting = GlobalSettings(key=key, value=value, description=description)
db.add(setting)
db.commit()
@staticmethod
def get_all(db: Session) -> dict:
settings = db.query(GlobalSettings).all()
result = {}
for s in settings:
if s.value_type == "int":
result[s.key] = int(s.value)
elif s.value_type == "float":
result[s.key] = float(s.value)
elif s.value_type == "bool":
result[s.key] = s.value.lower() in ("true", "1", "yes")
else:
result[s.key] = s.value
return result
@staticmethod
def update_all(db: Session, **kwargs):
"""批量更新设置"""
for key, value in kwargs.items():
if value is not None:
GlobalSettingsService.set(db, key, str(value))
# ==================== 批量操作服务 ====================
class BatchService:
"""批量操作服务 - 用于管理后台批量处理密钥"""
@staticmethod
def extend_keys(
db: Session,
key_ids: List[int],
extend_days: int = 0,
add_quota: int = 0
) -> dict:
"""
批量延长密钥
- extend_days: 延长天数 (Auto和Pro都可用)
- add_quota: 增加额度 (仅Pro有效)
"""
success = 0
failed = 0
errors = []
for key_id in key_ids:
try:
key = db.query(ActivationKey).filter(ActivationKey.id == key_id).first()
if not key:
failed += 1
errors.append(f"ID {key_id}: 密钥不存在")
continue
# 延长时间
if extend_days > 0:
if key.expire_at:
# 如果已过期,从今天开始算
if key.expire_at < datetime.now():
key.expire_at = datetime.now() + timedelta(days=extend_days)
else:
key.expire_at = key.expire_at + timedelta(days=extend_days)
else:
# 首次设置过期时间
key.expire_at = datetime.now() + timedelta(days=extend_days)
# 增加额度 (仅Pro)
if add_quota > 0 and key.membership_type == KeyMembershipType.PRO:
key.quota = (key.quota or 0) + add_quota
success += 1
except Exception as e:
failed += 1
errors.append(f"ID {key_id}: {str(e)}")
db.commit()
return {
"success": success,
"failed": failed,
"errors": errors[:20]
}
@staticmethod
def get_keys_for_compensation(
db: Session,
membership_type: KeyMembershipType = None,
activated_before: datetime = None,
not_expired_on: datetime = None
) -> List[ActivationKey]:
"""
获取符合补偿条件的密钥列表
参数:
- membership_type: 筛选密钥类型 (AUTO/PRO)
- activated_before: 在此日期之前激活
- not_expired_on: 在此日期时还未过期
"""
query = db.query(ActivationKey).filter(
ActivationKey.first_activated_at != None # 必须已激活
)
# 类型筛选
if membership_type:
query = query.filter(ActivationKey.membership_type == membership_type)
# 激活时间筛选
if activated_before:
query = query.filter(ActivationKey.first_activated_at < activated_before)
# 在指定日期时未过期 (expire_at > not_expired_on 或 expire_at 为空)
if not_expired_on:
query = query.filter(
or_(
ActivationKey.expire_at > not_expired_on,
ActivationKey.expire_at == None
)
)
return query.all()
@staticmethod
def batch_compensate(
db: Session,
membership_type: KeyMembershipType = None,
activated_before: datetime = None,
not_expired_on: datetime = None,
extend_days: int = 0,
add_quota: int = 0
) -> dict:
"""
批量补偿密钥
补偿逻辑:
- 如果密钥当前未过期: expire_at += extend_days
- 如果密钥已过期(但符合补偿条件): expire_at = 今天 + extend_days
参数:
- membership_type: 筛选密钥类型
- activated_before: 在此日期之前激活
- not_expired_on: 在此日期时还未过期
- extend_days: 延长天数
- add_quota: 增加额度 (仅Pro)
"""
# 获取符合条件的密钥
keys = BatchService.get_keys_for_compensation(
db, membership_type, activated_before, not_expired_on
)
if not keys:
return {
"success": 0,
"failed": 0,
"total_matched": 0,
"message": "没有找到符合条件的密钥"
}
success = 0
failed = 0
errors = []
now = datetime.now()
for key in keys:
try:
# 延长时间
if extend_days > 0:
if key.expire_at:
if key.expire_at < now:
# 已过期,从今天开始
key.expire_at = now + timedelta(days=extend_days)
# 恢复状态
if key.status == KeyStatus.EXPIRED:
key.status = KeyStatus.ACTIVE
else:
# 未过期,在原基础上延长
key.expire_at = key.expire_at + timedelta(days=extend_days)
else:
key.expire_at = now + timedelta(days=extend_days)
# 增加额度 (仅Pro)
if add_quota > 0 and key.membership_type == KeyMembershipType.PRO:
key.quota = (key.quota or 0) + add_quota
success += 1
except Exception as e:
failed += 1
errors.append(f"密钥 {key.key[:8]}...: {str(e)}")
db.commit()
return {
"success": success,
"failed": failed,
"total_matched": len(keys),
"extend_days": extend_days,
"add_quota": add_quota,
"errors": errors[:20],
"message": f"成功补偿 {success} 个密钥" + (f"{failed} 个失败" if failed > 0 else "")
}