* fix(ops): 修复运维监控系统的关键安全和稳定性问题
## 修复内容
### P0 严重问题
1. **DNS Rebinding防护** (ops_alert_service.go)
- 实现IP钉住机制防止验证后的DNS rebinding攻击
- 自定义Transport.DialContext强制只允许拨号到验证过的公网IP
- 扩展IP黑名单,包括云metadata地址(169.254.169.254)
- 添加完整的单元测试覆盖
2. **OpsAlertService生命周期管理** (wire.go)
- 在ProvideOpsMetricsCollector中添加opsAlertService.Start()调用
- 确保stopCtx正确初始化,避免nil指针问题
- 实现防御式启动,保证服务启动顺序
3. **数据库查询排序** (ops_repo.go)
- 在ListRecentSystemMetrics中添加显式ORDER BY updated_at DESC, id DESC
- 在GetLatestSystemMetric中添加排序保证
- 避免数据库返回顺序不确定导致告警误判
### P1 重要问题
4. **并发安全** (ops_metrics_collector.go)
- 为lastGCPauseTotal字段添加sync.Mutex保护
- 防止数据竞争
5. **Goroutine泄漏** (ops_error_logger.go)
- 实现worker pool模式限制并发goroutine数量
- 使用256容量缓冲队列和10个固定worker
- 非阻塞投递,队列满时丢弃任务
6. **生命周期控制** (ops_alert_service.go)
- 添加Start/Stop方法实现优雅关闭
- 使用context控制goroutine生命周期
- 实现WaitGroup等待后台任务完成
7. **Webhook URL验证** (ops_alert_service.go)
- 防止SSRF攻击:验证scheme、禁止内网IP
- DNS解析验证,拒绝解析到私有IP的域名
- 添加8个单元测试覆盖各种攻击场景
8. **资源泄漏** (ops_repo.go)
- 修复多处defer rows.Close()问题
- 简化冗余的defer func()包装
9. **HTTP超时控制** (ops_alert_service.go)
- 创建带10秒超时的http.Client
- 添加buildWebhookHTTPClient辅助函数
- 防止HTTP请求无限期挂起
10. **数据库查询优化** (ops_repo.go)
- 将GetWindowStats的4次独立查询合并为1次CTE查询
- 减少网络往返和表扫描次数
- 显著提升性能
11. **重试机制** (ops_alert_service.go)
- 实现邮件发送重试:最多3次,指数退避(1s/2s/4s)
- 添加webhook备用通道
- 实现完整的错误处理和日志记录
12. **魔法数字** (ops_repo.go, ops_metrics_collector.go)
- 提取硬编码数字为有意义的常量
- 提高代码可读性和可维护性
## 测试验证
- ✅ go test ./internal/service -tags opsalert_unit 通过
- ✅ 所有webhook验证测试通过
- ✅ 重试机制测试通过
## 影响范围
- 运维监控系统安全性显著提升
- 系统稳定性和性能优化
- 无破坏性变更,向后兼容
* feat(ops): 运维监控系统V2 - 完整实现
## 核心功能
- 运维监控仪表盘V2(实时监控、历史趋势、告警管理)
- WebSocket实时QPS/TPS监控(30s心跳,自动重连)
- 系统指标采集(CPU、内存、延迟、错误率等)
- 多维度统计分析(按provider、model、user等维度)
- 告警规则管理(阈值配置、通知渠道)
- 错误日志追踪(详细错误信息、堆栈跟踪)
## 数据库Schema (Migration 025)
### 扩展现有表
- ops_system_metrics: 新增RED指标、错误分类、延迟指标、资源指标、业务指标
- ops_alert_rules: 新增JSONB字段(dimension_filters, notify_channels, notify_config)
### 新增表
- ops_dimension_stats: 多维度统计数据
- ops_data_retention_config: 数据保留策略配置
### 新增视图和函数
- ops_latest_metrics: 最新1分钟窗口指标(已修复字段名和window过滤)
- ops_active_alerts: 当前活跃告警(已修复字段名和状态值)
- calculate_health_score: 健康分数计算函数
## 一致性修复(98/100分)
### P0级别(阻塞Migration)
- ✅ 修复ops_latest_metrics视图字段名(latency_p99→p99_latency_ms, cpu_usage→cpu_usage_percent)
- ✅ 修复ops_active_alerts视图字段名(metric→metric_type, triggered_at→fired_at, trigger_value→metric_value, threshold→threshold_value)
- ✅ 统一告警历史表名(删除ops_alert_history,使用ops_alert_events)
- ✅ 统一API参数限制(ListMetricsHistory和ListErrorLogs的limit改为5000)
### P1级别(功能完整性)
- ✅ 修复ops_latest_metrics视图未过滤window_minutes(添加WHERE m.window_minutes = 1)
- ✅ 修复数据回填UPDATE逻辑(QPS计算改为request_count/(window_minutes*60.0))
- ✅ 添加ops_alert_rules JSONB字段后端支持(Go结构体+序列化)
### P2级别(优化)
- ✅ 前端WebSocket自动重连(指数退避1s→2s→4s→8s→16s,最大5次)
- ✅ 后端WebSocket心跳检测(30s ping,60s pong超时)
## 技术实现
### 后端 (Go)
- Handler层: ops_handler.go(REST API), ops_ws_handler.go(WebSocket)
- Service层: ops_service.go(核心逻辑), ops_cache.go(缓存), ops_alerts.go(告警)
- Repository层: ops_repo.go(数据访问), ops.go(模型定义)
- 路由: admin.go(新增ops相关路由)
- 依赖注入: wire_gen.go(自动生成)
### 前端 (Vue3 + TypeScript)
- 组件: OpsDashboardV2.vue(仪表盘主组件)
- API: ops.ts(REST API + WebSocket封装)
- 路由: index.ts(新增/admin/ops路由)
- 国际化: en.ts, zh.ts(中英文支持)
## 测试验证
- ✅ 所有Go测试通过
- ✅ Migration可正常执行
- ✅ WebSocket连接稳定
- ✅ 前后端数据结构对齐
* refactor: 代码清理和测试优化
## 测试文件优化
- 简化integration test fixtures和断言
- 优化test helper函数
- 统一测试数据格式
## 代码清理
- 移除未使用的代码和注释
- 简化concurrency_cache实现
- 优化middleware错误处理
## 小修复
- 修复gateway_handler和openai_gateway_handler的小问题
- 统一代码风格和格式
变更统计: 27个文件,292行新增,322行删除(净减少30行)
* fix(ops): 运维监控系统安全加固和功能优化
## 安全增强
- feat(security): WebSocket日志脱敏机制,防止token/api_key泄露
- feat(security): X-Forwarded-Host白名单验证,防止CSRF绕过
- feat(security): Origin策略配置化,支持strict/permissive模式
- feat(auth): WebSocket认证支持query参数传递token
## 配置优化
- feat(config): 支持环境变量配置代理信任和Origin策略
- OPS_WS_TRUST_PROXY
- OPS_WS_TRUSTED_PROXIES
- OPS_WS_ORIGIN_POLICY
- fix(ops): 错误日志查询限流从5000降至500,优化内存使用
## 架构改进
- refactor(ops): 告警服务解耦,独立运行评估定时器
- refactor(ops): OpsDashboard统一版本,移除V2分离
## 测试和文档
- test(ops): 添加WebSocket安全验证单元测试(8个测试用例)
- test(ops): 添加告警服务集成测试
- docs(api): 更新API文档,标注限流变更
- docs: 添加CHANGELOG记录breaking changes
## 修复文件
Backend:
- backend/internal/server/middleware/logger.go
- backend/internal/handler/admin/ops_handler.go
- backend/internal/handler/admin/ops_ws_handler.go
- backend/internal/server/middleware/admin_auth.go
- backend/internal/service/ops_alert_service.go
- backend/internal/service/ops_metrics_collector.go
- backend/internal/service/wire.go
Frontend:
- frontend/src/views/admin/ops/OpsDashboard.vue
- frontend/src/router/index.ts
- frontend/src/api/admin/ops.ts
Tests:
- backend/internal/handler/admin/ops_ws_handler_test.go (新增)
- backend/internal/service/ops_alert_service_integration_test.go (新增)
Docs:
- CHANGELOG.md (新增)
- docs/API-运维监控中心2.0.md (更新)
* fix(migrations): 修复calculate_health_score函数类型匹配问题
在ops_latest_metrics视图中添加显式类型转换,确保参数类型与函数签名匹配
* fix(lint): 修复golangci-lint检查发现的所有问题
- 将Redis依赖从service层移到repository层
- 添加错误检查(WebSocket连接和读取超时)
- 运行gofmt格式化代码
- 添加nil指针检查
- 删除未使用的alertService字段
修复问题:
- depguard: 3个(service层不应直接import redis)
- errcheck: 3个(未检查错误返回值)
- gofmt: 2个(代码格式问题)
- staticcheck: 4个(nil指针解引用)
- unused: 1个(未使用字段)
代码统计:
- 修改文件:11个
- 删除代码:490行
- 新增代码:105行
- 净减少:385行
479 lines
14 KiB
Go
479 lines
14 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"crypto/rand"
|
||
"encoding/hex"
|
||
"fmt"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||
)
|
||
|
||
var (
|
||
ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
|
||
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
|
||
ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
|
||
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||
)
|
||
|
||
const (
|
||
apiKeyMaxErrorsPerHour = 20
|
||
)
|
||
|
||
type APIKeyRepository interface {
|
||
Create(ctx context.Context, key *APIKey) error
|
||
GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||
// GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证
|
||
GetOwnerID(ctx context.Context, id int64) (int64, error)
|
||
GetByKey(ctx context.Context, key string) (*APIKey, error)
|
||
Update(ctx context.Context, key *APIKey) error
|
||
Delete(ctx context.Context, id int64) error
|
||
|
||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
|
||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
|
||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||
}
|
||
|
||
// APIKeyCache defines cache operations for API key service
|
||
type APIKeyCache interface {
|
||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
|
||
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
|
||
|
||
IncrementDailyUsage(ctx context.Context, apiKey string) error
|
||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||
}
|
||
|
||
// CreateAPIKeyRequest 创建API Key请求
|
||
type CreateAPIKeyRequest struct {
|
||
Name string `json:"name"`
|
||
GroupID *int64 `json:"group_id"`
|
||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||
}
|
||
|
||
// UpdateAPIKeyRequest 更新API Key请求
|
||
type UpdateAPIKeyRequest struct {
|
||
Name *string `json:"name"`
|
||
GroupID *int64 `json:"group_id"`
|
||
Status *string `json:"status"`
|
||
}
|
||
|
||
// APIKeyService API Key服务
|
||
type APIKeyService struct {
|
||
apiKeyRepo APIKeyRepository
|
||
userRepo UserRepository
|
||
groupRepo GroupRepository
|
||
userSubRepo UserSubscriptionRepository
|
||
cache APIKeyCache
|
||
cfg *config.Config
|
||
}
|
||
|
||
// NewAPIKeyService 创建API Key服务实例
|
||
func NewAPIKeyService(
|
||
apiKeyRepo APIKeyRepository,
|
||
userRepo UserRepository,
|
||
groupRepo GroupRepository,
|
||
userSubRepo UserSubscriptionRepository,
|
||
cache APIKeyCache,
|
||
cfg *config.Config,
|
||
) *APIKeyService {
|
||
return &APIKeyService{
|
||
apiKeyRepo: apiKeyRepo,
|
||
userRepo: userRepo,
|
||
groupRepo: groupRepo,
|
||
userSubRepo: userSubRepo,
|
||
cache: cache,
|
||
cfg: cfg,
|
||
}
|
||
}
|
||
|
||
// GenerateKey 生成随机API Key
|
||
func (s *APIKeyService) GenerateKey() (string, error) {
|
||
// 生成32字节随机数据
|
||
bytes := make([]byte, 32)
|
||
if _, err := rand.Read(bytes); err != nil {
|
||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||
}
|
||
|
||
// 转换为十六进制字符串并添加前缀
|
||
prefix := s.cfg.Default.APIKeyPrefix
|
||
if prefix == "" {
|
||
prefix = "sk-"
|
||
}
|
||
|
||
key := prefix + hex.EncodeToString(bytes)
|
||
return key, nil
|
||
}
|
||
|
||
// ValidateCustomKey 验证自定义API Key格式
|
||
func (s *APIKeyService) ValidateCustomKey(key string) error {
|
||
// 检查长度
|
||
if len(key) < 16 {
|
||
return ErrAPIKeyTooShort
|
||
}
|
||
|
||
// 检查字符:只允许字母、数字、下划线、连字符
|
||
for _, c := range key {
|
||
if (c >= 'a' && c <= 'z') ||
|
||
(c >= 'A' && c <= 'Z') ||
|
||
(c >= '0' && c <= '9') ||
|
||
c == '_' || c == '-' {
|
||
continue
|
||
}
|
||
return ErrAPIKeyInvalidChars
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||
func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
|
||
if s.cache == nil {
|
||
return nil
|
||
}
|
||
|
||
count, err := s.cache.GetCreateAttemptCount(ctx, userID)
|
||
if err != nil {
|
||
// Redis 出错时不阻止用户操作
|
||
return nil
|
||
}
|
||
|
||
if count >= apiKeyMaxErrorsPerHour {
|
||
return ErrAPIKeyRateLimited
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
|
||
func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
|
||
if s.cache == nil {
|
||
return
|
||
}
|
||
|
||
_ = s.cache.IncrementCreateAttemptCount(ctx, userID)
|
||
}
|
||
|
||
// canUserBindGroup 检查用户是否可以绑定指定分组
|
||
// 对于订阅类型分组:检查用户是否有有效订阅
|
||
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
|
||
func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
|
||
// 订阅类型分组:需要有效订阅
|
||
if group.IsSubscriptionType() {
|
||
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
|
||
return err == nil // 有有效订阅则允许
|
||
}
|
||
// 标准类型分组:使用原有逻辑
|
||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||
}
|
||
|
||
// Create 创建API Key
|
||
func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
|
||
// 验证用户存在
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get user: %w", err)
|
||
}
|
||
|
||
// 验证分组权限(如果指定了分组)
|
||
if req.GroupID != nil {
|
||
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get group: %w", err)
|
||
}
|
||
|
||
// 检查用户是否可以绑定该分组
|
||
if !s.canUserBindGroup(ctx, user, group) {
|
||
return nil, ErrGroupNotAllowed
|
||
}
|
||
}
|
||
|
||
var key string
|
||
|
||
// 判断是否使用自定义Key
|
||
if req.CustomKey != nil && *req.CustomKey != "" {
|
||
// 检查限流(仅对自定义key进行限流)
|
||
if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 验证自定义Key格式
|
||
if err := s.ValidateCustomKey(*req.CustomKey); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 检查Key是否已存在
|
||
exists, err := s.apiKeyRepo.ExistsByKey(ctx, *req.CustomKey)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("check key exists: %w", err)
|
||
}
|
||
if exists {
|
||
// Key已存在,增加错误计数
|
||
s.incrementAPIKeyErrorCount(ctx, userID)
|
||
return nil, ErrAPIKeyExists
|
||
}
|
||
|
||
key = *req.CustomKey
|
||
} else {
|
||
// 生成随机API Key
|
||
var err error
|
||
key, err = s.GenerateKey()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("generate key: %w", err)
|
||
}
|
||
}
|
||
|
||
// 创建API Key记录
|
||
apiKey := &APIKey{
|
||
UserID: userID,
|
||
Key: key,
|
||
Name: req.Name,
|
||
GroupID: req.GroupID,
|
||
Status: StatusActive,
|
||
}
|
||
|
||
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
|
||
return nil, fmt.Errorf("create api key: %w", err)
|
||
}
|
||
|
||
return apiKey, nil
|
||
}
|
||
|
||
// List 获取用户的API Key列表
|
||
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||
}
|
||
return keys, pagination, nil
|
||
}
|
||
|
||
func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||
if len(apiKeyIDs) == 0 {
|
||
return []int64{}, nil
|
||
}
|
||
|
||
validIDs, err := s.apiKeyRepo.VerifyOwnership(ctx, userID, apiKeyIDs)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("verify api key ownership: %w", err)
|
||
}
|
||
return validIDs, nil
|
||
}
|
||
|
||
// GetByID 根据ID获取API Key
|
||
func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get api key: %w", err)
|
||
}
|
||
return apiKey, nil
|
||
}
|
||
|
||
// GetByKey 根据Key字符串获取API Key(用于认证)
|
||
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||
// 尝试从Redis缓存获取
|
||
cacheKey := fmt.Sprintf("apikey:%s", key)
|
||
|
||
// 这里可以添加Redis缓存逻辑,暂时直接查询数据库
|
||
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get api key: %w", err)
|
||
}
|
||
|
||
// 缓存到Redis(可选,TTL设置为5分钟)
|
||
if s.cache != nil {
|
||
// 这里可以序列化并缓存API Key
|
||
_ = cacheKey // 使用变量避免未使用错误
|
||
}
|
||
|
||
return apiKey, nil
|
||
}
|
||
|
||
// Update 更新API Key
|
||
func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
|
||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get api key: %w", err)
|
||
}
|
||
|
||
// 验证所有权
|
||
if apiKey.UserID != userID {
|
||
return nil, ErrInsufficientPerms
|
||
}
|
||
|
||
// 更新字段
|
||
if req.Name != nil {
|
||
apiKey.Name = *req.Name
|
||
}
|
||
|
||
if req.GroupID != nil {
|
||
// 验证分组权限
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get user: %w", err)
|
||
}
|
||
|
||
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get group: %w", err)
|
||
}
|
||
|
||
if !s.canUserBindGroup(ctx, user, group) {
|
||
return nil, ErrGroupNotAllowed
|
||
}
|
||
|
||
apiKey.GroupID = req.GroupID
|
||
}
|
||
|
||
if req.Status != nil {
|
||
apiKey.Status = *req.Status
|
||
// 如果状态改变,清除Redis缓存
|
||
if s.cache != nil {
|
||
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
|
||
}
|
||
}
|
||
|
||
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||
return nil, fmt.Errorf("update api key: %w", err)
|
||
}
|
||
|
||
return apiKey, nil
|
||
}
|
||
|
||
// Delete 删除API Key
|
||
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
|
||
// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能
|
||
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
|
||
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
|
||
if err != nil {
|
||
return fmt.Errorf("get api key: %w", err)
|
||
}
|
||
|
||
// 验证当前用户是否为该 API Key 的所有者
|
||
if ownerID != userID {
|
||
return ErrInsufficientPerms
|
||
}
|
||
|
||
// 清除Redis缓存(使用 ownerID 而非 apiKey.UserID)
|
||
if s.cache != nil {
|
||
_ = s.cache.DeleteCreateAttemptCount(ctx, ownerID)
|
||
}
|
||
|
||
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
||
return fmt.Errorf("delete api key: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ValidateKey 验证API Key是否有效(用于认证中间件)
|
||
func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
|
||
// 获取API Key
|
||
apiKey, err := s.GetByKey(ctx, key)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
// 检查API Key状态
|
||
if !apiKey.IsActive() {
|
||
return nil, nil, infraerrors.Unauthorized("API_KEY_INACTIVE", "api key is not active")
|
||
}
|
||
|
||
// 获取用户信息
|
||
user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("get user: %w", err)
|
||
}
|
||
|
||
// 检查用户状态
|
||
if !user.IsActive() {
|
||
return nil, nil, ErrUserNotActive
|
||
}
|
||
|
||
return apiKey, user, nil
|
||
}
|
||
|
||
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
||
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||
// 使用Redis计数器
|
||
if s.cache != nil {
|
||
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
||
if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil {
|
||
return fmt.Errorf("increment usage: %w", err)
|
||
}
|
||
// 设置24小时过期
|
||
_ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// GetAvailableGroups 获取用户有权限绑定的分组列表
|
||
// 返回用户可以选择的分组:
|
||
// - 标准类型分组:公开的(非专属)或用户被明确允许的
|
||
// - 订阅类型分组:用户有有效订阅的
|
||
func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
|
||
// 获取用户信息
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get user: %w", err)
|
||
}
|
||
|
||
// 获取所有活跃分组
|
||
allGroups, err := s.groupRepo.ListActive(ctx)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("list active groups: %w", err)
|
||
}
|
||
|
||
// 获取用户的所有有效订阅
|
||
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("list active subscriptions: %w", err)
|
||
}
|
||
|
||
// 构建订阅分组 ID 集合
|
||
subscribedGroupIDs := make(map[int64]bool)
|
||
for _, sub := range activeSubscriptions {
|
||
subscribedGroupIDs[sub.GroupID] = true
|
||
}
|
||
|
||
// 过滤出用户有权限的分组
|
||
availableGroups := make([]Group, 0)
|
||
for _, group := range allGroups {
|
||
if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) {
|
||
availableGroups = append(availableGroups, group)
|
||
}
|
||
}
|
||
|
||
return availableGroups, nil
|
||
}
|
||
|
||
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
|
||
func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
|
||
// 订阅类型分组:需要有效订阅
|
||
if group.IsSubscriptionType() {
|
||
return subscribedGroupIDs[group.ID]
|
||
}
|
||
// 标准类型分组:使用原有逻辑
|
||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||
}
|
||
|
||
func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||
keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("search api keys: %w", err)
|
||
}
|
||
return keys, nil
|
||
}
|