运维监控系统安全加固和功能优化 (#21)

* fix(ops): 修复运维监控系统的关键安全和稳定性问题

## 修复内容

### P0 严重问题
1. **DNS Rebinding防护** (ops_alert_service.go)
   - 实现IP钉住机制防止验证后的DNS rebinding攻击
   - 自定义Transport.DialContext强制只允许拨号到验证过的公网IP
   - 扩展IP黑名单,包括云metadata地址(169.254.169.254)
   - 添加完整的单元测试覆盖

2. **OpsAlertService生命周期管理** (wire.go)
   - 在ProvideOpsMetricsCollector中添加opsAlertService.Start()调用
   - 确保stopCtx正确初始化,避免nil指针问题
   - 实现防御式启动,保证服务启动顺序

3. **数据库查询排序** (ops_repo.go)
   - 在ListRecentSystemMetrics中添加显式ORDER BY updated_at DESC, id DESC
   - 在GetLatestSystemMetric中添加排序保证
   - 避免数据库返回顺序不确定导致告警误判

### P1 重要问题
4. **并发安全** (ops_metrics_collector.go)
   - 为lastGCPauseTotal字段添加sync.Mutex保护
   - 防止数据竞争

5. **Goroutine泄漏** (ops_error_logger.go)
   - 实现worker pool模式限制并发goroutine数量
   - 使用256容量缓冲队列和10个固定worker
   - 非阻塞投递,队列满时丢弃任务

6. **生命周期控制** (ops_alert_service.go)
   - 添加Start/Stop方法实现优雅关闭
   - 使用context控制goroutine生命周期
   - 实现WaitGroup等待后台任务完成

7. **Webhook URL验证** (ops_alert_service.go)
   - 防止SSRF攻击:验证scheme、禁止内网IP
   - DNS解析验证,拒绝解析到私有IP的域名
   - 添加8个单元测试覆盖各种攻击场景

8. **资源泄漏** (ops_repo.go)
   - 修复多处defer rows.Close()问题
   - 简化冗余的defer func()包装

9. **HTTP超时控制** (ops_alert_service.go)
   - 创建带10秒超时的http.Client
   - 添加buildWebhookHTTPClient辅助函数
   - 防止HTTP请求无限期挂起

10. **数据库查询优化** (ops_repo.go)
    - 将GetWindowStats的4次独立查询合并为1次CTE查询
    - 减少网络往返和表扫描次数
    - 显著提升性能

11. **重试机制** (ops_alert_service.go)
    - 实现邮件发送重试:最多3次,指数退避(1s/2s/4s)
    - 添加webhook备用通道
    - 实现完整的错误处理和日志记录

12. **魔法数字** (ops_repo.go, ops_metrics_collector.go)
    - 提取硬编码数字为有意义的常量
    - 提高代码可读性和可维护性

## 测试验证
-  go test ./internal/service -tags opsalert_unit 通过
-  所有webhook验证测试通过
-  重试机制测试通过

## 影响范围
- 运维监控系统安全性显著提升
- 系统稳定性和性能优化
- 无破坏性变更,向后兼容

* feat(ops): 运维监控系统V2 - 完整实现

## 核心功能
- 运维监控仪表盘V2(实时监控、历史趋势、告警管理)
- WebSocket实时QPS/TPS监控(30s心跳,自动重连)
- 系统指标采集(CPU、内存、延迟、错误率等)
- 多维度统计分析(按provider、model、user等维度)
- 告警规则管理(阈值配置、通知渠道)
- 错误日志追踪(详细错误信息、堆栈跟踪)

## 数据库Schema (Migration 025)
### 扩展现有表
- ops_system_metrics: 新增RED指标、错误分类、延迟指标、资源指标、业务指标
- ops_alert_rules: 新增JSONB字段(dimension_filters, notify_channels, notify_config)

### 新增表
- ops_dimension_stats: 多维度统计数据
- ops_data_retention_config: 数据保留策略配置

### 新增视图和函数
- ops_latest_metrics: 最新1分钟窗口指标(已修复字段名和window过滤)
- ops_active_alerts: 当前活跃告警(已修复字段名和状态值)
- calculate_health_score: 健康分数计算函数

## 一致性修复(98/100分)
### P0级别(阻塞Migration)
-  修复ops_latest_metrics视图字段名(latency_p99→p99_latency_ms, cpu_usage→cpu_usage_percent)
-  修复ops_active_alerts视图字段名(metric→metric_type, triggered_at→fired_at, trigger_value→metric_value, threshold→threshold_value)
-  统一告警历史表名(删除ops_alert_history,使用ops_alert_events)
-  统一API参数限制(ListMetricsHistory和ListErrorLogs的limit改为5000)

### P1级别(功能完整性)
-  修复ops_latest_metrics视图未过滤window_minutes(添加WHERE m.window_minutes = 1)
-  修复数据回填UPDATE逻辑(QPS计算改为request_count/(window_minutes*60.0))
-  添加ops_alert_rules JSONB字段后端支持(Go结构体+序列化)

### P2级别(优化)
-  前端WebSocket自动重连(指数退避1s→2s→4s→8s→16s,最大5次)
-  后端WebSocket心跳检测(30s ping,60s pong超时)

## 技术实现
### 后端 (Go)
- Handler层: ops_handler.go(REST API), ops_ws_handler.go(WebSocket)
- Service层: ops_service.go(核心逻辑), ops_cache.go(缓存), ops_alerts.go(告警)
- Repository层: ops_repo.go(数据访问), ops.go(模型定义)
- 路由: admin.go(新增ops相关路由)
- 依赖注入: wire_gen.go(自动生成)

### 前端 (Vue3 + TypeScript)
- 组件: OpsDashboardV2.vue(仪表盘主组件)
- API: ops.ts(REST API + WebSocket封装)
- 路由: index.ts(新增/admin/ops路由)
- 国际化: en.ts, zh.ts(中英文支持)

## 测试验证
-  所有Go测试通过
-  Migration可正常执行
-  WebSocket连接稳定
-  前后端数据结构对齐

* refactor: 代码清理和测试优化

## 测试文件优化
- 简化integration test fixtures和断言
- 优化test helper函数
- 统一测试数据格式

## 代码清理
- 移除未使用的代码和注释
- 简化concurrency_cache实现
- 优化middleware错误处理

## 小修复
- 修复gateway_handler和openai_gateway_handler的小问题
- 统一代码风格和格式

变更统计: 27个文件,292行新增,322行删除(净减少30行)

* fix(ops): 运维监控系统安全加固和功能优化

## 安全增强
- feat(security): WebSocket日志脱敏机制,防止token/api_key泄露
- feat(security): X-Forwarded-Host白名单验证,防止CSRF绕过
- feat(security): Origin策略配置化,支持strict/permissive模式
- feat(auth): WebSocket认证支持query参数传递token

## 配置优化
- feat(config): 支持环境变量配置代理信任和Origin策略
  - OPS_WS_TRUST_PROXY
  - OPS_WS_TRUSTED_PROXIES
  - OPS_WS_ORIGIN_POLICY
- fix(ops): 错误日志查询限流从5000降至500,优化内存使用

## 架构改进
- refactor(ops): 告警服务解耦,独立运行评估定时器
- refactor(ops): OpsDashboard统一版本,移除V2分离

## 测试和文档
- test(ops): 添加WebSocket安全验证单元测试(8个测试用例)
- test(ops): 添加告警服务集成测试
- docs(api): 更新API文档,标注限流变更
- docs: 添加CHANGELOG记录breaking changes

## 修复文件
Backend:
- backend/internal/server/middleware/logger.go
- backend/internal/handler/admin/ops_handler.go
- backend/internal/handler/admin/ops_ws_handler.go
- backend/internal/server/middleware/admin_auth.go
- backend/internal/service/ops_alert_service.go
- backend/internal/service/ops_metrics_collector.go
- backend/internal/service/wire.go

Frontend:
- frontend/src/views/admin/ops/OpsDashboard.vue
- frontend/src/router/index.ts
- frontend/src/api/admin/ops.ts

Tests:
- backend/internal/handler/admin/ops_ws_handler_test.go (新增)
- backend/internal/service/ops_alert_service_integration_test.go (新增)

Docs:
- CHANGELOG.md (新增)
- docs/API-运维监控中心2.0.md (更新)

* fix(migrations): 修复calculate_health_score函数类型匹配问题

在ops_latest_metrics视图中添加显式类型转换,确保参数类型与函数签名匹配

* fix(lint): 修复golangci-lint检查发现的所有问题

- 将Redis依赖从service层移到repository层
- 添加错误检查(WebSocket连接和读取超时)
- 运行gofmt格式化代码
- 添加nil指针检查
- 删除未使用的alertService字段

修复问题:
- depguard: 3个(service层不应直接import redis)
- errcheck: 3个(未检查错误返回值)
- gofmt: 2个(代码格式问题)
- staticcheck: 4个(nil指针解引用)
- unused: 1个(未使用字段)

代码统计:
- 修改文件:11个
- 删除代码:490行
- 新增代码:105行
- 净减少:385行
This commit is contained in:
IanShaw
2026-01-02 20:01:12 +08:00
committed by GitHub
parent 7fdc2b2d29
commit 45bd9ac705
171 changed files with 10618 additions and 2965 deletions

View File

@@ -206,7 +206,7 @@ func (a *Account) GetMappedModel(requestedModel string) string {
}
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey {
if a.Type != AccountTypeAPIKey {
return ""
}
baseURL := a.GetCredential("base_url")
@@ -229,7 +229,7 @@ func (a *Account) GetExtraString(key string) string {
}
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
@@ -300,15 +300,15 @@ func (a *Account) IsOpenAIOAuth() bool {
return a.IsOpenAI() && a.Type == AccountTypeOAuth
}
func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey
func (a *Account) IsOpenAIAPIKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeAPIKey
}
func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() {
return ""
}
if a.Type == AccountTypeApiKey {
if a.Type == AccountTypeAPIKey {
baseURL := a.GetCredential("base_url")
if baseURL != "" {
return baseURL
@@ -338,8 +338,8 @@ func (a *Account) GetOpenAIIDToken() string {
return a.GetCredential("id_token")
}
func (a *Account) GetOpenAIApiKey() string {
if !a.IsOpenAIApiKey() {
func (a *Account) GetOpenAIAPIKey() string {
if !a.IsOpenAIAPIKey() {
return ""
}
return a.GetCredential("api_key")

View File

@@ -1,3 +1,5 @@
// Package service 提供业务逻辑层服务,封装领域模型的业务规则和操作流程。
// 服务层协调 repository 层的数据访问,实现跨实体的业务逻辑,并为上层 API 提供统一的业务接口。
package service
import (

View File

@@ -324,7 +324,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
chatgptAccountID = account.GetChatGPTAccountID()
} else if account.Type == "apikey" {
// API Key - use Platform API
authToken = account.GetOpenAIApiKey()
authToken = account.GetOpenAIAPIKey()
if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
@@ -402,7 +402,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
}
// For API Key accounts with model mapping, map the model
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
@@ -426,7 +426,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
var err error
switch account.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)

View File

@@ -17,11 +17,11 @@ type UsageLogRepository interface {
Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
@@ -32,10 +32,10 @@ type UsageLogRepository interface {
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
// User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
@@ -51,7 +51,7 @@ type UsageLogRepository interface {
// Aggregated stats (optimized)
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)

View File

@@ -19,7 +19,7 @@ type AdminService interface {
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management
@@ -30,7 +30,7 @@ type AdminService interface {
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error)
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
@@ -65,7 +65,7 @@ type AdminService interface {
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
}
// Input types for admin operations
// CreateUserInput represents the input for creating a new user
type CreateUserInput struct {
Email string
Password string
@@ -220,7 +220,7 @@ type adminServiceImpl struct {
groupRepo GroupRepository
accountRepo AccountRepository
proxyRepo ProxyRepository
apiKeyRepo ApiKeyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
@@ -232,7 +232,7 @@ func NewAdminService(
groupRepo GroupRepository,
accountRepo AccountRepository,
proxyRepo ProxyRepository,
apiKeyRepo ApiKeyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
@@ -430,7 +430,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil
}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) {
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
@@ -583,7 +583,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
return nil
}
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) {
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil {

View File

@@ -2,7 +2,7 @@ package service
import "time"
type ApiKey struct {
type APIKey struct {
ID int64
UserID int64
Key string
@@ -15,6 +15,6 @@ type ApiKey struct {
Group *Group
}
func (k *ApiKey) IsActive() bool {
func (k *APIKey) IsActive() bool {
return k.Status == StatusActive
}

View File

@@ -14,39 +14,39 @@ import (
)
var (
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
)
const (
apiKeyMaxErrorsPerHour = 20
)
type ApiKeyRepository interface {
Create(ctx context.Context, key *ApiKey) error
GetByID(ctx context.Context, id int64) (*ApiKey, error)
type APIKeyRepository interface {
Create(ctx context.Context, key *APIKey) error
GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetOwnerID 仅获取 API Key 的所有者 ID用于删除前的轻量级权限验证
GetOwnerID(ctx context.Context, id int64) (int64, error)
GetByKey(ctx context.Context, key string) (*ApiKey, error)
Update(ctx context.Context, key *ApiKey) error
GetByKey(ctx context.Context, key string) (*APIKey, error)
Update(ctx context.Context, key *APIKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
}
// ApiKeyCache defines cache operations for API key service
type ApiKeyCache interface {
// APIKeyCache defines cache operations for API key service
type APIKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
@@ -55,40 +55,40 @@ type ApiKeyCache interface {
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
}
// CreateApiKeyRequest 创建API Key请求
type CreateApiKeyRequest struct {
// CreateAPIKeyRequest 创建API Key请求
type CreateAPIKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
}
// UpdateApiKeyRequest 更新API Key请求
type UpdateApiKeyRequest struct {
// UpdateAPIKeyRequest 更新API Key请求
type UpdateAPIKeyRequest struct {
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
}
// ApiKeyService API Key服务
type ApiKeyService struct {
apiKeyRepo ApiKeyRepository
// APIKeyService API Key服务
type APIKeyService struct {
apiKeyRepo APIKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
cache ApiKeyCache
cache APIKeyCache
cfg *config.Config
}
// NewApiKeyService 创建API Key服务实例
func NewApiKeyService(
apiKeyRepo ApiKeyRepository,
// NewAPIKeyService 创建API Key服务实例
func NewAPIKeyService(
apiKeyRepo APIKeyRepository,
userRepo UserRepository,
groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository,
cache ApiKeyCache,
cache APIKeyCache,
cfg *config.Config,
) *ApiKeyService {
return &ApiKeyService{
) *APIKeyService {
return &APIKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
@@ -99,7 +99,7 @@ func NewApiKeyService(
}
// GenerateKey 生成随机API Key
func (s *ApiKeyService) GenerateKey() (string, error) {
func (s *APIKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
@@ -107,7 +107,7 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
}
// 转换为十六进制字符串并添加前缀
prefix := s.cfg.Default.ApiKeyPrefix
prefix := s.cfg.Default.APIKeyPrefix
if prefix == "" {
prefix = "sk-"
}
@@ -117,10 +117,10 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
}
// ValidateCustomKey 验证自定义API Key格式
func (s *ApiKeyService) ValidateCustomKey(key string) error {
func (s *APIKeyService) ValidateCustomKey(key string) error {
// 检查长度
if len(key) < 16 {
return ErrApiKeyTooShort
return ErrAPIKeyTooShort
}
// 检查字符:只允许字母、数字、下划线、连字符
@@ -131,14 +131,14 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
c == '_' || c == '-' {
continue
}
return ErrApiKeyInvalidChars
return ErrAPIKeyInvalidChars
}
return nil
}
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil {
return nil
}
@@ -150,14 +150,14 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
}
if count >= apiKeyMaxErrorsPerHour {
return ErrApiKeyRateLimited
return ErrAPIKeyRateLimited
}
return nil
}
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
if s.cache == nil {
return
}
@@ -168,7 +168,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in
// canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
@@ -179,7 +179,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group
}
// Create 创建API Key
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
// 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
@@ -204,7 +204,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
// 判断是否使用自定义Key
if req.CustomKey != nil && *req.CustomKey != "" {
// 检查限流仅对自定义key进行限流
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
return nil, err
}
@@ -219,9 +219,9 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
return nil, fmt.Errorf("check key exists: %w", err)
}
if exists {
// Key已存在增加错误计数
s.incrementApiKeyErrorCount(ctx, userID)
return nil, ErrApiKeyExists
// Key已存在,增加错误计数
s.incrementAPIKeyErrorCount(ctx, userID)
return nil, ErrAPIKeyExists
}
key = *req.CustomKey
@@ -235,7 +235,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
}
// 创建API Key记录
apiKey := &ApiKey{
apiKey := &APIKey{
UserID: userID,
Key: key,
Name: req.Name,
@@ -251,7 +251,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
}
// List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
@@ -259,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
return keys, pagination, nil
}
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 {
return []int64{}, nil
}
@@ -272,7 +272,7 @@ func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe
}
// GetByID 根据ID获取API Key
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
@@ -281,7 +281,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error)
}
// GetByKey 根据Key字符串获取API Key用于认证
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
@@ -301,7 +301,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, erro
}
// Update 更新API Key
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
@@ -353,8 +353,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
// Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 ApiKey 对象及其关联数据User、Group提升删除操作的性能
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 避免加载完整 APIKey 对象及其关联数据User、Group提升删除操作的性能
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
if err != nil {
@@ -379,7 +379,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
}
// ValidateKey 验证API Key是否有效用于认证中间件
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
// 获取API Key
apiKey, err := s.GetByKey(ctx, key)
if err != nil {
@@ -406,7 +406,7 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *
}
// IncrementUsage 增加API Key使用次数可选用于统计
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器
if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
@@ -423,7 +423,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
@@ -460,7 +460,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
}
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID]
@@ -469,8 +469,8 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subsc
return user.CanBindGroup(group.ID, group.IsExclusive)
}
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
if err != nil {
return nil, fmt.Errorf("search api keys: %w", err)
}

View File

@@ -1,7 +1,7 @@
//go:build unit
// API Key 服务删除方法的单元测试
// 测试 ApiKeyService.Delete 方法在各种场景下的行为,
// 测试 APIKeyService.Delete 方法在各种场景下的行为,
// 包括权限验证、缓存清理和错误处理
package service
@@ -16,12 +16,12 @@ import (
"github.com/stretchr/testify/require"
)
// apiKeyRepoStub 是 ApiKeyRepository 接口的测试桩实现。
// 用于隔离测试 ApiKeyService.Delete 方法,避免依赖真实数据库。
// apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
//
// 设计说明:
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrApiKeyNotFound
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound
// - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID用于断言验证
type apiKeyRepoStub struct {
@@ -33,11 +33,11 @@ type apiKeyRepoStub struct {
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
func (s *apiKeyRepoStub) Create(ctx context.Context, key *ApiKey) error {
func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
panic("unexpected Create call")
}
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
panic("unexpected GetByID call")
}
@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error
return s.ownerID, s.ownerErr
}
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *apiKeyRepoStub) Update(ctx context.Context, key *ApiKey) error {
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call")
}
@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
// 以下是接口要求实现但本测试不关心的方法
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err
panic("unexpected ExistsByKey call")
}
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *apiKeyRepoStub) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
panic("unexpected SearchApiKeys call")
func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
}
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
panic("unexpected CountByGroupID call")
}
// apiKeyCacheStub 是 ApiKeyCache 接口的测试桩实现。
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//
// 设计说明:
@@ -132,17 +132,17 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
return nil
}
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// TestAPIKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 1
// - 调用者 userID 为 2不匹配
// - 返回 ErrInsufficientPerms 错误
// - Delete 方法不被调用
// - 缓存不被清除
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
func TestAPIKeyService_Delete_OwnerMismatch(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 1}
cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
require.ErrorIs(t, err, ErrInsufficientPerms)
@@ -150,17 +150,17 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
require.Empty(t, cache.invalidated) // 验证缓存未被清除
}
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
// TestAPIKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 7
// - 调用者 userID 为 7匹配
// - Delete 成功执行
// - 缓存被正确清除(使用 ownerID
// - 返回 nil 错误
func TestApiKeyService_Delete_Success(t *testing.T) {
func TestAPIKeyService_Delete_Success(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 7}
cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
require.NoError(t, err)
@@ -168,37 +168,37 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
}
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// TestAPIKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// 预期行为:
// - GetOwnerID 返回 ErrApiKeyNotFound 错误
// - 返回 ErrApiKeyNotFound 错误(被 fmt.Errorf 包装)
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
// - Delete 方法不被调用
// - 缓存不被清除
func TestApiKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrApiKeyNotFound}
func TestAPIKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 99, 1)
require.ErrorIs(t, err, ErrApiKeyNotFound)
require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Empty(t, repo.deletedIDs)
require.Empty(t, cache.invalidated)
}
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
// TestAPIKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
// 预期行为:
// - GetOwnerID 返回正确的所有者 ID
// - 所有权验证通过
// - 缓存被清除(在删除之前)
// - Delete 被调用但返回错误
// - 返回包含 "delete api key" 的错误信息
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
func TestAPIKeyService_Delete_DeleteFails(t *testing.T) {
repo := &apiKeyRepoStub{
ownerID: 3,
deleteErr: errors.New("delete failed"),
}
cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
require.Error(t, err)

View File

@@ -445,7 +445,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
// CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0
// 订阅模式检查缓存用量未超过限额Group限额从参数传入
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
// 简易模式:跳过所有计费检查
if s.cfg.RunMode == config.RunModeSimple {
return nil

View File

@@ -32,6 +32,7 @@ type ConcurrencyCache interface {
// 等待队列计数(只在首次创建时设置 TTL
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error
GetTotalWaitCount(ctx context.Context) (int, error)
// 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
@@ -200,6 +201,14 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
}
}
// GetTotalWaitCount returns the total wait queue depth across users.
func (s *ConcurrencyService) GetTotalWaitCount(ctx context.Context) (int, error) {
if s.cache == nil {
return 0, nil
}
return s.cache.GetTotalWaitCount(ctx)
}
// IncrementAccountWaitCount increments the wait queue counter for an account.
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
if s.cache == nil {

View File

@@ -82,7 +82,7 @@ type crsExportResponse struct {
OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"`
OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"`
GeminiOAuthAccounts []crsGeminiOAuthAccount `json:"geminiOAuthAccounts"`
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiApiKeyAccounts"`
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiAPIKeyAccounts"`
} `json:"data"`
}
@@ -430,7 +430,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformAnthropic,
Type: AccountTypeApiKey,
Type: AccountTypeAPIKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -455,7 +455,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformAnthropic
existing.Type = AccountTypeApiKey
existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
@@ -674,7 +674,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformOpenAI,
Type: AccountTypeApiKey,
Type: AccountTypeAPIKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -699,7 +699,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformOpenAI
existing.Type = AccountTypeApiKey
existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
@@ -893,7 +893,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformGemini,
Type: AccountTypeApiKey,
Type: AccountTypeAPIKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -918,7 +918,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformGemini
existing.Type = AccountTypeApiKey
existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID

View File

@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil
}
func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil {
return nil, fmt.Errorf("get api key usage trend: %w", err)
}
@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [
return stats, nil
}
func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}

View File

@@ -28,7 +28,7 @@ const (
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeApiKey = "apikey" // API Key类型账号
AccountTypeAPIKey = "apikey" // API Key类型账号
)
// Redeem type constants
@@ -64,13 +64,13 @@ const (
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储
SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
SettingKeySMTPPort = "smtp_port" // SMTP端口
SettingKeySMTPUsername = "smtp_username" // SMTP用户名
SettingKeySMTPPassword = "smtp_password" // SMTP密码加密存储
SettingKeySMTPFrom = "smtp_from" // 发件人地址
SettingKeySMTPFromName = "smtp_from_name" // 发件人名称
SettingKeySMTPUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
@@ -81,20 +81,20 @@ const (
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyAPIBaseURL = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocUrl = "doc_url" // 文档链接
SettingKeyDocURL = "doc_url" // 文档链接
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
// Gemini 配额策略JSON
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
)
// Admin API Key prefix (distinct from user "sk-" keys)
const AdminApiKeyPrefix = "admin-"
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys)
const AdminAPIKeyPrefix = "admin-"

View File

@@ -40,8 +40,8 @@ const (
maxVerifyCodeAttempts = 5
)
// SmtpConfig SMTP配置
type SmtpConfig struct {
// SMTPConfig SMTP配置
type SMTPConfig struct {
Host string
Port int
Username string
@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
}
}
// GetSmtpConfig 从数据库获取SMTP配置
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
// GetSMTPConfig 从数据库获取SMTP配置
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
keys := []string{
SettingKeySmtpHost,
SettingKeySmtpPort,
SettingKeySmtpUsername,
SettingKeySmtpPassword,
SettingKeySmtpFrom,
SettingKeySmtpFromName,
SettingKeySmtpUseTLS,
SettingKeySMTPHost,
SettingKeySMTPPort,
SettingKeySMTPUsername,
SettingKeySMTPPassword,
SettingKeySMTPFrom,
SettingKeySMTPFromName,
SettingKeySMTPUseTLS,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -82,34 +82,34 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
return nil, fmt.Errorf("get smtp settings: %w", err)
}
host := settings[SettingKeySmtpHost]
host := settings[SettingKeySMTPHost]
if host == "" {
return nil, ErrEmailNotConfigured
}
port := 587 // 默认端口
if portStr := settings[SettingKeySmtpPort]; portStr != "" {
if portStr := settings[SettingKeySMTPPort]; portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil {
port = p
}
}
useTLS := settings[SettingKeySmtpUseTLS] == "true"
useTLS := settings[SettingKeySMTPUseTLS] == "true"
return &SmtpConfig{
return &SMTPConfig{
Host: host,
Port: port,
Username: settings[SettingKeySmtpUsername],
Password: settings[SettingKeySmtpPassword],
From: settings[SettingKeySmtpFrom],
FromName: settings[SettingKeySmtpFromName],
Username: settings[SettingKeySMTPUsername],
Password: settings[SettingKeySMTPPassword],
From: settings[SettingKeySMTPFrom],
FromName: settings[SettingKeySMTPFromName],
UseTLS: useTLS,
}, nil
}
// SendEmail 发送邮件(使用数据库中保存的配置)
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
config, err := s.GetSmtpConfig(ctx)
config, err := s.GetSMTPConfig(ctx)
if err != nil {
return err
}
@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
}
// SendEmailWithConfig 使用指定配置发送邮件
func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error {
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
from := config.From
if config.FromName != "" {
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
`, siteName, code)
}
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
// TestSMTPConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
if config.UseTLS {

View File

@@ -276,7 +276,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
@@ -617,7 +617,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},

View File

@@ -905,7 +905,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account)
case AccountTypeApiKey:
case AccountTypeAPIKey:
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
@@ -976,7 +976,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 应用模型映射仅对apikey类型账号
originalModel := reqModel
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
// 替换请求体中的模型名
@@ -1110,7 +1110,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages"
}
@@ -1178,10 +1178,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
@@ -1248,12 +1248,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
return false
}
func defaultApiKeyBetaHeader(body []byte) string {
func defaultAPIKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.ApiKeyHaikuBetaHeader
return claude.APIKeyHaikuBetaHeader
}
return claude.ApiKeyBetaHeader
return claude.APIKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
@@ -1630,7 +1630,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
ApiKey *ApiKey
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
@@ -1639,7 +1639,7 @@ type RecordUsageInput struct {
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result
apiKey := input.ApiKey
apiKey := input.APIKey
user := input.User
account := input.Account
subscription := input.Subscription
@@ -1676,7 +1676,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
@@ -1762,7 +1762,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
if reqModel != "" {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
@@ -1848,7 +1848,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
}
@@ -1910,10 +1910,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}

View File

@@ -273,7 +273,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
return 999
}
switch a.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
return 0
}
@@ -351,7 +351,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
originalModel := req.Model
mappedModel := req.Model
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(req.Model)
}
@@ -374,7 +374,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
switch account.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
@@ -614,7 +614,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
mappedModel := originalModel
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel)
}
@@ -636,7 +636,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
var buildReq func(ctx context.Context) (*http.Request, string, error)
switch account.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
@@ -1758,7 +1758,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
}
switch account.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if apiKey == "" {
return nil, errors.New("gemini api_key not configured")

View File

@@ -275,7 +275,7 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPr
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 1, Platform: PlatformGemini, Type: AccountTypeAPIKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
},
accountsByID: map[int64]*Account{},

View File

@@ -251,7 +251,7 @@ func inferGoogleOneTier(storageBytes int64) string {
return TierGoogleOneUnknown
}
// fetchGoogleOneTier fetches Google One tier from Drive API
// FetchGoogleOneTier fetches Google One tier from Drive API
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
driveClient := geminicli.NewDriveClient()

View File

@@ -487,8 +487,8 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
return "", "", errors.New("access_token not found in credentials")
}
return accessToken, "oauth", nil
case AccountTypeApiKey:
apiKey := account.GetOpenAIApiKey()
case AccountTypeAPIKey:
apiKey := account.GetOpenAIAPIKey()
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
}
@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeOAuth:
// OAuth accounts use ChatGPT internal API
targetURL = chatgptCodexURL
case AccountTypeApiKey:
case AccountTypeAPIKey:
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
if baseURL != "" {
@@ -940,7 +940,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult
ApiKey *ApiKey
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription
@@ -949,7 +949,7 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result
apiKey := input.ApiKey
apiKey := input.APIKey
user := input.User
account := input.Account
subscription := input.Subscription
@@ -991,7 +991,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,

View File

@@ -0,0 +1,99 @@
package service
import (
"context"
"time"
)
// ErrorLog represents an ops error log item for list queries.
//
// Field naming matches docs/API-运维监控中心2.0.md (L3 根因追踪 - 错误日志列表).
type ErrorLog struct {
ID int64 `json:"id"`
Timestamp time.Time `json:"timestamp"`
Level string `json:"level,omitempty"`
RequestID string `json:"request_id,omitempty"`
AccountID string `json:"account_id,omitempty"`
APIPath string `json:"api_path,omitempty"`
Provider string `json:"provider,omitempty"`
Model string `json:"model,omitempty"`
HTTPCode int `json:"http_code,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
DurationMs *int `json:"duration_ms,omitempty"`
RetryCount *int `json:"retry_count,omitempty"`
Stream bool `json:"stream,omitempty"`
}
// ErrorLogFilter describes optional filters and pagination for listing ops error logs.
type ErrorLogFilter struct {
StartTime *time.Time
EndTime *time.Time
ErrorCode *int
Provider string
AccountID *int64
Page int
PageSize int
}
func (f *ErrorLogFilter) normalize() (page, pageSize int) {
page = 1
pageSize = 20
if f == nil {
return page, pageSize
}
if f.Page > 0 {
page = f.Page
}
if f.PageSize > 0 {
pageSize = f.PageSize
}
if pageSize > 100 {
pageSize = 100
}
return page, pageSize
}
type ErrorLogListResponse struct {
Errors []*ErrorLog `json:"errors"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
func (s *OpsService) GetErrorLogs(ctx context.Context, filter *ErrorLogFilter) (*ErrorLogListResponse, error) {
if s == nil || s.repo == nil {
return &ErrorLogListResponse{
Errors: []*ErrorLog{},
Total: 0,
Page: 1,
PageSize: 20,
}, nil
}
page, pageSize := filter.normalize()
if filter == nil {
filter = &ErrorLogFilter{}
}
filter.Page = page
filter.PageSize = pageSize
items, total, err := s.repo.ListErrorLogs(ctx, filter)
if err != nil {
return nil, err
}
if items == nil {
items = []*ErrorLog{}
}
return &ErrorLogListResponse{
Errors: items,
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}

View File

@@ -0,0 +1,834 @@
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
)
type OpsAlertService struct {
opsService *OpsService
userService *UserService
emailService *EmailService
httpClient *http.Client
interval time.Duration
startOnce sync.Once
stopOnce sync.Once
stopCtx context.Context
stop context.CancelFunc
wg sync.WaitGroup
}
// opsAlertEvalInterval defines how often OpsAlertService evaluates alert rules.
//
// Production uses opsMetricsInterval. Tests may override this variable to keep
// integration tests fast without changing production defaults.
var opsAlertEvalInterval = opsMetricsInterval
func NewOpsAlertService(opsService *OpsService, userService *UserService, emailService *EmailService) *OpsAlertService {
return &OpsAlertService{
opsService: opsService,
userService: userService,
emailService: emailService,
httpClient: &http.Client{Timeout: 10 * time.Second},
interval: opsAlertEvalInterval,
}
}
// Start launches the background alert evaluation loop.
//
// Stop must be called during shutdown to ensure the goroutine exits.
func (s *OpsAlertService) Start() {
s.StartWithContext(context.Background())
}
// StartWithContext is like Start but allows the caller to provide a parent context.
// When the parent context is canceled, the service stops automatically.
func (s *OpsAlertService) StartWithContext(ctx context.Context) {
if s == nil {
return
}
if ctx == nil {
ctx = context.Background()
}
s.startOnce.Do(func() {
if s.interval <= 0 {
s.interval = opsAlertEvalInterval
}
s.stopCtx, s.stop = context.WithCancel(ctx)
s.wg.Add(1)
go s.run()
})
}
// Stop gracefully stops the background goroutine started by Start/StartWithContext.
// It is safe to call Stop multiple times.
func (s *OpsAlertService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.stop != nil {
s.stop()
}
})
s.wg.Wait()
}
func (s *OpsAlertService) run() {
defer s.wg.Done()
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
s.evaluateOnce()
for {
select {
case <-ticker.C:
s.evaluateOnce()
case <-s.stopCtx.Done():
return
}
}
}
func (s *OpsAlertService) evaluateOnce() {
ctx, cancel := context.WithTimeout(s.stopCtx, opsAlertEvaluateTimeout)
defer cancel()
s.Evaluate(ctx, time.Now())
}
func (s *OpsAlertService) Evaluate(ctx context.Context, now time.Time) {
if s == nil || s.opsService == nil {
return
}
rules, err := s.opsService.ListAlertRules(ctx)
if err != nil {
log.Printf("[OpsAlert] failed to list rules: %v", err)
return
}
if len(rules) == 0 {
return
}
maxSustainedByWindow := make(map[int]int)
for _, rule := range rules {
if !rule.Enabled {
continue
}
window := rule.WindowMinutes
if window <= 0 {
window = 1
}
sustained := rule.SustainedMinutes
if sustained <= 0 {
sustained = 1
}
if sustained > maxSustainedByWindow[window] {
maxSustainedByWindow[window] = sustained
}
}
metricsByWindow := make(map[int][]OpsMetrics)
for window, limit := range maxSustainedByWindow {
metrics, err := s.opsService.ListRecentSystemMetrics(ctx, window, limit)
if err != nil {
log.Printf("[OpsAlert] failed to load metrics window=%dm: %v", window, err)
continue
}
metricsByWindow[window] = metrics
}
for _, rule := range rules {
if !rule.Enabled {
continue
}
window := rule.WindowMinutes
if window <= 0 {
window = 1
}
sustained := rule.SustainedMinutes
if sustained <= 0 {
sustained = 1
}
metrics := metricsByWindow[window]
selected, ok := selectContiguousMetrics(metrics, sustained, now)
if !ok {
continue
}
breached, latestValue, ok := evaluateRule(rule, selected)
if !ok {
continue
}
activeEvent, err := s.opsService.GetActiveAlertEvent(ctx, rule.ID)
if err != nil {
log.Printf("[OpsAlert] failed to get active event (rule=%d): %v", rule.ID, err)
continue
}
if breached {
if activeEvent != nil {
continue
}
lastEvent, err := s.opsService.GetLatestAlertEvent(ctx, rule.ID)
if err != nil {
log.Printf("[OpsAlert] failed to get latest event (rule=%d): %v", rule.ID, err)
continue
}
if lastEvent != nil && rule.CooldownMinutes > 0 {
cooldown := time.Duration(rule.CooldownMinutes) * time.Minute
if now.Sub(lastEvent.FiredAt) < cooldown {
continue
}
}
event := &OpsAlertEvent{
RuleID: rule.ID,
Severity: rule.Severity,
Status: OpsAlertStatusFiring,
Title: fmt.Sprintf("%s: %s", rule.Severity, rule.Name),
Description: buildAlertDescription(rule, latestValue),
MetricValue: latestValue,
ThresholdValue: rule.Threshold,
FiredAt: now,
CreatedAt: now,
}
if err := s.opsService.CreateAlertEvent(ctx, event); err != nil {
log.Printf("[OpsAlert] failed to create event (rule=%d): %v", rule.ID, err)
continue
}
emailSent, webhookSent := s.dispatchNotifications(ctx, rule, event)
if emailSent || webhookSent {
if err := s.opsService.UpdateAlertEventNotifications(ctx, event.ID, emailSent, webhookSent); err != nil {
log.Printf("[OpsAlert] failed to update notification flags (event=%d): %v", event.ID, err)
}
}
} else if activeEvent != nil {
resolvedAt := now
if err := s.opsService.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil {
log.Printf("[OpsAlert] failed to resolve event (event=%d): %v", activeEvent.ID, err)
}
}
}
}
const opsMetricsContinuityTolerance = 20 * time.Second
// selectContiguousMetrics picks the newest N metrics and verifies they are continuous.
//
// This prevents a sustained rule from triggering when metrics sampling has gaps
// (e.g. collector downtime) and avoids evaluating "stale" data.
//
// Assumptions:
// - Metrics are ordered by UpdatedAt DESC (newest first).
// - Metrics are expected to be collected at opsMetricsInterval cadence.
func selectContiguousMetrics(metrics []OpsMetrics, needed int, now time.Time) ([]OpsMetrics, bool) {
if needed <= 0 {
return nil, false
}
if len(metrics) < needed {
return nil, false
}
newest := metrics[0].UpdatedAt
if newest.IsZero() {
return nil, false
}
if now.Sub(newest) > opsMetricsInterval+opsMetricsContinuityTolerance {
return nil, false
}
selected := metrics[:needed]
for i := 0; i < len(selected)-1; i++ {
a := selected[i].UpdatedAt
b := selected[i+1].UpdatedAt
if a.IsZero() || b.IsZero() {
return nil, false
}
gap := a.Sub(b)
if gap < opsMetricsInterval-opsMetricsContinuityTolerance || gap > opsMetricsInterval+opsMetricsContinuityTolerance {
return nil, false
}
}
return selected, true
}
func evaluateRule(rule OpsAlertRule, metrics []OpsMetrics) (bool, float64, bool) {
if len(metrics) == 0 {
return false, 0, false
}
latestValue, ok := metricValue(metrics[0], rule.MetricType)
if !ok {
return false, 0, false
}
for _, metric := range metrics {
value, ok := metricValue(metric, rule.MetricType)
if !ok || !compareMetric(value, rule.Operator, rule.Threshold) {
return false, latestValue, true
}
}
return true, latestValue, true
}
func metricValue(metric OpsMetrics, metricType string) (float64, bool) {
switch metricType {
case OpsMetricSuccessRate:
if metric.RequestCount == 0 {
return 0, false
}
return metric.SuccessRate, true
case OpsMetricErrorRate:
if metric.RequestCount == 0 {
return 0, false
}
return metric.ErrorRate, true
case OpsMetricP95LatencyMs:
return float64(metric.P95LatencyMs), true
case OpsMetricP99LatencyMs:
return float64(metric.P99LatencyMs), true
case OpsMetricHTTP2Errors:
return float64(metric.HTTP2Errors), true
case OpsMetricCPUUsagePercent:
return metric.CPUUsagePercent, true
case OpsMetricMemoryUsagePercent:
return metric.MemoryUsagePercent, true
case OpsMetricQueueDepth:
return float64(metric.ConcurrencyQueueDepth), true
default:
return 0, false
}
}
func compareMetric(value float64, operator string, threshold float64) bool {
switch operator {
case ">":
return value > threshold
case ">=":
return value >= threshold
case "<":
return value < threshold
case "<=":
return value <= threshold
case "==":
return value == threshold
default:
return false
}
}
func buildAlertDescription(rule OpsAlertRule, value float64) string {
window := rule.WindowMinutes
if window <= 0 {
window = 1
}
return fmt.Sprintf("Rule %s triggered: %s %s %.2f (current %.2f) over last %dm",
rule.Name,
rule.MetricType,
rule.Operator,
rule.Threshold,
value,
window,
)
}
func (s *OpsAlertService) dispatchNotifications(ctx context.Context, rule OpsAlertRule, event *OpsAlertEvent) (bool, bool) {
emailSent := false
webhookSent := false
notifyCtx, cancel := s.notificationContext(ctx)
defer cancel()
if rule.NotifyEmail {
emailSent = s.sendEmailNotification(notifyCtx, rule, event)
}
if rule.NotifyWebhook && rule.WebhookURL != "" {
webhookSent = s.sendWebhookNotification(notifyCtx, rule, event)
}
// Fallback channel: if email is enabled but ultimately fails, try webhook even if the
// webhook toggle is off (as long as a webhook URL is configured).
if rule.NotifyEmail && !emailSent && !rule.NotifyWebhook && rule.WebhookURL != "" {
log.Printf("[OpsAlert] email failed; attempting webhook fallback (rule=%d)", rule.ID)
webhookSent = s.sendWebhookNotification(notifyCtx, rule, event)
}
return emailSent, webhookSent
}
const (
opsAlertEvaluateTimeout = 45 * time.Second
opsAlertNotificationTimeout = 30 * time.Second
opsAlertEmailMaxRetries = 3
)
var opsAlertEmailBackoff = []time.Duration{
1 * time.Second,
2 * time.Second,
4 * time.Second,
}
func (s *OpsAlertService) notificationContext(ctx context.Context) (context.Context, context.CancelFunc) {
parent := ctx
if s != nil && s.stopCtx != nil {
parent = s.stopCtx
}
if parent == nil {
parent = context.Background()
}
return context.WithTimeout(parent, opsAlertNotificationTimeout)
}
var opsAlertSleep = sleepWithContext
func sleepWithContext(ctx context.Context, d time.Duration) error {
if d <= 0 {
return nil
}
if ctx == nil {
time.Sleep(d)
return nil
}
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
func retryWithBackoff(
ctx context.Context,
maxRetries int,
backoff []time.Duration,
fn func() error,
onError func(attempt int, total int, nextDelay time.Duration, err error),
) error {
if ctx == nil {
ctx = context.Background()
}
if maxRetries < 0 {
maxRetries = 0
}
totalAttempts := maxRetries + 1
var lastErr error
for attempt := 1; attempt <= totalAttempts; attempt++ {
if attempt > 1 {
backoffIdx := attempt - 2
if backoffIdx < len(backoff) {
if err := opsAlertSleep(ctx, backoff[backoffIdx]); err != nil {
return err
}
}
}
if err := ctx.Err(); err != nil {
return err
}
if err := fn(); err != nil {
lastErr = err
nextDelay := time.Duration(0)
if attempt < totalAttempts {
nextIdx := attempt - 1
if nextIdx < len(backoff) {
nextDelay = backoff[nextIdx]
}
}
if onError != nil {
onError(attempt, totalAttempts, nextDelay, err)
}
continue
}
return nil
}
return lastErr
}
func (s *OpsAlertService) sendEmailNotification(ctx context.Context, rule OpsAlertRule, event *OpsAlertEvent) bool {
if s.emailService == nil || s.userService == nil {
return false
}
if ctx == nil {
ctx = context.Background()
}
admin, err := s.userService.GetFirstAdmin(ctx)
if err != nil || admin == nil || admin.Email == "" {
return false
}
subject := fmt.Sprintf("[Ops Alert][%s] %s", rule.Severity, rule.Name)
body := fmt.Sprintf(
"Alert triggered: %s\n\nMetric: %s\nThreshold: %.2f\nCurrent: %.2f\nWindow: %dm\nStatus: %s\nTime: %s",
rule.Name,
rule.MetricType,
rule.Threshold,
event.MetricValue,
rule.WindowMinutes,
event.Status,
event.FiredAt.Format(time.RFC3339),
)
config, err := s.emailService.GetSMTPConfig(ctx)
if err != nil {
log.Printf("[OpsAlert] email config load failed: %v", err)
return false
}
if err := retryWithBackoff(
ctx,
opsAlertEmailMaxRetries,
opsAlertEmailBackoff,
func() error {
return s.emailService.SendEmailWithConfig(config, admin.Email, subject, body)
},
func(attempt int, total int, nextDelay time.Duration, err error) {
if attempt < total {
log.Printf("[OpsAlert] email send failed (attempt=%d/%d), retrying in %s: %v", attempt, total, nextDelay, err)
return
}
log.Printf("[OpsAlert] email send failed (attempt=%d/%d), giving up: %v", attempt, total, err)
},
); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("[OpsAlert] email send canceled: %v", err)
}
return false
}
return true
}
func (s *OpsAlertService) sendWebhookNotification(ctx context.Context, rule OpsAlertRule, event *OpsAlertEvent) bool {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
webhookTarget, err := validateWebhookURL(ctx, rule.WebhookURL)
if err != nil {
log.Printf("[OpsAlert] invalid webhook url (rule=%d): %v", rule.ID, err)
return false
}
payload := map[string]any{
"rule_id": rule.ID,
"rule_name": rule.Name,
"severity": rule.Severity,
"status": event.Status,
"metric_type": rule.MetricType,
"metric_value": event.MetricValue,
"threshold_value": rule.Threshold,
"window_minutes": rule.WindowMinutes,
"fired_at": event.FiredAt.Format(time.RFC3339),
}
body, err := json.Marshal(payload)
if err != nil {
return false
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, webhookTarget.URL.String(), bytes.NewReader(body))
if err != nil {
return false
}
req.Header.Set("Content-Type", "application/json")
resp, err := buildWebhookHTTPClient(s.httpClient, webhookTarget).Do(req)
if err != nil {
log.Printf("[OpsAlert] webhook send failed: %v", err)
return false
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
log.Printf("[OpsAlert] webhook returned status %d", resp.StatusCode)
return false
}
return true
}
const webhookHTTPClientTimeout = 10 * time.Second
func buildWebhookHTTPClient(base *http.Client, webhookTarget *validatedWebhookTarget) *http.Client {
var client http.Client
if base != nil {
client = *base
}
if client.Timeout <= 0 {
client.Timeout = webhookHTTPClientTimeout
}
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
if webhookTarget != nil {
client.Transport = buildWebhookTransport(client.Transport, webhookTarget)
}
return &client
}
var disallowedWebhookIPNets = []net.IPNet{
// "this host on this network" / unspecified.
mustParseCIDR("0.0.0.0/8"),
mustParseCIDR("127.0.0.0/8"), // loopback (includes 127.0.0.1)
mustParseCIDR("10.0.0.0/8"), // RFC1918
mustParseCIDR("192.168.0.0/16"), // RFC1918
mustParseCIDR("172.16.0.0/12"), // RFC1918 (172.16.0.0 - 172.31.255.255)
mustParseCIDR("100.64.0.0/10"), // RFC6598 (carrier-grade NAT)
mustParseCIDR("169.254.0.0/16"), // IPv4 link-local (includes 169.254.169.254 metadata IP on many clouds)
mustParseCIDR("198.18.0.0/15"), // RFC2544 benchmark testing
mustParseCIDR("224.0.0.0/4"), // IPv4 multicast
mustParseCIDR("240.0.0.0/4"), // IPv4 reserved
mustParseCIDR("::/128"), // IPv6 unspecified
mustParseCIDR("::1/128"), // IPv6 loopback
mustParseCIDR("fc00::/7"), // IPv6 unique local
mustParseCIDR("fe80::/10"), // IPv6 link-local
mustParseCIDR("ff00::/8"), // IPv6 multicast
}
func mustParseCIDR(cidr string) net.IPNet {
_, block, err := net.ParseCIDR(cidr)
if err != nil {
panic(err)
}
return *block
}
var lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return net.DefaultResolver.LookupIPAddr(ctx, host)
}
type validatedWebhookTarget struct {
URL *url.URL
host string
port string
pinnedIPs []net.IP
}
var webhookBaseDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}
return dialer.DialContext(ctx, network, addr)
}
func buildWebhookTransport(base http.RoundTripper, webhookTarget *validatedWebhookTarget) http.RoundTripper {
if webhookTarget == nil || webhookTarget.URL == nil {
return base
}
var transport *http.Transport
switch typed := base.(type) {
case *http.Transport:
if typed != nil {
transport = typed.Clone()
}
}
if transport == nil {
if defaultTransport, ok := http.DefaultTransport.(*http.Transport); ok && defaultTransport != nil {
transport = defaultTransport.Clone()
} else {
transport = (&http.Transport{}).Clone()
}
}
webhookHost := webhookTarget.host
webhookPort := webhookTarget.port
pinnedIPs := append([]net.IP(nil), webhookTarget.pinnedIPs...)
transport.Proxy = nil
transport.DialTLSContext = nil
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil || host == "" || port == "" {
return nil, fmt.Errorf("webhook dial target is invalid: %q", addr)
}
canonicalHost := strings.TrimSuffix(strings.ToLower(host), ".")
if canonicalHost != webhookHost || port != webhookPort {
return nil, fmt.Errorf("webhook dial target mismatch: %q", addr)
}
var lastErr error
for _, ip := range pinnedIPs {
if isDisallowedWebhookIP(ip) {
lastErr = fmt.Errorf("webhook target resolves to a disallowed ip")
continue
}
dialAddr := net.JoinHostPort(ip.String(), port)
conn, err := webhookBaseDialContext(ctx, network, dialAddr)
if err == nil {
return conn, nil
}
lastErr = err
}
if lastErr == nil {
lastErr = errors.New("webhook target has no resolved addresses")
}
return nil, lastErr
}
return transport
}
func validateWebhookURL(ctx context.Context, raw string) (*validatedWebhookTarget, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil, errors.New("webhook url is empty")
}
// Avoid request smuggling / header injection vectors.
if strings.ContainsAny(raw, "\r\n") {
return nil, errors.New("webhook url contains invalid characters")
}
parsed, err := url.Parse(raw)
if err != nil {
return nil, errors.New("webhook url format is invalid")
}
if !strings.EqualFold(parsed.Scheme, "https") {
return nil, errors.New("webhook url scheme must be https")
}
parsed.Scheme = "https"
if parsed.Host == "" || parsed.Hostname() == "" {
return nil, errors.New("webhook url must include host")
}
if parsed.User != nil {
return nil, errors.New("webhook url must not include userinfo")
}
if parsed.Port() != "" {
port, err := strconv.Atoi(parsed.Port())
if err != nil || port < 1 || port > 65535 {
return nil, errors.New("webhook url port is invalid")
}
}
host := strings.TrimSuffix(strings.ToLower(parsed.Hostname()), ".")
if host == "localhost" {
return nil, errors.New("webhook url host must not be localhost")
}
if ip := net.ParseIP(host); ip != nil {
if isDisallowedWebhookIP(ip) {
return nil, errors.New("webhook url host resolves to a disallowed ip")
}
return &validatedWebhookTarget{
URL: parsed,
host: host,
port: portForScheme(parsed),
pinnedIPs: []net.IP{ip},
}, nil
}
if ctx == nil {
ctx = context.Background()
}
ips, err := lookupIPAddrs(ctx, host)
if err != nil || len(ips) == 0 {
return nil, errors.New("webhook url host cannot be resolved")
}
pinned := make([]net.IP, 0, len(ips))
for _, addr := range ips {
if isDisallowedWebhookIP(addr.IP) {
return nil, errors.New("webhook url host resolves to a disallowed ip")
}
if addr.IP != nil {
pinned = append(pinned, addr.IP)
}
}
if len(pinned) == 0 {
return nil, errors.New("webhook url host cannot be resolved")
}
return &validatedWebhookTarget{
URL: parsed,
host: host,
port: portForScheme(parsed),
pinnedIPs: uniqueResolvedIPs(pinned),
}, nil
}
func isDisallowedWebhookIP(ip net.IP) bool {
if ip == nil {
return false
}
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
} else if ip16 := ip.To16(); ip16 != nil {
ip = ip16
} else {
return false
}
// Disallow non-public addresses even if they're not explicitly covered by the CIDR list.
// This provides defense-in-depth against SSRF targets such as link-local, multicast, and
// unspecified addresses, and ensures any "pinned" IP is still blocked at dial time.
if ip.IsUnspecified() ||
ip.IsLoopback() ||
ip.IsMulticast() ||
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() ||
ip.IsPrivate() {
return true
}
for _, block := range disallowedWebhookIPNets {
if block.Contains(ip) {
return true
}
}
return false
}
func portForScheme(u *url.URL) string {
if u != nil && u.Port() != "" {
return u.Port()
}
return "443"
}
func uniqueResolvedIPs(ips []net.IP) []net.IP {
seen := make(map[string]struct{}, len(ips))
out := make([]net.IP, 0, len(ips))
for _, ip := range ips {
if ip == nil {
continue
}
key := ip.String()
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, ip)
}
return out
}

View File

@@ -0,0 +1,271 @@
//go:build integration
package service
import (
"context"
"database/sql"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// This integration test protects the DI startup contract for OpsAlertService.
//
// Background:
// - OpsMetricsCollector previously called alertService.Start()/Evaluate() directly.
// - Those direct calls were removed, so OpsAlertService must now start via DI
// (ProvideOpsAlertService in wire.go) and run its own evaluation ticker.
//
// What we validate here:
// 1. When we construct via the Wire provider functions (ProvideOpsAlertService +
// ProvideOpsMetricsCollector), OpsAlertService starts automatically.
// 2. Its evaluation loop continues to tick even if OpsMetricsCollector is stopped,
// proving the alert evaluator is independent.
// 3. The evaluation path can trigger alert logic (CreateAlertEvent called).
func TestOpsAlertService_StartedViaWireProviders_RunsIndependentTicker(t *testing.T) {
oldInterval := opsAlertEvalInterval
opsAlertEvalInterval = 25 * time.Millisecond
t.Cleanup(func() { opsAlertEvalInterval = oldInterval })
repo := newFakeOpsRepository()
opsService := NewOpsService(repo, nil)
// Start via the Wire provider function (the production DI path).
alertService := ProvideOpsAlertService(opsService, nil, nil)
t.Cleanup(alertService.Stop)
// Construct via ProvideOpsMetricsCollector (wire.go). Stop immediately to ensure
// the alert ticker keeps running without the metrics collector.
collector := ProvideOpsMetricsCollector(opsService, NewConcurrencyService(nil))
collector.Stop()
// Wait for at least one evaluation (run() calls evaluateOnce immediately).
require.Eventually(t, func() bool {
return repo.listRulesCalls.Load() >= 1
}, 1*time.Second, 5*time.Millisecond)
// Confirm the evaluation loop keeps ticking after the metrics collector is stopped.
callsAfterCollectorStop := repo.listRulesCalls.Load()
require.Eventually(t, func() bool {
return repo.listRulesCalls.Load() >= callsAfterCollectorStop+2
}, 1*time.Second, 5*time.Millisecond)
// Confirm the evaluation logic actually fires an alert event at least once.
select {
case <-repo.eventCreatedCh:
// ok
case <-time.After(2 * time.Second):
t.Fatalf("expected OpsAlertService to create an alert event, but none was created (ListAlertRules calls=%d)", repo.listRulesCalls.Load())
}
}
func newFakeOpsRepository() *fakeOpsRepository {
return &fakeOpsRepository{
eventCreatedCh: make(chan struct{}),
}
}
// fakeOpsRepository is a lightweight in-memory stub of OpsRepository for integration tests.
// It avoids real DB/Redis usage and provides deterministic responses fast.
type fakeOpsRepository struct {
listRulesCalls atomic.Int64
mu sync.Mutex
activeEvent *OpsAlertEvent
latestEvent *OpsAlertEvent
nextEventID int64
eventCreatedCh chan struct{}
eventOnce sync.Once
}
func (r *fakeOpsRepository) CreateErrorLog(ctx context.Context, log *OpsErrorLog) error {
return nil
}
func (r *fakeOpsRepository) ListErrorLogsLegacy(ctx context.Context, filters OpsErrorLogFilters) ([]OpsErrorLog, error) {
return nil, nil
}
func (r *fakeOpsRepository) ListErrorLogs(ctx context.Context, filter *ErrorLogFilter) ([]*ErrorLog, int64, error) {
return nil, 0, nil
}
func (r *fakeOpsRepository) GetLatestSystemMetric(ctx context.Context) (*OpsMetrics, error) {
return &OpsMetrics{WindowMinutes: 1}, sql.ErrNoRows
}
func (r *fakeOpsRepository) CreateSystemMetric(ctx context.Context, metric *OpsMetrics) error {
return nil
}
func (r *fakeOpsRepository) GetWindowStats(ctx context.Context, startTime, endTime time.Time) (*OpsWindowStats, error) {
return &OpsWindowStats{}, nil
}
func (r *fakeOpsRepository) GetProviderStats(ctx context.Context, startTime, endTime time.Time) ([]*ProviderStats, error) {
return nil, nil
}
func (r *fakeOpsRepository) GetLatencyHistogram(ctx context.Context, startTime, endTime time.Time) ([]*LatencyHistogramItem, error) {
return nil, nil
}
func (r *fakeOpsRepository) GetErrorDistribution(ctx context.Context, startTime, endTime time.Time) ([]*ErrorDistributionItem, error) {
return nil, nil
}
func (r *fakeOpsRepository) ListRecentSystemMetrics(ctx context.Context, windowMinutes, limit int) ([]OpsMetrics, error) {
if limit <= 0 {
limit = 1
}
now := time.Now()
metrics := make([]OpsMetrics, 0, limit)
for i := 0; i < limit; i++ {
metrics = append(metrics, OpsMetrics{
WindowMinutes: windowMinutes,
CPUUsagePercent: 99,
UpdatedAt: now.Add(-time.Duration(i) * opsMetricsInterval),
})
}
return metrics, nil
}
func (r *fakeOpsRepository) ListSystemMetricsRange(ctx context.Context, windowMinutes int, startTime, endTime time.Time, limit int) ([]OpsMetrics, error) {
return nil, nil
}
func (r *fakeOpsRepository) ListAlertRules(ctx context.Context) ([]OpsAlertRule, error) {
call := r.listRulesCalls.Add(1)
// Delay enabling rules slightly so the test can stop OpsMetricsCollector first,
// then observe the alert evaluator ticking independently.
if call < 5 {
return nil, nil
}
return []OpsAlertRule{
{
ID: 1,
Name: "cpu too high (test)",
Enabled: true,
MetricType: OpsMetricCPUUsagePercent,
Operator: ">",
Threshold: 0,
WindowMinutes: 1,
SustainedMinutes: 1,
Severity: "P1",
NotifyEmail: false,
NotifyWebhook: false,
CooldownMinutes: 0,
},
}, nil
}
func (r *fakeOpsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.activeEvent == nil {
return nil, nil
}
if r.activeEvent.RuleID != ruleID {
return nil, nil
}
if r.activeEvent.Status != OpsAlertStatusFiring {
return nil, nil
}
clone := *r.activeEvent
return &clone, nil
}
func (r *fakeOpsRepository) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.latestEvent == nil || r.latestEvent.RuleID != ruleID {
return nil, nil
}
clone := *r.latestEvent
return &clone, nil
}
func (r *fakeOpsRepository) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) error {
if event == nil {
return nil
}
r.mu.Lock()
defer r.mu.Unlock()
r.nextEventID++
event.ID = r.nextEventID
clone := *event
r.latestEvent = &clone
if clone.Status == OpsAlertStatusFiring {
r.activeEvent = &clone
}
r.eventOnce.Do(func() { close(r.eventCreatedCh) })
return nil
}
func (r *fakeOpsRepository) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.activeEvent != nil && r.activeEvent.ID == eventID {
r.activeEvent.Status = status
r.activeEvent.ResolvedAt = resolvedAt
}
if r.latestEvent != nil && r.latestEvent.ID == eventID {
r.latestEvent.Status = status
r.latestEvent.ResolvedAt = resolvedAt
}
return nil
}
func (r *fakeOpsRepository) UpdateAlertEventNotifications(ctx context.Context, eventID int64, emailSent, webhookSent bool) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.activeEvent != nil && r.activeEvent.ID == eventID {
r.activeEvent.EmailSent = emailSent
r.activeEvent.WebhookSent = webhookSent
}
if r.latestEvent != nil && r.latestEvent.ID == eventID {
r.latestEvent.EmailSent = emailSent
r.latestEvent.WebhookSent = webhookSent
}
return nil
}
func (r *fakeOpsRepository) CountActiveAlerts(ctx context.Context) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.activeEvent == nil {
return 0, nil
}
return 1, nil
}
func (r *fakeOpsRepository) GetOverviewStats(ctx context.Context, startTime, endTime time.Time) (*OverviewStats, error) {
return &OverviewStats{}, nil
}
func (r *fakeOpsRepository) GetCachedLatestSystemMetric(ctx context.Context) (*OpsMetrics, error) {
return nil, nil
}
func (r *fakeOpsRepository) SetCachedLatestSystemMetric(ctx context.Context, metric *OpsMetrics) error {
return nil
}
func (r *fakeOpsRepository) GetCachedDashboardOverview(ctx context.Context, timeRange string) (*DashboardOverviewData, error) {
return nil, nil
}
func (r *fakeOpsRepository) SetCachedDashboardOverview(ctx context.Context, timeRange string, data *DashboardOverviewData, ttl time.Duration) error {
return nil
}
func (r *fakeOpsRepository) PingRedis(ctx context.Context) error {
return nil
}

View File

@@ -0,0 +1,315 @@
//go:build unit || opsalert_unit
package service
import (
"context"
"errors"
"net"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestSelectContiguousMetrics_Contiguous(t *testing.T) {
now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
metrics := []OpsMetrics{
{UpdatedAt: now},
{UpdatedAt: now.Add(-1 * time.Minute)},
{UpdatedAt: now.Add(-2 * time.Minute)},
}
selected, ok := selectContiguousMetrics(metrics, 3, now)
require.True(t, ok)
require.Len(t, selected, 3)
}
func TestSelectContiguousMetrics_GapFails(t *testing.T) {
now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
metrics := []OpsMetrics{
{UpdatedAt: now},
// Missing the -1m sample (gap ~=2m).
{UpdatedAt: now.Add(-2 * time.Minute)},
{UpdatedAt: now.Add(-3 * time.Minute)},
}
_, ok := selectContiguousMetrics(metrics, 3, now)
require.False(t, ok)
}
func TestSelectContiguousMetrics_StaleNewestFails(t *testing.T) {
now := time.Date(2026, 1, 1, 0, 10, 0, 0, time.UTC)
metrics := []OpsMetrics{
{UpdatedAt: now.Add(-10 * time.Minute)},
{UpdatedAt: now.Add(-11 * time.Minute)},
}
_, ok := selectContiguousMetrics(metrics, 2, now)
require.False(t, ok)
}
func TestMetricValue_SuccessRate_NoTrafficIsNoData(t *testing.T) {
metric := OpsMetrics{
RequestCount: 0,
SuccessRate: 0,
}
value, ok := metricValue(metric, OpsMetricSuccessRate)
require.False(t, ok)
require.Equal(t, 0.0, value)
}
func TestOpsAlertService_StopWithoutStart_NoPanic(t *testing.T) {
s := NewOpsAlertService(nil, nil, nil)
require.NotPanics(t, func() { s.Stop() })
}
func TestOpsAlertService_StartStop_Graceful(t *testing.T) {
s := NewOpsAlertService(nil, nil, nil)
s.interval = 5 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.StartWithContext(ctx)
done := make(chan struct{})
go func() {
s.Stop()
close(done)
}()
select {
case <-done:
// ok
case <-time.After(1 * time.Second):
t.Fatal("Stop did not return; background goroutine likely stuck")
}
require.NotPanics(t, func() { s.Stop() })
}
func TestBuildWebhookHTTPClient_DefaultTimeout(t *testing.T) {
client := buildWebhookHTTPClient(nil, nil)
require.Equal(t, webhookHTTPClientTimeout, client.Timeout)
require.NotNil(t, client.CheckRedirect)
require.ErrorIs(t, client.CheckRedirect(nil, nil), http.ErrUseLastResponse)
base := &http.Client{}
client = buildWebhookHTTPClient(base, nil)
require.Equal(t, webhookHTTPClientTimeout, client.Timeout)
require.NotNil(t, client.CheckRedirect)
base = &http.Client{Timeout: 2 * time.Second}
client = buildWebhookHTTPClient(base, nil)
require.Equal(t, 2*time.Second, client.Timeout)
require.NotNil(t, client.CheckRedirect)
}
func TestValidateWebhookURL_RequiresHTTPS(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
}
_, err := validateWebhookURL(context.Background(), "http://example.com/webhook")
require.Error(t, err)
}
func TestValidateWebhookURL_InvalidFormatRejected(t *testing.T) {
_, err := validateWebhookURL(context.Background(), "https://[::1")
require.Error(t, err)
}
func TestValidateWebhookURL_RejectsUserinfo(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
}
_, err := validateWebhookURL(context.Background(), "https://user:pass@example.com/webhook")
require.Error(t, err)
}
func TestValidateWebhookURL_RejectsLocalhost(t *testing.T) {
_, err := validateWebhookURL(context.Background(), "https://localhost/webhook")
require.Error(t, err)
}
func TestValidateWebhookURL_RejectsPrivateIPLiteral(t *testing.T) {
cases := []string{
"https://0.0.0.0/webhook",
"https://127.0.0.1/webhook",
"https://10.0.0.1/webhook",
"https://192.168.1.2/webhook",
"https://172.16.0.1/webhook",
"https://172.31.255.255/webhook",
"https://100.64.0.1/webhook",
"https://169.254.169.254/webhook",
"https://198.18.0.1/webhook",
"https://224.0.0.1/webhook",
"https://240.0.0.1/webhook",
"https://[::]/webhook",
"https://[::1]/webhook",
"https://[ff02::1]/webhook",
}
for _, tc := range cases {
t.Run(tc, func(t *testing.T) {
_, err := validateWebhookURL(context.Background(), tc)
require.Error(t, err)
})
}
}
func TestValidateWebhookURL_RejectsPrivateIPViaDNS(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
require.Equal(t, "internal.example", host)
return []net.IPAddr{{IP: net.ParseIP("10.0.0.2")}}, nil
}
_, err := validateWebhookURL(context.Background(), "https://internal.example/webhook")
require.Error(t, err)
}
func TestValidateWebhookURL_RejectsLinkLocalIPViaDNS(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
require.Equal(t, "metadata.example", host)
return []net.IPAddr{{IP: net.ParseIP("169.254.169.254")}}, nil
}
_, err := validateWebhookURL(context.Background(), "https://metadata.example/webhook")
require.Error(t, err)
}
func TestValidateWebhookURL_AllowsPublicHostViaDNS(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
require.Equal(t, "example.com", host)
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
}
target, err := validateWebhookURL(context.Background(), "https://example.com:443/webhook")
require.NoError(t, err)
require.Equal(t, "https", target.URL.Scheme)
require.Equal(t, "example.com", target.URL.Hostname())
require.Equal(t, "443", target.URL.Port())
}
func TestValidateWebhookURL_RejectsInvalidPort(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
}
_, err := validateWebhookURL(context.Background(), "https://example.com:99999/webhook")
require.Error(t, err)
}
func TestWebhookTransport_UsesPinnedIP_NoDNSRebinding(t *testing.T) {
oldLookup := lookupIPAddrs
oldDial := webhookBaseDialContext
t.Cleanup(func() {
lookupIPAddrs = oldLookup
webhookBaseDialContext = oldDial
})
lookupCalls := 0
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
lookupCalls++
require.Equal(t, "example.com", host)
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
}
target, err := validateWebhookURL(context.Background(), "https://example.com/webhook")
require.NoError(t, err)
require.Equal(t, 1, lookupCalls)
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
lookupCalls++
return []net.IPAddr{{IP: net.ParseIP("10.0.0.1")}}, nil
}
var dialAddrs []string
webhookBaseDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
dialAddrs = append(dialAddrs, addr)
return nil, errors.New("dial blocked in test")
}
client := buildWebhookHTTPClient(nil, target)
transport, ok := client.Transport.(*http.Transport)
require.True(t, ok)
_, err = transport.DialContext(context.Background(), "tcp", "example.com:443")
require.Error(t, err)
require.Equal(t, []string{"93.184.216.34:443"}, dialAddrs)
require.Equal(t, 1, lookupCalls, "dial path must not re-resolve DNS")
}
func TestRetryWithBackoff_SucceedsAfterRetries(t *testing.T) {
oldSleep := opsAlertSleep
t.Cleanup(func() { opsAlertSleep = oldSleep })
var slept []time.Duration
opsAlertSleep = func(ctx context.Context, d time.Duration) error {
slept = append(slept, d)
return nil
}
attempts := 0
err := retryWithBackoff(
context.Background(),
3,
[]time.Duration{time.Second, 2 * time.Second, 4 * time.Second},
func() error {
attempts++
if attempts <= 3 {
return errors.New("send failed")
}
return nil
},
nil,
)
require.NoError(t, err)
require.Equal(t, 4, attempts)
require.Equal(t, []time.Duration{time.Second, 2 * time.Second, 4 * time.Second}, slept)
}
func TestRetryWithBackoff_ContextCanceledStopsRetries(t *testing.T) {
oldSleep := opsAlertSleep
t.Cleanup(func() { opsAlertSleep = oldSleep })
var slept []time.Duration
opsAlertSleep = func(ctx context.Context, d time.Duration) error {
slept = append(slept, d)
return ctx.Err()
}
ctx, cancel := context.WithCancel(context.Background())
attempts := 0
err := retryWithBackoff(
ctx,
3,
[]time.Duration{time.Second, 2 * time.Second, 4 * time.Second},
func() error {
attempts++
return errors.New("send failed")
},
func(attempt int, total int, nextDelay time.Duration, err error) {
if attempt == 1 {
cancel()
}
},
)
require.ErrorIs(t, err, context.Canceled)
require.Equal(t, 1, attempts)
require.Equal(t, []time.Duration{time.Second}, slept)
}

View File

@@ -0,0 +1,92 @@
package service
import (
"context"
"time"
)
const (
OpsAlertStatusFiring = "firing"
OpsAlertStatusResolved = "resolved"
)
const (
OpsMetricSuccessRate = "success_rate"
OpsMetricErrorRate = "error_rate"
OpsMetricP95LatencyMs = "p95_latency_ms"
OpsMetricP99LatencyMs = "p99_latency_ms"
OpsMetricHTTP2Errors = "http2_errors"
OpsMetricCPUUsagePercent = "cpu_usage_percent"
OpsMetricMemoryUsagePercent = "memory_usage_percent"
OpsMetricQueueDepth = "concurrency_queue_depth"
)
type OpsAlertRule struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Enabled bool `json:"enabled"`
MetricType string `json:"metric_type"`
Operator string `json:"operator"`
Threshold float64 `json:"threshold"`
WindowMinutes int `json:"window_minutes"`
SustainedMinutes int `json:"sustained_minutes"`
Severity string `json:"severity"`
NotifyEmail bool `json:"notify_email"`
NotifyWebhook bool `json:"notify_webhook"`
WebhookURL string `json:"webhook_url"`
CooldownMinutes int `json:"cooldown_minutes"`
DimensionFilters map[string]any `json:"dimension_filters,omitempty"`
NotifyChannels []string `json:"notify_channels,omitempty"`
NotifyConfig map[string]any `json:"notify_config,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type OpsAlertEvent struct {
ID int64 `json:"id"`
RuleID int64 `json:"rule_id"`
Severity string `json:"severity"`
Status string `json:"status"`
Title string `json:"title"`
Description string `json:"description"`
MetricValue float64 `json:"metric_value"`
ThresholdValue float64 `json:"threshold_value"`
FiredAt time.Time `json:"fired_at"`
ResolvedAt *time.Time `json:"resolved_at"`
EmailSent bool `json:"email_sent"`
WebhookSent bool `json:"webhook_sent"`
CreatedAt time.Time `json:"created_at"`
}
func (s *OpsService) ListAlertRules(ctx context.Context) ([]OpsAlertRule, error) {
return s.repo.ListAlertRules(ctx)
}
func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
return s.repo.GetActiveAlertEvent(ctx, ruleID)
}
func (s *OpsService) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
return s.repo.GetLatestAlertEvent(ctx, ruleID)
}
func (s *OpsService) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) error {
return s.repo.CreateAlertEvent(ctx, event)
}
func (s *OpsService) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
return s.repo.UpdateAlertEventStatus(ctx, eventID, status, resolvedAt)
}
func (s *OpsService) UpdateAlertEventNotifications(ctx context.Context, eventID int64, emailSent, webhookSent bool) error {
return s.repo.UpdateAlertEventNotifications(ctx, eventID, emailSent, webhookSent)
}
func (s *OpsService) ListRecentSystemMetrics(ctx context.Context, windowMinutes, limit int) ([]OpsMetrics, error) {
return s.repo.ListRecentSystemMetrics(ctx, windowMinutes, limit)
}
func (s *OpsService) CountActiveAlerts(ctx context.Context) (int, error) {
return s.repo.CountActiveAlerts(ctx)
}

View File

@@ -0,0 +1,203 @@
package service
import (
"context"
"log"
"runtime"
"sync"
"time"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/mem"
)
const (
opsMetricsInterval = 1 * time.Minute
opsMetricsCollectTimeout = 10 * time.Second
opsMetricsWindowShortMinutes = 1
opsMetricsWindowLongMinutes = 5
bytesPerMB = 1024 * 1024
cpuUsageSampleInterval = 0 * time.Second
percentScale = 100
)
type OpsMetricsCollector struct {
opsService *OpsService
concurrencyService *ConcurrencyService
interval time.Duration
lastGCPauseTotal uint64
lastGCPauseMu sync.Mutex
stopCh chan struct{}
startOnce sync.Once
stopOnce sync.Once
}
func NewOpsMetricsCollector(opsService *OpsService, concurrencyService *ConcurrencyService) *OpsMetricsCollector {
return &OpsMetricsCollector{
opsService: opsService,
concurrencyService: concurrencyService,
interval: opsMetricsInterval,
}
}
func (c *OpsMetricsCollector) Start() {
if c == nil {
return
}
c.startOnce.Do(func() {
if c.stopCh == nil {
c.stopCh = make(chan struct{})
}
go c.run()
})
}
func (c *OpsMetricsCollector) Stop() {
if c == nil {
return
}
c.stopOnce.Do(func() {
if c.stopCh != nil {
close(c.stopCh)
}
})
}
func (c *OpsMetricsCollector) run() {
ticker := time.NewTicker(c.interval)
defer ticker.Stop()
c.collectOnce()
for {
select {
case <-ticker.C:
c.collectOnce()
case <-c.stopCh:
return
}
}
}
func (c *OpsMetricsCollector) collectOnce() {
if c.opsService == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), opsMetricsCollectTimeout)
defer cancel()
now := time.Now()
systemStats := c.collectSystemStats(ctx)
queueDepth := c.collectQueueDepth(ctx)
activeAlerts := c.collectActiveAlerts(ctx)
for _, window := range []int{opsMetricsWindowShortMinutes, opsMetricsWindowLongMinutes} {
startTime := now.Add(-time.Duration(window) * time.Minute)
windowStats, err := c.opsService.GetWindowStats(ctx, startTime, now)
if err != nil {
log.Printf("[OpsMetrics] failed to get window stats (%dm): %v", window, err)
continue
}
successRate, errorRate := computeRates(windowStats.SuccessCount, windowStats.ErrorCount)
requestCount := windowStats.SuccessCount + windowStats.ErrorCount
metric := &OpsMetrics{
WindowMinutes: window,
RequestCount: requestCount,
SuccessCount: windowStats.SuccessCount,
ErrorCount: windowStats.ErrorCount,
SuccessRate: successRate,
ErrorRate: errorRate,
P95LatencyMs: windowStats.P95LatencyMs,
P99LatencyMs: windowStats.P99LatencyMs,
HTTP2Errors: windowStats.HTTP2Errors,
ActiveAlerts: activeAlerts,
CPUUsagePercent: systemStats.cpuUsage,
MemoryUsedMB: systemStats.memoryUsedMB,
MemoryTotalMB: systemStats.memoryTotalMB,
MemoryUsagePercent: systemStats.memoryUsagePercent,
HeapAllocMB: systemStats.heapAllocMB,
GCPauseMs: systemStats.gcPauseMs,
ConcurrencyQueueDepth: queueDepth,
UpdatedAt: now,
}
if err := c.opsService.RecordMetrics(ctx, metric); err != nil {
log.Printf("[OpsMetrics] failed to record metrics (%dm): %v", window, err)
}
}
}
func computeRates(successCount, errorCount int64) (float64, float64) {
total := successCount + errorCount
if total == 0 {
// No traffic => no data. Rates are kept at 0 and request_count will be 0.
// The UI should render this as N/A instead of "100% success".
return 0, 0
}
successRate := float64(successCount) / float64(total) * percentScale
errorRate := float64(errorCount) / float64(total) * percentScale
return successRate, errorRate
}
type opsSystemStats struct {
cpuUsage float64
memoryUsedMB int64
memoryTotalMB int64
memoryUsagePercent float64
heapAllocMB int64
gcPauseMs float64
}
func (c *OpsMetricsCollector) collectSystemStats(ctx context.Context) opsSystemStats {
stats := opsSystemStats{}
if percents, err := cpu.PercentWithContext(ctx, cpuUsageSampleInterval, false); err == nil && len(percents) > 0 {
stats.cpuUsage = percents[0]
}
if vm, err := mem.VirtualMemoryWithContext(ctx); err == nil {
stats.memoryUsedMB = int64(vm.Used / bytesPerMB)
stats.memoryTotalMB = int64(vm.Total / bytesPerMB)
stats.memoryUsagePercent = vm.UsedPercent
}
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
stats.heapAllocMB = int64(memStats.HeapAlloc / bytesPerMB)
c.lastGCPauseMu.Lock()
if c.lastGCPauseTotal != 0 && memStats.PauseTotalNs >= c.lastGCPauseTotal {
stats.gcPauseMs = float64(memStats.PauseTotalNs-c.lastGCPauseTotal) / float64(time.Millisecond)
}
c.lastGCPauseTotal = memStats.PauseTotalNs
c.lastGCPauseMu.Unlock()
return stats
}
func (c *OpsMetricsCollector) collectQueueDepth(ctx context.Context) int {
if c.concurrencyService == nil {
return 0
}
depth, err := c.concurrencyService.GetTotalWaitCount(ctx)
if err != nil {
log.Printf("[OpsMetrics] failed to get queue depth: %v", err)
return 0
}
return depth
}
func (c *OpsMetricsCollector) collectActiveAlerts(ctx context.Context) int {
if c.opsService == nil {
return 0
}
count, err := c.opsService.CountActiveAlerts(ctx)
if err != nil {
return 0
}
return count
}

File diff suppressed because it is too large Load Diff

View File

@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeySiteName,
SettingKeySiteLogo,
SettingKeySiteSubtitle,
SettingKeyApiBaseUrl,
SettingKeyAPIBaseURL,
SettingKeyContactInfo,
SettingKeyDocUrl,
SettingKeyDocURL,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocUrl: settings[SettingKeyDocUrl],
DocURL: settings[SettingKeyDocURL],
}, nil
}
@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
// 邮件服务设置(只有非空才更新密码)
updates[SettingKeySmtpHost] = settings.SmtpHost
updates[SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
updates[SettingKeySmtpUsername] = settings.SmtpUsername
if settings.SmtpPassword != "" {
updates[SettingKeySmtpPassword] = settings.SmtpPassword
updates[SettingKeySMTPHost] = settings.SMTPHost
updates[SettingKeySMTPPort] = strconv.Itoa(settings.SMTPPort)
updates[SettingKeySMTPUsername] = settings.SMTPUsername
if settings.SMTPPassword != "" {
updates[SettingKeySMTPPassword] = settings.SMTPPassword
}
updates[SettingKeySmtpFrom] = settings.SmtpFrom
updates[SettingKeySmtpFromName] = settings.SmtpFromName
updates[SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
updates[SettingKeySMTPFrom] = settings.SMTPFrom
updates[SettingKeySMTPFromName] = settings.SMTPFromName
updates[SettingKeySMTPUseTLS] = strconv.FormatBool(settings.SMTPUseTLS)
// Cloudflare Turnstile 设置(只有非空才更新密钥)
updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
@@ -115,9 +115,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo
updates[SettingKeySiteSubtitle] = settings.SiteSubtitle
updates[SettingKeyApiBaseUrl] = settings.ApiBaseUrl
updates[SettingKeyAPIBaseURL] = settings.APIBaseURL
updates[SettingKeyContactInfo] = settings.ContactInfo
updates[SettingKeyDocUrl] = settings.DocUrl
updates[SettingKeyDocURL] = settings.DocURL
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
@@ -198,8 +198,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeySmtpPort: "587",
SettingKeySmtpUseTLS: "false",
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
}
return s.settingRepo.SetMultiple(ctx, defaults)
@@ -210,26 +210,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SmtpHost: settings[SettingKeySmtpHost],
SmtpUsername: settings[SettingKeySmtpUsername],
SmtpFrom: settings[SettingKeySmtpFrom],
SmtpFromName: settings[SettingKeySmtpFromName],
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],
SMTPFromName: settings[SettingKeySMTPFromName],
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocUrl: settings[SettingKeyDocUrl],
DocURL: settings[SettingKeyDocURL],
}
// 解析整数类型
if port, err := strconv.Atoi(settings[SettingKeySmtpPort]); err == nil {
result.SmtpPort = port
if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil {
result.SMTPPort = port
} else {
result.SmtpPort = 587
result.SMTPPort = 587
}
if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil {
@@ -245,8 +245,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.DefaultBalance = s.cfg.Default.UserBalance
}
// 敏感信息直接返回方便测试连接时使用
result.SmtpPassword = settings[SettingKeySmtpPassword]
// 敏感信息直接返回,方便测试连接时使用
result.SMTPPassword = settings[SettingKeySMTPPassword]
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
return result
@@ -278,28 +278,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
return value
}
// GenerateAdminApiKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error) {
// GenerateAdminAPIKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) {
// 生成 32 字节随机数 = 64 位十六进制字符
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
key := AdminApiKeyPrefix + hex.EncodeToString(bytes)
key := AdminAPIKeyPrefix + hex.EncodeToString(bytes)
// 存储到 settings 表
if err := s.settingRepo.Set(ctx, SettingKeyAdminApiKey, key); err != nil {
if err := s.settingRepo.Set(ctx, SettingKeyAdminAPIKey, key); err != nil {
return "", fmt.Errorf("save admin api key: %w", err)
}
return key, nil
}
// GetAdminApiKeyStatus 获取管理员 API Key 状态
// GetAdminAPIKeyStatus 获取管理员 API Key 状态
// 返回脱敏的 key、是否存在、错误
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return "", false, nil
@@ -320,10 +320,10 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
return maskedKey, true, nil
}
// GetAdminApiKey 获取完整的管理员 API Key仅供内部验证使用
// GetAdminAPIKey 获取完整的管理员 API Key仅供内部验证使用
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return "", nil // 未配置,返回空字符串
@@ -333,7 +333,7 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
return key, nil
}
// DeleteAdminApiKey 删除管理员 API Key
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, SettingKeyAdminApiKey)
// DeleteAdminAPIKey 删除管理员 API Key
func (s *SettingService) DeleteAdminAPIKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, SettingKeyAdminAPIKey)
}

View File

@@ -4,13 +4,13 @@ type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
SmtpHost string
SmtpPort int
SmtpUsername string
SmtpPassword string
SmtpFrom string
SmtpFromName string
SmtpUseTLS bool
SMTPHost string
SMTPPort int
SMTPUsername string
SMTPPassword string
SMTPFrom string
SMTPFromName string
SMTPUseTLS bool
TurnstileEnabled bool
TurnstileSiteKey string
@@ -19,9 +19,9 @@ type SystemSettings struct {
SiteName string
SiteLogo string
SiteSubtitle string
ApiBaseUrl string
APIBaseURL string
ContactInfo string
DocUrl string
DocURL string
DefaultConcurrency int
DefaultBalance float64
@@ -35,8 +35,8 @@ type PublicSettings struct {
SiteName string
SiteLogo string
SiteSubtitle string
ApiBaseUrl string
APIBaseURL string
ContactInfo string
DocUrl string
DocURL string
Version string
}

View File

@@ -197,7 +197,7 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
{
name: "anthropic api-key - cannot refresh",
platform: PlatformAnthropic,
accType: AccountTypeApiKey,
accType: AccountTypeAPIKey,
want: false,
},
{

View File

@@ -79,7 +79,7 @@ type ReleaseInfo struct {
Name string `json:"name"`
Body string `json:"body"`
PublishedAt string `json:"published_at"`
HtmlURL string `json:"html_url"`
HTMLURL string `json:"html_url"`
Assets []Asset `json:"assets,omitempty"`
}
@@ -96,13 +96,13 @@ type GitHubRelease struct {
Name string `json:"name"`
Body string `json:"body"`
PublishedAt string `json:"published_at"`
HtmlUrl string `json:"html_url"`
HTMLURL string `json:"html_url"`
Assets []GitHubAsset `json:"assets"`
}
type GitHubAsset struct {
Name string `json:"name"`
BrowserDownloadUrl string `json:"browser_download_url"`
BrowserDownloadURL string `json:"browser_download_url"`
Size int64 `json:"size"`
}
@@ -285,7 +285,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
for i, a := range release.Assets {
assets[i] = Asset{
Name: a.Name,
DownloadURL: a.BrowserDownloadUrl,
DownloadURL: a.BrowserDownloadURL,
Size: a.Size,
}
}
@@ -298,7 +298,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
Name: release.Name,
Body: release.Body,
PublishedAt: release.PublishedAt,
HtmlURL: release.HtmlUrl,
HTMLURL: release.HTMLURL,
Assets: assets,
},
Cached: false,

View File

@@ -10,7 +10,7 @@ const (
type UsageLog struct {
ID int64
UserID int64
ApiKeyID int64
APIKeyID int64
AccountID int64
RequestID string
Model string
@@ -42,7 +42,7 @@ type UsageLog struct {
CreatedAt time.Time
User *User
ApiKey *ApiKey
APIKey *APIKey
Account *Account
Group *Group
Subscription *UserSubscription

View File

@@ -17,7 +17,7 @@ var (
// CreateUsageLogRequest 创建使用日志请求
type CreateUsageLogRequest struct {
UserID int64 `json:"user_id"`
ApiKeyID int64 `json:"api_key_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
@@ -75,7 +75,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
// 创建使用日志
usageLog := &UsageLog{
UserID: req.UserID,
ApiKeyID: req.ApiKeyID,
APIKeyID: req.APIKeyID,
AccountID: req.AccountID,
RequestID: req.RequestID,
Model: req.Model,
@@ -128,9 +128,9 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
return logs, pagination, nil
}
// ListByApiKey 获取API Key的使用日志列表
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
// ListByAPIKey 获取API Key的使用日志列表
func (s *UsageService) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByAPIKey(ctx, apiKeyID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
@@ -165,9 +165,9 @@ func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTi
}, nil
}
// GetStatsByApiKey 获取API Key的使用统计
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
stats, err := s.usageRepo.GetApiKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
// GetStatsByAPIKey 获取API Key的使用统计
func (s *UsageService) GetStatsByAPIKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
stats, err := s.usageRepo.GetAPIKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get api key stats: %w", err)
}
@@ -270,9 +270,9 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
return stats, nil
}
// GetBatchApiKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}

View File

@@ -21,7 +21,7 @@ type User struct {
CreatedAt time.Time
UpdatedAt time.Time
ApiKeys []ApiKey
APIKeys []APIKey
Subscriptions []UserSubscription
}

View File

@@ -73,6 +73,20 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
return svc
}
// ProvideOpsMetricsCollector creates and starts OpsMetricsCollector.
func ProvideOpsMetricsCollector(opsService *OpsService, concurrencyService *ConcurrencyService) *OpsMetricsCollector {
svc := NewOpsMetricsCollector(opsService, concurrencyService)
svc.Start()
return svc
}
// ProvideOpsAlertService creates and starts OpsAlertService.
func ProvideOpsAlertService(opsService *OpsService, userService *UserService, emailService *EmailService) *OpsAlertService {
svc := NewOpsAlertService(opsService, userService, emailService)
svc.Start()
return svc
}
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
svc := NewConcurrencyService(cache)
@@ -87,13 +101,14 @@ var ProviderSet = wire.NewSet(
// Core services
NewAuthService,
NewUserService,
NewApiKeyService,
NewAPIKeyService,
NewGroupService,
NewAccountService,
NewProxyService,
NewRedeemService,
NewUsageService,
NewDashboardService,
NewOpsService,
ProvidePricingService,
NewBillingService,
NewBillingCacheService,
@@ -125,5 +140,7 @@ var ProviderSet = wire.NewSet(
ProvideTimingWheelService,
ProvideDeferredService,
ProvideAntigravityQuotaRefresher,
ProvideOpsMetricsCollector,
ProvideOpsAlertService,
NewUserAttributeService,
)