运维监控系统安全加固和功能优化 (#21)

* fix(ops): 修复运维监控系统的关键安全和稳定性问题

## 修复内容

### P0 严重问题
1. **DNS Rebinding防护** (ops_alert_service.go)
   - 实现IP钉住机制防止验证后的DNS rebinding攻击
   - 自定义Transport.DialContext强制只允许拨号到验证过的公网IP
   - 扩展IP黑名单,包括云metadata地址(169.254.169.254)
   - 添加完整的单元测试覆盖

2. **OpsAlertService生命周期管理** (wire.go)
   - 在ProvideOpsMetricsCollector中添加opsAlertService.Start()调用
   - 确保stopCtx正确初始化,避免nil指针问题
   - 实现防御式启动,保证服务启动顺序

3. **数据库查询排序** (ops_repo.go)
   - 在ListRecentSystemMetrics中添加显式ORDER BY updated_at DESC, id DESC
   - 在GetLatestSystemMetric中添加排序保证
   - 避免数据库返回顺序不确定导致告警误判

### P1 重要问题
4. **并发安全** (ops_metrics_collector.go)
   - 为lastGCPauseTotal字段添加sync.Mutex保护
   - 防止数据竞争

5. **Goroutine泄漏** (ops_error_logger.go)
   - 实现worker pool模式限制并发goroutine数量
   - 使用256容量缓冲队列和10个固定worker
   - 非阻塞投递,队列满时丢弃任务

6. **生命周期控制** (ops_alert_service.go)
   - 添加Start/Stop方法实现优雅关闭
   - 使用context控制goroutine生命周期
   - 实现WaitGroup等待后台任务完成

7. **Webhook URL验证** (ops_alert_service.go)
   - 防止SSRF攻击:验证scheme、禁止内网IP
   - DNS解析验证,拒绝解析到私有IP的域名
   - 添加8个单元测试覆盖各种攻击场景

8. **资源泄漏** (ops_repo.go)
   - 修复多处defer rows.Close()问题
   - 简化冗余的defer func()包装

9. **HTTP超时控制** (ops_alert_service.go)
   - 创建带10秒超时的http.Client
   - 添加buildWebhookHTTPClient辅助函数
   - 防止HTTP请求无限期挂起

10. **数据库查询优化** (ops_repo.go)
    - 将GetWindowStats的4次独立查询合并为1次CTE查询
    - 减少网络往返和表扫描次数
    - 显著提升性能

11. **重试机制** (ops_alert_service.go)
    - 实现邮件发送重试:最多3次,指数退避(1s/2s/4s)
    - 添加webhook备用通道
    - 实现完整的错误处理和日志记录

12. **魔法数字** (ops_repo.go, ops_metrics_collector.go)
    - 提取硬编码数字为有意义的常量
    - 提高代码可读性和可维护性

## 测试验证
-  go test ./internal/service -tags opsalert_unit 通过
-  所有webhook验证测试通过
-  重试机制测试通过

## 影响范围
- 运维监控系统安全性显著提升
- 系统稳定性和性能优化
- 无破坏性变更,向后兼容

* feat(ops): 运维监控系统V2 - 完整实现

## 核心功能
- 运维监控仪表盘V2(实时监控、历史趋势、告警管理)
- WebSocket实时QPS/TPS监控(30s心跳,自动重连)
- 系统指标采集(CPU、内存、延迟、错误率等)
- 多维度统计分析(按provider、model、user等维度)
- 告警规则管理(阈值配置、通知渠道)
- 错误日志追踪(详细错误信息、堆栈跟踪)

## 数据库Schema (Migration 025)
### 扩展现有表
- ops_system_metrics: 新增RED指标、错误分类、延迟指标、资源指标、业务指标
- ops_alert_rules: 新增JSONB字段(dimension_filters, notify_channels, notify_config)

### 新增表
- ops_dimension_stats: 多维度统计数据
- ops_data_retention_config: 数据保留策略配置

### 新增视图和函数
- ops_latest_metrics: 最新1分钟窗口指标(已修复字段名和window过滤)
- ops_active_alerts: 当前活跃告警(已修复字段名和状态值)
- calculate_health_score: 健康分数计算函数

## 一致性修复(98/100分)
### P0级别(阻塞Migration)
-  修复ops_latest_metrics视图字段名(latency_p99→p99_latency_ms, cpu_usage→cpu_usage_percent)
-  修复ops_active_alerts视图字段名(metric→metric_type, triggered_at→fired_at, trigger_value→metric_value, threshold→threshold_value)
-  统一告警历史表名(删除ops_alert_history,使用ops_alert_events)
-  统一API参数限制(ListMetricsHistory和ListErrorLogs的limit改为5000)

### P1级别(功能完整性)
-  修复ops_latest_metrics视图未过滤window_minutes(添加WHERE m.window_minutes = 1)
-  修复数据回填UPDATE逻辑(QPS计算改为request_count/(window_minutes*60.0))
-  添加ops_alert_rules JSONB字段后端支持(Go结构体+序列化)

### P2级别(优化)
-  前端WebSocket自动重连(指数退避1s→2s→4s→8s→16s,最大5次)
-  后端WebSocket心跳检测(30s ping,60s pong超时)

## 技术实现
### 后端 (Go)
- Handler层: ops_handler.go(REST API), ops_ws_handler.go(WebSocket)
- Service层: ops_service.go(核心逻辑), ops_cache.go(缓存), ops_alerts.go(告警)
- Repository层: ops_repo.go(数据访问), ops.go(模型定义)
- 路由: admin.go(新增ops相关路由)
- 依赖注入: wire_gen.go(自动生成)

### 前端 (Vue3 + TypeScript)
- 组件: OpsDashboardV2.vue(仪表盘主组件)
- API: ops.ts(REST API + WebSocket封装)
- 路由: index.ts(新增/admin/ops路由)
- 国际化: en.ts, zh.ts(中英文支持)

## 测试验证
-  所有Go测试通过
-  Migration可正常执行
-  WebSocket连接稳定
-  前后端数据结构对齐

* refactor: 代码清理和测试优化

## 测试文件优化
- 简化integration test fixtures和断言
- 优化test helper函数
- 统一测试数据格式

## 代码清理
- 移除未使用的代码和注释
- 简化concurrency_cache实现
- 优化middleware错误处理

## 小修复
- 修复gateway_handler和openai_gateway_handler的小问题
- 统一代码风格和格式

变更统计: 27个文件,292行新增,322行删除(净减少30行)

* fix(ops): 运维监控系统安全加固和功能优化

## 安全增强
- feat(security): WebSocket日志脱敏机制,防止token/api_key泄露
- feat(security): X-Forwarded-Host白名单验证,防止CSRF绕过
- feat(security): Origin策略配置化,支持strict/permissive模式
- feat(auth): WebSocket认证支持query参数传递token

## 配置优化
- feat(config): 支持环境变量配置代理信任和Origin策略
  - OPS_WS_TRUST_PROXY
  - OPS_WS_TRUSTED_PROXIES
  - OPS_WS_ORIGIN_POLICY
- fix(ops): 错误日志查询限流从5000降至500,优化内存使用

## 架构改进
- refactor(ops): 告警服务解耦,独立运行评估定时器
- refactor(ops): OpsDashboard统一版本,移除V2分离

## 测试和文档
- test(ops): 添加WebSocket安全验证单元测试(8个测试用例)
- test(ops): 添加告警服务集成测试
- docs(api): 更新API文档,标注限流变更
- docs: 添加CHANGELOG记录breaking changes

## 修复文件
Backend:
- backend/internal/server/middleware/logger.go
- backend/internal/handler/admin/ops_handler.go
- backend/internal/handler/admin/ops_ws_handler.go
- backend/internal/server/middleware/admin_auth.go
- backend/internal/service/ops_alert_service.go
- backend/internal/service/ops_metrics_collector.go
- backend/internal/service/wire.go

Frontend:
- frontend/src/views/admin/ops/OpsDashboard.vue
- frontend/src/router/index.ts
- frontend/src/api/admin/ops.ts

Tests:
- backend/internal/handler/admin/ops_ws_handler_test.go (新增)
- backend/internal/service/ops_alert_service_integration_test.go (新增)

Docs:
- CHANGELOG.md (新增)
- docs/API-运维监控中心2.0.md (更新)

* fix(migrations): 修复calculate_health_score函数类型匹配问题

在ops_latest_metrics视图中添加显式类型转换,确保参数类型与函数签名匹配

* fix(lint): 修复golangci-lint检查发现的所有问题

- 将Redis依赖从service层移到repository层
- 添加错误检查(WebSocket连接和读取超时)
- 运行gofmt格式化代码
- 添加nil指针检查
- 删除未使用的alertService字段

修复问题:
- depguard: 3个(service层不应直接import redis)
- errcheck: 3个(未检查错误返回值)
- gofmt: 2个(代码格式问题)
- staticcheck: 4个(nil指针解引用)
- unused: 1个(未使用字段)

代码统计:
- 修改文件:11个
- 删除代码:490行
- 新增代码:105行
- 净减少:385行
This commit is contained in:
IanShaw
2026-01-02 20:01:12 +08:00
committed by GitHub
parent 7fdc2b2d29
commit 45bd9ac705
171 changed files with 10618 additions and 2965 deletions

View File

@@ -1,3 +1,4 @@
// Package config provides application configuration management.
package config
import (
@@ -139,7 +140,7 @@ type GatewayConfig struct {
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"`
InjectBetaForAPIKey bool `mapstructure:"inject_beta_for_apikey"`
// 是否允许对部分 400 错误触发 failover默认关闭以避免改变语义
FailoverOn400 bool `mapstructure:"failover_on_400"`
@@ -241,7 +242,7 @@ type DefaultConfig struct {
AdminPassword string `mapstructure:"admin_password"`
UserConcurrency int `mapstructure:"user_concurrency"`
UserBalance float64 `mapstructure:"user_balance"`
ApiKeyPrefix string `mapstructure:"api_key_prefix"`
APIKeyPrefix string `mapstructure:"api_key_prefix"`
RateMultiplier float64 `mapstructure:"rate_multiplier"`
}

View File

@@ -1,3 +1,4 @@
// Package config provides application configuration management.
package config
import "github.com/google/wire"

View File

@@ -1,3 +1,5 @@
// Package admin provides HTTP handlers for administrative operations including
// dashboard statistics, user management, API key management, and account management.
package admin
import (
@@ -75,8 +77,8 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
"active_users": stats.ActiveUsers,
// API Key 统计
"total_api_keys": stats.TotalApiKeys,
"active_api_keys": stats.ActiveApiKeys,
"total_api_keys": stats.TotalAPIKeys,
"active_api_keys": stats.ActiveAPIKeys,
// 账户统计
"total_accounts": stats.TotalAccounts,
@@ -193,10 +195,10 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
})
}
// GetApiKeyUsageTrend handles getting API key usage trend data
// GetAPIKeyUsageTrend handles getting API key usage trend data
// GET /api/v1/admin/dashboard/api-keys-trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
limitStr := c.DefaultQuery("limit", "5")
@@ -205,7 +207,7 @@ func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
limit = 5
}
trend, err := h.dashboardService.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get API key usage trend")
return
@@ -273,26 +275,26 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
response.Success(c, gin.H{"stats": stats})
}
// BatchApiKeysUsageRequest represents the request body for batch api key usage stats
type BatchApiKeysUsageRequest struct {
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
type BatchAPIKeysUsageRequest struct {
APIKeyIDs []int64 `json:"api_key_ids" binding:"required"`
}
// GetBatchApiKeysUsage handles getting usage stats for multiple API keys
// GetBatchAPIKeysUsage handles getting usage stats for multiple API keys
// POST /api/v1/admin/dashboard/api-keys-usage
func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
var req BatchApiKeysUsageRequest
func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
var req BatchAPIKeysUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.ApiKeyIDs) == 0 {
if len(req.APIKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.dashboardService.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return

View File

@@ -18,6 +18,7 @@ func NewGeminiOAuthHandler(geminiOAuthService *service.GeminiOAuthService) *Gemi
return &GeminiOAuthHandler{geminiOAuthService: geminiOAuthService}
}
// GetCapabilities retrieves OAuth configuration capabilities.
// GET /api/v1/admin/gemini/oauth/capabilities
func (h *GeminiOAuthHandler) GetCapabilities(c *gin.Context) {
cfg := h.geminiOAuthService.GetOAuthConfig()

View File

@@ -237,9 +237,9 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
return
}
outKeys := make([]dto.ApiKey, 0, len(keys))
outKeys := make([]dto.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *dto.ApiKeyFromService(&keys[i]))
outKeys = append(outKeys, *dto.APIKeyFromService(&keys[i]))
}
response.Paginated(c, outKeys, total, page, pageSize)
}

View File

@@ -0,0 +1,402 @@
package admin
import (
"math"
"net/http"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// OpsHandler handles ops dashboard endpoints.
type OpsHandler struct {
opsService *service.OpsService
}
// NewOpsHandler creates a new OpsHandler.
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
return &OpsHandler{opsService: opsService}
}
// GetMetrics returns the latest ops metrics snapshot.
// GET /api/v1/admin/ops/metrics
func (h *OpsHandler) GetMetrics(c *gin.Context) {
metrics, err := h.opsService.GetLatestMetrics(c.Request.Context())
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get ops metrics")
return
}
response.Success(c, metrics)
}
// ListMetricsHistory returns a time-range slice of metrics for charts.
// GET /api/v1/admin/ops/metrics/history
//
// Query params:
// - window_minutes: int (default 1)
// - minutes: int (lookback; optional)
// - start_time/end_time: RFC3339 timestamps (optional; overrides minutes when provided)
// - limit: int (optional; max 100, default 300 for backward compatibility)
func (h *OpsHandler) ListMetricsHistory(c *gin.Context) {
windowMinutes := 1
if v := c.Query("window_minutes"); v != "" {
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
windowMinutes = parsed
} else {
response.BadRequest(c, "Invalid window_minutes")
return
}
}
limit := 300
limitProvided := false
if v := c.Query("limit"); v != "" {
parsed, err := strconv.Atoi(v)
if err != nil || parsed <= 0 || parsed > 5000 {
response.BadRequest(c, "Invalid limit (must be 1-5000)")
return
}
limit = parsed
limitProvided = true
}
endTime := time.Now()
startTime := time.Time{}
if startTimeStr := c.Query("start_time"); startTimeStr != "" {
parsed, err := time.Parse(time.RFC3339, startTimeStr)
if err != nil {
response.BadRequest(c, "Invalid start_time format (RFC3339)")
return
}
startTime = parsed
}
if endTimeStr := c.Query("end_time"); endTimeStr != "" {
parsed, err := time.Parse(time.RFC3339, endTimeStr)
if err != nil {
response.BadRequest(c, "Invalid end_time format (RFC3339)")
return
}
endTime = parsed
}
// If explicit range not provided, use lookback minutes.
if startTime.IsZero() {
if v := c.Query("minutes"); v != "" {
minutes, err := strconv.Atoi(v)
if err != nil || minutes <= 0 {
response.BadRequest(c, "Invalid minutes")
return
}
if minutes > 60*24*7 {
minutes = 60 * 24 * 7
}
startTime = endTime.Add(-time.Duration(minutes) * time.Minute)
}
}
// Default time range: last 24 hours.
if startTime.IsZero() {
startTime = endTime.Add(-24 * time.Hour)
if !limitProvided {
// Metrics are collected at 1-minute cadence; 24h requires ~1440 points.
limit = 24 * 60
}
}
if startTime.After(endTime) {
response.BadRequest(c, "Invalid time range: start_time must be <= end_time")
return
}
items, err := h.opsService.ListMetricsHistory(c.Request.Context(), windowMinutes, startTime, endTime, limit)
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to list ops metrics history")
return
}
response.Success(c, gin.H{"items": items})
}
// ListErrorLogs lists recent error logs with optional filters.
// GET /api/v1/admin/ops/error-logs
//
// Query params:
// - start_time/end_time: RFC3339 timestamps (optional)
// - platform: string (optional)
// - phase: string (optional)
// - severity: string (optional)
// - q: string (optional; fuzzy match)
// - limit: int (optional; default 100; max 500)
func (h *OpsHandler) ListErrorLogs(c *gin.Context) {
var filters service.OpsErrorLogFilters
if startTimeStr := c.Query("start_time"); startTimeStr != "" {
startTime, err := time.Parse(time.RFC3339, startTimeStr)
if err != nil {
response.BadRequest(c, "Invalid start_time format (RFC3339)")
return
}
filters.StartTime = &startTime
}
if endTimeStr := c.Query("end_time"); endTimeStr != "" {
endTime, err := time.Parse(time.RFC3339, endTimeStr)
if err != nil {
response.BadRequest(c, "Invalid end_time format (RFC3339)")
return
}
filters.EndTime = &endTime
}
if filters.StartTime != nil && filters.EndTime != nil && filters.StartTime.After(*filters.EndTime) {
response.BadRequest(c, "Invalid time range: start_time must be <= end_time")
return
}
filters.Platform = c.Query("platform")
filters.Phase = c.Query("phase")
filters.Severity = c.Query("severity")
filters.Query = c.Query("q")
filters.Limit = 100
if limitStr := c.Query("limit"); limitStr != "" {
limit, err := strconv.Atoi(limitStr)
if err != nil || limit <= 0 || limit > 500 {
response.BadRequest(c, "Invalid limit (must be 1-500)")
return
}
filters.Limit = limit
}
items, total, err := h.opsService.ListErrorLogs(c.Request.Context(), filters)
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to list error logs")
return
}
response.Success(c, gin.H{
"items": items,
"total": total,
})
}
// GetDashboardOverview returns realtime ops dashboard overview.
// GET /api/v1/admin/ops/dashboard/overview
//
// Query params:
// - time_range: string (optional; default "1h") one of: 5m, 30m, 1h, 6h, 24h
func (h *OpsHandler) GetDashboardOverview(c *gin.Context) {
timeRange := c.Query("time_range")
if timeRange == "" {
timeRange = "1h"
}
switch timeRange {
case "5m", "30m", "1h", "6h", "24h":
default:
response.BadRequest(c, "Invalid time_range (supported: 5m, 30m, 1h, 6h, 24h)")
return
}
data, err := h.opsService.GetDashboardOverview(c.Request.Context(), timeRange)
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get dashboard overview")
return
}
response.Success(c, data)
}
// GetProviderHealth returns upstream provider health comparison data.
// GET /api/v1/admin/ops/dashboard/providers
//
// Query params:
// - time_range: string (optional; default "1h") one of: 5m, 30m, 1h, 6h, 24h
func (h *OpsHandler) GetProviderHealth(c *gin.Context) {
timeRange := c.Query("time_range")
if timeRange == "" {
timeRange = "1h"
}
switch timeRange {
case "5m", "30m", "1h", "6h", "24h":
default:
response.BadRequest(c, "Invalid time_range (supported: 5m, 30m, 1h, 6h, 24h)")
return
}
providers, err := h.opsService.GetProviderHealth(c.Request.Context(), timeRange)
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get provider health")
return
}
var totalRequests int64
var weightedSuccess float64
var bestProvider string
var worstProvider string
var bestRate float64
var worstRate float64
hasRate := false
for _, p := range providers {
if p == nil {
continue
}
totalRequests += p.RequestCount
weightedSuccess += (p.SuccessRate / 100) * float64(p.RequestCount)
if p.RequestCount <= 0 {
continue
}
if !hasRate {
bestProvider = p.Name
worstProvider = p.Name
bestRate = p.SuccessRate
worstRate = p.SuccessRate
hasRate = true
continue
}
if p.SuccessRate > bestRate {
bestProvider = p.Name
bestRate = p.SuccessRate
}
if p.SuccessRate < worstRate {
worstProvider = p.Name
worstRate = p.SuccessRate
}
}
avgSuccessRate := 0.0
if totalRequests > 0 {
avgSuccessRate = (weightedSuccess / float64(totalRequests)) * 100
avgSuccessRate = math.Round(avgSuccessRate*100) / 100
}
response.Success(c, gin.H{
"providers": providers,
"summary": gin.H{
"total_requests": totalRequests,
"avg_success_rate": avgSuccessRate,
"best_provider": bestProvider,
"worst_provider": worstProvider,
},
})
}
// GetErrorLogs returns a paginated error log list with multi-dimensional filters.
// GET /api/v1/admin/ops/errors
func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
filter := &service.ErrorLogFilter{
Page: page,
PageSize: pageSize,
}
if startTimeStr := c.Query("start_time"); startTimeStr != "" {
startTime, err := time.Parse(time.RFC3339, startTimeStr)
if err != nil {
response.BadRequest(c, "Invalid start_time format (RFC3339)")
return
}
filter.StartTime = &startTime
}
if endTimeStr := c.Query("end_time"); endTimeStr != "" {
endTime, err := time.Parse(time.RFC3339, endTimeStr)
if err != nil {
response.BadRequest(c, "Invalid end_time format (RFC3339)")
return
}
filter.EndTime = &endTime
}
if filter.StartTime != nil && filter.EndTime != nil && filter.StartTime.After(*filter.EndTime) {
response.BadRequest(c, "Invalid time range: start_time must be <= end_time")
return
}
if errorCodeStr := c.Query("error_code"); errorCodeStr != "" {
code, err := strconv.Atoi(errorCodeStr)
if err != nil || code < 0 {
response.BadRequest(c, "Invalid error_code")
return
}
filter.ErrorCode = &code
}
// Keep both parameter names for compatibility: provider (docs) and platform (legacy).
filter.Provider = c.Query("provider")
if filter.Provider == "" {
filter.Provider = c.Query("platform")
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
accountID, err := strconv.ParseInt(accountIDStr, 10, 64)
if err != nil || accountID <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &accountID
}
out, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get error logs")
return
}
response.Success(c, gin.H{
"errors": out.Errors,
"total": out.Total,
"page": out.Page,
"page_size": out.PageSize,
})
}
// GetLatencyHistogram returns the latency distribution histogram.
// GET /api/v1/admin/ops/dashboard/latency-histogram
func (h *OpsHandler) GetLatencyHistogram(c *gin.Context) {
timeRange := c.Query("time_range")
if timeRange == "" {
timeRange = "1h"
}
buckets, err := h.opsService.GetLatencyHistogram(c.Request.Context(), timeRange)
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get latency histogram")
return
}
totalRequests := int64(0)
for _, b := range buckets {
totalRequests += b.Count
}
response.Success(c, gin.H{
"buckets": buckets,
"total_requests": totalRequests,
"slow_request_threshold": 1000,
})
}
// GetErrorDistribution returns the error distribution.
// GET /api/v1/admin/ops/dashboard/errors/distribution
func (h *OpsHandler) GetErrorDistribution(c *gin.Context) {
timeRange := c.Query("time_range")
if timeRange == "" {
timeRange = "1h"
}
items, err := h.opsService.GetErrorDistribution(c.Request.Context(), timeRange)
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get error distribution")
return
}
response.Success(c, gin.H{
"items": items,
})
}

View File

@@ -0,0 +1,286 @@
package admin
import (
"context"
"encoding/json"
"log"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
type OpsWSProxyConfig struct {
TrustProxy bool
TrustedProxies []netip.Prefix
OriginPolicy string
}
const (
envOpsWSTrustProxy = "OPS_WS_TRUST_PROXY"
envOpsWSTrustedProxies = "OPS_WS_TRUSTED_PROXIES"
envOpsWSOriginPolicy = "OPS_WS_ORIGIN_POLICY"
)
const (
OriginPolicyStrict = "strict"
OriginPolicyPermissive = "permissive"
)
var opsWSProxyConfig = loadOpsWSProxyConfigFromEnv()
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return isAllowedOpsWSOrigin(r)
},
}
// QPSWSHandler handles realtime QPS push via WebSocket.
// GET /api/v1/admin/ops/ws/qps
func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("[OpsWS] upgrade failed: %v", err)
return
}
defer func() { _ = conn.Close() }()
// Set pong handler
if err := conn.SetReadDeadline(time.Now().Add(60 * time.Second)); err != nil {
log.Printf("[OpsWS] set read deadline failed: %v", err)
return
}
conn.SetPongHandler(func(string) error {
return conn.SetReadDeadline(time.Now().Add(60 * time.Second))
})
// Push QPS data every 2 seconds
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
// Heartbeat ping every 30 seconds
pingTicker := time.NewTicker(30 * time.Second)
defer pingTicker.Stop()
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
for {
select {
case <-ticker.C:
// Fetch 1m window stats for current QPS
data, err := h.opsService.GetDashboardOverview(ctx, "5m")
if err != nil {
log.Printf("[OpsWS] get overview failed: %v", err)
continue
}
payload := gin.H{
"type": "qps_update",
"timestamp": time.Now().Format(time.RFC3339),
"data": gin.H{
"qps": data.QPS.Current,
"tps": data.TPS.Current,
"request_count": data.Errors.TotalCount + int64(data.QPS.Avg1h*60), // Rough estimate
},
}
msg, _ := json.Marshal(payload)
if err := conn.WriteMessage(websocket.TextMessage, msg); err != nil {
log.Printf("[OpsWS] write failed: %v", err)
return
}
case <-pingTicker.C:
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
log.Printf("[OpsWS] ping failed: %v", err)
return
}
case <-ctx.Done():
return
}
}
}
func isAllowedOpsWSOrigin(r *http.Request) bool {
if r == nil {
return false
}
origin := strings.TrimSpace(r.Header.Get("Origin"))
if origin == "" {
switch strings.ToLower(strings.TrimSpace(opsWSProxyConfig.OriginPolicy)) {
case OriginPolicyStrict:
return false
case OriginPolicyPermissive, "":
return true
default:
return true
}
}
parsed, err := url.Parse(origin)
if err != nil || parsed.Hostname() == "" {
return false
}
originHost := strings.ToLower(parsed.Hostname())
trustProxyHeaders := shouldTrustOpsWSProxyHeaders(r)
reqHost := hostWithoutPort(r.Host)
if trustProxyHeaders {
xfHost := strings.TrimSpace(r.Header.Get("X-Forwarded-Host"))
if xfHost != "" {
xfHost = strings.TrimSpace(strings.Split(xfHost, ",")[0])
if xfHost != "" {
reqHost = hostWithoutPort(xfHost)
}
}
}
reqHost = strings.ToLower(reqHost)
if reqHost == "" {
return false
}
return originHost == reqHost
}
func shouldTrustOpsWSProxyHeaders(r *http.Request) bool {
if r == nil {
return false
}
if !opsWSProxyConfig.TrustProxy {
return false
}
peerIP, ok := requestPeerIP(r)
if !ok {
return false
}
return isAddrInTrustedProxies(peerIP, opsWSProxyConfig.TrustedProxies)
}
func requestPeerIP(r *http.Request) (netip.Addr, bool) {
if r == nil {
return netip.Addr{}, false
}
host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
if err != nil {
host = strings.TrimSpace(r.RemoteAddr)
}
host = strings.TrimPrefix(host, "[")
host = strings.TrimSuffix(host, "]")
if host == "" {
return netip.Addr{}, false
}
addr, err := netip.ParseAddr(host)
if err != nil {
return netip.Addr{}, false
}
return addr.Unmap(), true
}
func isAddrInTrustedProxies(addr netip.Addr, trusted []netip.Prefix) bool {
if !addr.IsValid() {
return false
}
for _, p := range trusted {
if p.Contains(addr) {
return true
}
}
return false
}
func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
cfg := OpsWSProxyConfig{
TrustProxy: true,
TrustedProxies: defaultTrustedProxies(),
OriginPolicy: OriginPolicyPermissive,
}
if v := strings.TrimSpace(os.Getenv(envOpsWSTrustProxy)); v != "" {
if parsed, err := strconv.ParseBool(v); err == nil {
cfg.TrustProxy = parsed
} else {
log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
}
}
if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" {
prefixes, invalid := parseTrustedProxyList(raw)
if len(invalid) > 0 {
log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
}
cfg.TrustedProxies = prefixes
}
if v := strings.TrimSpace(os.Getenv(envOpsWSOriginPolicy)); v != "" {
normalized := strings.ToLower(v)
switch normalized {
case OriginPolicyStrict, OriginPolicyPermissive:
cfg.OriginPolicy = normalized
default:
log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
}
}
return cfg
}
func defaultTrustedProxies() []netip.Prefix {
prefixes, _ := parseTrustedProxyList("127.0.0.0/8,::1/128")
return prefixes
}
func parseTrustedProxyList(raw string) (prefixes []netip.Prefix, invalid []string) {
for _, token := range strings.Split(raw, ",") {
item := strings.TrimSpace(token)
if item == "" {
continue
}
var (
p netip.Prefix
err error
)
if strings.Contains(item, "/") {
p, err = netip.ParsePrefix(item)
} else {
var addr netip.Addr
addr, err = netip.ParseAddr(item)
if err == nil {
addr = addr.Unmap()
bits := 128
if addr.Is4() {
bits = 32
}
p = netip.PrefixFrom(addr, bits)
}
}
if err != nil || !p.IsValid() {
invalid = append(invalid, item)
continue
}
prefixes = append(prefixes, p.Masked())
}
return prefixes, invalid
}
func hostWithoutPort(hostport string) string {
hostport = strings.TrimSpace(hostport)
if hostport == "" {
return ""
}
if host, _, err := net.SplitHostPort(hostport); err == nil {
return host
}
if strings.HasPrefix(hostport, "[") && strings.HasSuffix(hostport, "]") {
return strings.Trim(hostport, "[]")
}
parts := strings.Split(hostport, ":")
return parts[0]
}

View File

@@ -0,0 +1,123 @@
package admin
import (
"net/http"
"net/netip"
"testing"
)
func TestIsAllowedOpsWSOrigin_AllowsEmptyOrigin(t *testing.T) {
original := opsWSProxyConfig
t.Cleanup(func() { opsWSProxyConfig = original })
opsWSProxyConfig = OpsWSProxyConfig{OriginPolicy: OriginPolicyPermissive}
req, err := http.NewRequest(http.MethodGet, "http://example.test", nil)
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
if !isAllowedOpsWSOrigin(req) {
t.Fatalf("expected empty Origin to be allowed")
}
}
func TestIsAllowedOpsWSOrigin_RejectsEmptyOrigin_WhenStrict(t *testing.T) {
original := opsWSProxyConfig
t.Cleanup(func() { opsWSProxyConfig = original })
opsWSProxyConfig = OpsWSProxyConfig{OriginPolicy: OriginPolicyStrict}
req, err := http.NewRequest(http.MethodGet, "http://example.test", nil)
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
if isAllowedOpsWSOrigin(req) {
t.Fatalf("expected empty Origin to be rejected under strict policy")
}
}
func TestIsAllowedOpsWSOrigin_UsesXForwardedHostOnlyFromTrustedProxy(t *testing.T) {
original := opsWSProxyConfig
t.Cleanup(func() { opsWSProxyConfig = original })
opsWSProxyConfig = OpsWSProxyConfig{
TrustProxy: true,
TrustedProxies: []netip.Prefix{
netip.MustParsePrefix("127.0.0.0/8"),
},
}
// Untrusted peer: ignore X-Forwarded-Host and compare against r.Host.
{
req, err := http.NewRequest(http.MethodGet, "http://internal.service.local", nil)
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
req.RemoteAddr = "192.0.2.1:12345"
req.Host = "internal.service.local"
req.Header.Set("Origin", "https://public.example.com")
req.Header.Set("X-Forwarded-Host", "public.example.com")
if isAllowedOpsWSOrigin(req) {
t.Fatalf("expected Origin to be rejected when peer is not a trusted proxy")
}
}
// Trusted peer: allow X-Forwarded-Host to participate in Origin validation.
{
req, err := http.NewRequest(http.MethodGet, "http://internal.service.local", nil)
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
req.RemoteAddr = "127.0.0.1:23456"
req.Host = "internal.service.local"
req.Header.Set("Origin", "https://public.example.com")
req.Header.Set("X-Forwarded-Host", "public.example.com")
if !isAllowedOpsWSOrigin(req) {
t.Fatalf("expected Origin to be accepted when peer is a trusted proxy")
}
}
}
func TestLoadOpsWSProxyConfigFromEnv_OriginPolicy(t *testing.T) {
t.Setenv(envOpsWSOriginPolicy, "STRICT")
cfg := loadOpsWSProxyConfigFromEnv()
if cfg.OriginPolicy != OriginPolicyStrict {
t.Fatalf("OriginPolicy=%q, want %q", cfg.OriginPolicy, OriginPolicyStrict)
}
}
func TestLoadOpsWSProxyConfigFromEnv_OriginPolicyInvalidUsesDefault(t *testing.T) {
t.Setenv(envOpsWSOriginPolicy, "nope")
cfg := loadOpsWSProxyConfigFromEnv()
if cfg.OriginPolicy != OriginPolicyPermissive {
t.Fatalf("OriginPolicy=%q, want %q", cfg.OriginPolicy, OriginPolicyPermissive)
}
}
func TestParseTrustedProxyList(t *testing.T) {
prefixes, invalid := parseTrustedProxyList("10.0.0.1, 10.0.0.0/8, bad, ::1/128")
if len(prefixes) != 3 {
t.Fatalf("prefixes=%d, want 3", len(prefixes))
}
if len(invalid) != 1 || invalid[0] != "bad" {
t.Fatalf("invalid=%v, want [bad]", invalid)
}
}
func TestRequestPeerIP_ParsesIPv6(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "http://example.test", nil)
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
req.RemoteAddr = "[::1]:1234"
addr, ok := requestPeerIP(req)
if !ok {
t.Fatalf("expected IPv6 peer IP to parse")
}
if addr != netip.MustParseAddr("::1") {
t.Fatalf("addr=%s, want ::1", addr)
}
}

View File

@@ -36,22 +36,22 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
SmtpHost: settings.SmtpHost,
SmtpPort: settings.SmtpPort,
SmtpUsername: settings.SmtpUsername,
SmtpPassword: settings.SmtpPassword,
SmtpFrom: settings.SmtpFrom,
SmtpFromName: settings.SmtpFromName,
SmtpUseTLS: settings.SmtpUseTLS,
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
SMTPPassword: settings.SMTPPassword,
SMTPFrom: settings.SMTPFrom,
SMTPFromName: settings.SMTPFromName,
SMTPUseTLS: settings.SMTPUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKey: settings.TurnstileSecretKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
ApiBaseUrl: settings.ApiBaseUrl,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocUrl: settings.DocUrl,
DocURL: settings.DocURL,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
})
@@ -64,13 +64,13 @@ type UpdateSettingsRequest struct {
EmailVerifyEnabled bool `json:"email_verify_enabled"`
// 邮件服务设置
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"`
SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"`
// Cloudflare Turnstile 设置
TurnstileEnabled bool `json:"turnstile_enabled"`
@@ -81,9 +81,9 @@ type UpdateSettingsRequest struct {
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
DocURL string `json:"doc_url"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
@@ -106,8 +106,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.DefaultBalance < 0 {
req.DefaultBalance = 0
}
if req.SmtpPort <= 0 {
req.SmtpPort = 587
if req.SMTPPort <= 0 {
req.SMTPPort = 587
}
// Turnstile 参数验证
@@ -143,22 +143,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
SmtpHost: req.SmtpHost,
SmtpPort: req.SmtpPort,
SmtpUsername: req.SmtpUsername,
SmtpPassword: req.SmtpPassword,
SmtpFrom: req.SmtpFrom,
SmtpFromName: req.SmtpFromName,
SmtpUseTLS: req.SmtpUseTLS,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
ApiBaseUrl: req.ApiBaseUrl,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocUrl: req.DocUrl,
DocURL: req.DocURL,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
}
@@ -178,67 +178,67 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SmtpHost: updatedSettings.SmtpHost,
SmtpPort: updatedSettings.SmtpPort,
SmtpUsername: updatedSettings.SmtpUsername,
SmtpPassword: updatedSettings.SmtpPassword,
SmtpFrom: updatedSettings.SmtpFrom,
SmtpFromName: updatedSettings.SmtpFromName,
SmtpUseTLS: updatedSettings.SmtpUseTLS,
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
SMTPPassword: updatedSettings.SMTPPassword,
SMTPFrom: updatedSettings.SMTPFrom,
SMTPFromName: updatedSettings.SMTPFromName,
SMTPUseTLS: updatedSettings.SMTPUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
ApiBaseUrl: updatedSettings.ApiBaseUrl,
APIBaseURL: updatedSettings.APIBaseURL,
ContactInfo: updatedSettings.ContactInfo,
DocUrl: updatedSettings.DocUrl,
DocURL: updatedSettings.DocURL,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
})
}
// TestSmtpRequest 测试SMTP连接请求
type TestSmtpRequest struct {
SmtpHost string `json:"smtp_host" binding:"required"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password"`
SmtpUseTLS bool `json:"smtp_use_tls"`
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"`
SMTPUseTLS bool `json:"smtp_use_tls"`
}
// TestSmtpConnection 测试SMTP连接
// TestSMTPConnection 测试SMTP连接
// POST /api/v1/admin/settings/test-smtp
func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
var req TestSmtpRequest
func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
var req TestSMTPRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.SmtpPort <= 0 {
req.SmtpPort = 587
if req.SMTPPort <= 0 {
req.SMTPPort = 587
}
// 如果未提供密码,从数据库获取已保存的密码
password := req.SmtpPassword
password := req.SMTPPassword
if password == "" {
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
if err == nil && savedConfig != nil {
password = savedConfig.Password
}
}
config := &service.SmtpConfig{
Host: req.SmtpHost,
Port: req.SmtpPort,
Username: req.SmtpUsername,
config := &service.SMTPConfig{
Host: req.SMTPHost,
Port: req.SMTPPort,
Username: req.SMTPUsername,
Password: password,
UseTLS: req.SmtpUseTLS,
UseTLS: req.SMTPUseTLS,
}
err := h.emailService.TestSmtpConnectionWithConfig(config)
err := h.emailService.TestSMTPConnectionWithConfig(config)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -250,13 +250,13 @@ func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
// SendTestEmailRequest 发送测试邮件请求
type SendTestEmailRequest struct {
Email string `json:"email" binding:"required,email"`
SmtpHost string `json:"smtp_host" binding:"required"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
SMTPHost string `json:"smtp_host" binding:"required"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"`
SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"`
}
// SendTestEmail 发送测试邮件
@@ -268,27 +268,27 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
return
}
if req.SmtpPort <= 0 {
req.SmtpPort = 587
if req.SMTPPort <= 0 {
req.SMTPPort = 587
}
// 如果未提供密码,从数据库获取已保存的密码
password := req.SmtpPassword
password := req.SMTPPassword
if password == "" {
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
if err == nil && savedConfig != nil {
password = savedConfig.Password
}
}
config := &service.SmtpConfig{
Host: req.SmtpHost,
Port: req.SmtpPort,
Username: req.SmtpUsername,
config := &service.SMTPConfig{
Host: req.SMTPHost,
Port: req.SMTPPort,
Username: req.SMTPUsername,
Password: password,
From: req.SmtpFrom,
FromName: req.SmtpFromName,
UseTLS: req.SmtpUseTLS,
From: req.SMTPFrom,
FromName: req.SMTPFromName,
UseTLS: req.SMTPUseTLS,
}
siteName := h.settingService.GetSiteName(c.Request.Context())
@@ -333,10 +333,10 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
response.Success(c, gin.H{"message": "Test email sent successfully"})
}
// GetAdminApiKey 获取管理员 API Key 状态
// GetAdminAPIKey 获取管理员 API Key 状态
// GET /api/v1/admin/settings/admin-api-key
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
func (h *SettingHandler) GetAdminAPIKey(c *gin.Context) {
maskedKey, exists, err := h.settingService.GetAdminAPIKeyStatus(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
@@ -348,10 +348,10 @@ func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
})
}
// RegenerateAdminApiKey 生成/重新生成管理员 API Key
// RegenerateAdminAPIKey 生成/重新生成管理员 API Key
// POST /api/v1/admin/settings/admin-api-key/regenerate
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
func (h *SettingHandler) RegenerateAdminAPIKey(c *gin.Context) {
key, err := h.settingService.GenerateAdminAPIKey(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
@@ -362,10 +362,10 @@ func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
})
}
// DeleteAdminApiKey 删除管理员 API Key
// DeleteAdminAPIKey 删除管理员 API Key
// DELETE /api/v1/admin/settings/admin-api-key
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
if err := h.settingService.DeleteAdminAPIKey(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}

View File

@@ -17,14 +17,14 @@ import (
// UsageHandler handles admin usage-related requests
type UsageHandler struct {
usageService *service.UsageService
apiKeyService *service.ApiKeyService
apiKeyService *service.APIKeyService
adminService service.AdminService
}
// NewUsageHandler creates a new admin usage handler
func NewUsageHandler(
usageService *service.UsageService,
apiKeyService *service.ApiKeyService,
apiKeyService *service.APIKeyService,
adminService service.AdminService,
) *UsageHandler {
return &UsageHandler{
@@ -125,7 +125,7 @@ func (h *UsageHandler) List(c *gin.Context) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
filters := usagestats.UsageLogFilters{
UserID: userID,
ApiKeyID: apiKeyID,
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
Model: model,
@@ -207,7 +207,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
}
if apiKeyID > 0 {
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
stats, err := h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -269,9 +269,9 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
response.Success(c, result)
}
// SearchApiKeys handles searching API keys by user
// SearchAPIKeys handles searching API keys by user
// GET /api/v1/admin/usage/search-api-keys
func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
userIDStr := c.Query("user_id")
keyword := c.Query("q")
@@ -285,22 +285,22 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
userID = id
}
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
keys, err := h.apiKeyService.SearchAPIKeys(c.Request.Context(), userID, keyword, 30)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Return simplified API key list (only id and name)
type SimpleApiKey struct {
type SimpleAPIKey struct {
ID int64 `json:"id"`
Name string `json:"name"`
UserID int64 `json:"user_id"`
}
result := make([]SimpleApiKey, len(keys))
result := make([]SimpleAPIKey, len(keys))
for i, k := range keys {
result[i] = SimpleApiKey{
result[i] = SimpleAPIKey{
ID: k.ID,
Name: k.Name,
UserID: k.UserID,

View File

@@ -243,9 +243,9 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
return
}
out := make([]dto.ApiKey, 0, len(keys))
out := make([]dto.APIKey, 0, len(keys))
for i := range keys {
out = append(out, *dto.ApiKeyFromService(&keys[i]))
out = append(out, *dto.APIKeyFromService(&keys[i]))
}
response.Paginated(c, out, total, page, pageSize)
}

View File

@@ -14,11 +14,11 @@ import (
// APIKeyHandler handles API key-related requests
type APIKeyHandler struct {
apiKeyService *service.ApiKeyService
apiKeyService *service.APIKeyService
}
// NewAPIKeyHandler creates a new APIKeyHandler
func NewAPIKeyHandler(apiKeyService *service.ApiKeyService) *APIKeyHandler {
func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler {
return &APIKeyHandler{
apiKeyService: apiKeyService,
}
@@ -56,9 +56,9 @@ func (h *APIKeyHandler) List(c *gin.Context) {
return
}
out := make([]dto.ApiKey, 0, len(keys))
out := make([]dto.APIKey, 0, len(keys))
for i := range keys {
out = append(out, *dto.ApiKeyFromService(&keys[i]))
out = append(out, *dto.APIKeyFromService(&keys[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
}
@@ -90,7 +90,7 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.ApiKeyFromService(key))
response.Success(c, dto.APIKeyFromService(key))
}
// Create handles creating a new API key
@@ -108,7 +108,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
return
}
svcReq := service.CreateApiKeyRequest{
svcReq := service.CreateAPIKeyRequest{
Name: req.Name,
GroupID: req.GroupID,
CustomKey: req.CustomKey,
@@ -119,7 +119,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
return
}
response.Success(c, dto.ApiKeyFromService(key))
response.Success(c, dto.APIKeyFromService(key))
}
// Update handles updating an API key
@@ -143,7 +143,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
return
}
svcReq := service.UpdateApiKeyRequest{}
svcReq := service.UpdateAPIKeyRequest{}
if req.Name != "" {
svcReq.Name = &req.Name
}
@@ -158,7 +158,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
return
}
response.Success(c, dto.ApiKeyFromService(key))
response.Success(c, dto.APIKeyFromService(key))
}
// Delete handles deleting an API key

View File

@@ -1,3 +1,4 @@
// Package dto provides mapping utilities for converting between service layer and HTTP handler DTOs.
package dto
import "github.com/Wei-Shaw/sub2api/internal/service"
@@ -26,11 +27,11 @@ func UserFromService(u *service.User) *User {
return nil
}
out := UserFromServiceShallow(u)
if len(u.ApiKeys) > 0 {
out.ApiKeys = make([]ApiKey, 0, len(u.ApiKeys))
for i := range u.ApiKeys {
k := u.ApiKeys[i]
out.ApiKeys = append(out.ApiKeys, *ApiKeyFromService(&k))
if len(u.APIKeys) > 0 {
out.APIKeys = make([]APIKey, 0, len(u.APIKeys))
for i := range u.APIKeys {
k := u.APIKeys[i]
out.APIKeys = append(out.APIKeys, *APIKeyFromService(&k))
}
}
if len(u.Subscriptions) > 0 {
@@ -43,11 +44,11 @@ func UserFromService(u *service.User) *User {
return out
}
func ApiKeyFromService(k *service.ApiKey) *ApiKey {
func APIKeyFromService(k *service.APIKey) *APIKey {
if k == nil {
return nil
}
return &ApiKey{
return &APIKey{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
@@ -220,7 +221,7 @@ func UsageLogFromService(l *service.UsageLog) *UsageLog {
return &UsageLog{
ID: l.ID,
UserID: l.UserID,
ApiKeyID: l.ApiKeyID,
APIKeyID: l.APIKeyID,
AccountID: l.AccountID,
RequestID: l.RequestID,
Model: l.Model,
@@ -245,7 +246,7 @@ func UsageLogFromService(l *service.UsageLog) *UsageLog {
FirstTokenMs: l.FirstTokenMs,
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
ApiKey: ApiKeyFromService(l.ApiKey),
APIKey: APIKeyFromService(l.APIKey),
Account: AccountFromService(l.Account),
Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription),

View File

@@ -5,13 +5,13 @@ type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password,omitempty"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password,omitempty"`
SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
@@ -20,9 +20,9 @@ type SystemSettings struct {
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
DocURL string `json:"doc_url"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
@@ -36,8 +36,8 @@ type PublicSettings struct {
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
DocURL string `json:"doc_url"`
Version string `json:"version"`
}

View File

@@ -15,11 +15,11 @@ type User struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ApiKeys []ApiKey `json:"api_keys,omitempty"`
APIKeys []APIKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
type ApiKey struct {
type APIKey struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Key string `json:"key"`
@@ -136,7 +136,7 @@ type RedeemCode struct {
type UsageLog struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
ApiKeyID int64 `json:"api_key_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
@@ -168,7 +168,7 @@ type UsageLog struct {
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
ApiKey *ApiKey `json:"api_key,omitempty"`
APIKey *APIKey `json:"api_key,omitempty"`
Account *Account `json:"account,omitempty"`
Group *Group `json:"group,omitempty"`
Subscription *UserSubscription `json:"subscription,omitempty"`

View File

@@ -1,3 +1,5 @@
// Package handler provides HTTP request handlers for the API gateway.
// It handles authentication, request routing, concurrency control, and billing validation.
package handler
import (
@@ -27,6 +29,7 @@ type GatewayHandler struct {
userService *service.UserService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
opsService *service.OpsService
}
// NewGatewayHandler creates a new GatewayHandler
@@ -37,6 +40,7 @@ func NewGatewayHandler(
userService *service.UserService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
opsService *service.OpsService,
) *GatewayHandler {
return &GatewayHandler{
gatewayService: gatewayService,
@@ -45,14 +49,15 @@ func NewGatewayHandler(
userService: userService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
opsService: opsService,
}
}
// Messages handles Claude API compatible messages endpoint
// POST /v1/messages
func (h *GatewayHandler) Messages(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware2.GetApiKeyFromContext(c)
// 从context获取apiKey和userAPIKeyAuth中间件已设置
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
@@ -87,6 +92,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
setOpsRequestContext(c, reqModel, reqStream)
// 验证 model 必填
if reqModel == "" {
@@ -258,7 +264,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
@@ -382,7 +388,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
@@ -399,7 +405,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// Returns models based on account configurations (model_mapping whitelist)
// Falls back to default models if no whitelist is configured
func (h *GatewayHandler) Models(c *gin.Context) {
apiKey, _ := middleware2.GetApiKeyFromContext(c)
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
var groupID *int64
var platform string
@@ -448,7 +454,7 @@ func (h *GatewayHandler) Models(c *gin.Context) {
// Usage handles getting account balance for CC Switch integration
// GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware2.GetApiKeyFromContext(c)
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
@@ -573,6 +579,7 @@ func (h *GatewayHandler) mapUpstreamError(statusCode int) (int, string, string)
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
recordOpsError(c, h.opsService, status, errType, message, "")
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
@@ -604,6 +611,7 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
// errorResponse 返回Claude API格式的错误响应
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
recordOpsError(c, h.opsService, status, errType, message, "")
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
@@ -617,8 +625,8 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
// POST /v1/messages/count_tokens
// 特点:校验订阅/余额,但不计算并发、不记录使用量
func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware2.GetApiKeyFromContext(c)
// 从context获取apiKey和userAPIKeyAuth中间件已设置
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return

View File

@@ -20,7 +20,7 @@ import (
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
apiKey, ok := middleware.GetAPIKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
@@ -66,7 +66,7 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
// GeminiV1BetaGetModel proxies:
// GET /v1beta/models/{model}
func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
apiKey, ok := middleware.GetAPIKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
@@ -119,7 +119,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
// POST /v1beta/models/{model}:generateContent
// POST /v1beta/models/{model}:streamGenerateContent?alt=sse
func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
apiKey, ok := middleware.GetAPIKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
@@ -298,7 +298,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,

View File

@@ -7,6 +7,7 @@ import (
// AdminHandlers contains all admin-related HTTP handlers
type AdminHandlers struct {
Dashboard *admin.DashboardHandler
Ops *admin.OpsHandler
User *admin.UserHandler
Group *admin.GroupHandler
Account *admin.AccountHandler

View File

@@ -22,6 +22,7 @@ type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
opsService *service.OpsService
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
@@ -29,19 +30,21 @@ func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
opsService *service.OpsService,
) *OpenAIGatewayHandler {
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
opsService: opsService,
}
}
// Responses handles OpenAI Responses API endpoint
// POST /openai/v1/responses
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Get apiKey and user from context (set by ApiKeyAuth middleware)
apiKey, ok := middleware2.GetApiKeyFromContext(c)
// Get apiKey and user from context (set by APIKeyAuth middleware)
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
@@ -79,6 +82,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Extract model and stream
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
setOpsRequestContext(c, reqModel, reqStream)
// 验证 model 必填
if reqModel == "" {
@@ -235,7 +239,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
ApiKey: apiKey,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
@@ -278,6 +282,7 @@ func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, st
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
recordOpsError(c, h.opsService, status, errType, message, service.PlatformOpenAI)
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
@@ -297,6 +302,7 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
// errorResponse returns OpenAI API format error response
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
recordOpsError(c, h.opsService, status, errType, message, service.PlatformOpenAI)
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,

View File

@@ -0,0 +1,166 @@
package handler
import (
"context"
"strings"
"sync"
"time"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
opsModelKey = "ops_model"
opsStreamKey = "ops_stream"
)
const (
opsErrorLogWorkerCount = 10
opsErrorLogQueueSize = 256
opsErrorLogTimeout = 2 * time.Second
)
type opsErrorLogJob struct {
ops *service.OpsService
entry *service.OpsErrorLog
}
var (
opsErrorLogOnce sync.Once
opsErrorLogQueue chan opsErrorLogJob
)
func startOpsErrorLogWorkers() {
opsErrorLogQueue = make(chan opsErrorLogJob, opsErrorLogQueueSize)
for i := 0; i < opsErrorLogWorkerCount; i++ {
go func() {
for job := range opsErrorLogQueue {
if job.ops == nil || job.entry == nil {
continue
}
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
_ = job.ops.RecordError(ctx, job.entry)
cancel()
}
}()
}
}
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsErrorLog) {
if ops == nil || entry == nil {
return
}
opsErrorLogOnce.Do(startOpsErrorLogWorkers)
select {
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry}:
default:
// Queue is full; drop to avoid blocking request handling.
}
}
func setOpsRequestContext(c *gin.Context, model string, stream bool) {
c.Set(opsModelKey, model)
c.Set(opsStreamKey, stream)
}
func recordOpsError(c *gin.Context, ops *service.OpsService, status int, errType, message, fallbackPlatform string) {
if ops == nil || c == nil {
return
}
model, _ := c.Get(opsModelKey)
stream, _ := c.Get(opsStreamKey)
var modelName string
if m, ok := model.(string); ok {
modelName = m
}
streaming, _ := stream.(bool)
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
logEntry := &service.OpsErrorLog{
Phase: classifyOpsPhase(errType, message),
Type: errType,
Severity: classifyOpsSeverity(errType, status),
StatusCode: status,
Platform: resolveOpsPlatform(apiKey, fallbackPlatform),
Model: modelName,
RequestID: c.Writer.Header().Get("x-request-id"),
Message: message,
ClientIP: c.ClientIP(),
RequestPath: func() string {
if c.Request != nil && c.Request.URL != nil {
return c.Request.URL.Path
}
return ""
}(),
Stream: streaming,
}
if apiKey != nil {
logEntry.APIKeyID = &apiKey.ID
if apiKey.User != nil {
logEntry.UserID = &apiKey.User.ID
}
if apiKey.GroupID != nil {
logEntry.GroupID = apiKey.GroupID
}
}
enqueueOpsErrorLog(ops, logEntry)
}
func resolveOpsPlatform(apiKey *service.APIKey, fallback string) string {
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform != "" {
return apiKey.Group.Platform
}
return fallback
}
func classifyOpsPhase(errType, message string) string {
msg := strings.ToLower(message)
switch errType {
case "authentication_error":
return "auth"
case "billing_error", "subscription_error":
return "billing"
case "rate_limit_error":
if strings.Contains(msg, "concurrency") || strings.Contains(msg, "pending") {
return "concurrency"
}
return "upstream"
case "invalid_request_error":
return "response"
case "upstream_error", "overloaded_error":
return "upstream"
case "api_error":
if strings.Contains(msg, "no available accounts") {
return "scheduling"
}
return "internal"
default:
return "internal"
}
}
func classifyOpsSeverity(errType string, status int) string {
switch errType {
case "invalid_request_error", "authentication_error", "billing_error", "subscription_error":
return "P3"
}
if status >= 500 {
return "P1"
}
if status == 429 {
return "P1"
}
if status >= 400 {
return "P2"
}
return "P3"
}

View File

@@ -39,9 +39,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
ApiBaseUrl: settings.ApiBaseUrl,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocUrl: settings.DocUrl,
DocURL: settings.DocURL,
Version: h.version,
})
}

View File

@@ -18,11 +18,11 @@ import (
// UsageHandler handles usage-related requests
type UsageHandler struct {
usageService *service.UsageService
apiKeyService *service.ApiKeyService
apiKeyService *service.APIKeyService
}
// NewUsageHandler creates a new UsageHandler
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.ApiKeyService) *UsageHandler {
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.APIKeyService) *UsageHandler {
return &UsageHandler{
usageService: usageService,
apiKeyService: apiKeyService,
@@ -111,7 +111,7 @@ func (h *UsageHandler) List(c *gin.Context) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
filters := usagestats.UsageLogFilters{
UserID: subject.UserID, // Always filter by current user for security
ApiKeyID: apiKeyID,
APIKeyID: apiKeyID,
Model: model,
Stream: stream,
BillingType: billingType,
@@ -235,7 +235,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
var stats *service.UsageStats
var err error
if apiKeyID > 0 {
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
stats, err = h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime)
} else {
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime)
}
@@ -346,49 +346,49 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) {
})
}
// BatchApiKeysUsageRequest represents the request for batch API keys usage
type BatchApiKeysUsageRequest struct {
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
// BatchAPIKeysUsageRequest represents the request for batch API keys usage
type BatchAPIKeysUsageRequest struct {
APIKeyIDs []int64 `json:"api_key_ids" binding:"required"`
}
// DashboardApiKeysUsage handles getting usage stats for user's own API keys
// DashboardAPIKeysUsage handles getting usage stats for user's own API keys
// POST /api/v1/usage/dashboard/api-keys-usage
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req BatchApiKeysUsageRequest
var req BatchAPIKeysUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.ApiKeyIDs) == 0 {
if len(req.APIKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
// Limit the number of API key IDs to prevent SQL parameter overflow
if len(req.ApiKeyIDs) > 100 {
if len(req.APIKeyIDs) > 100 {
response.BadRequest(c, "Too many API key IDs (maximum 100 allowed)")
return
}
validApiKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.ApiKeyIDs)
validAPIKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.APIKeyIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
if len(validApiKeyIDs) == 0 {
if len(validAPIKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -10,6 +10,7 @@ import (
// ProvideAdminHandlers creates the AdminHandlers struct
func ProvideAdminHandlers(
dashboardHandler *admin.DashboardHandler,
opsHandler *admin.OpsHandler,
userHandler *admin.UserHandler,
groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler,
@@ -27,6 +28,7 @@ func ProvideAdminHandlers(
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
Ops: opsHandler,
User: userHandler,
Group: groupHandler,
Account: accountHandler,
@@ -96,6 +98,7 @@ var ProviderSet = wire.NewSet(
// Admin handlers
admin.NewDashboardHandler,
admin.NewOpsHandler,
admin.NewUserHandler,
admin.NewGroupHandler,
admin.NewAccountHandler,

View File

@@ -1,3 +1,5 @@
// Package antigravity provides a client for interacting with Google's Antigravity API,
// handling OAuth authentication, token management, and account tier information retrieval.
package antigravity
import (

View File

@@ -1,3 +1,4 @@
// Package claude provides Claude API client constants and utilities.
package claude
// Claude Code 客户端相关常量
@@ -16,13 +17,13 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header不包含 oauth
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// APIKeyBetaHeader API-key 账号建议使用的 anthropic-beta header不包含 oauth
const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header不包含 oauth / claude-code
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header不包含 oauth / claude-code
const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// Claude Code 客户端默认请求头
// DefaultHeaders are the default request headers for Claude Code client.
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",
"X-Stainless-Lang": "js",

View File

@@ -1,3 +1,4 @@
// Package errors provides custom error types and error handling utilities.
// nolint:mnd
package errors

View File

@@ -1,7 +1,7 @@
// Package gemini provides minimal fallback model metadata for Gemini native endpoints.
package gemini
// This package provides minimal fallback model metadata for Gemini native endpoints.
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
// This package is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
type Model struct {
Name string `json:"name"`

View File

@@ -1,3 +1,5 @@
// Package geminicli provides OAuth authentication and API client functionality
// for Google's Gemini AI services, supporting both AI Studio and Code Assist endpoints.
package geminicli
import "time"

View File

@@ -1,3 +1,4 @@
// Package googleapi provides utilities for Google API interactions.
package googleapi
import "net/http"

View File

@@ -1,3 +1,4 @@
// Package oauth provides OAuth 2.0 utilities including PKCE flow, session management, and token exchange.
package oauth
import (

View File

@@ -1,3 +1,4 @@
// Package openai provides OpenAI API models and configuration.
package openai
import _ "embed"

View File

@@ -327,7 +327,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
return &claims, nil
}
// ExtractUserInfo extracts user information from ID Token claims
// UserInfo extracts user information from ID Token claims
type UserInfo struct {
Email string
ChatGPTAccountID string

View File

@@ -1,3 +1,4 @@
// Package pagination provides utilities for handling paginated queries and results.
package pagination
// PaginationParams 分页参数

View File

@@ -1,3 +1,4 @@
// Package response provides HTTP response utilities for standardized API responses and error handling.
package response
import (

View File

@@ -1,3 +1,4 @@
// Package sysutil provides system-level utilities for service management.
package sysutil
import (

View File

@@ -1,3 +1,4 @@
// Package usagestats defines types for tracking and reporting API usage statistics.
package usagestats
import "time"
@@ -10,8 +11,8 @@ type DashboardStats struct {
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
TotalAPIKeys int64 `json:"total_api_keys"`
ActiveAPIKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
// 账户统计
TotalAccounts int64 `json:"total_accounts"`
@@ -82,10 +83,10 @@ type UserUsageTrendPoint struct {
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint struct {
// APIKeyUsageTrendPoint represents API key usage trend data point
type APIKeyUsageTrendPoint struct {
Date string `json:"date"`
ApiKeyID int64 `json:"api_key_id"`
APIKeyID int64 `json:"api_key_id"`
KeyName string `json:"key_name"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
@@ -94,8 +95,8 @@ type ApiKeyUsageTrendPoint struct {
// UserDashboardStats 用户仪表盘统计
type UserDashboardStats struct {
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"`
TotalAPIKeys int64 `json:"total_api_keys"`
ActiveAPIKeys int64 `json:"active_api_keys"`
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
@@ -128,7 +129,7 @@ type UserDashboardStats struct {
// UsageLogFilters represents filters for usage log queries
type UsageLogFilters struct {
UserID int64
ApiKeyID int64
APIKeyID int64
AccountID int64
GroupID int64
Model string
@@ -157,9 +158,9 @@ type BatchUserUsageStats struct {
TotalActualCost float64 `json:"total_actual_cost"`
}
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats struct {
ApiKeyID int64 `json:"api_key_id"`
// BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats struct {
APIKeyID int64 `json:"api_key_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}

View File

@@ -135,12 +135,12 @@ func (s *AccountRepoSuite) TestListWithFilters() {
name: "filter_by_type",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeApiKey})
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeAPIKey})
},
accType: service.AccountTypeApiKey,
accType: service.AccountTypeAPIKey,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type)
s.Require().Equal(service.AccountTypeAPIKey, accounts[0].Type)
},
},
{

View File

@@ -80,7 +80,7 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
}
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsAPIKeys(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
entClient := tx.Client()
@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
userRepo := newUserRepositoryWithSQL(entClient, tx)
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
apiKeyRepo := NewApiKeyRepository(entClient)
apiKeyRepo := NewAPIKeyRepository(entClient)
u := &service.User{
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
@@ -110,7 +110,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
}
require.NoError(t, userRepo.Create(ctx, u))
key := &service.ApiKey{
key := &service.APIKey{
UserID: u.ID,
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
Name: "test key",

View File

@@ -24,7 +24,7 @@ type apiKeyCache struct {
rdb *redis.Client
}
func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
func NewAPIKeyCache(rdb *redis.Client) service.APIKeyCache {
return &apiKeyCache{rdb: rdb}
}

View File

@@ -13,11 +13,11 @@ import (
"github.com/stretchr/testify/suite"
)
type ApiKeyCacheSuite struct {
type APIKeyCacheSuite struct {
IntegrationRedisSuite
}
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
func (s *APIKeyCacheSuite) TestCreateAttemptCount() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
@@ -78,7 +78,7 @@ func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
}
}
func (s *ApiKeyCacheSuite) TestDailyUsage() {
func (s *APIKeyCacheSuite) TestDailyUsage() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
@@ -122,6 +122,6 @@ func (s *ApiKeyCacheSuite) TestDailyUsage() {
}
}
func TestApiKeyCacheSuite(t *testing.T) {
suite.Run(t, new(ApiKeyCacheSuite))
func TestAPIKeyCacheSuite(t *testing.T) {
suite.Run(t, new(APIKeyCacheSuite))
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestApiKeyRateLimitKey(t *testing.T) {
func TestAPIKeyRateLimitKey(t *testing.T) {
tests := []struct {
name string
userID int64

View File

@@ -16,17 +16,17 @@ type apiKeyRepository struct {
client *dbent.Client
}
func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
return &apiKeyRepository{client: client}
}
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
// 默认过滤已软删除记录,避免删除后仍被查询到。
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
return r.client.APIKey.Query().Where(apikey.DeletedAtIsNil())
}
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
created, err := r.client.ApiKey.Create().
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
created, err := r.client.APIKey.Create().
SetUserID(key.UserID).
SetKey(key.Key).
SetName(key.Name).
@@ -38,10 +38,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) erro
key.CreatedAt = created.CreatedAt
key.UpdatedAt = created.UpdatedAt
}
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
return translatePersistenceError(err, nil, service.ErrAPIKeyExists)
}
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
WithUser().
@@ -49,7 +49,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrApiKeyNotFound
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
@@ -59,7 +59,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
// 相比 GetByID此方法性能更优因为
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
// - 不加载完整的 ApiKey 实体及其关联数据User、Group 等)
// - 不加载完整的 APIKey 实体及其关联数据User、Group 等)
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
m, err := r.activeQuery().
@@ -68,14 +68,14 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, service.ErrApiKeyNotFound
return 0, service.ErrAPIKeyNotFound
}
return 0, err
}
return m.UserID, nil
}
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.KeyEQ(key)).
WithUser().
@@ -83,21 +83,21 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrApiKeyNotFound
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
// 之前的实现先检查 Exist 再 UpdateOneID若在两步之间发生软删除
// 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at避免二次查询带来的并发可见性问题。
now := time.Now()
builder := r.client.ApiKey.Update().
builder := r.client.APIKey.Update().
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name).
SetStatus(key.Status).
@@ -114,7 +114,7 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
}
if affected == 0 {
// 更新影响行数为 0说明记录不存在或已被软删除。
return service.ErrApiKeyNotFound
return service.ErrAPIKeyNotFound
}
// 使用同一时间戳回填,避免并发删除导致二次查询失败。
@@ -124,18 +124,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
affected, err := r.client.ApiKey.Update().
affected, err := r.client.APIKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetDeletedAt(time.Now()).
Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrApiKeyNotFound
return service.ErrAPIKeyNotFound
}
return err
}
if affected == 0 {
exists, err := r.client.ApiKey.Query().
exists, err := r.client.APIKey.Query().
Where(apikey.IDEQ(id)).
Exist(mixins.SkipSoftDelete(ctx))
if err != nil {
@@ -144,12 +144,12 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
if exists {
return nil
}
return service.ErrApiKeyNotFound
return service.ErrAPIKeyNotFound
}
return nil
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
total, err := q.Count(ctx)
@@ -167,7 +167,7 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return nil, nil, err
}
outKeys := make([]service.ApiKey, 0, len(keys))
outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
@@ -180,7 +180,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
return []int64{}, nil
}
ids, err := r.client.ApiKey.Query().
ids, err := r.client.APIKey.Query().
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
IDs(ctx)
if err != nil {
@@ -199,7 +199,7 @@ func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
return count > 0, err
}
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
total, err := q.Count(ctx)
@@ -217,7 +217,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return nil, nil, err
}
outKeys := make([]service.ApiKey, 0, len(keys))
outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
@@ -225,8 +225,8 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return outKeys, paginationResultFromTotal(int64(total), params), nil
}
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
q := r.activeQuery()
if userID > 0 {
q = q.Where(apikey.UserIDEQ(userID))
@@ -241,7 +241,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
return nil, err
}
outKeys := make([]service.ApiKey, 0, len(keys))
outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
@@ -250,7 +250,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
n, err := r.client.ApiKey.Update().
n, err := r.client.APIKey.Update().
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
ClearGroupID().
Save(ctx)
@@ -263,11 +263,11 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
return int64(count), err
}
func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey {
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
if m == nil {
return nil
}
out := &service.ApiKey{
out := &service.APIKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,

View File

@@ -12,30 +12,30 @@ import (
"github.com/stretchr/testify/suite"
)
type ApiKeyRepoSuite struct {
type APIKeyRepoSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
repo *apiKeyRepository
}
func (s *ApiKeyRepoSuite) SetupTest() {
func (s *APIKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository)
s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository)
}
func TestApiKeyRepoSuite(t *testing.T) {
suite.Run(t, new(ApiKeyRepoSuite))
func TestAPIKeyRepoSuite(t *testing.T) {
suite.Run(t, new(APIKeyRepoSuite))
}
// --- Create / GetByID / GetByKey ---
func (s *ApiKeyRepoSuite) TestCreate() {
func (s *APIKeyRepoSuite) TestCreate() {
user := s.mustCreateUser("create@test.com")
key := &service.ApiKey{
key := &service.APIKey{
UserID: user.ID,
Key: "sk-create-test",
Name: "Test Key",
@@ -51,16 +51,16 @@ func (s *ApiKeyRepoSuite) TestCreate() {
s.Require().Equal("sk-create-test", got.Key)
}
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
func (s *APIKeyRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *ApiKeyRepoSuite) TestGetByKey() {
func (s *APIKeyRepoSuite) TestGetByKey() {
user := s.mustCreateUser("getbykey@test.com")
group := s.mustCreateGroup("g-key")
key := &service.ApiKey{
key := &service.APIKey{
UserID: user.ID,
Key: "sk-getbykey",
Name: "My Key",
@@ -78,16 +78,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey() {
s.Require().Equal(group.ID, got.Group.ID)
}
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
s.Require().Error(err, "expected error for non-existent key")
}
// --- Update ---
func (s *ApiKeyRepoSuite) TestUpdate() {
func (s *APIKeyRepoSuite) TestUpdate() {
user := s.mustCreateUser("update@test.com")
key := &service.ApiKey{
key := &service.APIKey{
UserID: user.ID,
Key: "sk-update",
Name: "Original",
@@ -108,10 +108,10 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
s.Require().Equal(service.StatusDisabled, got.Status)
}
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
func (s *APIKeyRepoSuite) TestUpdate_ClearGroupID() {
user := s.mustCreateUser("cleargroup@test.com")
group := s.mustCreateGroup("g-clear")
key := &service.ApiKey{
key := &service.APIKey{
UserID: user.ID,
Key: "sk-clear-group",
Name: "Group Key",
@@ -131,9 +131,9 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
// --- Delete ---
func (s *ApiKeyRepoSuite) TestDelete() {
func (s *APIKeyRepoSuite) TestDelete() {
user := s.mustCreateUser("delete@test.com")
key := &service.ApiKey{
key := &service.APIKey{
UserID: user.ID,
Key: "sk-delete",
Name: "Delete Me",
@@ -150,10 +150,10 @@ func (s *ApiKeyRepoSuite) TestDelete() {
// --- ListByUserID / CountByUserID ---
func (s *ApiKeyRepoSuite) TestListByUserID() {
func (s *APIKeyRepoSuite) TestListByUserID() {
user := s.mustCreateUser("listbyuser@test.com")
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
s.mustCreateAPIKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateAPIKey(user.ID, "sk-list-2", "Key 2", nil)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
@@ -161,10 +161,10 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
s.Require().Equal(int64(2), page.Total)
}
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
user := s.mustCreateUser("paging@test.com")
for i := 0; i < 5; i++ {
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
s.mustCreateAPIKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
}
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
@@ -174,10 +174,10 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
s.Require().Equal(3, page.Pages)
}
func (s *ApiKeyRepoSuite) TestCountByUserID() {
func (s *APIKeyRepoSuite) TestCountByUserID() {
user := s.mustCreateUser("count@test.com")
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
s.mustCreateAPIKey(user.ID, "sk-count-1", "K1", nil)
s.mustCreateAPIKey(user.ID, "sk-count-2", "K2", nil)
count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID")
@@ -186,13 +186,13 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
// --- ListByGroupID / CountByGroupID ---
func (s *ApiKeyRepoSuite) TestListByGroupID() {
func (s *APIKeyRepoSuite) TestListByGroupID() {
user := s.mustCreateUser("listbygroup@test.com")
group := s.mustCreateGroup("g-list")
s.mustCreateApiKey(user.ID, "sk-grp-1", "K1", &group.ID)
s.mustCreateApiKey(user.ID, "sk-grp-2", "K2", &group.ID)
s.mustCreateApiKey(user.ID, "sk-grp-3", "K3", nil) // no group
s.mustCreateAPIKey(user.ID, "sk-grp-1", "K1", &group.ID)
s.mustCreateAPIKey(user.ID, "sk-grp-2", "K2", &group.ID)
s.mustCreateAPIKey(user.ID, "sk-grp-3", "K3", nil) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
@@ -202,10 +202,10 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
s.Require().NotNil(keys[0].User)
}
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
func (s *APIKeyRepoSuite) TestCountByGroupID() {
user := s.mustCreateUser("countgroup@test.com")
group := s.mustCreateGroup("g-count")
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
s.mustCreateAPIKey(user.ID, "sk-gc-1", "K1", &group.ID)
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
@@ -214,9 +214,9 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
// --- ExistsByKey ---
func (s *ApiKeyRepoSuite) TestExistsByKey() {
func (s *APIKeyRepoSuite) TestExistsByKey() {
user := s.mustCreateUser("exists@test.com")
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
s.mustCreateAPIKey(user.ID, "sk-exists", "K", nil)
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey")
@@ -227,47 +227,47 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
s.Require().False(notExists)
}
// --- SearchApiKeys ---
// --- SearchAPIKeys ---
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
func (s *APIKeyRepoSuite) TestSearchAPIKeys() {
user := s.mustCreateUser("search@test.com")
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
s.mustCreateAPIKey(user.ID, "sk-search-1", "Production Key", nil)
s.mustCreateAPIKey(user.ID, "sk-search-2", "Development Key", nil)
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys")
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchAPIKeys")
s.Require().Len(found, 1)
s.Require().Contains(found[0].Name, "Production")
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoKeyword() {
user := s.mustCreateUser("searchnokw@test.com")
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
s.mustCreateAPIKey(user.ID, "sk-nk-1", "K1", nil)
s.mustCreateAPIKey(user.ID, "sk-nk-2", "K2", nil)
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err)
s.Require().Len(found, 2)
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoUserID() {
user := s.mustCreateUser("searchnouid@test.com")
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
s.mustCreateAPIKey(user.ID, "sk-nu-1", "TestKey", nil)
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
found, err := s.repo.SearchAPIKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err)
s.Require().Len(found, 1)
}
// --- ClearGroupIDByGroupID ---
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
func (s *APIKeyRepoSuite) TestClearGroupIDByGroupID() {
user := s.mustCreateUser("cleargrp@test.com")
group := s.mustCreateGroup("g-clear-bulk")
k1 := s.mustCreateApiKey(user.ID, "sk-clr-1", "K1", &group.ID)
k2 := s.mustCreateApiKey(user.ID, "sk-clr-2", "K2", &group.ID)
s.mustCreateApiKey(user.ID, "sk-clr-3", "K3", nil) // no group
k1 := s.mustCreateAPIKey(user.ID, "sk-clr-1", "K1", &group.ID)
k2 := s.mustCreateAPIKey(user.ID, "sk-clr-2", "K2", &group.ID)
s.mustCreateAPIKey(user.ID, "sk-clr-3", "K3", nil) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
@@ -284,10 +284,10 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := s.mustCreateUser("k@example.com")
group := s.mustCreateGroup("g-k")
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
key := s.mustCreateAPIKey(user.ID, "sk-test-1", "My Key", &group.ID)
key.GroupID = &group.ID
got, err := s.repo.GetByKey(s.ctx, key.Key)
@@ -320,13 +320,13 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().NoError(err, "ExistsByKey")
s.Require().True(exists, "expected key to exist")
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
s.Require().NoError(err, "SearchApiKeys")
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "renam", 10)
s.Require().NoError(err, "SearchAPIKeys")
s.Require().Len(found, 1)
s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID
k2 := s.mustCreateApiKey(user.ID, "sk-test-2", "Group Key", &group.ID)
k2 := s.mustCreateAPIKey(user.ID, "sk-test-2", "Group Key", &group.ID)
k2.GroupID = &group.ID
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
@@ -346,7 +346,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
}
func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
func (s *APIKeyRepoSuite) mustCreateUser(email string) *service.User {
s.T().Helper()
u, err := s.client.User.Create().
@@ -359,7 +359,7 @@ func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
return userEntityToService(u)
}
func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
func (s *APIKeyRepoSuite) mustCreateGroup(name string) *service.Group {
s.T().Helper()
g, err := s.client.Group.Create().
@@ -370,10 +370,10 @@ func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
return groupEntityToService(g)
}
func (s *ApiKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.ApiKey {
func (s *APIKeyRepoSuite) mustCreateAPIKey(userID int64, key, name string, groupID *int64) *service.APIKey {
s.T().Helper()
k := &service.ApiKey{
k := &service.APIKey{
UserID: userID,
Key: key,
Name: name,

View File

@@ -27,8 +27,14 @@ const (
accountSlotKeyPrefix = "concurrency:account:"
// 格式: concurrency:user:{userID}
userSlotKeyPrefix = "concurrency:user:"
// 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// Wait queue keys (global structures)
// - total: integer total queue depth across all users
// - updated: sorted set of userID -> lastUpdateUnixSec (for TTL cleanup)
// - counts: hash of userID -> current wait count
waitQueueTotalKey = "concurrency:wait:total"
waitQueueUpdatedKey = "concurrency:wait:updated"
waitQueueCountsKey = "concurrency:wait:counts"
// 账号级等待队列计数器格式: wait:account:{accountID}
accountWaitKeyPrefix = "wait:account:"
@@ -94,27 +100,55 @@ var (
`)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
// KEYS[1] = wait queue key
// ARGV[1] = maxWait
// ARGV[2] = TTL in seconds
// KEYS[1] = total key
// KEYS[2] = updated zset key
// KEYS[3] = counts hash key
// ARGV[1] = userID
// ARGV[2] = maxWait
// ARGV[3] = TTL in seconds
// ARGV[4] = cleanup limit
incrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
local totalKey = KEYS[1]
local updatedKey = KEYS[2]
local countsKey = KEYS[3]
local userID = ARGV[1]
local maxWait = tonumber(ARGV[2])
local ttl = tonumber(ARGV[3])
local cleanupLimit = tonumber(ARGV[4])
redis.call('SETNX', totalKey, 0)
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Cleanup expired users (bounded)
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
for _, uid in ipairs(expired) do
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
if c > 0 then
redis.call('DECRBY', totalKey, c)
end
redis.call('HDEL', countsKey, uid)
redis.call('ZREM', updatedKey, uid)
end
if current >= tonumber(ARGV[1]) then
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
if current >= maxWait then
return 0
end
local newVal = redis.call('INCR', KEYS[1])
local newVal = current + 1
redis.call('HSET', countsKey, userID, newVal)
redis.call('ZADD', updatedKey, now, userID)
redis.call('INCR', totalKey)
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
-- Keep global structures from living forever in totally idle deployments.
local ttlKeep = ttl * 2
redis.call('EXPIRE', totalKey, ttlKeep)
redis.call('EXPIRE', updatedKey, ttlKeep)
redis.call('EXPIRE', countsKey, ttlKeep)
return 1
`)
@@ -144,6 +178,111 @@ var (
// decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(`
local totalKey = KEYS[1]
local updatedKey = KEYS[2]
local countsKey = KEYS[3]
local userID = ARGV[1]
local ttl = tonumber(ARGV[2])
local cleanupLimit = tonumber(ARGV[3])
redis.call('SETNX', totalKey, 0)
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Cleanup expired users (bounded)
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
for _, uid in ipairs(expired) do
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
if c > 0 then
redis.call('DECRBY', totalKey, c)
end
redis.call('HDEL', countsKey, uid)
redis.call('ZREM', updatedKey, uid)
end
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
if current <= 0 then
return 1
end
local newVal = current - 1
if newVal <= 0 then
redis.call('HDEL', countsKey, userID)
redis.call('ZREM', updatedKey, userID)
else
redis.call('HSET', countsKey, userID, newVal)
redis.call('ZADD', updatedKey, now, userID)
end
redis.call('DECR', totalKey)
local ttlKeep = ttl * 2
redis.call('EXPIRE', totalKey, ttlKeep)
redis.call('EXPIRE', updatedKey, ttlKeep)
redis.call('EXPIRE', countsKey, ttlKeep)
return 1
`)
// getTotalWaitScript returns the global wait depth with TTL cleanup.
// KEYS[1] = total key
// KEYS[2] = updated zset key
// KEYS[3] = counts hash key
// ARGV[1] = TTL in seconds
// ARGV[2] = cleanup limit
getTotalWaitScript = redis.NewScript(`
local totalKey = KEYS[1]
local updatedKey = KEYS[2]
local countsKey = KEYS[3]
local ttl = tonumber(ARGV[1])
local cleanupLimit = tonumber(ARGV[2])
redis.call('SETNX', totalKey, 0)
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Cleanup expired users (bounded)
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
for _, uid in ipairs(expired) do
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
if c > 0 then
redis.call('DECRBY', totalKey, c)
end
redis.call('HDEL', countsKey, uid)
redis.call('ZREM', updatedKey, uid)
end
-- If totalKey got lost but counts exist (e.g. Redis restart), recompute once.
local total = redis.call('GET', totalKey)
if total == false then
total = 0
local vals = redis.call('HVALS', countsKey)
for _, v in ipairs(vals) do
total = total + tonumber(v)
end
redis.call('SET', totalKey, total)
end
local ttlKeep = ttl * 2
redis.call('EXPIRE', totalKey, ttlKeep)
redis.call('EXPIRE', updatedKey, ttlKeep)
redis.call('EXPIRE', countsKey, ttlKeep)
local result = tonumber(redis.call('GET', totalKey) or '0')
if result < 0 then
result = 0
redis.call('SET', totalKey, 0)
end
return result
`)
// decrementAccountWaitScript - account-level wait queue decrement
decrementAccountWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
@@ -244,7 +383,9 @@ func userSlotKey(userID int64) string {
}
func waitQueueKey(userID int64) string {
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
// Historical: per-user string keys were used.
// Now we use global structures keyed by userID string.
return strconv.FormatInt(userID, 10)
}
func accountWaitKey(accountID int64) string {
@@ -308,8 +449,16 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
// Wait queue operations
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int()
userKey := waitQueueKey(userID)
result, err := incrementWaitScript.Run(
ctx,
c.rdb,
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
userKey,
maxWait,
c.waitQueueTTLSeconds,
200, // cleanup limit per call
).Int()
if err != nil {
return false, err
}
@@ -317,11 +466,35 @@ func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64,
}
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
key := waitQueueKey(userID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
userKey := waitQueueKey(userID)
_, err := decrementWaitScript.Run(
ctx,
c.rdb,
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
userKey,
c.waitQueueTTLSeconds,
200, // cleanup limit per call
).Result()
return err
}
func (c *concurrencyCache) GetTotalWaitCount(ctx context.Context) (int, error) {
if c.rdb == nil {
return 0, nil
}
total, err := getTotalWaitScript.Run(
ctx,
c.rdb,
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
c.waitQueueTTLSeconds,
500, // cleanup limit per query (rare)
).Int64()
if err != nil {
return 0, err
}
return int(total), nil
}
// Account wait queue operations
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
@@ -335,7 +508,7 @@ func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accoun
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
key := accountWaitKey(accountID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
_, err := decrementAccountWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}

View File

@@ -158,7 +158,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
userID := int64(20)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
userKey := waitQueueKey(userID)
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 1")
@@ -172,31 +172,31 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
require.NoError(s.T(), err, "IncrementWaitCount 3")
require.False(s.T(), ok, "expected wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
ttl, err := s.rdb.TTL(s.ctx, waitQueueTotalKey).Result()
require.NoError(s.T(), err, "TTL wait total key")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL*2)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int()
require.NoError(s.T(), err, "HGET wait queue count")
require.Equal(s.T(), 1, val, "expected wait count 1")
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
require.NoError(s.T(), err, "GET wait queue total")
require.Equal(s.T(), 1, total, "expected total wait count 1")
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
userID := int64(300)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
userKey := waitQueueKey(userID)
// Test decrement on non-existent key - should not error and should not create negative value
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
// Verify no key was created or it's not negative
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
// Verify count remains zero / absent.
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int()
require.True(s.T(), errors.Is(err, redis.Nil))
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
// Set count to 1, then decrement twice
@@ -210,12 +210,15 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
// Decrement again on 0 - should not go negative
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
// Verify count is 0, not negative
val, err = s.rdb.Get(s.ctx, waitKey).Int()
// Verify per-user count is absent and total is non-negative.
_, err = s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Result()
require.True(s.T(), errors.Is(err, redis.Nil), "expected count field removed on zero")
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey after double decrement")
require.NoError(s.T(), err)
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
require.GreaterOrEqual(s.T(), total, 0, "expected non-negative total wait count")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {

View File

@@ -1,4 +1,4 @@
// Package infrastructure 提供应用程序的基础设施层组件。
// Package repository 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package repository

View File

@@ -243,7 +243,7 @@ func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) *
return a
}
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *service.ApiKey {
func mustCreateAPIKey(t *testing.T, client *dbent.Client, k *service.APIKey) *service.APIKey {
t.Helper()
ctx := context.Background()
@@ -257,7 +257,7 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *se
k.Name = "default"
}
create := client.ApiKey.Create().
create := client.APIKey.Create().
SetUserID(k.UserID).
SetKey(k.Key).
SetName(k.Name).

View File

@@ -293,8 +293,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
// 2. Clear group_id for api keys bound to this group.
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
// 与 ApiKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
if _, err := txClient.ApiKey.Update().
// 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
if _, err := txClient.APIKey.Update().
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
ClearGroupID().
Save(ctx); err != nil {

View File

@@ -0,0 +1,190 @@
package repository
import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// ListErrorLogs queries ops_error_logs with optional filters and pagination.
// It returns the list items and the total count of matching rows.
func (r *OpsRepository) ListErrorLogs(ctx context.Context, filter *service.ErrorLogFilter) ([]*service.ErrorLog, int64, error) {
page := 1
pageSize := 20
if filter != nil {
if filter.Page > 0 {
page = filter.Page
}
if filter.PageSize > 0 {
pageSize = filter.PageSize
}
}
if pageSize > 100 {
pageSize = 100
}
offset := (page - 1) * pageSize
conditions := make([]string, 0)
args := make([]any, 0)
addCondition := func(condition string, values ...any) {
conditions = append(conditions, condition)
args = append(args, values...)
}
if filter != nil {
// 默认查询最近 24 小时
if filter.StartTime == nil && filter.EndTime == nil {
defaultStart := time.Now().Add(-24 * time.Hour)
filter.StartTime = &defaultStart
}
if filter.StartTime != nil {
addCondition(fmt.Sprintf("created_at >= $%d", len(args)+1), *filter.StartTime)
}
if filter.EndTime != nil {
addCondition(fmt.Sprintf("created_at <= $%d", len(args)+1), *filter.EndTime)
}
if filter.ErrorCode != nil {
addCondition(fmt.Sprintf("status_code = $%d", len(args)+1), *filter.ErrorCode)
}
if provider := strings.TrimSpace(filter.Provider); provider != "" {
addCondition(fmt.Sprintf("platform = $%d", len(args)+1), provider)
}
if filter.AccountID != nil {
addCondition(fmt.Sprintf("account_id = $%d", len(args)+1), *filter.AccountID)
}
}
where := ""
if len(conditions) > 0 {
where = "WHERE " + strings.Join(conditions, " AND ")
}
countQuery := fmt.Sprintf(`SELECT COUNT(1) FROM ops_error_logs %s`, where)
var total int64
if err := scanSingleRow(ctx, r.sql, countQuery, args, &total); err != nil {
if err == sql.ErrNoRows {
total = 0
} else {
return nil, 0, err
}
}
listQuery := fmt.Sprintf(`
SELECT
id,
created_at,
severity,
request_id,
account_id,
request_path,
platform,
model,
status_code,
error_message,
duration_ms,
retry_count,
stream
FROM ops_error_logs
%s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d
`, where, len(args)+1, len(args)+2)
listArgs := append(append([]any{}, args...), pageSize, offset)
rows, err := r.sql.QueryContext(ctx, listQuery, listArgs...)
if err != nil {
return nil, 0, err
}
defer func() { _ = rows.Close() }()
results := make([]*service.ErrorLog, 0)
for rows.Next() {
var (
id int64
createdAt time.Time
severity sql.NullString
requestID sql.NullString
accountID sql.NullInt64
requestURI sql.NullString
platform sql.NullString
model sql.NullString
statusCode sql.NullInt64
message sql.NullString
durationMs sql.NullInt64
retryCount sql.NullInt64
stream sql.NullBool
)
if err := rows.Scan(
&id,
&createdAt,
&severity,
&requestID,
&accountID,
&requestURI,
&platform,
&model,
&statusCode,
&message,
&durationMs,
&retryCount,
&stream,
); err != nil {
return nil, 0, err
}
entry := &service.ErrorLog{
ID: id,
Timestamp: createdAt,
Level: levelFromSeverity(severity.String),
RequestID: requestID.String,
APIPath: requestURI.String,
Provider: platform.String,
Model: model.String,
HTTPCode: int(statusCode.Int64),
Stream: stream.Bool,
}
if accountID.Valid {
entry.AccountID = strconv.FormatInt(accountID.Int64, 10)
}
if message.Valid {
entry.ErrorMessage = message.String
}
if durationMs.Valid {
v := int(durationMs.Int64)
entry.DurationMs = &v
}
if retryCount.Valid {
v := int(retryCount.Int64)
entry.RetryCount = &v
}
results = append(results, entry)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return results, total, nil
}
func levelFromSeverity(severity string) string {
sev := strings.ToUpper(strings.TrimSpace(severity))
switch sev {
case "P0", "P1":
return "CRITICAL"
case "P2":
return "ERROR"
case "P3":
return "WARN"
default:
return "ERROR"
}
}

View File

@@ -0,0 +1,127 @@
package repository
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
opsLatestMetricsKey = "ops:metrics:latest"
opsDashboardOverviewKeyPrefix = "ops:dashboard:overview:"
opsLatestMetricsTTL = 10 * time.Second
)
func (r *OpsRepository) GetCachedLatestSystemMetric(ctx context.Context) (*service.OpsMetrics, error) {
if ctx == nil {
ctx = context.Background()
}
if r == nil || r.rdb == nil {
return nil, nil
}
data, err := r.rdb.Get(ctx, opsLatestMetricsKey).Bytes()
if errors.Is(err, redis.Nil) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("redis get cached latest system metric: %w", err)
}
var metric service.OpsMetrics
if err := json.Unmarshal(data, &metric); err != nil {
return nil, fmt.Errorf("unmarshal cached latest system metric: %w", err)
}
return &metric, nil
}
func (r *OpsRepository) SetCachedLatestSystemMetric(ctx context.Context, metric *service.OpsMetrics) error {
if metric == nil {
return nil
}
if ctx == nil {
ctx = context.Background()
}
if r == nil || r.rdb == nil {
return nil
}
data, err := json.Marshal(metric)
if err != nil {
return fmt.Errorf("marshal cached latest system metric: %w", err)
}
return r.rdb.Set(ctx, opsLatestMetricsKey, data, opsLatestMetricsTTL).Err()
}
func (r *OpsRepository) GetCachedDashboardOverview(ctx context.Context, timeRange string) (*service.DashboardOverviewData, error) {
if ctx == nil {
ctx = context.Background()
}
if r == nil || r.rdb == nil {
return nil, nil
}
rangeKey := strings.TrimSpace(timeRange)
if rangeKey == "" {
rangeKey = "1h"
}
key := opsDashboardOverviewKeyPrefix + rangeKey
data, err := r.rdb.Get(ctx, key).Bytes()
if errors.Is(err, redis.Nil) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("redis get cached dashboard overview: %w", err)
}
var overview service.DashboardOverviewData
if err := json.Unmarshal(data, &overview); err != nil {
return nil, fmt.Errorf("unmarshal cached dashboard overview: %w", err)
}
return &overview, nil
}
func (r *OpsRepository) SetCachedDashboardOverview(ctx context.Context, timeRange string, data *service.DashboardOverviewData, ttl time.Duration) error {
if data == nil {
return nil
}
if ttl <= 0 {
ttl = 10 * time.Second
}
if ctx == nil {
ctx = context.Background()
}
if r == nil || r.rdb == nil {
return nil
}
rangeKey := strings.TrimSpace(timeRange)
if rangeKey == "" {
rangeKey = "1h"
}
payload, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("marshal cached dashboard overview: %w", err)
}
key := opsDashboardOverviewKeyPrefix + rangeKey
return r.rdb.Set(ctx, key, payload, ttl).Err()
}
func (r *OpsRepository) PingRedis(ctx context.Context) error {
if ctx == nil {
ctx = context.Background()
}
if r == nil || r.rdb == nil {
return errors.New("redis client is nil")
}
return r.rdb.Ping(ctx).Err()
}

File diff suppressed because it is too large Load Diff

View File

@@ -34,15 +34,15 @@ func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, emai
return u
}
func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
func TestEntSoftDelete_APIKey_DefaultFilterAndSkip(t *testing.T) {
ctx := context.Background()
// 使用全局 ent client确保软删除验证在实际持久化数据上进行。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
Name: "soft-delete",
@@ -53,28 +53,28 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
_, err := repo.GetByID(ctx, key.ID)
require.ErrorIs(t, err, service.ErrApiKeyNotFound, "deleted rows should be hidden by default")
require.ErrorIs(t, err, service.ErrAPIKeyNotFound, "deleted rows should be hidden by default")
_, err = client.ApiKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
_, err = client.APIKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
require.Error(t, err, "default ent query should not see soft-deleted rows")
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
got, err := client.ApiKey.Query().
got, err := client.APIKey.Query().
Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
}
func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
func TestEntSoftDelete_APIKey_DeleteIdempotent(t *testing.T) {
ctx := context.Background()
// 使用全局 ent client避免事务回滚影响幂等性验证。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
Name: "soft-delete2",
@@ -86,15 +86,15 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "second delete should be idempotent")
}
func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
func TestEntSoftDelete_APIKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
ctx := context.Background()
// 使用全局 ent client确保 SkipSoftDelete 的硬删除语义可验证。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),
Name: "soft-delete3",
@@ -105,10 +105,10 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
// Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at.
_, err := client.ApiKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
_, err := client.APIKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "hard delete")
_, err = client.ApiKey.Query().
_, err = client.APIKey.Query().
Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")

View File

@@ -117,7 +117,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
args := []any{
log.UserID,
log.ApiKeyID,
log.APIKeyID,
log.AccountID,
log.RequestID,
log.Model,
@@ -183,7 +183,7 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params)
}
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params)
}
@@ -270,8 +270,8 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
r.sql,
apiKeyStatsQuery,
[]any{service.StatusActive},
&stats.TotalApiKeys,
&stats.ActiveApiKeys,
&stats.TotalAPIKeys,
&stats.ActiveAPIKeys,
); err != nil {
return nil, err
}
@@ -418,8 +418,8 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
return &stats, nil
}
// GetApiKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
// GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
@@ -623,7 +623,7 @@ func resolveUsageStatsTimezone() string {
return "UTC"
}
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
return logs, nil, err
@@ -709,11 +709,11 @@ type ModelStat = usagestats.ModelStat
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
// APIKeyUsageTrendPoint represents API key usage trend data point
type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) {
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
@@ -755,10 +755,10 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
}
}()
results = make([]ApiKeyUsageTrendPoint, 0)
results = make([]APIKeyUsageTrendPoint, 0)
for rows.Next() {
var row ApiKeyUsageTrendPoint
if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
var row APIKeyUsageTrendPoint
if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
return nil, err
}
results = append(results, row)
@@ -844,7 +844,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
[]any{userID},
&stats.TotalApiKeys,
&stats.TotalAPIKeys,
); err != nil {
return nil, err
}
@@ -853,7 +853,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
[]any{userID, service.StatusActive},
&stats.ActiveApiKeys,
&stats.ActiveAPIKeys,
); err != nil {
return nil, err
}
@@ -1023,9 +1023,9 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
args = append(args, filters.UserID)
}
if filters.ApiKeyID > 0 {
if filters.APIKeyID > 0 {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
args = append(args, filters.ApiKeyID)
args = append(args, filters.APIKeyID)
}
if filters.AccountID > 0 {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
@@ -1145,18 +1145,18 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
return result, nil
}
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
// BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
result := make(map[int64]*BatchApiKeyUsageStats)
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
result := make(map[int64]*BatchAPIKeyUsageStats)
if len(apiKeyIDs) == 0 {
return result, nil
}
for _, id := range apiKeyIDs {
result[id] = &BatchApiKeyUsageStats{ApiKeyID: id}
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
}
query := `
@@ -1582,7 +1582,7 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
if err != nil {
return err
}
apiKeys, err := r.loadApiKeys(ctx, ids.apiKeyIDs)
apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs)
if err != nil {
return err
}
@@ -1603,8 +1603,8 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
if user, ok := users[logs[i].UserID]; ok {
logs[i].User = user
}
if key, ok := apiKeys[logs[i].ApiKeyID]; ok {
logs[i].ApiKey = key
if key, ok := apiKeys[logs[i].APIKeyID]; ok {
logs[i].APIKey = key
}
if acc, ok := accounts[logs[i].AccountID]; ok {
logs[i].Account = acc
@@ -1642,7 +1642,7 @@ func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs {
for i := range logs {
userIDs[logs[i].UserID] = struct{}{}
apiKeyIDs[logs[i].ApiKeyID] = struct{}{}
apiKeyIDs[logs[i].APIKeyID] = struct{}{}
accountIDs[logs[i].AccountID] = struct{}{}
if logs[i].GroupID != nil {
groupIDs[*logs[i].GroupID] = struct{}{}
@@ -1676,12 +1676,12 @@ func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[in
return out, nil
}
func (r *usageLogRepository) loadApiKeys(ctx context.Context, ids []int64) (map[int64]*service.ApiKey, error) {
out := make(map[int64]*service.ApiKey)
func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) {
out := make(map[int64]*service.APIKey)
if len(ids) == 0 {
return out, nil
}
models, err := r.client.ApiKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
if err != nil {
return nil, err
}
@@ -1800,7 +1800,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
log := &service.UsageLog{
ID: id,
UserID: userID,
ApiKeyID: apiKeyID,
APIKeyID: apiKeyID,
AccountID: accountID,
Model: model,
InputTokens: inputTokens,

View File

@@ -35,10 +35,10 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite))
}
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.ApiKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: inputTokens,
@@ -55,12 +55,12 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A
func (s *UsageLogRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-create"})
log := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: 10,
@@ -76,7 +76,7 @@ func (s *UsageLogRepoSuite) TestCreate() {
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -96,7 +96,7 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
func (s *UsageLogRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-delete"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -112,7 +112,7 @@ func (s *UsageLogRepoSuite) TestDelete() {
func (s *UsageLogRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyuser"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -124,18 +124,18 @@ func (s *UsageLogRepoSuite) TestListByUser() {
s.Require().Equal(int64(2), page.Total)
}
// --- ListByApiKey ---
// --- ListByAPIKey ---
func (s *UsageLogRepoSuite) TestListByApiKey() {
func (s *UsageLogRepoSuite) TestListByAPIKey() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyapikey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByApiKey")
logs, page, err := s.repo.ListByAPIKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByAPIKey")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}
@@ -144,7 +144,7 @@ func (s *UsageLogRepoSuite) TestListByApiKey() {
func (s *UsageLogRepoSuite) TestListByAccount() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyaccount"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -159,7 +159,7 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
func (s *UsageLogRepoSuite) TestGetUserStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -179,7 +179,7 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
func (s *UsageLogRepoSuite) TestListWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filters"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -211,8 +211,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
})
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
resetAt := now.Add(10 * time.Minute)
accNormal := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-normal", Schedulable: true})
@@ -223,7 +223,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
d1, d2, d3 := 100, 200, 300
logToday := &service.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
APIKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
GroupID: &group.ID,
@@ -240,7 +240,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
logOld := &service.UsageLog{
UserID: userOld.ID,
ApiKeyID: apiKey1.ID,
APIKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
InputTokens: 5,
@@ -254,7 +254,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
logPerf := &service.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
APIKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
InputTokens: 1,
@@ -272,8 +272,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch")
s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch")
s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch")
s.Require().Equal(baseStats.TotalApiKeys+2, stats.TotalApiKeys, "TotalApiKeys mismatch")
s.Require().Equal(baseStats.ActiveApiKeys+1, stats.ActiveApiKeys, "ActiveApiKeys mismatch")
s.Require().Equal(baseStats.TotalAPIKeys+2, stats.TotalAPIKeys, "TotalAPIKeys mismatch")
s.Require().Equal(baseStats.ActiveAPIKeys+1, stats.ActiveAPIKeys, "ActiveAPIKeys mismatch")
s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch")
s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch")
s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch")
@@ -300,14 +300,14 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userdash"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID)
s.Require().NoError(err, "GetUserDashboardStats")
s.Require().Equal(int64(1), stats.TotalApiKeys)
s.Require().Equal(int64(1), stats.TotalAPIKeys)
s.Require().Equal(int64(1), stats.TotalRequests)
}
@@ -315,7 +315,7 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -331,8 +331,8 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batch"})
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
@@ -351,24 +351,24 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
s.Require().Empty(stats)
}
// --- GetBatchApiKeyUsageStats ---
// --- GetBatchAPIKeyUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
func (s *UsageLogRepoSuite) TestGetBatchAPIKeyUsageStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batchkey"})
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
s.Require().NoError(err, "GetBatchApiKeyUsageStats")
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
s.Require().Len(stats, 2)
}
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{})
func (s *UsageLogRepoSuite) TestGetBatchAPIKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
@@ -377,7 +377,7 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-global"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -402,7 +402,7 @@ func maxTime(a, b time.Time) time.Time {
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-timerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -417,11 +417,11 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
s.Require().Len(logs, 2)
}
// --- ListByApiKeyAndTimeRange ---
// --- ListByAPIKeyAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -431,8 +431,8 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
s.Require().NoError(err, "ListByApiKeyAndTimeRange")
logs, _, err := s.repo.ListByAPIKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
s.Require().NoError(err, "ListByAPIKeyAndTimeRange")
s.Require().Len(logs, 2)
}
@@ -440,7 +440,7 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-acctimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -459,7 +459,7 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modeltimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -467,7 +467,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
// Create logs with different models
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 10,
@@ -480,7 +480,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 15,
@@ -493,7 +493,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
log3 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 20,
@@ -515,7 +515,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-windowstats"})
now := time.Now()
@@ -535,7 +535,7 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrend"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -552,7 +552,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrendhourly"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -571,7 +571,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -579,7 +579,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
// Create logs with different models
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
@@ -592,7 +592,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
@@ -618,7 +618,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -646,7 +646,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters-h"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -665,14 +665,14 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
@@ -685,7 +685,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
@@ -719,7 +719,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-accstats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
@@ -727,7 +727,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
// Create logs on different days
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
@@ -740,7 +740,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
@@ -782,8 +782,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -799,12 +799,12 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
s.Require().GreaterOrEqual(len(trend), 2)
}
// --- GetApiKeyUsageTrend ---
// --- GetAPIKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -815,14 +815,14 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend")
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetAPIKeyUsageTrend")
s.Require().GreaterOrEqual(len(trend), 2)
}
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -832,21 +832,21 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend hourly")
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
s.Require().NoError(err, "GetAPIKeyUsageTrend hourly")
s.Require().Len(trend, 2)
}
// --- ListWithFilters (additional filter tests) ---
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
func (s *UsageLogRepoSuite) TestListWithFilters_APIKeyFilter() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterskey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID}
filters := usagestats.UsageLogFilters{APIKeyID: apiKey.ID}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters apiKey")
s.Require().Len(logs, 1)
@@ -855,7 +855,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterstime"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -874,7 +874,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterscombined"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -885,7 +885,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
endTime := base.Add(2 * time.Hour)
filters := usagestats.UsageLogFilters{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
StartTime: &startTime,
EndTime: &endTime,
}

View File

@@ -28,12 +28,13 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
NewApiKeyRepository,
NewAPIKeyRepository,
NewGroupRepository,
NewAccountRepository,
NewProxyRepository,
NewRedeemCodeRepository,
NewUsageLogRepository,
NewOpsRepository,
NewSettingRepository,
NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository,
@@ -42,7 +43,7 @@ var ProviderSet = wire.NewSet(
// Cache implementations
NewGatewayCache,
NewBillingCache,
NewApiKeyCache,
NewAPIKeyCache,
ProvideConcurrencyCache,
NewEmailCache,
NewIdentityCache,

View File

@@ -91,7 +91,7 @@ func TestAPIContracts(t *testing.T) {
name: "GET /api/v1/keys (paginated)",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.apiKeyRepo.MustSeed(&service.ApiKey{
deps.apiKeyRepo.MustSeed(&service.APIKey{
ID: 100,
UserID: 1,
Key: "sk_custom_1234567890",
@@ -135,7 +135,7 @@ func TestAPIContracts(t *testing.T) {
{
ID: 1,
UserID: 1,
ApiKeyID: 100,
APIKeyID: 100,
AccountID: 200,
Model: "claude-3",
InputTokens: 10,
@@ -150,7 +150,7 @@ func TestAPIContracts(t *testing.T) {
{
ID: 2,
UserID: 1,
ApiKeyID: 100,
APIKeyID: 100,
AccountID: 200,
Model: "claude-3",
InputTokens: 5,
@@ -188,7 +188,7 @@ func TestAPIContracts(t *testing.T) {
{
ID: 1,
UserID: 1,
ApiKeyID: 100,
APIKeyID: 100,
AccountID: 200,
RequestID: "req_123",
Model: "claude-3",
@@ -259,13 +259,13 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeySmtpHost: "smtp.example.com",
service.SettingKeySmtpPort: "587",
service.SettingKeySmtpUsername: "user",
service.SettingKeySmtpPassword: "secret",
service.SettingKeySmtpFrom: "no-reply@example.com",
service.SettingKeySmtpFromName: "Sub2API",
service.SettingKeySmtpUseTLS: "true",
service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPPort: "587",
service.SettingKeySMTPUsername: "user",
service.SettingKeySMTPPassword: "secret",
service.SettingKeySMTPFrom: "no-reply@example.com",
service.SettingKeySMTPFromName: "Sub2API",
service.SettingKeySMTPUseTLS: "true",
service.SettingKeyTurnstileEnabled: "true",
service.SettingKeyTurnstileSiteKey: "site-key",
@@ -274,9 +274,9 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeySiteName: "Sub2API",
service.SettingKeySiteLogo: "",
service.SettingKeySiteSubtitle: "Subtitle",
service.SettingKeyApiBaseUrl: "https://api.example.com",
service.SettingKeyAPIBaseURL: "https://api.example.com",
service.SettingKeyContactInfo: "support",
service.SettingKeyDocUrl: "https://docs.example.com",
service.SettingKeyDocURL: "https://docs.example.com",
service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25",
@@ -331,7 +331,7 @@ func TestAPIContracts(t *testing.T) {
type contractDeps struct {
now time.Time
router http.Handler
apiKeyRepo *stubApiKeyRepo
apiKeyRepo *stubAPIKeyRepo
usageRepo *stubUsageLogRepo
settingRepo *stubSettingRepo
}
@@ -359,20 +359,20 @@ func newContractDeps(t *testing.T) *contractDeps {
},
}
apiKeyRepo := newStubApiKeyRepo(now)
apiKeyCache := stubApiKeyCache{}
apiKeyRepo := newStubAPIKeyRepo(now)
apiKeyCache := stubAPIKeyCache{}
groupRepo := stubGroupRepo{}
userSubRepo := stubUserSubscriptionRepo{}
cfg := &config.Config{
Default: config.DefaultConfig{
ApiKeyPrefix: "sk-",
APIKeyPrefix: "sk-",
},
RunMode: config.RunModeStandard,
}
userService := service.NewUserService(userRepo)
apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo)
@@ -525,25 +525,25 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return 0, errors.New("not implemented")
}
type stubApiKeyCache struct{}
type stubAPIKeyCache struct{}
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
func (stubAPIKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (stubApiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
func (stubAPIKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (stubApiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
func (stubAPIKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (stubApiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
func (stubAPIKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return nil
}
func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
func (stubAPIKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return nil
}
@@ -660,24 +660,24 @@ func (stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (i
return 0, errors.New("not implemented")
}
type stubApiKeyRepo struct {
type stubAPIKeyRepo struct {
now time.Time
nextID int64
byID map[int64]*service.ApiKey
byKey map[string]*service.ApiKey
byID map[int64]*service.APIKey
byKey map[string]*service.APIKey
}
func newStubApiKeyRepo(now time.Time) *stubApiKeyRepo {
return &stubApiKeyRepo{
func newStubAPIKeyRepo(now time.Time) *stubAPIKeyRepo {
return &stubAPIKeyRepo{
now: now,
nextID: 100,
byID: make(map[int64]*service.ApiKey),
byKey: make(map[string]*service.ApiKey),
byID: make(map[int64]*service.APIKey),
byKey: make(map[string]*service.APIKey),
}
}
func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
func (r *stubAPIKeyRepo) MustSeed(key *service.APIKey) {
if key == nil {
return
}
@@ -686,7 +686,7 @@ func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
r.byKey[clone.Key] = &clone
}
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
func (r *stubAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
if key == nil {
return errors.New("nil key")
}
@@ -706,38 +706,38 @@ func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error
return nil
}
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
func (r *stubAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
key, ok := r.byID[id]
if !ok {
return nil, service.ErrApiKeyNotFound
return nil, service.ErrAPIKeyNotFound
}
clone := *key
return &clone, nil
}
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
func (r *stubAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
key, ok := r.byID[id]
if !ok {
return 0, service.ErrApiKeyNotFound
return 0, service.ErrAPIKeyNotFound
}
return key.UserID, nil
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
func (r *stubAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
found, ok := r.byKey[key]
if !ok {
return nil, service.ErrApiKeyNotFound
return nil, service.ErrAPIKeyNotFound
}
clone := *found
return &clone, nil
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
func (r *stubAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
if key == nil {
return errors.New("nil key")
}
if _, ok := r.byID[key.ID]; !ok {
return service.ErrApiKeyNotFound
return service.ErrAPIKeyNotFound
}
if key.UpdatedAt.IsZero() {
key.UpdatedAt = r.now
@@ -748,17 +748,17 @@ func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error
return nil
}
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
func (r *stubAPIKeyRepo) Delete(ctx context.Context, id int64) error {
key, ok := r.byID[id]
if !ok {
return service.ErrApiKeyNotFound
return service.ErrAPIKeyNotFound
}
delete(r.byID, id)
delete(r.byKey, key.Key)
return nil
}
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
ids := make([]int64, 0, len(r.byID))
for id := range r.byID {
if r.byID[id].UserID == userID {
@@ -776,7 +776,7 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params
end = len(ids)
}
out := make([]service.ApiKey, 0, end-start)
out := make([]service.APIKey, 0, end-start)
for _, id := range ids[start:end] {
clone := *r.byID[id]
out = append(out, clone)
@@ -796,7 +796,7 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params
}, nil
}
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
func (r *stubAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 {
return []int64{}, nil
}
@@ -815,7 +815,7 @@ func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiK
return out, nil
}
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
func (r *stubAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64
for _, key := range r.byID {
if key.UserID == userID {
@@ -825,24 +825,24 @@ func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64
return count, nil
}
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
func (r *stubAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
_, ok := r.byKey[key]
return ok, nil
}
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *stubAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
func (r *stubAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *stubAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *stubAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
@@ -877,7 +877,7 @@ func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params
return out, paginationResult(total, params), nil
}
func (r *stubUsageLogRepo) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
@@ -890,7 +890,7 @@ func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID in
return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil
}
func (r *stubUsageLogRepo) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
@@ -922,7 +922,7 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, errors.New("not implemented")
}
@@ -975,7 +975,7 @@ func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID in
}, nil
}
func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
func (r *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
@@ -995,7 +995,7 @@ func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs [
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
return nil, errors.New("not implemented")
}
@@ -1017,8 +1017,8 @@ func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params paginatio
// Apply filters
var filtered []service.UsageLog
for _, log := range logs {
// Apply ApiKeyID filter
if filters.ApiKeyID > 0 && log.ApiKeyID != filters.ApiKeyID {
// Apply APIKeyID filter
if filters.APIKeyID > 0 && log.APIKeyID != filters.APIKeyID {
continue
}
// Apply Model filter
@@ -1151,8 +1151,8 @@ func paginationResult(total int64, params pagination.PaginationParams) *paginati
// Ensure compile-time interface compliance.
var (
_ service.UserRepository = (*stubUserRepo)(nil)
_ service.ApiKeyRepository = (*stubApiKeyRepo)(nil)
_ service.ApiKeyCache = (*stubApiKeyCache)(nil)
_ service.APIKeyRepository = (*stubAPIKeyRepo)(nil)
_ service.APIKeyCache = (*stubAPIKeyCache)(nil)
_ service.GroupRepository = (*stubGroupRepo)(nil)
_ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil)
_ service.UsageLogRepository = (*stubUsageLogRepo)(nil)

View File

@@ -1,3 +1,4 @@
// Package server provides HTTP server setup and routing configuration.
package server
import (
@@ -25,8 +26,8 @@ func ProvideRouter(
handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
) *gin.Engine {
if cfg.Server.Mode == "release" {

View File

@@ -32,7 +32,7 @@ func adminAuth(
// 检查 x-api-key headerAdmin API Key 认证)
apiKey := c.GetHeader("x-api-key")
if apiKey != "" {
if !validateAdminApiKey(c, apiKey, settingService, userService) {
if !validateAdminAPIKey(c, apiKey, settingService, userService) {
return
}
c.Next()
@@ -52,19 +52,48 @@ func adminAuth(
}
}
// WebSocket 请求无法设置自定义 header允许在 query 中携带凭证
if isWebSocketRequest(c) {
if token := strings.TrimSpace(c.Query("token")); token != "" {
if !validateJWTForAdmin(c, token, authService, userService) {
return
}
c.Next()
return
}
if apiKey := strings.TrimSpace(c.Query("api_key")); apiKey != "" {
if !validateAdminAPIKey(c, apiKey, settingService, userService) {
return
}
c.Next()
return
}
}
// 无有效认证信息
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
}
}
// validateAdminApiKey 验证管理员 API Key
func validateAdminApiKey(
func isWebSocketRequest(c *gin.Context) bool {
if c == nil || c.Request == nil {
return false
}
if strings.EqualFold(c.GetHeader("Upgrade"), "websocket") {
return true
}
conn := strings.ToLower(c.GetHeader("Connection"))
return strings.Contains(conn, "upgrade") && strings.EqualFold(c.GetHeader("Upgrade"), "websocket")
}
// validateAdminAPIKey 验证管理员 API Key
func validateAdminAPIKey(
c *gin.Context,
key string,
settingService *service.SettingService,
userService *service.UserService,
) bool {
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
storedKey, err := settingService.GetAdminAPIKey(c.Request.Context())
if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
return false

View File

@@ -11,13 +11,13 @@ import (
"github.com/gin-gonic/gin"
)
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware {
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
// NewAPIKeyAuthMiddleware 创建 API Key 认证中间件
func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config, opsService *service.OpsService) APIKeyAuthMiddleware {
return APIKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg, opsService))
}
// apiKeyAuthWithSubscription API Key认证中间件支持订阅验证
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config, opsService *service.OpsService) gin.HandlerFunc {
return func(c *gin.Context) {
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization")
@@ -53,6 +53,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 如果所有header都没有API key
if apiKeyString == "" {
recordOpsAuthError(c, opsService, nil, 401, "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
return
}
@@ -60,35 +61,40 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 从数据库验证API key
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, service.ErrApiKeyNotFound) {
if errors.Is(err, service.ErrAPIKeyNotFound) {
recordOpsAuthError(c, opsService, nil, 401, "Invalid API key")
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return
}
recordOpsAuthError(c, opsService, nil, 500, "Failed to validate API key")
AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key")
return
}
// 检查API key是否激活
if !apiKey.IsActive() {
recordOpsAuthError(c, opsService, apiKey, 401, "API key is disabled")
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
return
}
// 检查关联的用户
if apiKey.User == nil {
recordOpsAuthError(c, opsService, apiKey, 401, "User associated with API key not found")
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
return
}
// 检查用户状态
if !apiKey.User.IsActive() {
recordOpsAuthError(c, opsService, apiKey, 401, "User account is not active")
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return
}
if cfg.RunMode == config.RunModeSimple {
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
@@ -109,12 +115,14 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
apiKey.Group.ID,
)
if err != nil {
recordOpsAuthError(c, opsService, apiKey, 403, "No active subscription found for this group")
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
return
}
// 验证订阅状态(是否过期、暂停等)
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
recordOpsAuthError(c, opsService, apiKey, 403, err.Error())
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
return
}
@@ -131,6 +139,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 预检查用量限制使用0作为额外费用进行预检查
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
recordOpsAuthError(c, opsService, apiKey, 429, err.Error())
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
return
}
@@ -140,13 +149,14 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
} else {
// 余额模式:检查用户余额
if apiKey.User.Balance <= 0 {
recordOpsAuthError(c, opsService, apiKey, 403, "Insufficient account balance")
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
return
}
}
// 将API key和用户信息存入上下文
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
@@ -157,13 +167,66 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
}
}
// GetApiKeyFromContext 从上下文中获取API key
func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) {
value, exists := c.Get(string(ContextKeyApiKey))
func recordOpsAuthError(c *gin.Context, opsService *service.OpsService, apiKey *service.APIKey, status int, message string) {
if opsService == nil || c == nil {
return
}
errType := "authentication_error"
phase := "auth"
severity := "P3"
switch status {
case 403:
errType = "billing_error"
phase = "billing"
case 429:
errType = "rate_limit_error"
phase = "billing"
severity = "P2"
case 500:
errType = "api_error"
phase = "internal"
severity = "P1"
}
logEntry := &service.OpsErrorLog{
Phase: phase,
Type: errType,
Severity: severity,
StatusCode: status,
Message: message,
ClientIP: c.ClientIP(),
RequestPath: func() string {
if c.Request != nil && c.Request.URL != nil {
return c.Request.URL.Path
}
return ""
}(),
}
if apiKey != nil {
logEntry.APIKeyID = &apiKey.ID
if apiKey.User != nil {
logEntry.UserID = &apiKey.User.ID
}
if apiKey.GroupID != nil {
logEntry.GroupID = apiKey.GroupID
}
if apiKey.Group != nil {
logEntry.Platform = apiKey.Group.Platform
}
}
enqueueOpsAuthErrorLog(opsService, logEntry)
}
// GetAPIKeyFromContext 从上下文中获取API key
func GetAPIKeyFromContext(c *gin.Context) (*service.APIKey, bool) {
value, exists := c.Get(string(ContextKeyAPIKey))
if !exists {
return nil, false
}
apiKey, ok := value.(*service.ApiKey)
apiKey, ok := value.(*service.APIKey)
return apiKey, ok
}

View File

@@ -11,16 +11,16 @@ import (
"github.com/gin-gonic/gin"
)
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth.
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc {
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
// APIKeyAuthGoogle is a Google-style error wrapper for API key auth.
func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config) gin.HandlerFunc {
return APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
}
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
// APIKeyAuthWithSubscriptionGoogle behaves like APIKeyAuthWithSubscription but returns Google-style errors:
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
//
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" {
@@ -30,7 +30,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, service.ErrApiKeyNotFound) {
if errors.Is(err, service.ErrAPIKeyNotFound) {
abortWithGoogleError(c, 401, "Invalid API key")
return
}
@@ -53,7 +53,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
// 简易模式:跳过余额和订阅检查
if cfg.RunMode == config.RunModeSimple {
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
@@ -92,7 +92,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
}
}
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,

View File

@@ -16,53 +16,53 @@ import (
"github.com/stretchr/testify/require"
)
type fakeApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
type fakeAPIKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
}
func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
if f.getByKey == nil {
return nil, errors.New("unexpected call")
}
return f.getByKey(ctx, key)
}
func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error {
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
func (f fakeAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
func (f fakeAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented")
}
func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (f fakeAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
func (f fakeAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
@@ -74,8 +74,8 @@ type googleErrorResponse struct {
} `json:"error"`
}
func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService {
return service.NewApiKeyService(
func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService {
return service.NewAPIKeyService(
repo,
nil, // userRepo (unused in GetByKey)
nil, // groupRepo
@@ -85,16 +85,16 @@ func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService
)
}
func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
func TestAPIKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("should not be called")
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
@@ -109,16 +109,16 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
func TestAPIKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, service.ErrApiKeyNotFound
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, service.ErrAPIKeyNotFound
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
@@ -134,16 +134,16 @@ func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
func TestAPIKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("db down")
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
@@ -159,13 +159,13 @@ func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
require.Equal(t, "INTERNAL", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
func TestAPIKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.ApiKey{
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.APIKey{
ID: 1,
Key: key,
Status: service.StatusDisabled,
@@ -176,7 +176,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
}, nil
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
@@ -192,13 +192,13 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
func TestAPIKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.ApiKey{
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.APIKey{
ID: 1,
Key: key,
Status: service.StatusActive,
@@ -210,7 +210,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
}, nil
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)

View File

@@ -35,7 +35,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
Balance: 10,
Concurrency: 3,
}
apiKey := &service.ApiKey{
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
@@ -45,10 +45,10 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
}
apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
apiKeyRepo := &stubAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrApiKeyNotFound
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
@@ -57,7 +57,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
@@ -71,7 +71,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
now := time.Now()
sub := &service.UserSubscription{
@@ -110,75 +110,75 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
})
}
func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New()
router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg, nil)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
return router
}
type stubApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
type stubAPIKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
}
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
func (r *stubAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
func (r *stubAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
func (r *stubAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
func (r *stubAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
if r.getByKey != nil {
return r.getByKey(ctx, key)
}
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
func (r *stubAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
func (r *stubAPIKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
func (r *stubAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
func (r *stubAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
func (r *stubAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *stubAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
func (r *stubAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *stubAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *stubAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}

View File

@@ -2,11 +2,14 @@ package middleware
import (
"log"
"regexp"
"time"
"github.com/gin-gonic/gin"
)
var sensitiveQueryParamRE = regexp.MustCompile(`(?i)([?&](?:token|api_key)=)[^&#]*`)
// Logger 请求日志中间件
func Logger() gin.HandlerFunc {
return func(c *gin.Context) {
@@ -26,7 +29,7 @@ func Logger() gin.HandlerFunc {
method := c.Request.Method
// 请求路径
path := c.Request.URL.Path
path := sensitiveQueryParamRE.ReplaceAllString(c.Request.URL.RequestURI(), "${1}***")
// 状态码
statusCode := c.Writer.Status()

View File

@@ -1,3 +1,5 @@
// Package middleware provides HTTP middleware components for authentication,
// authorization, logging, error recovery, and request processing.
package middleware
import (
@@ -15,8 +17,8 @@ const (
ContextKeyUser ContextKey = "user"
// ContextKeyUserRole 当前用户角色string
ContextKeyUserRole ContextKey = "user_role"
// ContextKeyApiKey API密钥上下文键
ContextKeyApiKey ContextKey = "api_key"
// ContextKeyAPIKey API密钥上下文键
ContextKeyAPIKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键
ContextKeySubscription ContextKey = "subscription"
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)

View File

@@ -0,0 +1,55 @@
package middleware
import (
"context"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const (
opsAuthErrorLogWorkerCount = 10
opsAuthErrorLogQueueSize = 256
opsAuthErrorLogTimeout = 2 * time.Second
)
type opsAuthErrorLogJob struct {
ops *service.OpsService
entry *service.OpsErrorLog
}
var (
opsAuthErrorLogOnce sync.Once
opsAuthErrorLogQueue chan opsAuthErrorLogJob
)
func startOpsAuthErrorLogWorkers() {
opsAuthErrorLogQueue = make(chan opsAuthErrorLogJob, opsAuthErrorLogQueueSize)
for i := 0; i < opsAuthErrorLogWorkerCount; i++ {
go func() {
for job := range opsAuthErrorLogQueue {
if job.ops == nil || job.entry == nil {
continue
}
ctx, cancel := context.WithTimeout(context.Background(), opsAuthErrorLogTimeout)
_ = job.ops.RecordError(ctx, job.entry)
cancel()
}
}()
}
}
func enqueueOpsAuthErrorLog(ops *service.OpsService, entry *service.OpsErrorLog) {
if ops == nil || entry == nil {
return
}
opsAuthErrorLogOnce.Do(startOpsAuthErrorLogWorkers)
select {
case opsAuthErrorLogQueue <- opsAuthErrorLogJob{ops: ops, entry: entry}:
default:
// Queue is full; drop to avoid blocking request handling.
}
}

View File

@@ -11,12 +11,12 @@ type JWTAuthMiddleware gin.HandlerFunc
// AdminAuthMiddleware 管理员认证中间件类型
type AdminAuthMiddleware gin.HandlerFunc
// ApiKeyAuthMiddleware API Key 认证中间件类型
type ApiKeyAuthMiddleware gin.HandlerFunc
// APIKeyAuthMiddleware API Key 认证中间件类型
type APIKeyAuthMiddleware gin.HandlerFunc
// ProviderSet 中间件层的依赖注入
var ProviderSet = wire.NewSet(
NewJWTAuthMiddleware,
NewAdminAuthMiddleware,
NewApiKeyAuthMiddleware,
NewAPIKeyAuthMiddleware,
)

View File

@@ -17,8 +17,8 @@ func SetupRouter(
handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) *gin.Engine {
@@ -43,8 +43,8 @@ func registerRoutes(
h *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) {

View File

@@ -19,6 +19,9 @@ func RegisterAdminRoutes(
// 仪表盘
registerDashboardRoutes(admin, h)
// 运维监控
registerOpsRoutes(admin, h)
// 用户管理
registerUserManagementRoutes(admin, h)
@@ -67,10 +70,35 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
}
}
func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
ops := admin.Group("/ops")
{
ops.GET("/metrics", h.Admin.Ops.GetMetrics)
ops.GET("/metrics/history", h.Admin.Ops.ListMetricsHistory)
ops.GET("/errors", h.Admin.Ops.GetErrorLogs)
ops.GET("/error-logs", h.Admin.Ops.ListErrorLogs)
// Dashboard routes
dashboard := ops.Group("/dashboard")
{
dashboard.GET("/overview", h.Admin.Ops.GetDashboardOverview)
dashboard.GET("/providers", h.Admin.Ops.GetProviderHealth)
dashboard.GET("/latency-histogram", h.Admin.Ops.GetLatencyHistogram)
dashboard.GET("/errors/distribution", h.Admin.Ops.GetErrorDistribution)
}
// WebSocket routes
ws := ops.Group("/ws")
{
ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
}
}
}
@@ -203,12 +231,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSMTPConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
// Admin API Key 管理
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
}
}
@@ -248,7 +276,7 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys)
}
}

View File

@@ -1,3 +1,4 @@
// Package routes 提供 HTTP 路由注册和处理函数
package routes
import (

View File

@@ -13,8 +13,8 @@ import (
func RegisterGatewayRoutes(
r *gin.Engine,
h *handler.Handlers,
apiKeyAuth middleware.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
apiKeyAuth middleware.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) {
@@ -36,7 +36,7 @@ func RegisterGatewayRoutes(
// Gemini 原生 API 兼容层Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta")
gemini.Use(bodyLimit)
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
@@ -62,7 +62,7 @@ func RegisterGatewayRoutes(
antigravityV1Beta := r.Group("/antigravity/v1beta")
antigravityV1Beta.Use(bodyLimit)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)

View File

@@ -50,7 +50,7 @@ func RegisterUserRoutes(
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage)
}
// 卡密兑换

View File

@@ -206,7 +206,7 @@ func (a *Account) GetMappedModel(requestedModel string) string {
}
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey {
if a.Type != AccountTypeAPIKey {
return ""
}
baseURL := a.GetCredential("base_url")
@@ -229,7 +229,7 @@ func (a *Account) GetExtraString(key string) string {
}
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
@@ -300,15 +300,15 @@ func (a *Account) IsOpenAIOAuth() bool {
return a.IsOpenAI() && a.Type == AccountTypeOAuth
}
func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey
func (a *Account) IsOpenAIAPIKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeAPIKey
}
func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() {
return ""
}
if a.Type == AccountTypeApiKey {
if a.Type == AccountTypeAPIKey {
baseURL := a.GetCredential("base_url")
if baseURL != "" {
return baseURL
@@ -338,8 +338,8 @@ func (a *Account) GetOpenAIIDToken() string {
return a.GetCredential("id_token")
}
func (a *Account) GetOpenAIApiKey() string {
if !a.IsOpenAIApiKey() {
func (a *Account) GetOpenAIAPIKey() string {
if !a.IsOpenAIAPIKey() {
return ""
}
return a.GetCredential("api_key")

View File

@@ -1,3 +1,5 @@
// Package service 提供业务逻辑层服务,封装领域模型的业务规则和操作流程。
// 服务层协调 repository 层的数据访问,实现跨实体的业务逻辑,并为上层 API 提供统一的业务接口。
package service
import (

View File

@@ -324,7 +324,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
chatgptAccountID = account.GetChatGPTAccountID()
} else if account.Type == "apikey" {
// API Key - use Platform API
authToken = account.GetOpenAIApiKey()
authToken = account.GetOpenAIAPIKey()
if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
@@ -402,7 +402,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
}
// For API Key accounts with model mapping, map the model
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
@@ -426,7 +426,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
var err error
switch account.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)

View File

@@ -17,11 +17,11 @@ type UsageLogRepository interface {
Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
@@ -32,10 +32,10 @@ type UsageLogRepository interface {
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
// User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
@@ -51,7 +51,7 @@ type UsageLogRepository interface {
// Aggregated stats (optimized)
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)

View File

@@ -19,7 +19,7 @@ type AdminService interface {
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management
@@ -30,7 +30,7 @@ type AdminService interface {
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error)
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
@@ -65,7 +65,7 @@ type AdminService interface {
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
}
// Input types for admin operations
// CreateUserInput represents the input for creating a new user
type CreateUserInput struct {
Email string
Password string
@@ -220,7 +220,7 @@ type adminServiceImpl struct {
groupRepo GroupRepository
accountRepo AccountRepository
proxyRepo ProxyRepository
apiKeyRepo ApiKeyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
@@ -232,7 +232,7 @@ func NewAdminService(
groupRepo GroupRepository,
accountRepo AccountRepository,
proxyRepo ProxyRepository,
apiKeyRepo ApiKeyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
@@ -430,7 +430,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil
}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) {
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
@@ -583,7 +583,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
return nil
}
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) {
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil {

View File

@@ -2,7 +2,7 @@ package service
import "time"
type ApiKey struct {
type APIKey struct {
ID int64
UserID int64
Key string
@@ -15,6 +15,6 @@ type ApiKey struct {
Group *Group
}
func (k *ApiKey) IsActive() bool {
func (k *APIKey) IsActive() bool {
return k.Status == StatusActive
}

View File

@@ -14,39 +14,39 @@ import (
)
var (
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
)
const (
apiKeyMaxErrorsPerHour = 20
)
type ApiKeyRepository interface {
Create(ctx context.Context, key *ApiKey) error
GetByID(ctx context.Context, id int64) (*ApiKey, error)
type APIKeyRepository interface {
Create(ctx context.Context, key *APIKey) error
GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetOwnerID 仅获取 API Key 的所有者 ID用于删除前的轻量级权限验证
GetOwnerID(ctx context.Context, id int64) (int64, error)
GetByKey(ctx context.Context, key string) (*ApiKey, error)
Update(ctx context.Context, key *ApiKey) error
GetByKey(ctx context.Context, key string) (*APIKey, error)
Update(ctx context.Context, key *APIKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
}
// ApiKeyCache defines cache operations for API key service
type ApiKeyCache interface {
// APIKeyCache defines cache operations for API key service
type APIKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
@@ -55,40 +55,40 @@ type ApiKeyCache interface {
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
}
// CreateApiKeyRequest 创建API Key请求
type CreateApiKeyRequest struct {
// CreateAPIKeyRequest 创建API Key请求
type CreateAPIKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
}
// UpdateApiKeyRequest 更新API Key请求
type UpdateApiKeyRequest struct {
// UpdateAPIKeyRequest 更新API Key请求
type UpdateAPIKeyRequest struct {
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
}
// ApiKeyService API Key服务
type ApiKeyService struct {
apiKeyRepo ApiKeyRepository
// APIKeyService API Key服务
type APIKeyService struct {
apiKeyRepo APIKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
cache ApiKeyCache
cache APIKeyCache
cfg *config.Config
}
// NewApiKeyService 创建API Key服务实例
func NewApiKeyService(
apiKeyRepo ApiKeyRepository,
// NewAPIKeyService 创建API Key服务实例
func NewAPIKeyService(
apiKeyRepo APIKeyRepository,
userRepo UserRepository,
groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository,
cache ApiKeyCache,
cache APIKeyCache,
cfg *config.Config,
) *ApiKeyService {
return &ApiKeyService{
) *APIKeyService {
return &APIKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
@@ -99,7 +99,7 @@ func NewApiKeyService(
}
// GenerateKey 生成随机API Key
func (s *ApiKeyService) GenerateKey() (string, error) {
func (s *APIKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
@@ -107,7 +107,7 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
}
// 转换为十六进制字符串并添加前缀
prefix := s.cfg.Default.ApiKeyPrefix
prefix := s.cfg.Default.APIKeyPrefix
if prefix == "" {
prefix = "sk-"
}
@@ -117,10 +117,10 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
}
// ValidateCustomKey 验证自定义API Key格式
func (s *ApiKeyService) ValidateCustomKey(key string) error {
func (s *APIKeyService) ValidateCustomKey(key string) error {
// 检查长度
if len(key) < 16 {
return ErrApiKeyTooShort
return ErrAPIKeyTooShort
}
// 检查字符:只允许字母、数字、下划线、连字符
@@ -131,14 +131,14 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
c == '_' || c == '-' {
continue
}
return ErrApiKeyInvalidChars
return ErrAPIKeyInvalidChars
}
return nil
}
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil {
return nil
}
@@ -150,14 +150,14 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
}
if count >= apiKeyMaxErrorsPerHour {
return ErrApiKeyRateLimited
return ErrAPIKeyRateLimited
}
return nil
}
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
if s.cache == nil {
return
}
@@ -168,7 +168,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in
// canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
@@ -179,7 +179,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group
}
// Create 创建API Key
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
// 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
@@ -204,7 +204,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
// 判断是否使用自定义Key
if req.CustomKey != nil && *req.CustomKey != "" {
// 检查限流仅对自定义key进行限流
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
return nil, err
}
@@ -219,9 +219,9 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
return nil, fmt.Errorf("check key exists: %w", err)
}
if exists {
// Key已存在增加错误计数
s.incrementApiKeyErrorCount(ctx, userID)
return nil, ErrApiKeyExists
// Key已存在,增加错误计数
s.incrementAPIKeyErrorCount(ctx, userID)
return nil, ErrAPIKeyExists
}
key = *req.CustomKey
@@ -235,7 +235,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
}
// 创建API Key记录
apiKey := &ApiKey{
apiKey := &APIKey{
UserID: userID,
Key: key,
Name: req.Name,
@@ -251,7 +251,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
}
// List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
@@ -259,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
return keys, pagination, nil
}
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 {
return []int64{}, nil
}
@@ -272,7 +272,7 @@ func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe
}
// GetByID 根据ID获取API Key
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
@@ -281,7 +281,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error)
}
// GetByKey 根据Key字符串获取API Key用于认证
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
@@ -301,7 +301,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, erro
}
// Update 更新API Key
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
@@ -353,8 +353,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
// Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 ApiKey 对象及其关联数据User、Group提升删除操作的性能
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 避免加载完整 APIKey 对象及其关联数据User、Group提升删除操作的性能
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
if err != nil {
@@ -379,7 +379,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
}
// ValidateKey 验证API Key是否有效用于认证中间件
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
// 获取API Key
apiKey, err := s.GetByKey(ctx, key)
if err != nil {
@@ -406,7 +406,7 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *
}
// IncrementUsage 增加API Key使用次数可选用于统计
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器
if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
@@ -423,7 +423,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
@@ -460,7 +460,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
}
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID]
@@ -469,8 +469,8 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subsc
return user.CanBindGroup(group.ID, group.IsExclusive)
}
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
if err != nil {
return nil, fmt.Errorf("search api keys: %w", err)
}

View File

@@ -1,7 +1,7 @@
//go:build unit
// API Key 服务删除方法的单元测试
// 测试 ApiKeyService.Delete 方法在各种场景下的行为,
// 测试 APIKeyService.Delete 方法在各种场景下的行为,
// 包括权限验证、缓存清理和错误处理
package service
@@ -16,12 +16,12 @@ import (
"github.com/stretchr/testify/require"
)
// apiKeyRepoStub 是 ApiKeyRepository 接口的测试桩实现。
// 用于隔离测试 ApiKeyService.Delete 方法,避免依赖真实数据库。
// apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
//
// 设计说明:
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrApiKeyNotFound
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound
// - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID用于断言验证
type apiKeyRepoStub struct {
@@ -33,11 +33,11 @@ type apiKeyRepoStub struct {
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
func (s *apiKeyRepoStub) Create(ctx context.Context, key *ApiKey) error {
func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
panic("unexpected Create call")
}
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
panic("unexpected GetByID call")
}
@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error
return s.ownerID, s.ownerErr
}
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *apiKeyRepoStub) Update(ctx context.Context, key *ApiKey) error {
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call")
}
@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
// 以下是接口要求实现但本测试不关心的方法
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err
panic("unexpected ExistsByKey call")
}
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *apiKeyRepoStub) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
panic("unexpected SearchApiKeys call")
func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
}
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
panic("unexpected CountByGroupID call")
}
// apiKeyCacheStub 是 ApiKeyCache 接口的测试桩实现。
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//
// 设计说明:
@@ -132,17 +132,17 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
return nil
}
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// TestAPIKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 1
// - 调用者 userID 为 2不匹配
// - 返回 ErrInsufficientPerms 错误
// - Delete 方法不被调用
// - 缓存不被清除
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
func TestAPIKeyService_Delete_OwnerMismatch(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 1}
cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
require.ErrorIs(t, err, ErrInsufficientPerms)
@@ -150,17 +150,17 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
require.Empty(t, cache.invalidated) // 验证缓存未被清除
}
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
// TestAPIKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 7
// - 调用者 userID 为 7匹配
// - Delete 成功执行
// - 缓存被正确清除(使用 ownerID
// - 返回 nil 错误
func TestApiKeyService_Delete_Success(t *testing.T) {
func TestAPIKeyService_Delete_Success(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 7}
cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
require.NoError(t, err)
@@ -168,37 +168,37 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
}
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// TestAPIKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// 预期行为:
// - GetOwnerID 返回 ErrApiKeyNotFound 错误
// - 返回 ErrApiKeyNotFound 错误(被 fmt.Errorf 包装)
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
// - Delete 方法不被调用
// - 缓存不被清除
func TestApiKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrApiKeyNotFound}
func TestAPIKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 99, 1)
require.ErrorIs(t, err, ErrApiKeyNotFound)
require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Empty(t, repo.deletedIDs)
require.Empty(t, cache.invalidated)
}
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
// TestAPIKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
// 预期行为:
// - GetOwnerID 返回正确的所有者 ID
// - 所有权验证通过
// - 缓存被清除(在删除之前)
// - Delete 被调用但返回错误
// - 返回包含 "delete api key" 的错误信息
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
func TestAPIKeyService_Delete_DeleteFails(t *testing.T) {
repo := &apiKeyRepoStub{
ownerID: 3,
deleteErr: errors.New("delete failed"),
}
cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
require.Error(t, err)

View File

@@ -445,7 +445,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
// CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0
// 订阅模式检查缓存用量未超过限额Group限额从参数传入
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
// 简易模式:跳过所有计费检查
if s.cfg.RunMode == config.RunModeSimple {
return nil

View File

@@ -32,6 +32,7 @@ type ConcurrencyCache interface {
// 等待队列计数(只在首次创建时设置 TTL
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error
GetTotalWaitCount(ctx context.Context) (int, error)
// 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
@@ -200,6 +201,14 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
}
}
// GetTotalWaitCount returns the total wait queue depth across users.
func (s *ConcurrencyService) GetTotalWaitCount(ctx context.Context) (int, error) {
if s.cache == nil {
return 0, nil
}
return s.cache.GetTotalWaitCount(ctx)
}
// IncrementAccountWaitCount increments the wait queue counter for an account.
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
if s.cache == nil {

View File

@@ -82,7 +82,7 @@ type crsExportResponse struct {
OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"`
OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"`
GeminiOAuthAccounts []crsGeminiOAuthAccount `json:"geminiOAuthAccounts"`
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiApiKeyAccounts"`
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiAPIKeyAccounts"`
} `json:"data"`
}
@@ -430,7 +430,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformAnthropic,
Type: AccountTypeApiKey,
Type: AccountTypeAPIKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -455,7 +455,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformAnthropic
existing.Type = AccountTypeApiKey
existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
@@ -674,7 +674,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformOpenAI,
Type: AccountTypeApiKey,
Type: AccountTypeAPIKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -699,7 +699,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformOpenAI
existing.Type = AccountTypeApiKey
existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
@@ -893,7 +893,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformGemini,
Type: AccountTypeApiKey,
Type: AccountTypeAPIKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -918,7 +918,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformGemini
existing.Type = AccountTypeApiKey
existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID

View File

@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil
}
func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil {
return nil, fmt.Errorf("get api key usage trend: %w", err)
}
@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [
return stats, nil
}
func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}

View File

@@ -28,7 +28,7 @@ const (
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeApiKey = "apikey" // API Key类型账号
AccountTypeAPIKey = "apikey" // API Key类型账号
)
// Redeem type constants
@@ -64,13 +64,13 @@ const (
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储
SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
SettingKeySMTPPort = "smtp_port" // SMTP端口
SettingKeySMTPUsername = "smtp_username" // SMTP用户名
SettingKeySMTPPassword = "smtp_password" // SMTP密码加密存储
SettingKeySMTPFrom = "smtp_from" // 发件人地址
SettingKeySMTPFromName = "smtp_from_name" // 发件人名称
SettingKeySMTPUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
@@ -81,20 +81,20 @@ const (
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyAPIBaseURL = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocUrl = "doc_url" // 文档链接
SettingKeyDocURL = "doc_url" // 文档链接
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
// Gemini 配额策略JSON
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
)
// Admin API Key prefix (distinct from user "sk-" keys)
const AdminApiKeyPrefix = "admin-"
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys)
const AdminAPIKeyPrefix = "admin-"

View File

@@ -40,8 +40,8 @@ const (
maxVerifyCodeAttempts = 5
)
// SmtpConfig SMTP配置
type SmtpConfig struct {
// SMTPConfig SMTP配置
type SMTPConfig struct {
Host string
Port int
Username string
@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
}
}
// GetSmtpConfig 从数据库获取SMTP配置
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
// GetSMTPConfig 从数据库获取SMTP配置
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
keys := []string{
SettingKeySmtpHost,
SettingKeySmtpPort,
SettingKeySmtpUsername,
SettingKeySmtpPassword,
SettingKeySmtpFrom,
SettingKeySmtpFromName,
SettingKeySmtpUseTLS,
SettingKeySMTPHost,
SettingKeySMTPPort,
SettingKeySMTPUsername,
SettingKeySMTPPassword,
SettingKeySMTPFrom,
SettingKeySMTPFromName,
SettingKeySMTPUseTLS,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -82,34 +82,34 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
return nil, fmt.Errorf("get smtp settings: %w", err)
}
host := settings[SettingKeySmtpHost]
host := settings[SettingKeySMTPHost]
if host == "" {
return nil, ErrEmailNotConfigured
}
port := 587 // 默认端口
if portStr := settings[SettingKeySmtpPort]; portStr != "" {
if portStr := settings[SettingKeySMTPPort]; portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil {
port = p
}
}
useTLS := settings[SettingKeySmtpUseTLS] == "true"
useTLS := settings[SettingKeySMTPUseTLS] == "true"
return &SmtpConfig{
return &SMTPConfig{
Host: host,
Port: port,
Username: settings[SettingKeySmtpUsername],
Password: settings[SettingKeySmtpPassword],
From: settings[SettingKeySmtpFrom],
FromName: settings[SettingKeySmtpFromName],
Username: settings[SettingKeySMTPUsername],
Password: settings[SettingKeySMTPPassword],
From: settings[SettingKeySMTPFrom],
FromName: settings[SettingKeySMTPFromName],
UseTLS: useTLS,
}, nil
}
// SendEmail 发送邮件(使用数据库中保存的配置)
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
config, err := s.GetSmtpConfig(ctx)
config, err := s.GetSMTPConfig(ctx)
if err != nil {
return err
}
@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
}
// SendEmailWithConfig 使用指定配置发送邮件
func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error {
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
from := config.From
if config.FromName != "" {
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
`, siteName, code)
}
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
// TestSMTPConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
if config.UseTLS {

View File

@@ -276,7 +276,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
@@ -617,7 +617,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},

View File

@@ -905,7 +905,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account)
case AccountTypeApiKey:
case AccountTypeAPIKey:
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
@@ -976,7 +976,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 应用模型映射仅对apikey类型账号
originalModel := reqModel
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
// 替换请求体中的模型名
@@ -1110,7 +1110,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages"
}
@@ -1178,10 +1178,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
@@ -1248,12 +1248,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
return false
}
func defaultApiKeyBetaHeader(body []byte) string {
func defaultAPIKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.ApiKeyHaikuBetaHeader
return claude.APIKeyHaikuBetaHeader
}
return claude.ApiKeyBetaHeader
return claude.APIKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
@@ -1630,7 +1630,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
ApiKey *ApiKey
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
@@ -1639,7 +1639,7 @@ type RecordUsageInput struct {
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result
apiKey := input.ApiKey
apiKey := input.APIKey
user := input.User
account := input.Account
subscription := input.Subscription
@@ -1676,7 +1676,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
@@ -1762,7 +1762,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
if reqModel != "" {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
@@ -1848,7 +1848,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
}
@@ -1910,10 +1910,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}

View File

@@ -273,7 +273,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
return 999
}
switch a.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
return 0
}
@@ -351,7 +351,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
originalModel := req.Model
mappedModel := req.Model
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(req.Model)
}
@@ -374,7 +374,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
switch account.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
@@ -614,7 +614,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
mappedModel := originalModel
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel)
}
@@ -636,7 +636,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
var buildReq func(ctx context.Context) (*http.Request, string, error)
switch account.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
@@ -1758,7 +1758,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
}
switch account.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if apiKey == "" {
return nil, errors.New("gemini api_key not configured")

View File

@@ -275,7 +275,7 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPr
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 1, Platform: PlatformGemini, Type: AccountTypeAPIKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
},
accountsByID: map[int64]*Account{},

View File

@@ -251,7 +251,7 @@ func inferGoogleOneTier(storageBytes int64) string {
return TierGoogleOneUnknown
}
// fetchGoogleOneTier fetches Google One tier from Drive API
// FetchGoogleOneTier fetches Google One tier from Drive API
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
driveClient := geminicli.NewDriveClient()

View File

@@ -487,8 +487,8 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
return "", "", errors.New("access_token not found in credentials")
}
return accessToken, "oauth", nil
case AccountTypeApiKey:
apiKey := account.GetOpenAIApiKey()
case AccountTypeAPIKey:
apiKey := account.GetOpenAIAPIKey()
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
}
@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeOAuth:
// OAuth accounts use ChatGPT internal API
targetURL = chatgptCodexURL
case AccountTypeApiKey:
case AccountTypeAPIKey:
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
if baseURL != "" {
@@ -940,7 +940,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult
ApiKey *ApiKey
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription
@@ -949,7 +949,7 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result
apiKey := input.ApiKey
apiKey := input.APIKey
user := input.User
account := input.Account
subscription := input.Subscription
@@ -991,7 +991,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,

View File

@@ -0,0 +1,99 @@
package service
import (
"context"
"time"
)
// ErrorLog represents an ops error log item for list queries.
//
// Field naming matches docs/API-运维监控中心2.0.md (L3 根因追踪 - 错误日志列表).
type ErrorLog struct {
ID int64 `json:"id"`
Timestamp time.Time `json:"timestamp"`
Level string `json:"level,omitempty"`
RequestID string `json:"request_id,omitempty"`
AccountID string `json:"account_id,omitempty"`
APIPath string `json:"api_path,omitempty"`
Provider string `json:"provider,omitempty"`
Model string `json:"model,omitempty"`
HTTPCode int `json:"http_code,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
DurationMs *int `json:"duration_ms,omitempty"`
RetryCount *int `json:"retry_count,omitempty"`
Stream bool `json:"stream,omitempty"`
}
// ErrorLogFilter describes optional filters and pagination for listing ops error logs.
type ErrorLogFilter struct {
StartTime *time.Time
EndTime *time.Time
ErrorCode *int
Provider string
AccountID *int64
Page int
PageSize int
}
func (f *ErrorLogFilter) normalize() (page, pageSize int) {
page = 1
pageSize = 20
if f == nil {
return page, pageSize
}
if f.Page > 0 {
page = f.Page
}
if f.PageSize > 0 {
pageSize = f.PageSize
}
if pageSize > 100 {
pageSize = 100
}
return page, pageSize
}
type ErrorLogListResponse struct {
Errors []*ErrorLog `json:"errors"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
func (s *OpsService) GetErrorLogs(ctx context.Context, filter *ErrorLogFilter) (*ErrorLogListResponse, error) {
if s == nil || s.repo == nil {
return &ErrorLogListResponse{
Errors: []*ErrorLog{},
Total: 0,
Page: 1,
PageSize: 20,
}, nil
}
page, pageSize := filter.normalize()
if filter == nil {
filter = &ErrorLogFilter{}
}
filter.Page = page
filter.PageSize = pageSize
items, total, err := s.repo.ListErrorLogs(ctx, filter)
if err != nil {
return nil, err
}
if items == nil {
items = []*ErrorLog{}
}
return &ErrorLogListResponse{
Errors: items,
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}

View File

@@ -0,0 +1,834 @@
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
)
type OpsAlertService struct {
opsService *OpsService
userService *UserService
emailService *EmailService
httpClient *http.Client
interval time.Duration
startOnce sync.Once
stopOnce sync.Once
stopCtx context.Context
stop context.CancelFunc
wg sync.WaitGroup
}
// opsAlertEvalInterval defines how often OpsAlertService evaluates alert rules.
//
// Production uses opsMetricsInterval. Tests may override this variable to keep
// integration tests fast without changing production defaults.
var opsAlertEvalInterval = opsMetricsInterval
func NewOpsAlertService(opsService *OpsService, userService *UserService, emailService *EmailService) *OpsAlertService {
return &OpsAlertService{
opsService: opsService,
userService: userService,
emailService: emailService,
httpClient: &http.Client{Timeout: 10 * time.Second},
interval: opsAlertEvalInterval,
}
}
// Start launches the background alert evaluation loop.
//
// Stop must be called during shutdown to ensure the goroutine exits.
func (s *OpsAlertService) Start() {
s.StartWithContext(context.Background())
}
// StartWithContext is like Start but allows the caller to provide a parent context.
// When the parent context is canceled, the service stops automatically.
func (s *OpsAlertService) StartWithContext(ctx context.Context) {
if s == nil {
return
}
if ctx == nil {
ctx = context.Background()
}
s.startOnce.Do(func() {
if s.interval <= 0 {
s.interval = opsAlertEvalInterval
}
s.stopCtx, s.stop = context.WithCancel(ctx)
s.wg.Add(1)
go s.run()
})
}
// Stop gracefully stops the background goroutine started by Start/StartWithContext.
// It is safe to call Stop multiple times.
func (s *OpsAlertService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.stop != nil {
s.stop()
}
})
s.wg.Wait()
}
func (s *OpsAlertService) run() {
defer s.wg.Done()
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
s.evaluateOnce()
for {
select {
case <-ticker.C:
s.evaluateOnce()
case <-s.stopCtx.Done():
return
}
}
}
func (s *OpsAlertService) evaluateOnce() {
ctx, cancel := context.WithTimeout(s.stopCtx, opsAlertEvaluateTimeout)
defer cancel()
s.Evaluate(ctx, time.Now())
}
func (s *OpsAlertService) Evaluate(ctx context.Context, now time.Time) {
if s == nil || s.opsService == nil {
return
}
rules, err := s.opsService.ListAlertRules(ctx)
if err != nil {
log.Printf("[OpsAlert] failed to list rules: %v", err)
return
}
if len(rules) == 0 {
return
}
maxSustainedByWindow := make(map[int]int)
for _, rule := range rules {
if !rule.Enabled {
continue
}
window := rule.WindowMinutes
if window <= 0 {
window = 1
}
sustained := rule.SustainedMinutes
if sustained <= 0 {
sustained = 1
}
if sustained > maxSustainedByWindow[window] {
maxSustainedByWindow[window] = sustained
}
}
metricsByWindow := make(map[int][]OpsMetrics)
for window, limit := range maxSustainedByWindow {
metrics, err := s.opsService.ListRecentSystemMetrics(ctx, window, limit)
if err != nil {
log.Printf("[OpsAlert] failed to load metrics window=%dm: %v", window, err)
continue
}
metricsByWindow[window] = metrics
}
for _, rule := range rules {
if !rule.Enabled {
continue
}
window := rule.WindowMinutes
if window <= 0 {
window = 1
}
sustained := rule.SustainedMinutes
if sustained <= 0 {
sustained = 1
}
metrics := metricsByWindow[window]
selected, ok := selectContiguousMetrics(metrics, sustained, now)
if !ok {
continue
}
breached, latestValue, ok := evaluateRule(rule, selected)
if !ok {
continue
}
activeEvent, err := s.opsService.GetActiveAlertEvent(ctx, rule.ID)
if err != nil {
log.Printf("[OpsAlert] failed to get active event (rule=%d): %v", rule.ID, err)
continue
}
if breached {
if activeEvent != nil {
continue
}
lastEvent, err := s.opsService.GetLatestAlertEvent(ctx, rule.ID)
if err != nil {
log.Printf("[OpsAlert] failed to get latest event (rule=%d): %v", rule.ID, err)
continue
}
if lastEvent != nil && rule.CooldownMinutes > 0 {
cooldown := time.Duration(rule.CooldownMinutes) * time.Minute
if now.Sub(lastEvent.FiredAt) < cooldown {
continue
}
}
event := &OpsAlertEvent{
RuleID: rule.ID,
Severity: rule.Severity,
Status: OpsAlertStatusFiring,
Title: fmt.Sprintf("%s: %s", rule.Severity, rule.Name),
Description: buildAlertDescription(rule, latestValue),
MetricValue: latestValue,
ThresholdValue: rule.Threshold,
FiredAt: now,
CreatedAt: now,
}
if err := s.opsService.CreateAlertEvent(ctx, event); err != nil {
log.Printf("[OpsAlert] failed to create event (rule=%d): %v", rule.ID, err)
continue
}
emailSent, webhookSent := s.dispatchNotifications(ctx, rule, event)
if emailSent || webhookSent {
if err := s.opsService.UpdateAlertEventNotifications(ctx, event.ID, emailSent, webhookSent); err != nil {
log.Printf("[OpsAlert] failed to update notification flags (event=%d): %v", event.ID, err)
}
}
} else if activeEvent != nil {
resolvedAt := now
if err := s.opsService.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil {
log.Printf("[OpsAlert] failed to resolve event (event=%d): %v", activeEvent.ID, err)
}
}
}
}
const opsMetricsContinuityTolerance = 20 * time.Second
// selectContiguousMetrics picks the newest N metrics and verifies they are continuous.
//
// This prevents a sustained rule from triggering when metrics sampling has gaps
// (e.g. collector downtime) and avoids evaluating "stale" data.
//
// Assumptions:
// - Metrics are ordered by UpdatedAt DESC (newest first).
// - Metrics are expected to be collected at opsMetricsInterval cadence.
func selectContiguousMetrics(metrics []OpsMetrics, needed int, now time.Time) ([]OpsMetrics, bool) {
if needed <= 0 {
return nil, false
}
if len(metrics) < needed {
return nil, false
}
newest := metrics[0].UpdatedAt
if newest.IsZero() {
return nil, false
}
if now.Sub(newest) > opsMetricsInterval+opsMetricsContinuityTolerance {
return nil, false
}
selected := metrics[:needed]
for i := 0; i < len(selected)-1; i++ {
a := selected[i].UpdatedAt
b := selected[i+1].UpdatedAt
if a.IsZero() || b.IsZero() {
return nil, false
}
gap := a.Sub(b)
if gap < opsMetricsInterval-opsMetricsContinuityTolerance || gap > opsMetricsInterval+opsMetricsContinuityTolerance {
return nil, false
}
}
return selected, true
}
func evaluateRule(rule OpsAlertRule, metrics []OpsMetrics) (bool, float64, bool) {
if len(metrics) == 0 {
return false, 0, false
}
latestValue, ok := metricValue(metrics[0], rule.MetricType)
if !ok {
return false, 0, false
}
for _, metric := range metrics {
value, ok := metricValue(metric, rule.MetricType)
if !ok || !compareMetric(value, rule.Operator, rule.Threshold) {
return false, latestValue, true
}
}
return true, latestValue, true
}
func metricValue(metric OpsMetrics, metricType string) (float64, bool) {
switch metricType {
case OpsMetricSuccessRate:
if metric.RequestCount == 0 {
return 0, false
}
return metric.SuccessRate, true
case OpsMetricErrorRate:
if metric.RequestCount == 0 {
return 0, false
}
return metric.ErrorRate, true
case OpsMetricP95LatencyMs:
return float64(metric.P95LatencyMs), true
case OpsMetricP99LatencyMs:
return float64(metric.P99LatencyMs), true
case OpsMetricHTTP2Errors:
return float64(metric.HTTP2Errors), true
case OpsMetricCPUUsagePercent:
return metric.CPUUsagePercent, true
case OpsMetricMemoryUsagePercent:
return metric.MemoryUsagePercent, true
case OpsMetricQueueDepth:
return float64(metric.ConcurrencyQueueDepth), true
default:
return 0, false
}
}
func compareMetric(value float64, operator string, threshold float64) bool {
switch operator {
case ">":
return value > threshold
case ">=":
return value >= threshold
case "<":
return value < threshold
case "<=":
return value <= threshold
case "==":
return value == threshold
default:
return false
}
}
func buildAlertDescription(rule OpsAlertRule, value float64) string {
window := rule.WindowMinutes
if window <= 0 {
window = 1
}
return fmt.Sprintf("Rule %s triggered: %s %s %.2f (current %.2f) over last %dm",
rule.Name,
rule.MetricType,
rule.Operator,
rule.Threshold,
value,
window,
)
}
func (s *OpsAlertService) dispatchNotifications(ctx context.Context, rule OpsAlertRule, event *OpsAlertEvent) (bool, bool) {
emailSent := false
webhookSent := false
notifyCtx, cancel := s.notificationContext(ctx)
defer cancel()
if rule.NotifyEmail {
emailSent = s.sendEmailNotification(notifyCtx, rule, event)
}
if rule.NotifyWebhook && rule.WebhookURL != "" {
webhookSent = s.sendWebhookNotification(notifyCtx, rule, event)
}
// Fallback channel: if email is enabled but ultimately fails, try webhook even if the
// webhook toggle is off (as long as a webhook URL is configured).
if rule.NotifyEmail && !emailSent && !rule.NotifyWebhook && rule.WebhookURL != "" {
log.Printf("[OpsAlert] email failed; attempting webhook fallback (rule=%d)", rule.ID)
webhookSent = s.sendWebhookNotification(notifyCtx, rule, event)
}
return emailSent, webhookSent
}
const (
opsAlertEvaluateTimeout = 45 * time.Second
opsAlertNotificationTimeout = 30 * time.Second
opsAlertEmailMaxRetries = 3
)
var opsAlertEmailBackoff = []time.Duration{
1 * time.Second,
2 * time.Second,
4 * time.Second,
}
func (s *OpsAlertService) notificationContext(ctx context.Context) (context.Context, context.CancelFunc) {
parent := ctx
if s != nil && s.stopCtx != nil {
parent = s.stopCtx
}
if parent == nil {
parent = context.Background()
}
return context.WithTimeout(parent, opsAlertNotificationTimeout)
}
var opsAlertSleep = sleepWithContext
func sleepWithContext(ctx context.Context, d time.Duration) error {
if d <= 0 {
return nil
}
if ctx == nil {
time.Sleep(d)
return nil
}
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
func retryWithBackoff(
ctx context.Context,
maxRetries int,
backoff []time.Duration,
fn func() error,
onError func(attempt int, total int, nextDelay time.Duration, err error),
) error {
if ctx == nil {
ctx = context.Background()
}
if maxRetries < 0 {
maxRetries = 0
}
totalAttempts := maxRetries + 1
var lastErr error
for attempt := 1; attempt <= totalAttempts; attempt++ {
if attempt > 1 {
backoffIdx := attempt - 2
if backoffIdx < len(backoff) {
if err := opsAlertSleep(ctx, backoff[backoffIdx]); err != nil {
return err
}
}
}
if err := ctx.Err(); err != nil {
return err
}
if err := fn(); err != nil {
lastErr = err
nextDelay := time.Duration(0)
if attempt < totalAttempts {
nextIdx := attempt - 1
if nextIdx < len(backoff) {
nextDelay = backoff[nextIdx]
}
}
if onError != nil {
onError(attempt, totalAttempts, nextDelay, err)
}
continue
}
return nil
}
return lastErr
}
func (s *OpsAlertService) sendEmailNotification(ctx context.Context, rule OpsAlertRule, event *OpsAlertEvent) bool {
if s.emailService == nil || s.userService == nil {
return false
}
if ctx == nil {
ctx = context.Background()
}
admin, err := s.userService.GetFirstAdmin(ctx)
if err != nil || admin == nil || admin.Email == "" {
return false
}
subject := fmt.Sprintf("[Ops Alert][%s] %s", rule.Severity, rule.Name)
body := fmt.Sprintf(
"Alert triggered: %s\n\nMetric: %s\nThreshold: %.2f\nCurrent: %.2f\nWindow: %dm\nStatus: %s\nTime: %s",
rule.Name,
rule.MetricType,
rule.Threshold,
event.MetricValue,
rule.WindowMinutes,
event.Status,
event.FiredAt.Format(time.RFC3339),
)
config, err := s.emailService.GetSMTPConfig(ctx)
if err != nil {
log.Printf("[OpsAlert] email config load failed: %v", err)
return false
}
if err := retryWithBackoff(
ctx,
opsAlertEmailMaxRetries,
opsAlertEmailBackoff,
func() error {
return s.emailService.SendEmailWithConfig(config, admin.Email, subject, body)
},
func(attempt int, total int, nextDelay time.Duration, err error) {
if attempt < total {
log.Printf("[OpsAlert] email send failed (attempt=%d/%d), retrying in %s: %v", attempt, total, nextDelay, err)
return
}
log.Printf("[OpsAlert] email send failed (attempt=%d/%d), giving up: %v", attempt, total, err)
},
); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("[OpsAlert] email send canceled: %v", err)
}
return false
}
return true
}
func (s *OpsAlertService) sendWebhookNotification(ctx context.Context, rule OpsAlertRule, event *OpsAlertEvent) bool {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
webhookTarget, err := validateWebhookURL(ctx, rule.WebhookURL)
if err != nil {
log.Printf("[OpsAlert] invalid webhook url (rule=%d): %v", rule.ID, err)
return false
}
payload := map[string]any{
"rule_id": rule.ID,
"rule_name": rule.Name,
"severity": rule.Severity,
"status": event.Status,
"metric_type": rule.MetricType,
"metric_value": event.MetricValue,
"threshold_value": rule.Threshold,
"window_minutes": rule.WindowMinutes,
"fired_at": event.FiredAt.Format(time.RFC3339),
}
body, err := json.Marshal(payload)
if err != nil {
return false
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, webhookTarget.URL.String(), bytes.NewReader(body))
if err != nil {
return false
}
req.Header.Set("Content-Type", "application/json")
resp, err := buildWebhookHTTPClient(s.httpClient, webhookTarget).Do(req)
if err != nil {
log.Printf("[OpsAlert] webhook send failed: %v", err)
return false
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
log.Printf("[OpsAlert] webhook returned status %d", resp.StatusCode)
return false
}
return true
}
const webhookHTTPClientTimeout = 10 * time.Second
func buildWebhookHTTPClient(base *http.Client, webhookTarget *validatedWebhookTarget) *http.Client {
var client http.Client
if base != nil {
client = *base
}
if client.Timeout <= 0 {
client.Timeout = webhookHTTPClientTimeout
}
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
if webhookTarget != nil {
client.Transport = buildWebhookTransport(client.Transport, webhookTarget)
}
return &client
}
var disallowedWebhookIPNets = []net.IPNet{
// "this host on this network" / unspecified.
mustParseCIDR("0.0.0.0/8"),
mustParseCIDR("127.0.0.0/8"), // loopback (includes 127.0.0.1)
mustParseCIDR("10.0.0.0/8"), // RFC1918
mustParseCIDR("192.168.0.0/16"), // RFC1918
mustParseCIDR("172.16.0.0/12"), // RFC1918 (172.16.0.0 - 172.31.255.255)
mustParseCIDR("100.64.0.0/10"), // RFC6598 (carrier-grade NAT)
mustParseCIDR("169.254.0.0/16"), // IPv4 link-local (includes 169.254.169.254 metadata IP on many clouds)
mustParseCIDR("198.18.0.0/15"), // RFC2544 benchmark testing
mustParseCIDR("224.0.0.0/4"), // IPv4 multicast
mustParseCIDR("240.0.0.0/4"), // IPv4 reserved
mustParseCIDR("::/128"), // IPv6 unspecified
mustParseCIDR("::1/128"), // IPv6 loopback
mustParseCIDR("fc00::/7"), // IPv6 unique local
mustParseCIDR("fe80::/10"), // IPv6 link-local
mustParseCIDR("ff00::/8"), // IPv6 multicast
}
func mustParseCIDR(cidr string) net.IPNet {
_, block, err := net.ParseCIDR(cidr)
if err != nil {
panic(err)
}
return *block
}
var lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return net.DefaultResolver.LookupIPAddr(ctx, host)
}
type validatedWebhookTarget struct {
URL *url.URL
host string
port string
pinnedIPs []net.IP
}
var webhookBaseDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}
return dialer.DialContext(ctx, network, addr)
}
func buildWebhookTransport(base http.RoundTripper, webhookTarget *validatedWebhookTarget) http.RoundTripper {
if webhookTarget == nil || webhookTarget.URL == nil {
return base
}
var transport *http.Transport
switch typed := base.(type) {
case *http.Transport:
if typed != nil {
transport = typed.Clone()
}
}
if transport == nil {
if defaultTransport, ok := http.DefaultTransport.(*http.Transport); ok && defaultTransport != nil {
transport = defaultTransport.Clone()
} else {
transport = (&http.Transport{}).Clone()
}
}
webhookHost := webhookTarget.host
webhookPort := webhookTarget.port
pinnedIPs := append([]net.IP(nil), webhookTarget.pinnedIPs...)
transport.Proxy = nil
transport.DialTLSContext = nil
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil || host == "" || port == "" {
return nil, fmt.Errorf("webhook dial target is invalid: %q", addr)
}
canonicalHost := strings.TrimSuffix(strings.ToLower(host), ".")
if canonicalHost != webhookHost || port != webhookPort {
return nil, fmt.Errorf("webhook dial target mismatch: %q", addr)
}
var lastErr error
for _, ip := range pinnedIPs {
if isDisallowedWebhookIP(ip) {
lastErr = fmt.Errorf("webhook target resolves to a disallowed ip")
continue
}
dialAddr := net.JoinHostPort(ip.String(), port)
conn, err := webhookBaseDialContext(ctx, network, dialAddr)
if err == nil {
return conn, nil
}
lastErr = err
}
if lastErr == nil {
lastErr = errors.New("webhook target has no resolved addresses")
}
return nil, lastErr
}
return transport
}
func validateWebhookURL(ctx context.Context, raw string) (*validatedWebhookTarget, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil, errors.New("webhook url is empty")
}
// Avoid request smuggling / header injection vectors.
if strings.ContainsAny(raw, "\r\n") {
return nil, errors.New("webhook url contains invalid characters")
}
parsed, err := url.Parse(raw)
if err != nil {
return nil, errors.New("webhook url format is invalid")
}
if !strings.EqualFold(parsed.Scheme, "https") {
return nil, errors.New("webhook url scheme must be https")
}
parsed.Scheme = "https"
if parsed.Host == "" || parsed.Hostname() == "" {
return nil, errors.New("webhook url must include host")
}
if parsed.User != nil {
return nil, errors.New("webhook url must not include userinfo")
}
if parsed.Port() != "" {
port, err := strconv.Atoi(parsed.Port())
if err != nil || port < 1 || port > 65535 {
return nil, errors.New("webhook url port is invalid")
}
}
host := strings.TrimSuffix(strings.ToLower(parsed.Hostname()), ".")
if host == "localhost" {
return nil, errors.New("webhook url host must not be localhost")
}
if ip := net.ParseIP(host); ip != nil {
if isDisallowedWebhookIP(ip) {
return nil, errors.New("webhook url host resolves to a disallowed ip")
}
return &validatedWebhookTarget{
URL: parsed,
host: host,
port: portForScheme(parsed),
pinnedIPs: []net.IP{ip},
}, nil
}
if ctx == nil {
ctx = context.Background()
}
ips, err := lookupIPAddrs(ctx, host)
if err != nil || len(ips) == 0 {
return nil, errors.New("webhook url host cannot be resolved")
}
pinned := make([]net.IP, 0, len(ips))
for _, addr := range ips {
if isDisallowedWebhookIP(addr.IP) {
return nil, errors.New("webhook url host resolves to a disallowed ip")
}
if addr.IP != nil {
pinned = append(pinned, addr.IP)
}
}
if len(pinned) == 0 {
return nil, errors.New("webhook url host cannot be resolved")
}
return &validatedWebhookTarget{
URL: parsed,
host: host,
port: portForScheme(parsed),
pinnedIPs: uniqueResolvedIPs(pinned),
}, nil
}
func isDisallowedWebhookIP(ip net.IP) bool {
if ip == nil {
return false
}
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
} else if ip16 := ip.To16(); ip16 != nil {
ip = ip16
} else {
return false
}
// Disallow non-public addresses even if they're not explicitly covered by the CIDR list.
// This provides defense-in-depth against SSRF targets such as link-local, multicast, and
// unspecified addresses, and ensures any "pinned" IP is still blocked at dial time.
if ip.IsUnspecified() ||
ip.IsLoopback() ||
ip.IsMulticast() ||
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() ||
ip.IsPrivate() {
return true
}
for _, block := range disallowedWebhookIPNets {
if block.Contains(ip) {
return true
}
}
return false
}
func portForScheme(u *url.URL) string {
if u != nil && u.Port() != "" {
return u.Port()
}
return "443"
}
func uniqueResolvedIPs(ips []net.IP) []net.IP {
seen := make(map[string]struct{}, len(ips))
out := make([]net.IP, 0, len(ips))
for _, ip := range ips {
if ip == nil {
continue
}
key := ip.String()
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, ip)
}
return out
}

View File

@@ -0,0 +1,271 @@
//go:build integration
package service
import (
"context"
"database/sql"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// This integration test protects the DI startup contract for OpsAlertService.
//
// Background:
// - OpsMetricsCollector previously called alertService.Start()/Evaluate() directly.
// - Those direct calls were removed, so OpsAlertService must now start via DI
// (ProvideOpsAlertService in wire.go) and run its own evaluation ticker.
//
// What we validate here:
// 1. When we construct via the Wire provider functions (ProvideOpsAlertService +
// ProvideOpsMetricsCollector), OpsAlertService starts automatically.
// 2. Its evaluation loop continues to tick even if OpsMetricsCollector is stopped,
// proving the alert evaluator is independent.
// 3. The evaluation path can trigger alert logic (CreateAlertEvent called).
func TestOpsAlertService_StartedViaWireProviders_RunsIndependentTicker(t *testing.T) {
oldInterval := opsAlertEvalInterval
opsAlertEvalInterval = 25 * time.Millisecond
t.Cleanup(func() { opsAlertEvalInterval = oldInterval })
repo := newFakeOpsRepository()
opsService := NewOpsService(repo, nil)
// Start via the Wire provider function (the production DI path).
alertService := ProvideOpsAlertService(opsService, nil, nil)
t.Cleanup(alertService.Stop)
// Construct via ProvideOpsMetricsCollector (wire.go). Stop immediately to ensure
// the alert ticker keeps running without the metrics collector.
collector := ProvideOpsMetricsCollector(opsService, NewConcurrencyService(nil))
collector.Stop()
// Wait for at least one evaluation (run() calls evaluateOnce immediately).
require.Eventually(t, func() bool {
return repo.listRulesCalls.Load() >= 1
}, 1*time.Second, 5*time.Millisecond)
// Confirm the evaluation loop keeps ticking after the metrics collector is stopped.
callsAfterCollectorStop := repo.listRulesCalls.Load()
require.Eventually(t, func() bool {
return repo.listRulesCalls.Load() >= callsAfterCollectorStop+2
}, 1*time.Second, 5*time.Millisecond)
// Confirm the evaluation logic actually fires an alert event at least once.
select {
case <-repo.eventCreatedCh:
// ok
case <-time.After(2 * time.Second):
t.Fatalf("expected OpsAlertService to create an alert event, but none was created (ListAlertRules calls=%d)", repo.listRulesCalls.Load())
}
}
func newFakeOpsRepository() *fakeOpsRepository {
return &fakeOpsRepository{
eventCreatedCh: make(chan struct{}),
}
}
// fakeOpsRepository is a lightweight in-memory stub of OpsRepository for integration tests.
// It avoids real DB/Redis usage and provides deterministic responses fast.
type fakeOpsRepository struct {
listRulesCalls atomic.Int64
mu sync.Mutex
activeEvent *OpsAlertEvent
latestEvent *OpsAlertEvent
nextEventID int64
eventCreatedCh chan struct{}
eventOnce sync.Once
}
func (r *fakeOpsRepository) CreateErrorLog(ctx context.Context, log *OpsErrorLog) error {
return nil
}
func (r *fakeOpsRepository) ListErrorLogsLegacy(ctx context.Context, filters OpsErrorLogFilters) ([]OpsErrorLog, error) {
return nil, nil
}
func (r *fakeOpsRepository) ListErrorLogs(ctx context.Context, filter *ErrorLogFilter) ([]*ErrorLog, int64, error) {
return nil, 0, nil
}
func (r *fakeOpsRepository) GetLatestSystemMetric(ctx context.Context) (*OpsMetrics, error) {
return &OpsMetrics{WindowMinutes: 1}, sql.ErrNoRows
}
func (r *fakeOpsRepository) CreateSystemMetric(ctx context.Context, metric *OpsMetrics) error {
return nil
}
func (r *fakeOpsRepository) GetWindowStats(ctx context.Context, startTime, endTime time.Time) (*OpsWindowStats, error) {
return &OpsWindowStats{}, nil
}
func (r *fakeOpsRepository) GetProviderStats(ctx context.Context, startTime, endTime time.Time) ([]*ProviderStats, error) {
return nil, nil
}
func (r *fakeOpsRepository) GetLatencyHistogram(ctx context.Context, startTime, endTime time.Time) ([]*LatencyHistogramItem, error) {
return nil, nil
}
func (r *fakeOpsRepository) GetErrorDistribution(ctx context.Context, startTime, endTime time.Time) ([]*ErrorDistributionItem, error) {
return nil, nil
}
func (r *fakeOpsRepository) ListRecentSystemMetrics(ctx context.Context, windowMinutes, limit int) ([]OpsMetrics, error) {
if limit <= 0 {
limit = 1
}
now := time.Now()
metrics := make([]OpsMetrics, 0, limit)
for i := 0; i < limit; i++ {
metrics = append(metrics, OpsMetrics{
WindowMinutes: windowMinutes,
CPUUsagePercent: 99,
UpdatedAt: now.Add(-time.Duration(i) * opsMetricsInterval),
})
}
return metrics, nil
}
func (r *fakeOpsRepository) ListSystemMetricsRange(ctx context.Context, windowMinutes int, startTime, endTime time.Time, limit int) ([]OpsMetrics, error) {
return nil, nil
}
func (r *fakeOpsRepository) ListAlertRules(ctx context.Context) ([]OpsAlertRule, error) {
call := r.listRulesCalls.Add(1)
// Delay enabling rules slightly so the test can stop OpsMetricsCollector first,
// then observe the alert evaluator ticking independently.
if call < 5 {
return nil, nil
}
return []OpsAlertRule{
{
ID: 1,
Name: "cpu too high (test)",
Enabled: true,
MetricType: OpsMetricCPUUsagePercent,
Operator: ">",
Threshold: 0,
WindowMinutes: 1,
SustainedMinutes: 1,
Severity: "P1",
NotifyEmail: false,
NotifyWebhook: false,
CooldownMinutes: 0,
},
}, nil
}
func (r *fakeOpsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.activeEvent == nil {
return nil, nil
}
if r.activeEvent.RuleID != ruleID {
return nil, nil
}
if r.activeEvent.Status != OpsAlertStatusFiring {
return nil, nil
}
clone := *r.activeEvent
return &clone, nil
}
func (r *fakeOpsRepository) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.latestEvent == nil || r.latestEvent.RuleID != ruleID {
return nil, nil
}
clone := *r.latestEvent
return &clone, nil
}
func (r *fakeOpsRepository) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) error {
if event == nil {
return nil
}
r.mu.Lock()
defer r.mu.Unlock()
r.nextEventID++
event.ID = r.nextEventID
clone := *event
r.latestEvent = &clone
if clone.Status == OpsAlertStatusFiring {
r.activeEvent = &clone
}
r.eventOnce.Do(func() { close(r.eventCreatedCh) })
return nil
}
func (r *fakeOpsRepository) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.activeEvent != nil && r.activeEvent.ID == eventID {
r.activeEvent.Status = status
r.activeEvent.ResolvedAt = resolvedAt
}
if r.latestEvent != nil && r.latestEvent.ID == eventID {
r.latestEvent.Status = status
r.latestEvent.ResolvedAt = resolvedAt
}
return nil
}
func (r *fakeOpsRepository) UpdateAlertEventNotifications(ctx context.Context, eventID int64, emailSent, webhookSent bool) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.activeEvent != nil && r.activeEvent.ID == eventID {
r.activeEvent.EmailSent = emailSent
r.activeEvent.WebhookSent = webhookSent
}
if r.latestEvent != nil && r.latestEvent.ID == eventID {
r.latestEvent.EmailSent = emailSent
r.latestEvent.WebhookSent = webhookSent
}
return nil
}
func (r *fakeOpsRepository) CountActiveAlerts(ctx context.Context) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.activeEvent == nil {
return 0, nil
}
return 1, nil
}
func (r *fakeOpsRepository) GetOverviewStats(ctx context.Context, startTime, endTime time.Time) (*OverviewStats, error) {
return &OverviewStats{}, nil
}
func (r *fakeOpsRepository) GetCachedLatestSystemMetric(ctx context.Context) (*OpsMetrics, error) {
return nil, nil
}
func (r *fakeOpsRepository) SetCachedLatestSystemMetric(ctx context.Context, metric *OpsMetrics) error {
return nil
}
func (r *fakeOpsRepository) GetCachedDashboardOverview(ctx context.Context, timeRange string) (*DashboardOverviewData, error) {
return nil, nil
}
func (r *fakeOpsRepository) SetCachedDashboardOverview(ctx context.Context, timeRange string, data *DashboardOverviewData, ttl time.Duration) error {
return nil
}
func (r *fakeOpsRepository) PingRedis(ctx context.Context) error {
return nil
}

View File

@@ -0,0 +1,315 @@
//go:build unit || opsalert_unit
package service
import (
"context"
"errors"
"net"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestSelectContiguousMetrics_Contiguous(t *testing.T) {
now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
metrics := []OpsMetrics{
{UpdatedAt: now},
{UpdatedAt: now.Add(-1 * time.Minute)},
{UpdatedAt: now.Add(-2 * time.Minute)},
}
selected, ok := selectContiguousMetrics(metrics, 3, now)
require.True(t, ok)
require.Len(t, selected, 3)
}
func TestSelectContiguousMetrics_GapFails(t *testing.T) {
now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
metrics := []OpsMetrics{
{UpdatedAt: now},
// Missing the -1m sample (gap ~=2m).
{UpdatedAt: now.Add(-2 * time.Minute)},
{UpdatedAt: now.Add(-3 * time.Minute)},
}
_, ok := selectContiguousMetrics(metrics, 3, now)
require.False(t, ok)
}
func TestSelectContiguousMetrics_StaleNewestFails(t *testing.T) {
now := time.Date(2026, 1, 1, 0, 10, 0, 0, time.UTC)
metrics := []OpsMetrics{
{UpdatedAt: now.Add(-10 * time.Minute)},
{UpdatedAt: now.Add(-11 * time.Minute)},
}
_, ok := selectContiguousMetrics(metrics, 2, now)
require.False(t, ok)
}
func TestMetricValue_SuccessRate_NoTrafficIsNoData(t *testing.T) {
metric := OpsMetrics{
RequestCount: 0,
SuccessRate: 0,
}
value, ok := metricValue(metric, OpsMetricSuccessRate)
require.False(t, ok)
require.Equal(t, 0.0, value)
}
func TestOpsAlertService_StopWithoutStart_NoPanic(t *testing.T) {
s := NewOpsAlertService(nil, nil, nil)
require.NotPanics(t, func() { s.Stop() })
}
func TestOpsAlertService_StartStop_Graceful(t *testing.T) {
s := NewOpsAlertService(nil, nil, nil)
s.interval = 5 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.StartWithContext(ctx)
done := make(chan struct{})
go func() {
s.Stop()
close(done)
}()
select {
case <-done:
// ok
case <-time.After(1 * time.Second):
t.Fatal("Stop did not return; background goroutine likely stuck")
}
require.NotPanics(t, func() { s.Stop() })
}
func TestBuildWebhookHTTPClient_DefaultTimeout(t *testing.T) {
client := buildWebhookHTTPClient(nil, nil)
require.Equal(t, webhookHTTPClientTimeout, client.Timeout)
require.NotNil(t, client.CheckRedirect)
require.ErrorIs(t, client.CheckRedirect(nil, nil), http.ErrUseLastResponse)
base := &http.Client{}
client = buildWebhookHTTPClient(base, nil)
require.Equal(t, webhookHTTPClientTimeout, client.Timeout)
require.NotNil(t, client.CheckRedirect)
base = &http.Client{Timeout: 2 * time.Second}
client = buildWebhookHTTPClient(base, nil)
require.Equal(t, 2*time.Second, client.Timeout)
require.NotNil(t, client.CheckRedirect)
}
func TestValidateWebhookURL_RequiresHTTPS(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
}
_, err := validateWebhookURL(context.Background(), "http://example.com/webhook")
require.Error(t, err)
}
func TestValidateWebhookURL_InvalidFormatRejected(t *testing.T) {
_, err := validateWebhookURL(context.Background(), "https://[::1")
require.Error(t, err)
}
func TestValidateWebhookURL_RejectsUserinfo(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
}
_, err := validateWebhookURL(context.Background(), "https://user:pass@example.com/webhook")
require.Error(t, err)
}
func TestValidateWebhookURL_RejectsLocalhost(t *testing.T) {
_, err := validateWebhookURL(context.Background(), "https://localhost/webhook")
require.Error(t, err)
}
func TestValidateWebhookURL_RejectsPrivateIPLiteral(t *testing.T) {
cases := []string{
"https://0.0.0.0/webhook",
"https://127.0.0.1/webhook",
"https://10.0.0.1/webhook",
"https://192.168.1.2/webhook",
"https://172.16.0.1/webhook",
"https://172.31.255.255/webhook",
"https://100.64.0.1/webhook",
"https://169.254.169.254/webhook",
"https://198.18.0.1/webhook",
"https://224.0.0.1/webhook",
"https://240.0.0.1/webhook",
"https://[::]/webhook",
"https://[::1]/webhook",
"https://[ff02::1]/webhook",
}
for _, tc := range cases {
t.Run(tc, func(t *testing.T) {
_, err := validateWebhookURL(context.Background(), tc)
require.Error(t, err)
})
}
}
func TestValidateWebhookURL_RejectsPrivateIPViaDNS(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
require.Equal(t, "internal.example", host)
return []net.IPAddr{{IP: net.ParseIP("10.0.0.2")}}, nil
}
_, err := validateWebhookURL(context.Background(), "https://internal.example/webhook")
require.Error(t, err)
}
func TestValidateWebhookURL_RejectsLinkLocalIPViaDNS(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
require.Equal(t, "metadata.example", host)
return []net.IPAddr{{IP: net.ParseIP("169.254.169.254")}}, nil
}
_, err := validateWebhookURL(context.Background(), "https://metadata.example/webhook")
require.Error(t, err)
}
func TestValidateWebhookURL_AllowsPublicHostViaDNS(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
require.Equal(t, "example.com", host)
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
}
target, err := validateWebhookURL(context.Background(), "https://example.com:443/webhook")
require.NoError(t, err)
require.Equal(t, "https", target.URL.Scheme)
require.Equal(t, "example.com", target.URL.Hostname())
require.Equal(t, "443", target.URL.Port())
}
func TestValidateWebhookURL_RejectsInvalidPort(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
}
_, err := validateWebhookURL(context.Background(), "https://example.com:99999/webhook")
require.Error(t, err)
}
func TestWebhookTransport_UsesPinnedIP_NoDNSRebinding(t *testing.T) {
oldLookup := lookupIPAddrs
oldDial := webhookBaseDialContext
t.Cleanup(func() {
lookupIPAddrs = oldLookup
webhookBaseDialContext = oldDial
})
lookupCalls := 0
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
lookupCalls++
require.Equal(t, "example.com", host)
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
}
target, err := validateWebhookURL(context.Background(), "https://example.com/webhook")
require.NoError(t, err)
require.Equal(t, 1, lookupCalls)
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
lookupCalls++
return []net.IPAddr{{IP: net.ParseIP("10.0.0.1")}}, nil
}
var dialAddrs []string
webhookBaseDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
dialAddrs = append(dialAddrs, addr)
return nil, errors.New("dial blocked in test")
}
client := buildWebhookHTTPClient(nil, target)
transport, ok := client.Transport.(*http.Transport)
require.True(t, ok)
_, err = transport.DialContext(context.Background(), "tcp", "example.com:443")
require.Error(t, err)
require.Equal(t, []string{"93.184.216.34:443"}, dialAddrs)
require.Equal(t, 1, lookupCalls, "dial path must not re-resolve DNS")
}
func TestRetryWithBackoff_SucceedsAfterRetries(t *testing.T) {
oldSleep := opsAlertSleep
t.Cleanup(func() { opsAlertSleep = oldSleep })
var slept []time.Duration
opsAlertSleep = func(ctx context.Context, d time.Duration) error {
slept = append(slept, d)
return nil
}
attempts := 0
err := retryWithBackoff(
context.Background(),
3,
[]time.Duration{time.Second, 2 * time.Second, 4 * time.Second},
func() error {
attempts++
if attempts <= 3 {
return errors.New("send failed")
}
return nil
},
nil,
)
require.NoError(t, err)
require.Equal(t, 4, attempts)
require.Equal(t, []time.Duration{time.Second, 2 * time.Second, 4 * time.Second}, slept)
}
func TestRetryWithBackoff_ContextCanceledStopsRetries(t *testing.T) {
oldSleep := opsAlertSleep
t.Cleanup(func() { opsAlertSleep = oldSleep })
var slept []time.Duration
opsAlertSleep = func(ctx context.Context, d time.Duration) error {
slept = append(slept, d)
return ctx.Err()
}
ctx, cancel := context.WithCancel(context.Background())
attempts := 0
err := retryWithBackoff(
ctx,
3,
[]time.Duration{time.Second, 2 * time.Second, 4 * time.Second},
func() error {
attempts++
return errors.New("send failed")
},
func(attempt int, total int, nextDelay time.Duration, err error) {
if attempt == 1 {
cancel()
}
},
)
require.ErrorIs(t, err, context.Canceled)
require.Equal(t, 1, attempts)
require.Equal(t, []time.Duration{time.Second}, slept)
}

View File

@@ -0,0 +1,92 @@
package service
import (
"context"
"time"
)
const (
OpsAlertStatusFiring = "firing"
OpsAlertStatusResolved = "resolved"
)
const (
OpsMetricSuccessRate = "success_rate"
OpsMetricErrorRate = "error_rate"
OpsMetricP95LatencyMs = "p95_latency_ms"
OpsMetricP99LatencyMs = "p99_latency_ms"
OpsMetricHTTP2Errors = "http2_errors"
OpsMetricCPUUsagePercent = "cpu_usage_percent"
OpsMetricMemoryUsagePercent = "memory_usage_percent"
OpsMetricQueueDepth = "concurrency_queue_depth"
)
type OpsAlertRule struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Enabled bool `json:"enabled"`
MetricType string `json:"metric_type"`
Operator string `json:"operator"`
Threshold float64 `json:"threshold"`
WindowMinutes int `json:"window_minutes"`
SustainedMinutes int `json:"sustained_minutes"`
Severity string `json:"severity"`
NotifyEmail bool `json:"notify_email"`
NotifyWebhook bool `json:"notify_webhook"`
WebhookURL string `json:"webhook_url"`
CooldownMinutes int `json:"cooldown_minutes"`
DimensionFilters map[string]any `json:"dimension_filters,omitempty"`
NotifyChannels []string `json:"notify_channels,omitempty"`
NotifyConfig map[string]any `json:"notify_config,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type OpsAlertEvent struct {
ID int64 `json:"id"`
RuleID int64 `json:"rule_id"`
Severity string `json:"severity"`
Status string `json:"status"`
Title string `json:"title"`
Description string `json:"description"`
MetricValue float64 `json:"metric_value"`
ThresholdValue float64 `json:"threshold_value"`
FiredAt time.Time `json:"fired_at"`
ResolvedAt *time.Time `json:"resolved_at"`
EmailSent bool `json:"email_sent"`
WebhookSent bool `json:"webhook_sent"`
CreatedAt time.Time `json:"created_at"`
}
func (s *OpsService) ListAlertRules(ctx context.Context) ([]OpsAlertRule, error) {
return s.repo.ListAlertRules(ctx)
}
func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
return s.repo.GetActiveAlertEvent(ctx, ruleID)
}
func (s *OpsService) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
return s.repo.GetLatestAlertEvent(ctx, ruleID)
}
func (s *OpsService) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) error {
return s.repo.CreateAlertEvent(ctx, event)
}
func (s *OpsService) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
return s.repo.UpdateAlertEventStatus(ctx, eventID, status, resolvedAt)
}
func (s *OpsService) UpdateAlertEventNotifications(ctx context.Context, eventID int64, emailSent, webhookSent bool) error {
return s.repo.UpdateAlertEventNotifications(ctx, eventID, emailSent, webhookSent)
}
func (s *OpsService) ListRecentSystemMetrics(ctx context.Context, windowMinutes, limit int) ([]OpsMetrics, error) {
return s.repo.ListRecentSystemMetrics(ctx, windowMinutes, limit)
}
func (s *OpsService) CountActiveAlerts(ctx context.Context) (int, error) {
return s.repo.CountActiveAlerts(ctx)
}

View File

@@ -0,0 +1,203 @@
package service
import (
"context"
"log"
"runtime"
"sync"
"time"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/mem"
)
const (
opsMetricsInterval = 1 * time.Minute
opsMetricsCollectTimeout = 10 * time.Second
opsMetricsWindowShortMinutes = 1
opsMetricsWindowLongMinutes = 5
bytesPerMB = 1024 * 1024
cpuUsageSampleInterval = 0 * time.Second
percentScale = 100
)
type OpsMetricsCollector struct {
opsService *OpsService
concurrencyService *ConcurrencyService
interval time.Duration
lastGCPauseTotal uint64
lastGCPauseMu sync.Mutex
stopCh chan struct{}
startOnce sync.Once
stopOnce sync.Once
}
func NewOpsMetricsCollector(opsService *OpsService, concurrencyService *ConcurrencyService) *OpsMetricsCollector {
return &OpsMetricsCollector{
opsService: opsService,
concurrencyService: concurrencyService,
interval: opsMetricsInterval,
}
}
func (c *OpsMetricsCollector) Start() {
if c == nil {
return
}
c.startOnce.Do(func() {
if c.stopCh == nil {
c.stopCh = make(chan struct{})
}
go c.run()
})
}
func (c *OpsMetricsCollector) Stop() {
if c == nil {
return
}
c.stopOnce.Do(func() {
if c.stopCh != nil {
close(c.stopCh)
}
})
}
func (c *OpsMetricsCollector) run() {
ticker := time.NewTicker(c.interval)
defer ticker.Stop()
c.collectOnce()
for {
select {
case <-ticker.C:
c.collectOnce()
case <-c.stopCh:
return
}
}
}
func (c *OpsMetricsCollector) collectOnce() {
if c.opsService == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), opsMetricsCollectTimeout)
defer cancel()
now := time.Now()
systemStats := c.collectSystemStats(ctx)
queueDepth := c.collectQueueDepth(ctx)
activeAlerts := c.collectActiveAlerts(ctx)
for _, window := range []int{opsMetricsWindowShortMinutes, opsMetricsWindowLongMinutes} {
startTime := now.Add(-time.Duration(window) * time.Minute)
windowStats, err := c.opsService.GetWindowStats(ctx, startTime, now)
if err != nil {
log.Printf("[OpsMetrics] failed to get window stats (%dm): %v", window, err)
continue
}
successRate, errorRate := computeRates(windowStats.SuccessCount, windowStats.ErrorCount)
requestCount := windowStats.SuccessCount + windowStats.ErrorCount
metric := &OpsMetrics{
WindowMinutes: window,
RequestCount: requestCount,
SuccessCount: windowStats.SuccessCount,
ErrorCount: windowStats.ErrorCount,
SuccessRate: successRate,
ErrorRate: errorRate,
P95LatencyMs: windowStats.P95LatencyMs,
P99LatencyMs: windowStats.P99LatencyMs,
HTTP2Errors: windowStats.HTTP2Errors,
ActiveAlerts: activeAlerts,
CPUUsagePercent: systemStats.cpuUsage,
MemoryUsedMB: systemStats.memoryUsedMB,
MemoryTotalMB: systemStats.memoryTotalMB,
MemoryUsagePercent: systemStats.memoryUsagePercent,
HeapAllocMB: systemStats.heapAllocMB,
GCPauseMs: systemStats.gcPauseMs,
ConcurrencyQueueDepth: queueDepth,
UpdatedAt: now,
}
if err := c.opsService.RecordMetrics(ctx, metric); err != nil {
log.Printf("[OpsMetrics] failed to record metrics (%dm): %v", window, err)
}
}
}
func computeRates(successCount, errorCount int64) (float64, float64) {
total := successCount + errorCount
if total == 0 {
// No traffic => no data. Rates are kept at 0 and request_count will be 0.
// The UI should render this as N/A instead of "100% success".
return 0, 0
}
successRate := float64(successCount) / float64(total) * percentScale
errorRate := float64(errorCount) / float64(total) * percentScale
return successRate, errorRate
}
type opsSystemStats struct {
cpuUsage float64
memoryUsedMB int64
memoryTotalMB int64
memoryUsagePercent float64
heapAllocMB int64
gcPauseMs float64
}
func (c *OpsMetricsCollector) collectSystemStats(ctx context.Context) opsSystemStats {
stats := opsSystemStats{}
if percents, err := cpu.PercentWithContext(ctx, cpuUsageSampleInterval, false); err == nil && len(percents) > 0 {
stats.cpuUsage = percents[0]
}
if vm, err := mem.VirtualMemoryWithContext(ctx); err == nil {
stats.memoryUsedMB = int64(vm.Used / bytesPerMB)
stats.memoryTotalMB = int64(vm.Total / bytesPerMB)
stats.memoryUsagePercent = vm.UsedPercent
}
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
stats.heapAllocMB = int64(memStats.HeapAlloc / bytesPerMB)
c.lastGCPauseMu.Lock()
if c.lastGCPauseTotal != 0 && memStats.PauseTotalNs >= c.lastGCPauseTotal {
stats.gcPauseMs = float64(memStats.PauseTotalNs-c.lastGCPauseTotal) / float64(time.Millisecond)
}
c.lastGCPauseTotal = memStats.PauseTotalNs
c.lastGCPauseMu.Unlock()
return stats
}
func (c *OpsMetricsCollector) collectQueueDepth(ctx context.Context) int {
if c.concurrencyService == nil {
return 0
}
depth, err := c.concurrencyService.GetTotalWaitCount(ctx)
if err != nil {
log.Printf("[OpsMetrics] failed to get queue depth: %v", err)
return 0
}
return depth
}
func (c *OpsMetricsCollector) collectActiveAlerts(ctx context.Context) int {
if c.opsService == nil {
return 0
}
count, err := c.opsService.CountActiveAlerts(ctx)
if err != nil {
return 0
}
return count
}

File diff suppressed because it is too large Load Diff

View File

@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeySiteName,
SettingKeySiteLogo,
SettingKeySiteSubtitle,
SettingKeyApiBaseUrl,
SettingKeyAPIBaseURL,
SettingKeyContactInfo,
SettingKeyDocUrl,
SettingKeyDocURL,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocUrl: settings[SettingKeyDocUrl],
DocURL: settings[SettingKeyDocURL],
}, nil
}
@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
// 邮件服务设置(只有非空才更新密码)
updates[SettingKeySmtpHost] = settings.SmtpHost
updates[SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
updates[SettingKeySmtpUsername] = settings.SmtpUsername
if settings.SmtpPassword != "" {
updates[SettingKeySmtpPassword] = settings.SmtpPassword
updates[SettingKeySMTPHost] = settings.SMTPHost
updates[SettingKeySMTPPort] = strconv.Itoa(settings.SMTPPort)
updates[SettingKeySMTPUsername] = settings.SMTPUsername
if settings.SMTPPassword != "" {
updates[SettingKeySMTPPassword] = settings.SMTPPassword
}
updates[SettingKeySmtpFrom] = settings.SmtpFrom
updates[SettingKeySmtpFromName] = settings.SmtpFromName
updates[SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
updates[SettingKeySMTPFrom] = settings.SMTPFrom
updates[SettingKeySMTPFromName] = settings.SMTPFromName
updates[SettingKeySMTPUseTLS] = strconv.FormatBool(settings.SMTPUseTLS)
// Cloudflare Turnstile 设置(只有非空才更新密钥)
updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
@@ -115,9 +115,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo
updates[SettingKeySiteSubtitle] = settings.SiteSubtitle
updates[SettingKeyApiBaseUrl] = settings.ApiBaseUrl
updates[SettingKeyAPIBaseURL] = settings.APIBaseURL
updates[SettingKeyContactInfo] = settings.ContactInfo
updates[SettingKeyDocUrl] = settings.DocUrl
updates[SettingKeyDocURL] = settings.DocURL
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
@@ -198,8 +198,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeySmtpPort: "587",
SettingKeySmtpUseTLS: "false",
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
}
return s.settingRepo.SetMultiple(ctx, defaults)
@@ -210,26 +210,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SmtpHost: settings[SettingKeySmtpHost],
SmtpUsername: settings[SettingKeySmtpUsername],
SmtpFrom: settings[SettingKeySmtpFrom],
SmtpFromName: settings[SettingKeySmtpFromName],
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],
SMTPFromName: settings[SettingKeySMTPFromName],
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocUrl: settings[SettingKeyDocUrl],
DocURL: settings[SettingKeyDocURL],
}
// 解析整数类型
if port, err := strconv.Atoi(settings[SettingKeySmtpPort]); err == nil {
result.SmtpPort = port
if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil {
result.SMTPPort = port
} else {
result.SmtpPort = 587
result.SMTPPort = 587
}
if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil {
@@ -245,8 +245,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.DefaultBalance = s.cfg.Default.UserBalance
}
// 敏感信息直接返回方便测试连接时使用
result.SmtpPassword = settings[SettingKeySmtpPassword]
// 敏感信息直接返回,方便测试连接时使用
result.SMTPPassword = settings[SettingKeySMTPPassword]
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
return result
@@ -278,28 +278,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
return value
}
// GenerateAdminApiKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error) {
// GenerateAdminAPIKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) {
// 生成 32 字节随机数 = 64 位十六进制字符
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
key := AdminApiKeyPrefix + hex.EncodeToString(bytes)
key := AdminAPIKeyPrefix + hex.EncodeToString(bytes)
// 存储到 settings 表
if err := s.settingRepo.Set(ctx, SettingKeyAdminApiKey, key); err != nil {
if err := s.settingRepo.Set(ctx, SettingKeyAdminAPIKey, key); err != nil {
return "", fmt.Errorf("save admin api key: %w", err)
}
return key, nil
}
// GetAdminApiKeyStatus 获取管理员 API Key 状态
// GetAdminAPIKeyStatus 获取管理员 API Key 状态
// 返回脱敏的 key、是否存在、错误
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return "", false, nil
@@ -320,10 +320,10 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
return maskedKey, true, nil
}
// GetAdminApiKey 获取完整的管理员 API Key仅供内部验证使用
// GetAdminAPIKey 获取完整的管理员 API Key仅供内部验证使用
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return "", nil // 未配置,返回空字符串
@@ -333,7 +333,7 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
return key, nil
}
// DeleteAdminApiKey 删除管理员 API Key
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, SettingKeyAdminApiKey)
// DeleteAdminAPIKey 删除管理员 API Key
func (s *SettingService) DeleteAdminAPIKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, SettingKeyAdminAPIKey)
}

View File

@@ -4,13 +4,13 @@ type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
SmtpHost string
SmtpPort int
SmtpUsername string
SmtpPassword string
SmtpFrom string
SmtpFromName string
SmtpUseTLS bool
SMTPHost string
SMTPPort int
SMTPUsername string
SMTPPassword string
SMTPFrom string
SMTPFromName string
SMTPUseTLS bool
TurnstileEnabled bool
TurnstileSiteKey string
@@ -19,9 +19,9 @@ type SystemSettings struct {
SiteName string
SiteLogo string
SiteSubtitle string
ApiBaseUrl string
APIBaseURL string
ContactInfo string
DocUrl string
DocURL string
DefaultConcurrency int
DefaultBalance float64
@@ -35,8 +35,8 @@ type PublicSettings struct {
SiteName string
SiteLogo string
SiteSubtitle string
ApiBaseUrl string
APIBaseURL string
ContactInfo string
DocUrl string
DocURL string
Version string
}

Some files were not shown because too many files have changed in this diff Show More