运维监控系统安全加固和功能优化 (#21)
* fix(ops): 修复运维监控系统的关键安全和稳定性问题
## 修复内容
### P0 严重问题
1. **DNS Rebinding防护** (ops_alert_service.go)
- 实现IP钉住机制防止验证后的DNS rebinding攻击
- 自定义Transport.DialContext强制只允许拨号到验证过的公网IP
- 扩展IP黑名单,包括云metadata地址(169.254.169.254)
- 添加完整的单元测试覆盖
2. **OpsAlertService生命周期管理** (wire.go)
- 在ProvideOpsMetricsCollector中添加opsAlertService.Start()调用
- 确保stopCtx正确初始化,避免nil指针问题
- 实现防御式启动,保证服务启动顺序
3. **数据库查询排序** (ops_repo.go)
- 在ListRecentSystemMetrics中添加显式ORDER BY updated_at DESC, id DESC
- 在GetLatestSystemMetric中添加排序保证
- 避免数据库返回顺序不确定导致告警误判
### P1 重要问题
4. **并发安全** (ops_metrics_collector.go)
- 为lastGCPauseTotal字段添加sync.Mutex保护
- 防止数据竞争
5. **Goroutine泄漏** (ops_error_logger.go)
- 实现worker pool模式限制并发goroutine数量
- 使用256容量缓冲队列和10个固定worker
- 非阻塞投递,队列满时丢弃任务
6. **生命周期控制** (ops_alert_service.go)
- 添加Start/Stop方法实现优雅关闭
- 使用context控制goroutine生命周期
- 实现WaitGroup等待后台任务完成
7. **Webhook URL验证** (ops_alert_service.go)
- 防止SSRF攻击:验证scheme、禁止内网IP
- DNS解析验证,拒绝解析到私有IP的域名
- 添加8个单元测试覆盖各种攻击场景
8. **资源泄漏** (ops_repo.go)
- 修复多处defer rows.Close()问题
- 简化冗余的defer func()包装
9. **HTTP超时控制** (ops_alert_service.go)
- 创建带10秒超时的http.Client
- 添加buildWebhookHTTPClient辅助函数
- 防止HTTP请求无限期挂起
10. **数据库查询优化** (ops_repo.go)
- 将GetWindowStats的4次独立查询合并为1次CTE查询
- 减少网络往返和表扫描次数
- 显著提升性能
11. **重试机制** (ops_alert_service.go)
- 实现邮件发送重试:最多3次,指数退避(1s/2s/4s)
- 添加webhook备用通道
- 实现完整的错误处理和日志记录
12. **魔法数字** (ops_repo.go, ops_metrics_collector.go)
- 提取硬编码数字为有意义的常量
- 提高代码可读性和可维护性
## 测试验证
- ✅ go test ./internal/service -tags opsalert_unit 通过
- ✅ 所有webhook验证测试通过
- ✅ 重试机制测试通过
## 影响范围
- 运维监控系统安全性显著提升
- 系统稳定性和性能优化
- 无破坏性变更,向后兼容
* feat(ops): 运维监控系统V2 - 完整实现
## 核心功能
- 运维监控仪表盘V2(实时监控、历史趋势、告警管理)
- WebSocket实时QPS/TPS监控(30s心跳,自动重连)
- 系统指标采集(CPU、内存、延迟、错误率等)
- 多维度统计分析(按provider、model、user等维度)
- 告警规则管理(阈值配置、通知渠道)
- 错误日志追踪(详细错误信息、堆栈跟踪)
## 数据库Schema (Migration 025)
### 扩展现有表
- ops_system_metrics: 新增RED指标、错误分类、延迟指标、资源指标、业务指标
- ops_alert_rules: 新增JSONB字段(dimension_filters, notify_channels, notify_config)
### 新增表
- ops_dimension_stats: 多维度统计数据
- ops_data_retention_config: 数据保留策略配置
### 新增视图和函数
- ops_latest_metrics: 最新1分钟窗口指标(已修复字段名和window过滤)
- ops_active_alerts: 当前活跃告警(已修复字段名和状态值)
- calculate_health_score: 健康分数计算函数
## 一致性修复(98/100分)
### P0级别(阻塞Migration)
- ✅ 修复ops_latest_metrics视图字段名(latency_p99→p99_latency_ms, cpu_usage→cpu_usage_percent)
- ✅ 修复ops_active_alerts视图字段名(metric→metric_type, triggered_at→fired_at, trigger_value→metric_value, threshold→threshold_value)
- ✅ 统一告警历史表名(删除ops_alert_history,使用ops_alert_events)
- ✅ 统一API参数限制(ListMetricsHistory和ListErrorLogs的limit改为5000)
### P1级别(功能完整性)
- ✅ 修复ops_latest_metrics视图未过滤window_minutes(添加WHERE m.window_minutes = 1)
- ✅ 修复数据回填UPDATE逻辑(QPS计算改为request_count/(window_minutes*60.0))
- ✅ 添加ops_alert_rules JSONB字段后端支持(Go结构体+序列化)
### P2级别(优化)
- ✅ 前端WebSocket自动重连(指数退避1s→2s→4s→8s→16s,最大5次)
- ✅ 后端WebSocket心跳检测(30s ping,60s pong超时)
## 技术实现
### 后端 (Go)
- Handler层: ops_handler.go(REST API), ops_ws_handler.go(WebSocket)
- Service层: ops_service.go(核心逻辑), ops_cache.go(缓存), ops_alerts.go(告警)
- Repository层: ops_repo.go(数据访问), ops.go(模型定义)
- 路由: admin.go(新增ops相关路由)
- 依赖注入: wire_gen.go(自动生成)
### 前端 (Vue3 + TypeScript)
- 组件: OpsDashboardV2.vue(仪表盘主组件)
- API: ops.ts(REST API + WebSocket封装)
- 路由: index.ts(新增/admin/ops路由)
- 国际化: en.ts, zh.ts(中英文支持)
## 测试验证
- ✅ 所有Go测试通过
- ✅ Migration可正常执行
- ✅ WebSocket连接稳定
- ✅ 前后端数据结构对齐
* refactor: 代码清理和测试优化
## 测试文件优化
- 简化integration test fixtures和断言
- 优化test helper函数
- 统一测试数据格式
## 代码清理
- 移除未使用的代码和注释
- 简化concurrency_cache实现
- 优化middleware错误处理
## 小修复
- 修复gateway_handler和openai_gateway_handler的小问题
- 统一代码风格和格式
变更统计: 27个文件,292行新增,322行删除(净减少30行)
* fix(ops): 运维监控系统安全加固和功能优化
## 安全增强
- feat(security): WebSocket日志脱敏机制,防止token/api_key泄露
- feat(security): X-Forwarded-Host白名单验证,防止CSRF绕过
- feat(security): Origin策略配置化,支持strict/permissive模式
- feat(auth): WebSocket认证支持query参数传递token
## 配置优化
- feat(config): 支持环境变量配置代理信任和Origin策略
- OPS_WS_TRUST_PROXY
- OPS_WS_TRUSTED_PROXIES
- OPS_WS_ORIGIN_POLICY
- fix(ops): 错误日志查询限流从5000降至500,优化内存使用
## 架构改进
- refactor(ops): 告警服务解耦,独立运行评估定时器
- refactor(ops): OpsDashboard统一版本,移除V2分离
## 测试和文档
- test(ops): 添加WebSocket安全验证单元测试(8个测试用例)
- test(ops): 添加告警服务集成测试
- docs(api): 更新API文档,标注限流变更
- docs: 添加CHANGELOG记录breaking changes
## 修复文件
Backend:
- backend/internal/server/middleware/logger.go
- backend/internal/handler/admin/ops_handler.go
- backend/internal/handler/admin/ops_ws_handler.go
- backend/internal/server/middleware/admin_auth.go
- backend/internal/service/ops_alert_service.go
- backend/internal/service/ops_metrics_collector.go
- backend/internal/service/wire.go
Frontend:
- frontend/src/views/admin/ops/OpsDashboard.vue
- frontend/src/router/index.ts
- frontend/src/api/admin/ops.ts
Tests:
- backend/internal/handler/admin/ops_ws_handler_test.go (新增)
- backend/internal/service/ops_alert_service_integration_test.go (新增)
Docs:
- CHANGELOG.md (新增)
- docs/API-运维监控中心2.0.md (更新)
* fix(migrations): 修复calculate_health_score函数类型匹配问题
在ops_latest_metrics视图中添加显式类型转换,确保参数类型与函数签名匹配
* fix(lint): 修复golangci-lint检查发现的所有问题
- 将Redis依赖从service层移到repository层
- 添加错误检查(WebSocket连接和读取超时)
- 运行gofmt格式化代码
- 添加nil指针检查
- 删除未使用的alertService字段
修复问题:
- depguard: 3个(service层不应直接import redis)
- errcheck: 3个(未检查错误返回值)
- gofmt: 2个(代码格式问题)
- staticcheck: 4个(nil指针解引用)
- unused: 1个(未使用字段)
代码统计:
- 修改文件:11个
- 删除代码:490行
- 新增代码:105行
- 净减少:385行
This commit is contained in:
@@ -206,7 +206,7 @@ func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
}
|
||||
|
||||
func (a *Account) GetBaseURL() string {
|
||||
if a.Type != AccountTypeApiKey {
|
||||
if a.Type != AccountTypeAPIKey {
|
||||
return ""
|
||||
}
|
||||
baseURL := a.GetCredential("base_url")
|
||||
@@ -229,7 +229,7 @@ func (a *Account) GetExtraString(key string) string {
|
||||
}
|
||||
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
if a.Type != AccountTypeApiKey || a.Credentials == nil {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
|
||||
@@ -300,15 +300,15 @@ func (a *Account) IsOpenAIOAuth() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAIApiKey() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeApiKey
|
||||
func (a *Account) IsOpenAIAPIKey() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeAPIKey
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIBaseURL() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
if a.Type == AccountTypeApiKey {
|
||||
if a.Type == AccountTypeAPIKey {
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL != "" {
|
||||
return baseURL
|
||||
@@ -338,8 +338,8 @@ func (a *Account) GetOpenAIIDToken() string {
|
||||
return a.GetCredential("id_token")
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIApiKey() string {
|
||||
if !a.IsOpenAIApiKey() {
|
||||
func (a *Account) GetOpenAIAPIKey() string {
|
||||
if !a.IsOpenAIAPIKey() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("api_key")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// Package service 提供业务逻辑层服务,封装领域模型的业务规则和操作流程。
|
||||
// 服务层协调 repository 层的数据访问,实现跨实体的业务逻辑,并为上层 API 提供统一的业务接口。
|
||||
package service
|
||||
|
||||
import (
|
||||
|
||||
@@ -324,7 +324,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
chatgptAccountID = account.GetChatGPTAccountID()
|
||||
} else if account.Type == "apikey" {
|
||||
// API Key - use Platform API
|
||||
authToken = account.GetOpenAIApiKey()
|
||||
authToken = account.GetOpenAIAPIKey()
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No API key available")
|
||||
}
|
||||
@@ -402,7 +402,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
@@ -426,7 +426,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
var err error
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
|
||||
case AccountTypeOAuth:
|
||||
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
|
||||
|
||||
@@ -17,11 +17,11 @@ type UsageLogRepository interface {
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
@@ -32,10 +32,10 @@ type UsageLogRepository interface {
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
|
||||
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error)
|
||||
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
||||
|
||||
// User dashboard stats
|
||||
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
|
||||
@@ -51,7 +51,7 @@ type UsageLogRepository interface {
|
||||
|
||||
// Aggregated stats (optimized)
|
||||
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
|
||||
|
||||
@@ -19,7 +19,7 @@ type AdminService interface {
|
||||
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
|
||||
DeleteUser(ctx context.Context, id int64) error
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
|
||||
// Group management
|
||||
@@ -30,7 +30,7 @@ type AdminService interface {
|
||||
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
|
||||
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
|
||||
DeleteGroup(ctx context.Context, id int64) error
|
||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error)
|
||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||
|
||||
// Account management
|
||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
|
||||
@@ -65,7 +65,7 @@ type AdminService interface {
|
||||
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
|
||||
}
|
||||
|
||||
// Input types for admin operations
|
||||
// CreateUserInput represents the input for creating a new user
|
||||
type CreateUserInput struct {
|
||||
Email string
|
||||
Password string
|
||||
@@ -220,7 +220,7 @@ type adminServiceImpl struct {
|
||||
groupRepo GroupRepository
|
||||
accountRepo AccountRepository
|
||||
proxyRepo ProxyRepository
|
||||
apiKeyRepo ApiKeyRepository
|
||||
apiKeyRepo APIKeyRepository
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
@@ -232,7 +232,7 @@ func NewAdminService(
|
||||
groupRepo GroupRepository,
|
||||
accountRepo AccountRepository,
|
||||
proxyRepo ProxyRepository,
|
||||
apiKeyRepo ApiKeyRepository,
|
||||
apiKeyRepo APIKeyRepository,
|
||||
redeemCodeRepo RedeemCodeRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
@@ -430,7 +430,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) {
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
@@ -583,7 +583,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) {
|
||||
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,7 +2,7 @@ package service
|
||||
|
||||
import "time"
|
||||
|
||||
type ApiKey struct {
|
||||
type APIKey struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Key string
|
||||
@@ -15,6 +15,6 @@ type ApiKey struct {
|
||||
Group *Group
|
||||
}
|
||||
|
||||
func (k *ApiKey) IsActive() bool {
|
||||
func (k *APIKey) IsActive() bool {
|
||||
return k.Status == StatusActive
|
||||
}
|
||||
|
||||
@@ -14,39 +14,39 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
|
||||
ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
|
||||
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
|
||||
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
|
||||
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
|
||||
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
)
|
||||
|
||||
type ApiKeyRepository interface {
|
||||
Create(ctx context.Context, key *ApiKey) error
|
||||
GetByID(ctx context.Context, id int64) (*ApiKey, error)
|
||||
type APIKeyRepository interface {
|
||||
Create(ctx context.Context, key *APIKey) error
|
||||
GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||
// GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证
|
||||
GetOwnerID(ctx context.Context, id int64) (int64, error)
|
||||
GetByKey(ctx context.Context, key string) (*ApiKey, error)
|
||||
Update(ctx context.Context, key *ApiKey) error
|
||||
GetByKey(ctx context.Context, key string) (*APIKey, error)
|
||||
Update(ctx context.Context, key *APIKey) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||||
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||||
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
|
||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
}
|
||||
|
||||
// ApiKeyCache defines cache operations for API key service
|
||||
type ApiKeyCache interface {
|
||||
// APIKeyCache defines cache operations for API key service
|
||||
type APIKeyCache interface {
|
||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
@@ -55,40 +55,40 @@ type ApiKeyCache interface {
|
||||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||
}
|
||||
|
||||
// CreateApiKeyRequest 创建API Key请求
|
||||
type CreateApiKeyRequest struct {
|
||||
// CreateAPIKeyRequest 创建API Key请求
|
||||
type CreateAPIKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
}
|
||||
|
||||
// UpdateApiKeyRequest 更新API Key请求
|
||||
type UpdateApiKeyRequest struct {
|
||||
// UpdateAPIKeyRequest 更新API Key请求
|
||||
type UpdateAPIKeyRequest struct {
|
||||
Name *string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status *string `json:"status"`
|
||||
}
|
||||
|
||||
// ApiKeyService API Key服务
|
||||
type ApiKeyService struct {
|
||||
apiKeyRepo ApiKeyRepository
|
||||
// APIKeyService API Key服务
|
||||
type APIKeyService struct {
|
||||
apiKeyRepo APIKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache ApiKeyCache
|
||||
cache APIKeyCache
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewApiKeyService 创建API Key服务实例
|
||||
func NewApiKeyService(
|
||||
apiKeyRepo ApiKeyRepository,
|
||||
// NewAPIKeyService 创建API Key服务实例
|
||||
func NewAPIKeyService(
|
||||
apiKeyRepo APIKeyRepository,
|
||||
userRepo UserRepository,
|
||||
groupRepo GroupRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache ApiKeyCache,
|
||||
cache APIKeyCache,
|
||||
cfg *config.Config,
|
||||
) *ApiKeyService {
|
||||
return &ApiKeyService{
|
||||
) *APIKeyService {
|
||||
return &APIKeyService{
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
@@ -99,7 +99,7 @@ func NewApiKeyService(
|
||||
}
|
||||
|
||||
// GenerateKey 生成随机API Key
|
||||
func (s *ApiKeyService) GenerateKey() (string, error) {
|
||||
func (s *APIKeyService) GenerateKey() (string, error) {
|
||||
// 生成32字节随机数据
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
@@ -107,7 +107,7 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
|
||||
}
|
||||
|
||||
// 转换为十六进制字符串并添加前缀
|
||||
prefix := s.cfg.Default.ApiKeyPrefix
|
||||
prefix := s.cfg.Default.APIKeyPrefix
|
||||
if prefix == "" {
|
||||
prefix = "sk-"
|
||||
}
|
||||
@@ -117,10 +117,10 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
|
||||
}
|
||||
|
||||
// ValidateCustomKey 验证自定义API Key格式
|
||||
func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
func (s *APIKeyService) ValidateCustomKey(key string) error {
|
||||
// 检查长度
|
||||
if len(key) < 16 {
|
||||
return ErrApiKeyTooShort
|
||||
return ErrAPIKeyTooShort
|
||||
}
|
||||
|
||||
// 检查字符:只允许字母、数字、下划线、连字符
|
||||
@@ -131,14 +131,14 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
c == '_' || c == '-' {
|
||||
continue
|
||||
}
|
||||
return ErrApiKeyInvalidChars
|
||||
return ErrAPIKeyInvalidChars
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -150,14 +150,14 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
|
||||
}
|
||||
|
||||
if count >= apiKeyMaxErrorsPerHour {
|
||||
return ErrApiKeyRateLimited
|
||||
return ErrAPIKeyRateLimited
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
|
||||
// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
@@ -168,7 +168,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in
|
||||
// canUserBindGroup 检查用户是否可以绑定指定分组
|
||||
// 对于订阅类型分组:检查用户是否有有效订阅
|
||||
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
|
||||
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
|
||||
func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
|
||||
@@ -179,7 +179,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group
|
||||
}
|
||||
|
||||
// Create 创建API Key
|
||||
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
|
||||
func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
|
||||
// 验证用户存在
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -204,7 +204,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
// 判断是否使用自定义Key
|
||||
if req.CustomKey != nil && *req.CustomKey != "" {
|
||||
// 检查限流(仅对自定义key进行限流)
|
||||
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
|
||||
if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -219,9 +219,9 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
return nil, fmt.Errorf("check key exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
// Key已存在,增加错误计数
|
||||
s.incrementApiKeyErrorCount(ctx, userID)
|
||||
return nil, ErrApiKeyExists
|
||||
// Key已存在,增加错误计数
|
||||
s.incrementAPIKeyErrorCount(ctx, userID)
|
||||
return nil, ErrAPIKeyExists
|
||||
}
|
||||
|
||||
key = *req.CustomKey
|
||||
@@ -235,7 +235,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
}
|
||||
|
||||
// 创建API Key记录
|
||||
apiKey := &ApiKey{
|
||||
apiKey := &APIKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: req.Name,
|
||||
@@ -251,7 +251,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
}
|
||||
|
||||
// List 获取用户的API Key列表
|
||||
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
|
||||
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||
@@ -259,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
|
||||
return keys, pagination, nil
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
@@ -272,7 +272,7 @@ func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取API Key
|
||||
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
|
||||
func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
@@ -281,7 +281,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error)
|
||||
}
|
||||
|
||||
// GetByKey 根据Key字符串获取API Key(用于认证)
|
||||
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
|
||||
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
// 尝试从Redis缓存获取
|
||||
cacheKey := fmt.Sprintf("apikey:%s", key)
|
||||
|
||||
@@ -301,7 +301,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, erro
|
||||
}
|
||||
|
||||
// Update 更新API Key
|
||||
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
|
||||
func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
@@ -353,8 +353,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
|
||||
// Delete 删除API Key
|
||||
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
|
||||
// 避免加载完整 ApiKey 对象及其关联数据(User、Group),提升删除操作的性能
|
||||
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能
|
||||
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
|
||||
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -379,7 +379,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
||||
}
|
||||
|
||||
// ValidateKey 验证API Key是否有效(用于认证中间件)
|
||||
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
|
||||
func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
|
||||
// 获取API Key
|
||||
apiKey, err := s.GetByKey(ctx, key)
|
||||
if err != nil {
|
||||
@@ -406,7 +406,7 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *
|
||||
}
|
||||
|
||||
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
||||
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 使用Redis计数器
|
||||
if s.cache != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
||||
@@ -423,7 +423,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 返回用户可以选择的分组:
|
||||
// - 标准类型分组:公开的(非专属)或用户被明确允许的
|
||||
// - 订阅类型分组:用户有有效订阅的
|
||||
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
|
||||
func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -460,7 +460,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
|
||||
}
|
||||
|
||||
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
|
||||
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
|
||||
func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
return subscribedGroupIDs[group.ID]
|
||||
@@ -469,8 +469,8 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subsc
|
||||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
|
||||
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
|
||||
func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search api keys: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//go:build unit
|
||||
|
||||
// API Key 服务删除方法的单元测试
|
||||
// 测试 ApiKeyService.Delete 方法在各种场景下的行为,
|
||||
// 测试 APIKeyService.Delete 方法在各种场景下的行为,
|
||||
// 包括权限验证、缓存清理和错误处理
|
||||
|
||||
package service
|
||||
@@ -16,12 +16,12 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// apiKeyRepoStub 是 ApiKeyRepository 接口的测试桩实现。
|
||||
// 用于隔离测试 ApiKeyService.Delete 方法,避免依赖真实数据库。
|
||||
// apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
|
||||
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
|
||||
//
|
||||
// 设计说明:
|
||||
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
|
||||
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrApiKeyNotFound)
|
||||
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound)
|
||||
// - deleteErr: 模拟 Delete 返回的错误
|
||||
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
|
||||
type apiKeyRepoStub struct {
|
||||
@@ -33,11 +33,11 @@ type apiKeyRepoStub struct {
|
||||
|
||||
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
|
||||
|
||||
func (s *apiKeyRepoStub) Create(ctx context.Context, key *ApiKey) error {
|
||||
func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
|
||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error
|
||||
return s.ownerID, s.ownerErr
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
|
||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
panic("unexpected GetByKey call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *ApiKey) error {
|
||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
|
||||
// 以下是接口要求实现但本测试不关心的方法
|
||||
|
||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
|
||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
|
||||
@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err
|
||||
panic("unexpected ExistsByKey call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
|
||||
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByGroupID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
|
||||
panic("unexpected SearchApiKeys call")
|
||||
func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
panic("unexpected SearchAPIKeys call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
|
||||
panic("unexpected CountByGroupID call")
|
||||
}
|
||||
|
||||
// apiKeyCacheStub 是 ApiKeyCache 接口的测试桩实现。
|
||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||
//
|
||||
// 设计说明:
|
||||
@@ -132,17 +132,17 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
||||
// TestAPIKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回所有者 ID 为 1
|
||||
// - 调用者 userID 为 2(不匹配)
|
||||
// - 返回 ErrInsufficientPerms 错误
|
||||
// - Delete 方法不被调用
|
||||
// - 缓存不被清除
|
||||
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
func TestAPIKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerID: 1}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
|
||||
require.ErrorIs(t, err, ErrInsufficientPerms)
|
||||
@@ -150,17 +150,17 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
require.Empty(t, cache.invalidated) // 验证缓存未被清除
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
|
||||
// TestAPIKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回所有者 ID 为 7
|
||||
// - 调用者 userID 为 7(匹配)
|
||||
// - Delete 成功执行
|
||||
// - 缓存被正确清除(使用 ownerID)
|
||||
// - 返回 nil 错误
|
||||
func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||
func TestAPIKeyService_Delete_Success(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerID: 7}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
|
||||
require.NoError(t, err)
|
||||
@@ -168,37 +168,37 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
||||
// TestAPIKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回 ErrApiKeyNotFound 错误
|
||||
// - 返回 ErrApiKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
|
||||
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||
// - Delete 方法不被调用
|
||||
// - 缓存不被清除
|
||||
func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerErr: ErrApiKeyNotFound}
|
||||
func TestAPIKeyService_Delete_NotFound(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 99, 1)
|
||||
require.ErrorIs(t, err, ErrApiKeyNotFound)
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
require.Empty(t, cache.invalidated)
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
|
||||
// TestAPIKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回正确的所有者 ID
|
||||
// - 所有权验证通过
|
||||
// - 缓存被清除(在删除之前)
|
||||
// - Delete 被调用但返回错误
|
||||
// - 返回包含 "delete api key" 的错误信息
|
||||
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
||||
func TestAPIKeyService_Delete_DeleteFails(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
ownerID: 3,
|
||||
deleteErr: errors.New("delete failed"),
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -445,7 +445,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
|
||||
// CheckBillingEligibility 检查用户是否有资格发起请求
|
||||
// 余额模式:检查缓存余额 > 0
|
||||
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
|
||||
// 简易模式:跳过所有计费检查
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
return nil
|
||||
|
||||
@@ -32,6 +32,7 @@ type ConcurrencyCache interface {
|
||||
// 等待队列计数(只在首次创建时设置 TTL)
|
||||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||
GetTotalWaitCount(ctx context.Context) (int, error)
|
||||
|
||||
// 批量负载查询(只读)
|
||||
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
|
||||
@@ -200,6 +201,14 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
|
||||
}
|
||||
}
|
||||
|
||||
// GetTotalWaitCount returns the total wait queue depth across users.
|
||||
func (s *ConcurrencyService) GetTotalWaitCount(ctx context.Context) (int, error) {
|
||||
if s.cache == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return s.cache.GetTotalWaitCount(ctx)
|
||||
}
|
||||
|
||||
// IncrementAccountWaitCount increments the wait queue counter for an account.
|
||||
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
if s.cache == nil {
|
||||
|
||||
@@ -82,7 +82,7 @@ type crsExportResponse struct {
|
||||
OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"`
|
||||
OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"`
|
||||
GeminiOAuthAccounts []crsGeminiOAuthAccount `json:"geminiOAuthAccounts"`
|
||||
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiApiKeyAccounts"`
|
||||
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiAPIKeyAccounts"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
@@ -430,7 +430,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeApiKey,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
@@ -455,7 +455,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = PlatformAnthropic
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Type = AccountTypeAPIKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
@@ -674,7 +674,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeApiKey,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
@@ -699,7 +699,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = PlatformOpenAI
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Type = AccountTypeAPIKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
@@ -893,7 +893,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeApiKey,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
@@ -918,7 +918,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = PlatformGemini
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Type = AccountTypeAPIKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
|
||||
@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key usage trend: %w", err)
|
||||
}
|
||||
@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
|
||||
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ const (
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeApiKey = "apikey" // API Key类型账号
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
@@ -64,13 +64,13 @@ const (
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySMTPPort = "smtp_port" // SMTP端口
|
||||
SettingKeySMTPUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySMTPPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySMTPFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySMTPFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySMTPUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
||||
@@ -81,20 +81,20 @@ const (
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocUrl = "doc_url" // 文档链接
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
|
||||
// 管理员 API Key
|
||||
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
|
||||
// Gemini 配额策略(JSON)
|
||||
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
|
||||
)
|
||||
|
||||
// Admin API Key prefix (distinct from user "sk-" keys)
|
||||
const AdminApiKeyPrefix = "admin-"
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys)
|
||||
const AdminAPIKeyPrefix = "admin-"
|
||||
|
||||
@@ -40,8 +40,8 @@ const (
|
||||
maxVerifyCodeAttempts = 5
|
||||
)
|
||||
|
||||
// SmtpConfig SMTP配置
|
||||
type SmtpConfig struct {
|
||||
// SMTPConfig SMTP配置
|
||||
type SMTPConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
|
||||
}
|
||||
}
|
||||
|
||||
// GetSmtpConfig 从数据库获取SMTP配置
|
||||
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
|
||||
// GetSMTPConfig 从数据库获取SMTP配置
|
||||
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
|
||||
keys := []string{
|
||||
SettingKeySmtpHost,
|
||||
SettingKeySmtpPort,
|
||||
SettingKeySmtpUsername,
|
||||
SettingKeySmtpPassword,
|
||||
SettingKeySmtpFrom,
|
||||
SettingKeySmtpFromName,
|
||||
SettingKeySmtpUseTLS,
|
||||
SettingKeySMTPHost,
|
||||
SettingKeySMTPPort,
|
||||
SettingKeySMTPUsername,
|
||||
SettingKeySMTPPassword,
|
||||
SettingKeySMTPFrom,
|
||||
SettingKeySMTPFromName,
|
||||
SettingKeySMTPUseTLS,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -82,34 +82,34 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
|
||||
return nil, fmt.Errorf("get smtp settings: %w", err)
|
||||
}
|
||||
|
||||
host := settings[SettingKeySmtpHost]
|
||||
host := settings[SettingKeySMTPHost]
|
||||
if host == "" {
|
||||
return nil, ErrEmailNotConfigured
|
||||
}
|
||||
|
||||
port := 587 // 默认端口
|
||||
if portStr := settings[SettingKeySmtpPort]; portStr != "" {
|
||||
if portStr := settings[SettingKeySMTPPort]; portStr != "" {
|
||||
if p, err := strconv.Atoi(portStr); err == nil {
|
||||
port = p
|
||||
}
|
||||
}
|
||||
|
||||
useTLS := settings[SettingKeySmtpUseTLS] == "true"
|
||||
useTLS := settings[SettingKeySMTPUseTLS] == "true"
|
||||
|
||||
return &SmtpConfig{
|
||||
return &SMTPConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: settings[SettingKeySmtpUsername],
|
||||
Password: settings[SettingKeySmtpPassword],
|
||||
From: settings[SettingKeySmtpFrom],
|
||||
FromName: settings[SettingKeySmtpFromName],
|
||||
Username: settings[SettingKeySMTPUsername],
|
||||
Password: settings[SettingKeySMTPPassword],
|
||||
From: settings[SettingKeySMTPFrom],
|
||||
FromName: settings[SettingKeySMTPFromName],
|
||||
UseTLS: useTLS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendEmail 发送邮件(使用数据库中保存的配置)
|
||||
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
|
||||
config, err := s.GetSmtpConfig(ctx)
|
||||
config, err := s.GetSMTPConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
|
||||
}
|
||||
|
||||
// SendEmailWithConfig 使用指定配置发送邮件
|
||||
func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error {
|
||||
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
|
||||
from := config.From
|
||||
if config.FromName != "" {
|
||||
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
|
||||
@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
|
||||
`, siteName, code)
|
||||
}
|
||||
|
||||
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接
|
||||
func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
|
||||
// TestSMTPConnectionWithConfig 使用指定配置测试SMTP连接
|
||||
func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
|
||||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
|
||||
if config.UseTLS {
|
||||
|
||||
@@ -276,7 +276,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
@@ -617,7 +617,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
|
||||
@@ -905,7 +905,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
||||
case AccountTypeOAuth, AccountTypeSetupToken:
|
||||
// Both oauth and setup-token use OAuth token flow
|
||||
return s.getOAuthToken(ctx, account)
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
@@ -976,7 +976,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
originalModel := reqModel
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
// 替换请求体中的模型名
|
||||
@@ -1110,7 +1110,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages"
|
||||
}
|
||||
@@ -1178,10 +1178,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
@@ -1248,12 +1248,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func defaultApiKeyBetaHeader(body []byte) string {
|
||||
func defaultAPIKeyBetaHeader(body []byte) string {
|
||||
modelID := gjson.GetBytes(body, "model").String()
|
||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
return claude.ApiKeyHaikuBetaHeader
|
||||
return claude.APIKeyHaikuBetaHeader
|
||||
}
|
||||
return claude.ApiKeyBetaHeader
|
||||
return claude.APIKeyBetaHeader
|
||||
}
|
||||
|
||||
func truncateForLog(b []byte, maxBytes int) string {
|
||||
@@ -1630,7 +1630,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
ApiKey *ApiKey
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
@@ -1639,7 +1639,7 @@ type RecordUsageInput struct {
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||
result := input.Result
|
||||
apiKey := input.ApiKey
|
||||
apiKey := input.APIKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
@@ -1676,7 +1676,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
@@ -1762,7 +1762,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 应用模型映射(仅对 apikey 类型账号)
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if reqModel != "" {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
@@ -1848,7 +1848,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标 URL
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages/count_tokens"
|
||||
}
|
||||
@@ -1910,10 +1910,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
// OAuth 账号:处理 anthropic-beta header
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,7 +273,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
||||
return 999
|
||||
}
|
||||
switch a.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
|
||||
return 0
|
||||
}
|
||||
@@ -351,7 +351,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
|
||||
originalModel := req.Model
|
||||
mappedModel := req.Model
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(req.Model)
|
||||
}
|
||||
|
||||
@@ -374,7 +374,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
@@ -614,7 +614,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}
|
||||
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
|
||||
@@ -636,7 +636,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
var buildReq func(ctx context.Context) (*http.Request, string, error)
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
@@ -1758,7 +1758,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
}
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||
if apiKey == "" {
|
||||
return nil, errors.New("gemini api_key not configured")
|
||||
|
||||
@@ -275,7 +275,7 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPr
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
{ID: 1, Platform: PlatformGemini, Type: AccountTypeAPIKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
|
||||
@@ -251,7 +251,7 @@ func inferGoogleOneTier(storageBytes int64) string {
|
||||
return TierGoogleOneUnknown
|
||||
}
|
||||
|
||||
// fetchGoogleOneTier fetches Google One tier from Drive API
|
||||
// FetchGoogleOneTier fetches Google One tier from Drive API
|
||||
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
|
||||
driveClient := geminicli.NewDriveClient()
|
||||
|
||||
|
||||
@@ -487,8 +487,8 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
case AccountTypeApiKey:
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
case AccountTypeAPIKey:
|
||||
apiKey := account.GetOpenAIAPIKey()
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
}
|
||||
@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
case AccountTypeOAuth:
|
||||
// OAuth accounts use ChatGPT internal API
|
||||
targetURL = chatgptCodexURL
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
// API Key accounts use Platform API or custom base URL
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL != "" {
|
||||
@@ -940,7 +940,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
ApiKey *ApiKey
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
@@ -949,7 +949,7 @@ type OpenAIRecordUsageInput struct {
|
||||
// RecordUsage records usage and deducts balance
|
||||
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
|
||||
result := input.Result
|
||||
apiKey := input.ApiKey
|
||||
apiKey := input.APIKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
@@ -991,7 +991,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
|
||||
99
backend/internal/service/ops.go
Normal file
99
backend/internal/service/ops.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrorLog represents an ops error log item for list queries.
|
||||
//
|
||||
// Field naming matches docs/API-运维监控中心2.0.md (L3 根因追踪 - 错误日志列表).
|
||||
type ErrorLog struct {
|
||||
ID int64 `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
|
||||
Level string `json:"level,omitempty"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
APIPath string `json:"api_path,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
HTTPCode int `json:"http_code,omitempty"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
|
||||
DurationMs *int `json:"duration_ms,omitempty"`
|
||||
RetryCount *int `json:"retry_count,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorLogFilter describes optional filters and pagination for listing ops error logs.
|
||||
type ErrorLogFilter struct {
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
|
||||
ErrorCode *int
|
||||
Provider string
|
||||
AccountID *int64
|
||||
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
func (f *ErrorLogFilter) normalize() (page, pageSize int) {
|
||||
page = 1
|
||||
pageSize = 20
|
||||
if f == nil {
|
||||
return page, pageSize
|
||||
}
|
||||
|
||||
if f.Page > 0 {
|
||||
page = f.Page
|
||||
}
|
||||
if f.PageSize > 0 {
|
||||
pageSize = f.PageSize
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
return page, pageSize
|
||||
}
|
||||
|
||||
type ErrorLogListResponse struct {
|
||||
Errors []*ErrorLog `json:"errors"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
func (s *OpsService) GetErrorLogs(ctx context.Context, filter *ErrorLogFilter) (*ErrorLogListResponse, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return &ErrorLogListResponse{
|
||||
Errors: []*ErrorLog{},
|
||||
Total: 0,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
}, nil
|
||||
}
|
||||
|
||||
page, pageSize := filter.normalize()
|
||||
if filter == nil {
|
||||
filter = &ErrorLogFilter{}
|
||||
}
|
||||
filter.Page = page
|
||||
filter.PageSize = pageSize
|
||||
|
||||
items, total, err := s.repo.ListErrorLogs(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if items == nil {
|
||||
items = []*ErrorLog{}
|
||||
}
|
||||
|
||||
return &ErrorLogListResponse{
|
||||
Errors: items,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}, nil
|
||||
}
|
||||
834
backend/internal/service/ops_alert_service.go
Normal file
834
backend/internal/service/ops_alert_service.go
Normal file
@@ -0,0 +1,834 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OpsAlertService struct {
|
||||
opsService *OpsService
|
||||
userService *UserService
|
||||
emailService *EmailService
|
||||
httpClient *http.Client
|
||||
|
||||
interval time.Duration
|
||||
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
stopCtx context.Context
|
||||
stop context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// opsAlertEvalInterval defines how often OpsAlertService evaluates alert rules.
|
||||
//
|
||||
// Production uses opsMetricsInterval. Tests may override this variable to keep
|
||||
// integration tests fast without changing production defaults.
|
||||
var opsAlertEvalInterval = opsMetricsInterval
|
||||
|
||||
func NewOpsAlertService(opsService *OpsService, userService *UserService, emailService *EmailService) *OpsAlertService {
|
||||
return &OpsAlertService{
|
||||
opsService: opsService,
|
||||
userService: userService,
|
||||
emailService: emailService,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
interval: opsAlertEvalInterval,
|
||||
}
|
||||
}
|
||||
|
||||
// Start launches the background alert evaluation loop.
|
||||
//
|
||||
// Stop must be called during shutdown to ensure the goroutine exits.
|
||||
func (s *OpsAlertService) Start() {
|
||||
s.StartWithContext(context.Background())
|
||||
}
|
||||
|
||||
// StartWithContext is like Start but allows the caller to provide a parent context.
|
||||
// When the parent context is canceled, the service stops automatically.
|
||||
func (s *OpsAlertService) StartWithContext(ctx context.Context) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
s.startOnce.Do(func() {
|
||||
if s.interval <= 0 {
|
||||
s.interval = opsAlertEvalInterval
|
||||
}
|
||||
|
||||
s.stopCtx, s.stop = context.WithCancel(ctx)
|
||||
s.wg.Add(1)
|
||||
go s.run()
|
||||
})
|
||||
}
|
||||
|
||||
// Stop gracefully stops the background goroutine started by Start/StartWithContext.
|
||||
// It is safe to call Stop multiple times.
|
||||
func (s *OpsAlertService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.stopOnce.Do(func() {
|
||||
if s.stop != nil {
|
||||
s.stop()
|
||||
}
|
||||
})
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *OpsAlertService) run() {
|
||||
defer s.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(s.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.evaluateOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.evaluateOnce()
|
||||
case <-s.stopCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsAlertService) evaluateOnce() {
|
||||
ctx, cancel := context.WithTimeout(s.stopCtx, opsAlertEvaluateTimeout)
|
||||
defer cancel()
|
||||
|
||||
s.Evaluate(ctx, time.Now())
|
||||
}
|
||||
|
||||
func (s *OpsAlertService) Evaluate(ctx context.Context, now time.Time) {
|
||||
if s == nil || s.opsService == nil {
|
||||
return
|
||||
}
|
||||
|
||||
rules, err := s.opsService.ListAlertRules(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlert] failed to list rules: %v", err)
|
||||
return
|
||||
}
|
||||
if len(rules) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
maxSustainedByWindow := make(map[int]int)
|
||||
for _, rule := range rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
window := rule.WindowMinutes
|
||||
if window <= 0 {
|
||||
window = 1
|
||||
}
|
||||
sustained := rule.SustainedMinutes
|
||||
if sustained <= 0 {
|
||||
sustained = 1
|
||||
}
|
||||
if sustained > maxSustainedByWindow[window] {
|
||||
maxSustainedByWindow[window] = sustained
|
||||
}
|
||||
}
|
||||
|
||||
metricsByWindow := make(map[int][]OpsMetrics)
|
||||
for window, limit := range maxSustainedByWindow {
|
||||
metrics, err := s.opsService.ListRecentSystemMetrics(ctx, window, limit)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlert] failed to load metrics window=%dm: %v", window, err)
|
||||
continue
|
||||
}
|
||||
metricsByWindow[window] = metrics
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
window := rule.WindowMinutes
|
||||
if window <= 0 {
|
||||
window = 1
|
||||
}
|
||||
sustained := rule.SustainedMinutes
|
||||
if sustained <= 0 {
|
||||
sustained = 1
|
||||
}
|
||||
|
||||
metrics := metricsByWindow[window]
|
||||
selected, ok := selectContiguousMetrics(metrics, sustained, now)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
breached, latestValue, ok := evaluateRule(rule, selected)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
activeEvent, err := s.opsService.GetActiveAlertEvent(ctx, rule.ID)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlert] failed to get active event (rule=%d): %v", rule.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if breached {
|
||||
if activeEvent != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
lastEvent, err := s.opsService.GetLatestAlertEvent(ctx, rule.ID)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlert] failed to get latest event (rule=%d): %v", rule.ID, err)
|
||||
continue
|
||||
}
|
||||
if lastEvent != nil && rule.CooldownMinutes > 0 {
|
||||
cooldown := time.Duration(rule.CooldownMinutes) * time.Minute
|
||||
if now.Sub(lastEvent.FiredAt) < cooldown {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
event := &OpsAlertEvent{
|
||||
RuleID: rule.ID,
|
||||
Severity: rule.Severity,
|
||||
Status: OpsAlertStatusFiring,
|
||||
Title: fmt.Sprintf("%s: %s", rule.Severity, rule.Name),
|
||||
Description: buildAlertDescription(rule, latestValue),
|
||||
MetricValue: latestValue,
|
||||
ThresholdValue: rule.Threshold,
|
||||
FiredAt: now,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
if err := s.opsService.CreateAlertEvent(ctx, event); err != nil {
|
||||
log.Printf("[OpsAlert] failed to create event (rule=%d): %v", rule.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
emailSent, webhookSent := s.dispatchNotifications(ctx, rule, event)
|
||||
if emailSent || webhookSent {
|
||||
if err := s.opsService.UpdateAlertEventNotifications(ctx, event.ID, emailSent, webhookSent); err != nil {
|
||||
log.Printf("[OpsAlert] failed to update notification flags (event=%d): %v", event.ID, err)
|
||||
}
|
||||
}
|
||||
} else if activeEvent != nil {
|
||||
resolvedAt := now
|
||||
if err := s.opsService.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil {
|
||||
log.Printf("[OpsAlert] failed to resolve event (event=%d): %v", activeEvent.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const opsMetricsContinuityTolerance = 20 * time.Second
|
||||
|
||||
// selectContiguousMetrics picks the newest N metrics and verifies they are continuous.
|
||||
//
|
||||
// This prevents a sustained rule from triggering when metrics sampling has gaps
|
||||
// (e.g. collector downtime) and avoids evaluating "stale" data.
|
||||
//
|
||||
// Assumptions:
|
||||
// - Metrics are ordered by UpdatedAt DESC (newest first).
|
||||
// - Metrics are expected to be collected at opsMetricsInterval cadence.
|
||||
func selectContiguousMetrics(metrics []OpsMetrics, needed int, now time.Time) ([]OpsMetrics, bool) {
|
||||
if needed <= 0 {
|
||||
return nil, false
|
||||
}
|
||||
if len(metrics) < needed {
|
||||
return nil, false
|
||||
}
|
||||
newest := metrics[0].UpdatedAt
|
||||
if newest.IsZero() {
|
||||
return nil, false
|
||||
}
|
||||
if now.Sub(newest) > opsMetricsInterval+opsMetricsContinuityTolerance {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
selected := metrics[:needed]
|
||||
for i := 0; i < len(selected)-1; i++ {
|
||||
a := selected[i].UpdatedAt
|
||||
b := selected[i+1].UpdatedAt
|
||||
if a.IsZero() || b.IsZero() {
|
||||
return nil, false
|
||||
}
|
||||
gap := a.Sub(b)
|
||||
if gap < opsMetricsInterval-opsMetricsContinuityTolerance || gap > opsMetricsInterval+opsMetricsContinuityTolerance {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
return selected, true
|
||||
}
|
||||
|
||||
func evaluateRule(rule OpsAlertRule, metrics []OpsMetrics) (bool, float64, bool) {
|
||||
if len(metrics) == 0 {
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
latestValue, ok := metricValue(metrics[0], rule.MetricType)
|
||||
if !ok {
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
for _, metric := range metrics {
|
||||
value, ok := metricValue(metric, rule.MetricType)
|
||||
if !ok || !compareMetric(value, rule.Operator, rule.Threshold) {
|
||||
return false, latestValue, true
|
||||
}
|
||||
}
|
||||
|
||||
return true, latestValue, true
|
||||
}
|
||||
|
||||
func metricValue(metric OpsMetrics, metricType string) (float64, bool) {
|
||||
switch metricType {
|
||||
case OpsMetricSuccessRate:
|
||||
if metric.RequestCount == 0 {
|
||||
return 0, false
|
||||
}
|
||||
return metric.SuccessRate, true
|
||||
case OpsMetricErrorRate:
|
||||
if metric.RequestCount == 0 {
|
||||
return 0, false
|
||||
}
|
||||
return metric.ErrorRate, true
|
||||
case OpsMetricP95LatencyMs:
|
||||
return float64(metric.P95LatencyMs), true
|
||||
case OpsMetricP99LatencyMs:
|
||||
return float64(metric.P99LatencyMs), true
|
||||
case OpsMetricHTTP2Errors:
|
||||
return float64(metric.HTTP2Errors), true
|
||||
case OpsMetricCPUUsagePercent:
|
||||
return metric.CPUUsagePercent, true
|
||||
case OpsMetricMemoryUsagePercent:
|
||||
return metric.MemoryUsagePercent, true
|
||||
case OpsMetricQueueDepth:
|
||||
return float64(metric.ConcurrencyQueueDepth), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func compareMetric(value float64, operator string, threshold float64) bool {
|
||||
switch operator {
|
||||
case ">":
|
||||
return value > threshold
|
||||
case ">=":
|
||||
return value >= threshold
|
||||
case "<":
|
||||
return value < threshold
|
||||
case "<=":
|
||||
return value <= threshold
|
||||
case "==":
|
||||
return value == threshold
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func buildAlertDescription(rule OpsAlertRule, value float64) string {
|
||||
window := rule.WindowMinutes
|
||||
if window <= 0 {
|
||||
window = 1
|
||||
}
|
||||
return fmt.Sprintf("Rule %s triggered: %s %s %.2f (current %.2f) over last %dm",
|
||||
rule.Name,
|
||||
rule.MetricType,
|
||||
rule.Operator,
|
||||
rule.Threshold,
|
||||
value,
|
||||
window,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OpsAlertService) dispatchNotifications(ctx context.Context, rule OpsAlertRule, event *OpsAlertEvent) (bool, bool) {
|
||||
emailSent := false
|
||||
webhookSent := false
|
||||
|
||||
notifyCtx, cancel := s.notificationContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
if rule.NotifyEmail {
|
||||
emailSent = s.sendEmailNotification(notifyCtx, rule, event)
|
||||
}
|
||||
if rule.NotifyWebhook && rule.WebhookURL != "" {
|
||||
webhookSent = s.sendWebhookNotification(notifyCtx, rule, event)
|
||||
}
|
||||
// Fallback channel: if email is enabled but ultimately fails, try webhook even if the
|
||||
// webhook toggle is off (as long as a webhook URL is configured).
|
||||
if rule.NotifyEmail && !emailSent && !rule.NotifyWebhook && rule.WebhookURL != "" {
|
||||
log.Printf("[OpsAlert] email failed; attempting webhook fallback (rule=%d)", rule.ID)
|
||||
webhookSent = s.sendWebhookNotification(notifyCtx, rule, event)
|
||||
}
|
||||
|
||||
return emailSent, webhookSent
|
||||
}
|
||||
|
||||
const (
|
||||
opsAlertEvaluateTimeout = 45 * time.Second
|
||||
opsAlertNotificationTimeout = 30 * time.Second
|
||||
opsAlertEmailMaxRetries = 3
|
||||
)
|
||||
|
||||
var opsAlertEmailBackoff = []time.Duration{
|
||||
1 * time.Second,
|
||||
2 * time.Second,
|
||||
4 * time.Second,
|
||||
}
|
||||
|
||||
func (s *OpsAlertService) notificationContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
parent := ctx
|
||||
if s != nil && s.stopCtx != nil {
|
||||
parent = s.stopCtx
|
||||
}
|
||||
if parent == nil {
|
||||
parent = context.Background()
|
||||
}
|
||||
return context.WithTimeout(parent, opsAlertNotificationTimeout)
|
||||
}
|
||||
|
||||
var opsAlertSleep = sleepWithContext
|
||||
|
||||
func sleepWithContext(ctx context.Context, d time.Duration) error {
|
||||
if d <= 0 {
|
||||
return nil
|
||||
}
|
||||
if ctx == nil {
|
||||
time.Sleep(d)
|
||||
return nil
|
||||
}
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func retryWithBackoff(
|
||||
ctx context.Context,
|
||||
maxRetries int,
|
||||
backoff []time.Duration,
|
||||
fn func() error,
|
||||
onError func(attempt int, total int, nextDelay time.Duration, err error),
|
||||
) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if maxRetries < 0 {
|
||||
maxRetries = 0
|
||||
}
|
||||
totalAttempts := maxRetries + 1
|
||||
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= totalAttempts; attempt++ {
|
||||
if attempt > 1 {
|
||||
backoffIdx := attempt - 2
|
||||
if backoffIdx < len(backoff) {
|
||||
if err := opsAlertSleep(ctx, backoff[backoffIdx]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := fn(); err != nil {
|
||||
lastErr = err
|
||||
nextDelay := time.Duration(0)
|
||||
if attempt < totalAttempts {
|
||||
nextIdx := attempt - 1
|
||||
if nextIdx < len(backoff) {
|
||||
nextDelay = backoff[nextIdx]
|
||||
}
|
||||
}
|
||||
if onError != nil {
|
||||
onError(attempt, totalAttempts, nextDelay, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (s *OpsAlertService) sendEmailNotification(ctx context.Context, rule OpsAlertRule, event *OpsAlertEvent) bool {
|
||||
if s.emailService == nil || s.userService == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
admin, err := s.userService.GetFirstAdmin(ctx)
|
||||
if err != nil || admin == nil || admin.Email == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
subject := fmt.Sprintf("[Ops Alert][%s] %s", rule.Severity, rule.Name)
|
||||
body := fmt.Sprintf(
|
||||
"Alert triggered: %s\n\nMetric: %s\nThreshold: %.2f\nCurrent: %.2f\nWindow: %dm\nStatus: %s\nTime: %s",
|
||||
rule.Name,
|
||||
rule.MetricType,
|
||||
rule.Threshold,
|
||||
event.MetricValue,
|
||||
rule.WindowMinutes,
|
||||
event.Status,
|
||||
event.FiredAt.Format(time.RFC3339),
|
||||
)
|
||||
|
||||
config, err := s.emailService.GetSMTPConfig(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlert] email config load failed: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := retryWithBackoff(
|
||||
ctx,
|
||||
opsAlertEmailMaxRetries,
|
||||
opsAlertEmailBackoff,
|
||||
func() error {
|
||||
return s.emailService.SendEmailWithConfig(config, admin.Email, subject, body)
|
||||
},
|
||||
func(attempt int, total int, nextDelay time.Duration, err error) {
|
||||
if attempt < total {
|
||||
log.Printf("[OpsAlert] email send failed (attempt=%d/%d), retrying in %s: %v", attempt, total, nextDelay, err)
|
||||
return
|
||||
}
|
||||
log.Printf("[OpsAlert] email send failed (attempt=%d/%d), giving up: %v", attempt, total, err)
|
||||
},
|
||||
); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
log.Printf("[OpsAlert] email send canceled: %v", err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *OpsAlertService) sendWebhookNotification(ctx context.Context, rule OpsAlertRule, event *OpsAlertEvent) bool {
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
webhookTarget, err := validateWebhookURL(ctx, rule.WebhookURL)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlert] invalid webhook url (rule=%d): %v", rule.ID, err)
|
||||
return false
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"rule_id": rule.ID,
|
||||
"rule_name": rule.Name,
|
||||
"severity": rule.Severity,
|
||||
"status": event.Status,
|
||||
"metric_type": rule.MetricType,
|
||||
"metric_value": event.MetricValue,
|
||||
"threshold_value": rule.Threshold,
|
||||
"window_minutes": rule.WindowMinutes,
|
||||
"fired_at": event.FiredAt.Format(time.RFC3339),
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, webhookTarget.URL.String(), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := buildWebhookHTTPClient(s.httpClient, webhookTarget).Do(req)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlert] webhook send failed: %v", err)
|
||||
return false
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
log.Printf("[OpsAlert] webhook returned status %d", resp.StatusCode)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
const webhookHTTPClientTimeout = 10 * time.Second
|
||||
|
||||
func buildWebhookHTTPClient(base *http.Client, webhookTarget *validatedWebhookTarget) *http.Client {
|
||||
var client http.Client
|
||||
if base != nil {
|
||||
client = *base
|
||||
}
|
||||
if client.Timeout <= 0 {
|
||||
client.Timeout = webhookHTTPClientTimeout
|
||||
}
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
if webhookTarget != nil {
|
||||
client.Transport = buildWebhookTransport(client.Transport, webhookTarget)
|
||||
}
|
||||
return &client
|
||||
}
|
||||
|
||||
var disallowedWebhookIPNets = []net.IPNet{
|
||||
// "this host on this network" / unspecified.
|
||||
mustParseCIDR("0.0.0.0/8"),
|
||||
mustParseCIDR("127.0.0.0/8"), // loopback (includes 127.0.0.1)
|
||||
mustParseCIDR("10.0.0.0/8"), // RFC1918
|
||||
mustParseCIDR("192.168.0.0/16"), // RFC1918
|
||||
mustParseCIDR("172.16.0.0/12"), // RFC1918 (172.16.0.0 - 172.31.255.255)
|
||||
mustParseCIDR("100.64.0.0/10"), // RFC6598 (carrier-grade NAT)
|
||||
mustParseCIDR("169.254.0.0/16"), // IPv4 link-local (includes 169.254.169.254 metadata IP on many clouds)
|
||||
mustParseCIDR("198.18.0.0/15"), // RFC2544 benchmark testing
|
||||
mustParseCIDR("224.0.0.0/4"), // IPv4 multicast
|
||||
mustParseCIDR("240.0.0.0/4"), // IPv4 reserved
|
||||
mustParseCIDR("::/128"), // IPv6 unspecified
|
||||
mustParseCIDR("::1/128"), // IPv6 loopback
|
||||
mustParseCIDR("fc00::/7"), // IPv6 unique local
|
||||
mustParseCIDR("fe80::/10"), // IPv6 link-local
|
||||
mustParseCIDR("ff00::/8"), // IPv6 multicast
|
||||
}
|
||||
|
||||
func mustParseCIDR(cidr string) net.IPNet {
|
||||
_, block, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return *block
|
||||
}
|
||||
|
||||
var lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
return net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
}
|
||||
|
||||
type validatedWebhookTarget struct {
|
||||
URL *url.URL
|
||||
|
||||
host string
|
||||
port string
|
||||
pinnedIPs []net.IP
|
||||
}
|
||||
|
||||
var webhookBaseDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialer := net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
func buildWebhookTransport(base http.RoundTripper, webhookTarget *validatedWebhookTarget) http.RoundTripper {
|
||||
if webhookTarget == nil || webhookTarget.URL == nil {
|
||||
return base
|
||||
}
|
||||
|
||||
var transport *http.Transport
|
||||
switch typed := base.(type) {
|
||||
case *http.Transport:
|
||||
if typed != nil {
|
||||
transport = typed.Clone()
|
||||
}
|
||||
}
|
||||
if transport == nil {
|
||||
if defaultTransport, ok := http.DefaultTransport.(*http.Transport); ok && defaultTransport != nil {
|
||||
transport = defaultTransport.Clone()
|
||||
} else {
|
||||
transport = (&http.Transport{}).Clone()
|
||||
}
|
||||
}
|
||||
|
||||
webhookHost := webhookTarget.host
|
||||
webhookPort := webhookTarget.port
|
||||
pinnedIPs := append([]net.IP(nil), webhookTarget.pinnedIPs...)
|
||||
|
||||
transport.Proxy = nil
|
||||
transport.DialTLSContext = nil
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil || host == "" || port == "" {
|
||||
return nil, fmt.Errorf("webhook dial target is invalid: %q", addr)
|
||||
}
|
||||
|
||||
canonicalHost := strings.TrimSuffix(strings.ToLower(host), ".")
|
||||
if canonicalHost != webhookHost || port != webhookPort {
|
||||
return nil, fmt.Errorf("webhook dial target mismatch: %q", addr)
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, ip := range pinnedIPs {
|
||||
if isDisallowedWebhookIP(ip) {
|
||||
lastErr = fmt.Errorf("webhook target resolves to a disallowed ip")
|
||||
continue
|
||||
}
|
||||
|
||||
dialAddr := net.JoinHostPort(ip.String(), port)
|
||||
conn, err := webhookBaseDialContext(ctx, network, dialAddr)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
if lastErr == nil {
|
||||
lastErr = errors.New("webhook target has no resolved addresses")
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
return transport
|
||||
}
|
||||
|
||||
func validateWebhookURL(ctx context.Context, raw string) (*validatedWebhookTarget, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil, errors.New("webhook url is empty")
|
||||
}
|
||||
// Avoid request smuggling / header injection vectors.
|
||||
if strings.ContainsAny(raw, "\r\n") {
|
||||
return nil, errors.New("webhook url contains invalid characters")
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return nil, errors.New("webhook url format is invalid")
|
||||
}
|
||||
if !strings.EqualFold(parsed.Scheme, "https") {
|
||||
return nil, errors.New("webhook url scheme must be https")
|
||||
}
|
||||
parsed.Scheme = "https"
|
||||
if parsed.Host == "" || parsed.Hostname() == "" {
|
||||
return nil, errors.New("webhook url must include host")
|
||||
}
|
||||
if parsed.User != nil {
|
||||
return nil, errors.New("webhook url must not include userinfo")
|
||||
}
|
||||
if parsed.Port() != "" {
|
||||
port, err := strconv.Atoi(parsed.Port())
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return nil, errors.New("webhook url port is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
host := strings.TrimSuffix(strings.ToLower(parsed.Hostname()), ".")
|
||||
if host == "localhost" {
|
||||
return nil, errors.New("webhook url host must not be localhost")
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isDisallowedWebhookIP(ip) {
|
||||
return nil, errors.New("webhook url host resolves to a disallowed ip")
|
||||
}
|
||||
return &validatedWebhookTarget{
|
||||
URL: parsed,
|
||||
host: host,
|
||||
port: portForScheme(parsed),
|
||||
pinnedIPs: []net.IP{ip},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
ips, err := lookupIPAddrs(ctx, host)
|
||||
if err != nil || len(ips) == 0 {
|
||||
return nil, errors.New("webhook url host cannot be resolved")
|
||||
}
|
||||
pinned := make([]net.IP, 0, len(ips))
|
||||
for _, addr := range ips {
|
||||
if isDisallowedWebhookIP(addr.IP) {
|
||||
return nil, errors.New("webhook url host resolves to a disallowed ip")
|
||||
}
|
||||
if addr.IP != nil {
|
||||
pinned = append(pinned, addr.IP)
|
||||
}
|
||||
}
|
||||
|
||||
if len(pinned) == 0 {
|
||||
return nil, errors.New("webhook url host cannot be resolved")
|
||||
}
|
||||
|
||||
return &validatedWebhookTarget{
|
||||
URL: parsed,
|
||||
host: host,
|
||||
port: portForScheme(parsed),
|
||||
pinnedIPs: uniqueResolvedIPs(pinned),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isDisallowedWebhookIP(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
ip = ip4
|
||||
} else if ip16 := ip.To16(); ip16 != nil {
|
||||
ip = ip16
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
|
||||
// Disallow non-public addresses even if they're not explicitly covered by the CIDR list.
|
||||
// This provides defense-in-depth against SSRF targets such as link-local, multicast, and
|
||||
// unspecified addresses, and ensures any "pinned" IP is still blocked at dial time.
|
||||
if ip.IsUnspecified() ||
|
||||
ip.IsLoopback() ||
|
||||
ip.IsMulticast() ||
|
||||
ip.IsLinkLocalUnicast() ||
|
||||
ip.IsLinkLocalMulticast() ||
|
||||
ip.IsPrivate() {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, block := range disallowedWebhookIPNets {
|
||||
if block.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func portForScheme(u *url.URL) string {
|
||||
if u != nil && u.Port() != "" {
|
||||
return u.Port()
|
||||
}
|
||||
return "443"
|
||||
}
|
||||
|
||||
func uniqueResolvedIPs(ips []net.IP) []net.IP {
|
||||
seen := make(map[string]struct{}, len(ips))
|
||||
out := make([]net.IP, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
key := ip.String()
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, ip)
|
||||
}
|
||||
return out
|
||||
}
|
||||
271
backend/internal/service/ops_alert_service_integration_test.go
Normal file
271
backend/internal/service/ops_alert_service_integration_test.go
Normal file
@@ -0,0 +1,271 @@
|
||||
//go:build integration
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// This integration test protects the DI startup contract for OpsAlertService.
|
||||
//
|
||||
// Background:
|
||||
// - OpsMetricsCollector previously called alertService.Start()/Evaluate() directly.
|
||||
// - Those direct calls were removed, so OpsAlertService must now start via DI
|
||||
// (ProvideOpsAlertService in wire.go) and run its own evaluation ticker.
|
||||
//
|
||||
// What we validate here:
|
||||
// 1. When we construct via the Wire provider functions (ProvideOpsAlertService +
|
||||
// ProvideOpsMetricsCollector), OpsAlertService starts automatically.
|
||||
// 2. Its evaluation loop continues to tick even if OpsMetricsCollector is stopped,
|
||||
// proving the alert evaluator is independent.
|
||||
// 3. The evaluation path can trigger alert logic (CreateAlertEvent called).
|
||||
func TestOpsAlertService_StartedViaWireProviders_RunsIndependentTicker(t *testing.T) {
|
||||
oldInterval := opsAlertEvalInterval
|
||||
opsAlertEvalInterval = 25 * time.Millisecond
|
||||
t.Cleanup(func() { opsAlertEvalInterval = oldInterval })
|
||||
|
||||
repo := newFakeOpsRepository()
|
||||
opsService := NewOpsService(repo, nil)
|
||||
|
||||
// Start via the Wire provider function (the production DI path).
|
||||
alertService := ProvideOpsAlertService(opsService, nil, nil)
|
||||
t.Cleanup(alertService.Stop)
|
||||
|
||||
// Construct via ProvideOpsMetricsCollector (wire.go). Stop immediately to ensure
|
||||
// the alert ticker keeps running without the metrics collector.
|
||||
collector := ProvideOpsMetricsCollector(opsService, NewConcurrencyService(nil))
|
||||
collector.Stop()
|
||||
|
||||
// Wait for at least one evaluation (run() calls evaluateOnce immediately).
|
||||
require.Eventually(t, func() bool {
|
||||
return repo.listRulesCalls.Load() >= 1
|
||||
}, 1*time.Second, 5*time.Millisecond)
|
||||
|
||||
// Confirm the evaluation loop keeps ticking after the metrics collector is stopped.
|
||||
callsAfterCollectorStop := repo.listRulesCalls.Load()
|
||||
require.Eventually(t, func() bool {
|
||||
return repo.listRulesCalls.Load() >= callsAfterCollectorStop+2
|
||||
}, 1*time.Second, 5*time.Millisecond)
|
||||
|
||||
// Confirm the evaluation logic actually fires an alert event at least once.
|
||||
select {
|
||||
case <-repo.eventCreatedCh:
|
||||
// ok
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("expected OpsAlertService to create an alert event, but none was created (ListAlertRules calls=%d)", repo.listRulesCalls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func newFakeOpsRepository() *fakeOpsRepository {
|
||||
return &fakeOpsRepository{
|
||||
eventCreatedCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// fakeOpsRepository is a lightweight in-memory stub of OpsRepository for integration tests.
|
||||
// It avoids real DB/Redis usage and provides deterministic responses fast.
|
||||
type fakeOpsRepository struct {
|
||||
listRulesCalls atomic.Int64
|
||||
|
||||
mu sync.Mutex
|
||||
activeEvent *OpsAlertEvent
|
||||
latestEvent *OpsAlertEvent
|
||||
nextEventID int64
|
||||
eventCreatedCh chan struct{}
|
||||
eventOnce sync.Once
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) CreateErrorLog(ctx context.Context, log *OpsErrorLog) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) ListErrorLogsLegacy(ctx context.Context, filters OpsErrorLogFilters) ([]OpsErrorLog, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) ListErrorLogs(ctx context.Context, filter *ErrorLogFilter) ([]*ErrorLog, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetLatestSystemMetric(ctx context.Context) (*OpsMetrics, error) {
|
||||
return &OpsMetrics{WindowMinutes: 1}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) CreateSystemMetric(ctx context.Context, metric *OpsMetrics) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetWindowStats(ctx context.Context, startTime, endTime time.Time) (*OpsWindowStats, error) {
|
||||
return &OpsWindowStats{}, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetProviderStats(ctx context.Context, startTime, endTime time.Time) ([]*ProviderStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetLatencyHistogram(ctx context.Context, startTime, endTime time.Time) ([]*LatencyHistogramItem, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetErrorDistribution(ctx context.Context, startTime, endTime time.Time) ([]*ErrorDistributionItem, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) ListRecentSystemMetrics(ctx context.Context, windowMinutes, limit int) ([]OpsMetrics, error) {
|
||||
if limit <= 0 {
|
||||
limit = 1
|
||||
}
|
||||
now := time.Now()
|
||||
metrics := make([]OpsMetrics, 0, limit)
|
||||
for i := 0; i < limit; i++ {
|
||||
metrics = append(metrics, OpsMetrics{
|
||||
WindowMinutes: windowMinutes,
|
||||
CPUUsagePercent: 99,
|
||||
UpdatedAt: now.Add(-time.Duration(i) * opsMetricsInterval),
|
||||
})
|
||||
}
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) ListSystemMetricsRange(ctx context.Context, windowMinutes int, startTime, endTime time.Time, limit int) ([]OpsMetrics, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) ListAlertRules(ctx context.Context) ([]OpsAlertRule, error) {
|
||||
call := r.listRulesCalls.Add(1)
|
||||
// Delay enabling rules slightly so the test can stop OpsMetricsCollector first,
|
||||
// then observe the alert evaluator ticking independently.
|
||||
if call < 5 {
|
||||
return nil, nil
|
||||
}
|
||||
return []OpsAlertRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "cpu too high (test)",
|
||||
Enabled: true,
|
||||
MetricType: OpsMetricCPUUsagePercent,
|
||||
Operator: ">",
|
||||
Threshold: 0,
|
||||
WindowMinutes: 1,
|
||||
SustainedMinutes: 1,
|
||||
Severity: "P1",
|
||||
NotifyEmail: false,
|
||||
NotifyWebhook: false,
|
||||
CooldownMinutes: 0,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.activeEvent == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if r.activeEvent.RuleID != ruleID {
|
||||
return nil, nil
|
||||
}
|
||||
if r.activeEvent.Status != OpsAlertStatusFiring {
|
||||
return nil, nil
|
||||
}
|
||||
clone := *r.activeEvent
|
||||
return &clone, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.latestEvent == nil || r.latestEvent.RuleID != ruleID {
|
||||
return nil, nil
|
||||
}
|
||||
clone := *r.latestEvent
|
||||
return &clone, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) error {
|
||||
if event == nil {
|
||||
return nil
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.nextEventID++
|
||||
event.ID = r.nextEventID
|
||||
|
||||
clone := *event
|
||||
r.latestEvent = &clone
|
||||
if clone.Status == OpsAlertStatusFiring {
|
||||
r.activeEvent = &clone
|
||||
}
|
||||
|
||||
r.eventOnce.Do(func() { close(r.eventCreatedCh) })
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.activeEvent != nil && r.activeEvent.ID == eventID {
|
||||
r.activeEvent.Status = status
|
||||
r.activeEvent.ResolvedAt = resolvedAt
|
||||
}
|
||||
if r.latestEvent != nil && r.latestEvent.ID == eventID {
|
||||
r.latestEvent.Status = status
|
||||
r.latestEvent.ResolvedAt = resolvedAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) UpdateAlertEventNotifications(ctx context.Context, eventID int64, emailSent, webhookSent bool) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.activeEvent != nil && r.activeEvent.ID == eventID {
|
||||
r.activeEvent.EmailSent = emailSent
|
||||
r.activeEvent.WebhookSent = webhookSent
|
||||
}
|
||||
if r.latestEvent != nil && r.latestEvent.ID == eventID {
|
||||
r.latestEvent.EmailSent = emailSent
|
||||
r.latestEvent.WebhookSent = webhookSent
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) CountActiveAlerts(ctx context.Context) (int, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.activeEvent == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetOverviewStats(ctx context.Context, startTime, endTime time.Time) (*OverviewStats, error) {
|
||||
return &OverviewStats{}, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetCachedLatestSystemMetric(ctx context.Context) (*OpsMetrics, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) SetCachedLatestSystemMetric(ctx context.Context, metric *OpsMetrics) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetCachedDashboardOverview(ctx context.Context, timeRange string) (*DashboardOverviewData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) SetCachedDashboardOverview(ctx context.Context, timeRange string, data *DashboardOverviewData, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) PingRedis(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
315
backend/internal/service/ops_alert_service_test.go
Normal file
315
backend/internal/service/ops_alert_service_test.go
Normal file
@@ -0,0 +1,315 @@
|
||||
//go:build unit || opsalert_unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSelectContiguousMetrics_Contiguous(t *testing.T) {
|
||||
now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
metrics := []OpsMetrics{
|
||||
{UpdatedAt: now},
|
||||
{UpdatedAt: now.Add(-1 * time.Minute)},
|
||||
{UpdatedAt: now.Add(-2 * time.Minute)},
|
||||
}
|
||||
|
||||
selected, ok := selectContiguousMetrics(metrics, 3, now)
|
||||
require.True(t, ok)
|
||||
require.Len(t, selected, 3)
|
||||
}
|
||||
|
||||
func TestSelectContiguousMetrics_GapFails(t *testing.T) {
|
||||
now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
metrics := []OpsMetrics{
|
||||
{UpdatedAt: now},
|
||||
// Missing the -1m sample (gap ~=2m).
|
||||
{UpdatedAt: now.Add(-2 * time.Minute)},
|
||||
{UpdatedAt: now.Add(-3 * time.Minute)},
|
||||
}
|
||||
|
||||
_, ok := selectContiguousMetrics(metrics, 3, now)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestSelectContiguousMetrics_StaleNewestFails(t *testing.T) {
|
||||
now := time.Date(2026, 1, 1, 0, 10, 0, 0, time.UTC)
|
||||
metrics := []OpsMetrics{
|
||||
{UpdatedAt: now.Add(-10 * time.Minute)},
|
||||
{UpdatedAt: now.Add(-11 * time.Minute)},
|
||||
}
|
||||
|
||||
_, ok := selectContiguousMetrics(metrics, 2, now)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestMetricValue_SuccessRate_NoTrafficIsNoData(t *testing.T) {
|
||||
metric := OpsMetrics{
|
||||
RequestCount: 0,
|
||||
SuccessRate: 0,
|
||||
}
|
||||
value, ok := metricValue(metric, OpsMetricSuccessRate)
|
||||
require.False(t, ok)
|
||||
require.Equal(t, 0.0, value)
|
||||
}
|
||||
|
||||
func TestOpsAlertService_StopWithoutStart_NoPanic(t *testing.T) {
|
||||
s := NewOpsAlertService(nil, nil, nil)
|
||||
require.NotPanics(t, func() { s.Stop() })
|
||||
}
|
||||
|
||||
func TestOpsAlertService_StartStop_Graceful(t *testing.T) {
|
||||
s := NewOpsAlertService(nil, nil, nil)
|
||||
s.interval = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.StartWithContext(ctx)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.Stop()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// ok
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("Stop did not return; background goroutine likely stuck")
|
||||
}
|
||||
|
||||
require.NotPanics(t, func() { s.Stop() })
|
||||
}
|
||||
|
||||
func TestBuildWebhookHTTPClient_DefaultTimeout(t *testing.T) {
|
||||
client := buildWebhookHTTPClient(nil, nil)
|
||||
require.Equal(t, webhookHTTPClientTimeout, client.Timeout)
|
||||
require.NotNil(t, client.CheckRedirect)
|
||||
require.ErrorIs(t, client.CheckRedirect(nil, nil), http.ErrUseLastResponse)
|
||||
|
||||
base := &http.Client{}
|
||||
client = buildWebhookHTTPClient(base, nil)
|
||||
require.Equal(t, webhookHTTPClientTimeout, client.Timeout)
|
||||
require.NotNil(t, client.CheckRedirect)
|
||||
|
||||
base = &http.Client{Timeout: 2 * time.Second}
|
||||
client = buildWebhookHTTPClient(base, nil)
|
||||
require.Equal(t, 2*time.Second, client.Timeout)
|
||||
require.NotNil(t, client.CheckRedirect)
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_RequiresHTTPS(t *testing.T) {
|
||||
oldLookup := lookupIPAddrs
|
||||
t.Cleanup(func() { lookupIPAddrs = oldLookup })
|
||||
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
|
||||
}
|
||||
|
||||
_, err := validateWebhookURL(context.Background(), "http://example.com/webhook")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_InvalidFormatRejected(t *testing.T) {
|
||||
_, err := validateWebhookURL(context.Background(), "https://[::1")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_RejectsUserinfo(t *testing.T) {
|
||||
oldLookup := lookupIPAddrs
|
||||
t.Cleanup(func() { lookupIPAddrs = oldLookup })
|
||||
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
|
||||
}
|
||||
|
||||
_, err := validateWebhookURL(context.Background(), "https://user:pass@example.com/webhook")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_RejectsLocalhost(t *testing.T) {
|
||||
_, err := validateWebhookURL(context.Background(), "https://localhost/webhook")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_RejectsPrivateIPLiteral(t *testing.T) {
|
||||
cases := []string{
|
||||
"https://0.0.0.0/webhook",
|
||||
"https://127.0.0.1/webhook",
|
||||
"https://10.0.0.1/webhook",
|
||||
"https://192.168.1.2/webhook",
|
||||
"https://172.16.0.1/webhook",
|
||||
"https://172.31.255.255/webhook",
|
||||
"https://100.64.0.1/webhook",
|
||||
"https://169.254.169.254/webhook",
|
||||
"https://198.18.0.1/webhook",
|
||||
"https://224.0.0.1/webhook",
|
||||
"https://240.0.0.1/webhook",
|
||||
"https://[::]/webhook",
|
||||
"https://[::1]/webhook",
|
||||
"https://[ff02::1]/webhook",
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc, func(t *testing.T) {
|
||||
_, err := validateWebhookURL(context.Background(), tc)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_RejectsPrivateIPViaDNS(t *testing.T) {
|
||||
oldLookup := lookupIPAddrs
|
||||
t.Cleanup(func() { lookupIPAddrs = oldLookup })
|
||||
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
require.Equal(t, "internal.example", host)
|
||||
return []net.IPAddr{{IP: net.ParseIP("10.0.0.2")}}, nil
|
||||
}
|
||||
|
||||
_, err := validateWebhookURL(context.Background(), "https://internal.example/webhook")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_RejectsLinkLocalIPViaDNS(t *testing.T) {
|
||||
oldLookup := lookupIPAddrs
|
||||
t.Cleanup(func() { lookupIPAddrs = oldLookup })
|
||||
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
require.Equal(t, "metadata.example", host)
|
||||
return []net.IPAddr{{IP: net.ParseIP("169.254.169.254")}}, nil
|
||||
}
|
||||
|
||||
_, err := validateWebhookURL(context.Background(), "https://metadata.example/webhook")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_AllowsPublicHostViaDNS(t *testing.T) {
|
||||
oldLookup := lookupIPAddrs
|
||||
t.Cleanup(func() { lookupIPAddrs = oldLookup })
|
||||
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
require.Equal(t, "example.com", host)
|
||||
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
|
||||
}
|
||||
|
||||
target, err := validateWebhookURL(context.Background(), "https://example.com:443/webhook")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https", target.URL.Scheme)
|
||||
require.Equal(t, "example.com", target.URL.Hostname())
|
||||
require.Equal(t, "443", target.URL.Port())
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_RejectsInvalidPort(t *testing.T) {
|
||||
oldLookup := lookupIPAddrs
|
||||
t.Cleanup(func() { lookupIPAddrs = oldLookup })
|
||||
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
|
||||
}
|
||||
|
||||
_, err := validateWebhookURL(context.Background(), "https://example.com:99999/webhook")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestWebhookTransport_UsesPinnedIP_NoDNSRebinding(t *testing.T) {
|
||||
oldLookup := lookupIPAddrs
|
||||
oldDial := webhookBaseDialContext
|
||||
t.Cleanup(func() {
|
||||
lookupIPAddrs = oldLookup
|
||||
webhookBaseDialContext = oldDial
|
||||
})
|
||||
|
||||
lookupCalls := 0
|
||||
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
lookupCalls++
|
||||
require.Equal(t, "example.com", host)
|
||||
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
|
||||
}
|
||||
|
||||
target, err := validateWebhookURL(context.Background(), "https://example.com/webhook")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, lookupCalls)
|
||||
|
||||
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
lookupCalls++
|
||||
return []net.IPAddr{{IP: net.ParseIP("10.0.0.1")}}, nil
|
||||
}
|
||||
|
||||
var dialAddrs []string
|
||||
webhookBaseDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialAddrs = append(dialAddrs, addr)
|
||||
return nil, errors.New("dial blocked in test")
|
||||
}
|
||||
|
||||
client := buildWebhookHTTPClient(nil, target)
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
|
||||
_, err = transport.DialContext(context.Background(), "tcp", "example.com:443")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, []string{"93.184.216.34:443"}, dialAddrs)
|
||||
require.Equal(t, 1, lookupCalls, "dial path must not re-resolve DNS")
|
||||
}
|
||||
|
||||
func TestRetryWithBackoff_SucceedsAfterRetries(t *testing.T) {
|
||||
oldSleep := opsAlertSleep
|
||||
t.Cleanup(func() { opsAlertSleep = oldSleep })
|
||||
|
||||
var slept []time.Duration
|
||||
opsAlertSleep = func(ctx context.Context, d time.Duration) error {
|
||||
slept = append(slept, d)
|
||||
return nil
|
||||
}
|
||||
|
||||
attempts := 0
|
||||
err := retryWithBackoff(
|
||||
context.Background(),
|
||||
3,
|
||||
[]time.Duration{time.Second, 2 * time.Second, 4 * time.Second},
|
||||
func() error {
|
||||
attempts++
|
||||
if attempts <= 3 {
|
||||
return errors.New("send failed")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 4, attempts)
|
||||
require.Equal(t, []time.Duration{time.Second, 2 * time.Second, 4 * time.Second}, slept)
|
||||
}
|
||||
|
||||
func TestRetryWithBackoff_ContextCanceledStopsRetries(t *testing.T) {
|
||||
oldSleep := opsAlertSleep
|
||||
t.Cleanup(func() { opsAlertSleep = oldSleep })
|
||||
|
||||
var slept []time.Duration
|
||||
opsAlertSleep = func(ctx context.Context, d time.Duration) error {
|
||||
slept = append(slept, d)
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
attempts := 0
|
||||
err := retryWithBackoff(
|
||||
ctx,
|
||||
3,
|
||||
[]time.Duration{time.Second, 2 * time.Second, 4 * time.Second},
|
||||
func() error {
|
||||
attempts++
|
||||
return errors.New("send failed")
|
||||
},
|
||||
func(attempt int, total int, nextDelay time.Duration, err error) {
|
||||
if attempt == 1 {
|
||||
cancel()
|
||||
}
|
||||
},
|
||||
)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
require.Equal(t, 1, attempts)
|
||||
require.Equal(t, []time.Duration{time.Second}, slept)
|
||||
}
|
||||
92
backend/internal/service/ops_alerts.go
Normal file
92
backend/internal/service/ops_alerts.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
OpsAlertStatusFiring = "firing"
|
||||
OpsAlertStatusResolved = "resolved"
|
||||
)
|
||||
|
||||
const (
|
||||
OpsMetricSuccessRate = "success_rate"
|
||||
OpsMetricErrorRate = "error_rate"
|
||||
OpsMetricP95LatencyMs = "p95_latency_ms"
|
||||
OpsMetricP99LatencyMs = "p99_latency_ms"
|
||||
OpsMetricHTTP2Errors = "http2_errors"
|
||||
OpsMetricCPUUsagePercent = "cpu_usage_percent"
|
||||
OpsMetricMemoryUsagePercent = "memory_usage_percent"
|
||||
OpsMetricQueueDepth = "concurrency_queue_depth"
|
||||
)
|
||||
|
||||
type OpsAlertRule struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Enabled bool `json:"enabled"`
|
||||
MetricType string `json:"metric_type"`
|
||||
Operator string `json:"operator"`
|
||||
Threshold float64 `json:"threshold"`
|
||||
WindowMinutes int `json:"window_minutes"`
|
||||
SustainedMinutes int `json:"sustained_minutes"`
|
||||
Severity string `json:"severity"`
|
||||
NotifyEmail bool `json:"notify_email"`
|
||||
NotifyWebhook bool `json:"notify_webhook"`
|
||||
WebhookURL string `json:"webhook_url"`
|
||||
CooldownMinutes int `json:"cooldown_minutes"`
|
||||
DimensionFilters map[string]any `json:"dimension_filters,omitempty"`
|
||||
NotifyChannels []string `json:"notify_channels,omitempty"`
|
||||
NotifyConfig map[string]any `json:"notify_config,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type OpsAlertEvent struct {
|
||||
ID int64 `json:"id"`
|
||||
RuleID int64 `json:"rule_id"`
|
||||
Severity string `json:"severity"`
|
||||
Status string `json:"status"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
MetricValue float64 `json:"metric_value"`
|
||||
ThresholdValue float64 `json:"threshold_value"`
|
||||
FiredAt time.Time `json:"fired_at"`
|
||||
ResolvedAt *time.Time `json:"resolved_at"`
|
||||
EmailSent bool `json:"email_sent"`
|
||||
WebhookSent bool `json:"webhook_sent"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
func (s *OpsService) ListAlertRules(ctx context.Context) ([]OpsAlertRule, error) {
|
||||
return s.repo.ListAlertRules(ctx)
|
||||
}
|
||||
|
||||
func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||
return s.repo.GetActiveAlertEvent(ctx, ruleID)
|
||||
}
|
||||
|
||||
func (s *OpsService) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||
return s.repo.GetLatestAlertEvent(ctx, ruleID)
|
||||
}
|
||||
|
||||
func (s *OpsService) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) error {
|
||||
return s.repo.CreateAlertEvent(ctx, event)
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
|
||||
return s.repo.UpdateAlertEventStatus(ctx, eventID, status, resolvedAt)
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateAlertEventNotifications(ctx context.Context, eventID int64, emailSent, webhookSent bool) error {
|
||||
return s.repo.UpdateAlertEventNotifications(ctx, eventID, emailSent, webhookSent)
|
||||
}
|
||||
|
||||
func (s *OpsService) ListRecentSystemMetrics(ctx context.Context, windowMinutes, limit int) ([]OpsMetrics, error) {
|
||||
return s.repo.ListRecentSystemMetrics(ctx, windowMinutes, limit)
|
||||
}
|
||||
|
||||
func (s *OpsService) CountActiveAlerts(ctx context.Context) (int, error) {
|
||||
return s.repo.CountActiveAlerts(ctx)
|
||||
}
|
||||
203
backend/internal/service/ops_metrics_collector.go
Normal file
203
backend/internal/service/ops_metrics_collector.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/shirou/gopsutil/v4/cpu"
|
||||
"github.com/shirou/gopsutil/v4/mem"
|
||||
)
|
||||
|
||||
const (
|
||||
opsMetricsInterval = 1 * time.Minute
|
||||
opsMetricsCollectTimeout = 10 * time.Second
|
||||
|
||||
opsMetricsWindowShortMinutes = 1
|
||||
opsMetricsWindowLongMinutes = 5
|
||||
|
||||
bytesPerMB = 1024 * 1024
|
||||
cpuUsageSampleInterval = 0 * time.Second
|
||||
|
||||
percentScale = 100
|
||||
)
|
||||
|
||||
type OpsMetricsCollector struct {
|
||||
opsService *OpsService
|
||||
concurrencyService *ConcurrencyService
|
||||
interval time.Duration
|
||||
lastGCPauseTotal uint64
|
||||
lastGCPauseMu sync.Mutex
|
||||
stopCh chan struct{}
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
func NewOpsMetricsCollector(opsService *OpsService, concurrencyService *ConcurrencyService) *OpsMetricsCollector {
|
||||
return &OpsMetricsCollector{
|
||||
opsService: opsService,
|
||||
concurrencyService: concurrencyService,
|
||||
interval: opsMetricsInterval,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) Start() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.startOnce.Do(func() {
|
||||
if c.stopCh == nil {
|
||||
c.stopCh = make(chan struct{})
|
||||
}
|
||||
go c.run()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) Stop() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.stopOnce.Do(func() {
|
||||
if c.stopCh != nil {
|
||||
close(c.stopCh)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) run() {
|
||||
ticker := time.NewTicker(c.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
c.collectOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.collectOnce()
|
||||
case <-c.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) collectOnce() {
|
||||
if c.opsService == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsMetricsCollectTimeout)
|
||||
defer cancel()
|
||||
|
||||
now := time.Now()
|
||||
systemStats := c.collectSystemStats(ctx)
|
||||
queueDepth := c.collectQueueDepth(ctx)
|
||||
activeAlerts := c.collectActiveAlerts(ctx)
|
||||
|
||||
for _, window := range []int{opsMetricsWindowShortMinutes, opsMetricsWindowLongMinutes} {
|
||||
startTime := now.Add(-time.Duration(window) * time.Minute)
|
||||
windowStats, err := c.opsService.GetWindowStats(ctx, startTime, now)
|
||||
if err != nil {
|
||||
log.Printf("[OpsMetrics] failed to get window stats (%dm): %v", window, err)
|
||||
continue
|
||||
}
|
||||
|
||||
successRate, errorRate := computeRates(windowStats.SuccessCount, windowStats.ErrorCount)
|
||||
requestCount := windowStats.SuccessCount + windowStats.ErrorCount
|
||||
metric := &OpsMetrics{
|
||||
WindowMinutes: window,
|
||||
RequestCount: requestCount,
|
||||
SuccessCount: windowStats.SuccessCount,
|
||||
ErrorCount: windowStats.ErrorCount,
|
||||
SuccessRate: successRate,
|
||||
ErrorRate: errorRate,
|
||||
P95LatencyMs: windowStats.P95LatencyMs,
|
||||
P99LatencyMs: windowStats.P99LatencyMs,
|
||||
HTTP2Errors: windowStats.HTTP2Errors,
|
||||
ActiveAlerts: activeAlerts,
|
||||
CPUUsagePercent: systemStats.cpuUsage,
|
||||
MemoryUsedMB: systemStats.memoryUsedMB,
|
||||
MemoryTotalMB: systemStats.memoryTotalMB,
|
||||
MemoryUsagePercent: systemStats.memoryUsagePercent,
|
||||
HeapAllocMB: systemStats.heapAllocMB,
|
||||
GCPauseMs: systemStats.gcPauseMs,
|
||||
ConcurrencyQueueDepth: queueDepth,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
if err := c.opsService.RecordMetrics(ctx, metric); err != nil {
|
||||
log.Printf("[OpsMetrics] failed to record metrics (%dm): %v", window, err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func computeRates(successCount, errorCount int64) (float64, float64) {
|
||||
total := successCount + errorCount
|
||||
if total == 0 {
|
||||
// No traffic => no data. Rates are kept at 0 and request_count will be 0.
|
||||
// The UI should render this as N/A instead of "100% success".
|
||||
return 0, 0
|
||||
}
|
||||
successRate := float64(successCount) / float64(total) * percentScale
|
||||
errorRate := float64(errorCount) / float64(total) * percentScale
|
||||
return successRate, errorRate
|
||||
}
|
||||
|
||||
type opsSystemStats struct {
|
||||
cpuUsage float64
|
||||
memoryUsedMB int64
|
||||
memoryTotalMB int64
|
||||
memoryUsagePercent float64
|
||||
heapAllocMB int64
|
||||
gcPauseMs float64
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) collectSystemStats(ctx context.Context) opsSystemStats {
|
||||
stats := opsSystemStats{}
|
||||
|
||||
if percents, err := cpu.PercentWithContext(ctx, cpuUsageSampleInterval, false); err == nil && len(percents) > 0 {
|
||||
stats.cpuUsage = percents[0]
|
||||
}
|
||||
|
||||
if vm, err := mem.VirtualMemoryWithContext(ctx); err == nil {
|
||||
stats.memoryUsedMB = int64(vm.Used / bytesPerMB)
|
||||
stats.memoryTotalMB = int64(vm.Total / bytesPerMB)
|
||||
stats.memoryUsagePercent = vm.UsedPercent
|
||||
}
|
||||
|
||||
var memStats runtime.MemStats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
stats.heapAllocMB = int64(memStats.HeapAlloc / bytesPerMB)
|
||||
c.lastGCPauseMu.Lock()
|
||||
if c.lastGCPauseTotal != 0 && memStats.PauseTotalNs >= c.lastGCPauseTotal {
|
||||
stats.gcPauseMs = float64(memStats.PauseTotalNs-c.lastGCPauseTotal) / float64(time.Millisecond)
|
||||
}
|
||||
c.lastGCPauseTotal = memStats.PauseTotalNs
|
||||
c.lastGCPauseMu.Unlock()
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) collectQueueDepth(ctx context.Context) int {
|
||||
if c.concurrencyService == nil {
|
||||
return 0
|
||||
}
|
||||
depth, err := c.concurrencyService.GetTotalWaitCount(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[OpsMetrics] failed to get queue depth: %v", err)
|
||||
return 0
|
||||
}
|
||||
return depth
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) collectActiveAlerts(ctx context.Context) int {
|
||||
if c.opsService == nil {
|
||||
return 0
|
||||
}
|
||||
count, err := c.opsService.CountActiveAlerts(ctx)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return count
|
||||
}
|
||||
1020
backend/internal/service/ops_service.go
Normal file
1020
backend/internal/service/ops_service.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeySiteName,
|
||||
SettingKeySiteLogo,
|
||||
SettingKeySiteSubtitle,
|
||||
SettingKeyApiBaseUrl,
|
||||
SettingKeyAPIBaseURL,
|
||||
SettingKeyContactInfo,
|
||||
SettingKeyDocUrl,
|
||||
SettingKeyDocURL,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocUrl: settings[SettingKeyDocUrl],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
|
||||
|
||||
// 邮件服务设置(只有非空才更新密码)
|
||||
updates[SettingKeySmtpHost] = settings.SmtpHost
|
||||
updates[SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
|
||||
updates[SettingKeySmtpUsername] = settings.SmtpUsername
|
||||
if settings.SmtpPassword != "" {
|
||||
updates[SettingKeySmtpPassword] = settings.SmtpPassword
|
||||
updates[SettingKeySMTPHost] = settings.SMTPHost
|
||||
updates[SettingKeySMTPPort] = strconv.Itoa(settings.SMTPPort)
|
||||
updates[SettingKeySMTPUsername] = settings.SMTPUsername
|
||||
if settings.SMTPPassword != "" {
|
||||
updates[SettingKeySMTPPassword] = settings.SMTPPassword
|
||||
}
|
||||
updates[SettingKeySmtpFrom] = settings.SmtpFrom
|
||||
updates[SettingKeySmtpFromName] = settings.SmtpFromName
|
||||
updates[SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
|
||||
updates[SettingKeySMTPFrom] = settings.SMTPFrom
|
||||
updates[SettingKeySMTPFromName] = settings.SMTPFromName
|
||||
updates[SettingKeySMTPUseTLS] = strconv.FormatBool(settings.SMTPUseTLS)
|
||||
|
||||
// Cloudflare Turnstile 设置(只有非空才更新密钥)
|
||||
updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
|
||||
@@ -115,9 +115,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeySiteName] = settings.SiteName
|
||||
updates[SettingKeySiteLogo] = settings.SiteLogo
|
||||
updates[SettingKeySiteSubtitle] = settings.SiteSubtitle
|
||||
updates[SettingKeyApiBaseUrl] = settings.ApiBaseUrl
|
||||
updates[SettingKeyAPIBaseURL] = settings.APIBaseURL
|
||||
updates[SettingKeyContactInfo] = settings.ContactInfo
|
||||
updates[SettingKeyDocUrl] = settings.DocUrl
|
||||
updates[SettingKeyDocURL] = settings.DocURL
|
||||
|
||||
// 默认配置
|
||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
@@ -198,8 +198,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeySiteLogo: "",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeySmtpPort: "587",
|
||||
SettingKeySmtpUseTLS: "false",
|
||||
SettingKeySMTPPort: "587",
|
||||
SettingKeySMTPUseTLS: "false",
|
||||
}
|
||||
|
||||
return s.settingRepo.SetMultiple(ctx, defaults)
|
||||
@@ -210,26 +210,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result := &SystemSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
SmtpHost: settings[SettingKeySmtpHost],
|
||||
SmtpUsername: settings[SettingKeySmtpUsername],
|
||||
SmtpFrom: settings[SettingKeySmtpFrom],
|
||||
SmtpFromName: settings[SettingKeySmtpFromName],
|
||||
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
|
||||
SMTPHost: settings[SettingKeySMTPHost],
|
||||
SMTPUsername: settings[SettingKeySMTPUsername],
|
||||
SMTPFrom: settings[SettingKeySMTPFrom],
|
||||
SMTPFromName: settings[SettingKeySMTPFromName],
|
||||
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocUrl: settings[SettingKeyDocUrl],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
}
|
||||
|
||||
// 解析整数类型
|
||||
if port, err := strconv.Atoi(settings[SettingKeySmtpPort]); err == nil {
|
||||
result.SmtpPort = port
|
||||
if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil {
|
||||
result.SMTPPort = port
|
||||
} else {
|
||||
result.SmtpPort = 587
|
||||
result.SMTPPort = 587
|
||||
}
|
||||
|
||||
if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil {
|
||||
@@ -245,8 +245,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.DefaultBalance = s.cfg.Default.UserBalance
|
||||
}
|
||||
|
||||
// 敏感信息直接返回,方便测试连接时使用
|
||||
result.SmtpPassword = settings[SettingKeySmtpPassword]
|
||||
// 敏感信息直接返回,方便测试连接时使用
|
||||
result.SMTPPassword = settings[SettingKeySMTPPassword]
|
||||
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
|
||||
|
||||
return result
|
||||
@@ -278,28 +278,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
|
||||
return value
|
||||
}
|
||||
|
||||
// GenerateAdminApiKey 生成新的管理员 API Key
|
||||
func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error) {
|
||||
// GenerateAdminAPIKey 生成新的管理员 API Key
|
||||
func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) {
|
||||
// 生成 32 字节随机数 = 64 位十六进制字符
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
key := AdminApiKeyPrefix + hex.EncodeToString(bytes)
|
||||
key := AdminAPIKeyPrefix + hex.EncodeToString(bytes)
|
||||
|
||||
// 存储到 settings 表
|
||||
if err := s.settingRepo.Set(ctx, SettingKeyAdminApiKey, key); err != nil {
|
||||
if err := s.settingRepo.Set(ctx, SettingKeyAdminAPIKey, key); err != nil {
|
||||
return "", fmt.Errorf("save admin api key: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GetAdminApiKeyStatus 获取管理员 API Key 状态
|
||||
// GetAdminAPIKeyStatus 获取管理员 API Key 状态
|
||||
// 返回脱敏的 key、是否存在、错误
|
||||
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
|
||||
func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return "", false, nil
|
||||
@@ -320,10 +320,10 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
|
||||
return maskedKey, true, nil
|
||||
}
|
||||
|
||||
// GetAdminApiKey 获取完整的管理员 API Key(仅供内部验证使用)
|
||||
// GetAdminAPIKey 获取完整的管理员 API Key(仅供内部验证使用)
|
||||
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
|
||||
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
|
||||
func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return "", nil // 未配置,返回空字符串
|
||||
@@ -333,7 +333,7 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// DeleteAdminApiKey 删除管理员 API Key
|
||||
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
|
||||
return s.settingRepo.Delete(ctx, SettingKeyAdminApiKey)
|
||||
// DeleteAdminAPIKey 删除管理员 API Key
|
||||
func (s *SettingService) DeleteAdminAPIKey(ctx context.Context) error {
|
||||
return s.settingRepo.Delete(ctx, SettingKeyAdminAPIKey)
|
||||
}
|
||||
|
||||
@@ -4,13 +4,13 @@ type SystemSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
|
||||
SmtpHost string
|
||||
SmtpPort int
|
||||
SmtpUsername string
|
||||
SmtpPassword string
|
||||
SmtpFrom string
|
||||
SmtpFromName string
|
||||
SmtpUseTLS bool
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
SMTPUsername string
|
||||
SMTPPassword string
|
||||
SMTPFrom string
|
||||
SMTPFromName string
|
||||
SMTPUseTLS bool
|
||||
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
@@ -19,9 +19,9 @@ type SystemSettings struct {
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
ApiBaseUrl string
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocUrl string
|
||||
DocURL string
|
||||
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
@@ -35,8 +35,8 @@ type PublicSettings struct {
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
ApiBaseUrl string
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocUrl string
|
||||
DocURL string
|
||||
Version string
|
||||
}
|
||||
|
||||
@@ -197,7 +197,7 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
|
||||
{
|
||||
name: "anthropic api-key - cannot refresh",
|
||||
platform: PlatformAnthropic,
|
||||
accType: AccountTypeApiKey,
|
||||
accType: AccountTypeAPIKey,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
|
||||
@@ -79,7 +79,7 @@ type ReleaseInfo struct {
|
||||
Name string `json:"name"`
|
||||
Body string `json:"body"`
|
||||
PublishedAt string `json:"published_at"`
|
||||
HtmlURL string `json:"html_url"`
|
||||
HTMLURL string `json:"html_url"`
|
||||
Assets []Asset `json:"assets,omitempty"`
|
||||
}
|
||||
|
||||
@@ -96,13 +96,13 @@ type GitHubRelease struct {
|
||||
Name string `json:"name"`
|
||||
Body string `json:"body"`
|
||||
PublishedAt string `json:"published_at"`
|
||||
HtmlUrl string `json:"html_url"`
|
||||
HTMLURL string `json:"html_url"`
|
||||
Assets []GitHubAsset `json:"assets"`
|
||||
}
|
||||
|
||||
type GitHubAsset struct {
|
||||
Name string `json:"name"`
|
||||
BrowserDownloadUrl string `json:"browser_download_url"`
|
||||
BrowserDownloadURL string `json:"browser_download_url"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
@@ -285,7 +285,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
|
||||
for i, a := range release.Assets {
|
||||
assets[i] = Asset{
|
||||
Name: a.Name,
|
||||
DownloadURL: a.BrowserDownloadUrl,
|
||||
DownloadURL: a.BrowserDownloadURL,
|
||||
Size: a.Size,
|
||||
}
|
||||
}
|
||||
@@ -298,7 +298,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
|
||||
Name: release.Name,
|
||||
Body: release.Body,
|
||||
PublishedAt: release.PublishedAt,
|
||||
HtmlURL: release.HtmlUrl,
|
||||
HTMLURL: release.HTMLURL,
|
||||
Assets: assets,
|
||||
},
|
||||
Cached: false,
|
||||
|
||||
@@ -10,7 +10,7 @@ const (
|
||||
type UsageLog struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
ApiKeyID int64
|
||||
APIKeyID int64
|
||||
AccountID int64
|
||||
RequestID string
|
||||
Model string
|
||||
@@ -42,7 +42,7 @@ type UsageLog struct {
|
||||
CreatedAt time.Time
|
||||
|
||||
User *User
|
||||
ApiKey *ApiKey
|
||||
APIKey *APIKey
|
||||
Account *Account
|
||||
Group *Group
|
||||
Subscription *UserSubscription
|
||||
|
||||
@@ -17,7 +17,7 @@ var (
|
||||
// CreateUsageLogRequest 创建使用日志请求
|
||||
type CreateUsageLogRequest struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
@@ -75,7 +75,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
|
||||
// 创建使用日志
|
||||
usageLog := &UsageLog{
|
||||
UserID: req.UserID,
|
||||
ApiKeyID: req.ApiKeyID,
|
||||
APIKeyID: req.APIKeyID,
|
||||
AccountID: req.AccountID,
|
||||
RequestID: req.RequestID,
|
||||
Model: req.Model,
|
||||
@@ -128,9 +128,9 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
|
||||
return logs, pagination, nil
|
||||
}
|
||||
|
||||
// ListByApiKey 获取API Key的使用日志列表
|
||||
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
|
||||
// ListByAPIKey 获取API Key的使用日志列表
|
||||
func (s *UsageService) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByAPIKey(ctx, apiKeyID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
}
|
||||
@@ -165,9 +165,9 @@ func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTi
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetStatsByApiKey 获取API Key的使用统计
|
||||
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
stats, err := s.usageRepo.GetApiKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
|
||||
// GetStatsByAPIKey 获取API Key的使用统计
|
||||
func (s *UsageService) GetStatsByAPIKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
stats, err := s.usageRepo.GetAPIKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key stats: %w", err)
|
||||
}
|
||||
@@ -270,9 +270,9 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetBatchApiKeyUsageStats returns today/total actual_cost for given api keys.
|
||||
func (s *UsageService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
|
||||
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
|
||||
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ type User struct {
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
ApiKeys []ApiKey
|
||||
APIKeys []APIKey
|
||||
Subscriptions []UserSubscription
|
||||
}
|
||||
|
||||
|
||||
@@ -73,6 +73,20 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideOpsMetricsCollector creates and starts OpsMetricsCollector.
|
||||
func ProvideOpsMetricsCollector(opsService *OpsService, concurrencyService *ConcurrencyService) *OpsMetricsCollector {
|
||||
svc := NewOpsMetricsCollector(opsService, concurrencyService)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideOpsAlertService creates and starts OpsAlertService.
|
||||
func ProvideOpsAlertService(opsService *OpsService, userService *UserService, emailService *EmailService) *OpsAlertService {
|
||||
svc := NewOpsAlertService(opsService, userService, emailService)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
|
||||
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
|
||||
svc := NewConcurrencyService(cache)
|
||||
@@ -87,13 +101,14 @@ var ProviderSet = wire.NewSet(
|
||||
// Core services
|
||||
NewAuthService,
|
||||
NewUserService,
|
||||
NewApiKeyService,
|
||||
NewAPIKeyService,
|
||||
NewGroupService,
|
||||
NewAccountService,
|
||||
NewProxyService,
|
||||
NewRedeemService,
|
||||
NewUsageService,
|
||||
NewDashboardService,
|
||||
NewOpsService,
|
||||
ProvidePricingService,
|
||||
NewBillingService,
|
||||
NewBillingCacheService,
|
||||
@@ -125,5 +140,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideTimingWheelService,
|
||||
ProvideDeferredService,
|
||||
ProvideAntigravityQuotaRefresher,
|
||||
ProvideOpsMetricsCollector,
|
||||
ProvideOpsAlertService,
|
||||
NewUserAttributeService,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user