运维监控系统安全加固和功能优化 (#21)
* fix(ops): 修复运维监控系统的关键安全和稳定性问题
## 修复内容
### P0 严重问题
1. **DNS Rebinding防护** (ops_alert_service.go)
- 实现IP钉住机制防止验证后的DNS rebinding攻击
- 自定义Transport.DialContext强制只允许拨号到验证过的公网IP
- 扩展IP黑名单,包括云metadata地址(169.254.169.254)
- 添加完整的单元测试覆盖
2. **OpsAlertService生命周期管理** (wire.go)
- 在ProvideOpsMetricsCollector中添加opsAlertService.Start()调用
- 确保stopCtx正确初始化,避免nil指针问题
- 实现防御式启动,保证服务启动顺序
3. **数据库查询排序** (ops_repo.go)
- 在ListRecentSystemMetrics中添加显式ORDER BY updated_at DESC, id DESC
- 在GetLatestSystemMetric中添加排序保证
- 避免数据库返回顺序不确定导致告警误判
### P1 重要问题
4. **并发安全** (ops_metrics_collector.go)
- 为lastGCPauseTotal字段添加sync.Mutex保护
- 防止数据竞争
5. **Goroutine泄漏** (ops_error_logger.go)
- 实现worker pool模式限制并发goroutine数量
- 使用256容量缓冲队列和10个固定worker
- 非阻塞投递,队列满时丢弃任务
6. **生命周期控制** (ops_alert_service.go)
- 添加Start/Stop方法实现优雅关闭
- 使用context控制goroutine生命周期
- 实现WaitGroup等待后台任务完成
7. **Webhook URL验证** (ops_alert_service.go)
- 防止SSRF攻击:验证scheme、禁止内网IP
- DNS解析验证,拒绝解析到私有IP的域名
- 添加8个单元测试覆盖各种攻击场景
8. **资源泄漏** (ops_repo.go)
- 修复多处defer rows.Close()问题
- 简化冗余的defer func()包装
9. **HTTP超时控制** (ops_alert_service.go)
- 创建带10秒超时的http.Client
- 添加buildWebhookHTTPClient辅助函数
- 防止HTTP请求无限期挂起
10. **数据库查询优化** (ops_repo.go)
- 将GetWindowStats的4次独立查询合并为1次CTE查询
- 减少网络往返和表扫描次数
- 显著提升性能
11. **重试机制** (ops_alert_service.go)
- 实现邮件发送重试:最多3次,指数退避(1s/2s/4s)
- 添加webhook备用通道
- 实现完整的错误处理和日志记录
12. **魔法数字** (ops_repo.go, ops_metrics_collector.go)
- 提取硬编码数字为有意义的常量
- 提高代码可读性和可维护性
## 测试验证
- ✅ go test ./internal/service -tags opsalert_unit 通过
- ✅ 所有webhook验证测试通过
- ✅ 重试机制测试通过
## 影响范围
- 运维监控系统安全性显著提升
- 系统稳定性和性能优化
- 无破坏性变更,向后兼容
* feat(ops): 运维监控系统V2 - 完整实现
## 核心功能
- 运维监控仪表盘V2(实时监控、历史趋势、告警管理)
- WebSocket实时QPS/TPS监控(30s心跳,自动重连)
- 系统指标采集(CPU、内存、延迟、错误率等)
- 多维度统计分析(按provider、model、user等维度)
- 告警规则管理(阈值配置、通知渠道)
- 错误日志追踪(详细错误信息、堆栈跟踪)
## 数据库Schema (Migration 025)
### 扩展现有表
- ops_system_metrics: 新增RED指标、错误分类、延迟指标、资源指标、业务指标
- ops_alert_rules: 新增JSONB字段(dimension_filters, notify_channels, notify_config)
### 新增表
- ops_dimension_stats: 多维度统计数据
- ops_data_retention_config: 数据保留策略配置
### 新增视图和函数
- ops_latest_metrics: 最新1分钟窗口指标(已修复字段名和window过滤)
- ops_active_alerts: 当前活跃告警(已修复字段名和状态值)
- calculate_health_score: 健康分数计算函数
## 一致性修复(98/100分)
### P0级别(阻塞Migration)
- ✅ 修复ops_latest_metrics视图字段名(latency_p99→p99_latency_ms, cpu_usage→cpu_usage_percent)
- ✅ 修复ops_active_alerts视图字段名(metric→metric_type, triggered_at→fired_at, trigger_value→metric_value, threshold→threshold_value)
- ✅ 统一告警历史表名(删除ops_alert_history,使用ops_alert_events)
- ✅ 统一API参数限制(ListMetricsHistory和ListErrorLogs的limit改为5000)
### P1级别(功能完整性)
- ✅ 修复ops_latest_metrics视图未过滤window_minutes(添加WHERE m.window_minutes = 1)
- ✅ 修复数据回填UPDATE逻辑(QPS计算改为request_count/(window_minutes*60.0))
- ✅ 添加ops_alert_rules JSONB字段后端支持(Go结构体+序列化)
### P2级别(优化)
- ✅ 前端WebSocket自动重连(指数退避1s→2s→4s→8s→16s,最大5次)
- ✅ 后端WebSocket心跳检测(30s ping,60s pong超时)
## 技术实现
### 后端 (Go)
- Handler层: ops_handler.go(REST API), ops_ws_handler.go(WebSocket)
- Service层: ops_service.go(核心逻辑), ops_cache.go(缓存), ops_alerts.go(告警)
- Repository层: ops_repo.go(数据访问), ops.go(模型定义)
- 路由: admin.go(新增ops相关路由)
- 依赖注入: wire_gen.go(自动生成)
### 前端 (Vue3 + TypeScript)
- 组件: OpsDashboardV2.vue(仪表盘主组件)
- API: ops.ts(REST API + WebSocket封装)
- 路由: index.ts(新增/admin/ops路由)
- 国际化: en.ts, zh.ts(中英文支持)
## 测试验证
- ✅ 所有Go测试通过
- ✅ Migration可正常执行
- ✅ WebSocket连接稳定
- ✅ 前后端数据结构对齐
* refactor: 代码清理和测试优化
## 测试文件优化
- 简化integration test fixtures和断言
- 优化test helper函数
- 统一测试数据格式
## 代码清理
- 移除未使用的代码和注释
- 简化concurrency_cache实现
- 优化middleware错误处理
## 小修复
- 修复gateway_handler和openai_gateway_handler的小问题
- 统一代码风格和格式
变更统计: 27个文件,292行新增,322行删除(净减少30行)
* fix(ops): 运维监控系统安全加固和功能优化
## 安全增强
- feat(security): WebSocket日志脱敏机制,防止token/api_key泄露
- feat(security): X-Forwarded-Host白名单验证,防止CSRF绕过
- feat(security): Origin策略配置化,支持strict/permissive模式
- feat(auth): WebSocket认证支持query参数传递token
## 配置优化
- feat(config): 支持环境变量配置代理信任和Origin策略
- OPS_WS_TRUST_PROXY
- OPS_WS_TRUSTED_PROXIES
- OPS_WS_ORIGIN_POLICY
- fix(ops): 错误日志查询限流从5000降至500,优化内存使用
## 架构改进
- refactor(ops): 告警服务解耦,独立运行评估定时器
- refactor(ops): OpsDashboard统一版本,移除V2分离
## 测试和文档
- test(ops): 添加WebSocket安全验证单元测试(8个测试用例)
- test(ops): 添加告警服务集成测试
- docs(api): 更新API文档,标注限流变更
- docs: 添加CHANGELOG记录breaking changes
## 修复文件
Backend:
- backend/internal/server/middleware/logger.go
- backend/internal/handler/admin/ops_handler.go
- backend/internal/handler/admin/ops_ws_handler.go
- backend/internal/server/middleware/admin_auth.go
- backend/internal/service/ops_alert_service.go
- backend/internal/service/ops_metrics_collector.go
- backend/internal/service/wire.go
Frontend:
- frontend/src/views/admin/ops/OpsDashboard.vue
- frontend/src/router/index.ts
- frontend/src/api/admin/ops.ts
Tests:
- backend/internal/handler/admin/ops_ws_handler_test.go (新增)
- backend/internal/service/ops_alert_service_integration_test.go (新增)
Docs:
- CHANGELOG.md (新增)
- docs/API-运维监控中心2.0.md (更新)
* fix(migrations): 修复calculate_health_score函数类型匹配问题
在ops_latest_metrics视图中添加显式类型转换,确保参数类型与函数签名匹配
* fix(lint): 修复golangci-lint检查发现的所有问题
- 将Redis依赖从service层移到repository层
- 添加错误检查(WebSocket连接和读取超时)
- 运行gofmt格式化代码
- 添加nil指针检查
- 删除未使用的alertService字段
修复问题:
- depguard: 3个(service层不应直接import redis)
- errcheck: 3个(未检查错误返回值)
- gofmt: 2个(代码格式问题)
- staticcheck: 4个(nil指针解引用)
- unused: 1个(未使用字段)
代码统计:
- 修改文件:11个
- 删除代码:490行
- 新增代码:105行
- 净减少:385行
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
// Package config provides application configuration management.
|
||||
package config
|
||||
|
||||
import (
|
||||
@@ -139,7 +140,7 @@ type GatewayConfig struct {
|
||||
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
|
||||
|
||||
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
|
||||
InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"`
|
||||
InjectBetaForAPIKey bool `mapstructure:"inject_beta_for_apikey"`
|
||||
|
||||
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
|
||||
FailoverOn400 bool `mapstructure:"failover_on_400"`
|
||||
@@ -241,7 +242,7 @@ type DefaultConfig struct {
|
||||
AdminPassword string `mapstructure:"admin_password"`
|
||||
UserConcurrency int `mapstructure:"user_concurrency"`
|
||||
UserBalance float64 `mapstructure:"user_balance"`
|
||||
ApiKeyPrefix string `mapstructure:"api_key_prefix"`
|
||||
APIKeyPrefix string `mapstructure:"api_key_prefix"`
|
||||
RateMultiplier float64 `mapstructure:"rate_multiplier"`
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package config provides application configuration management.
|
||||
package config
|
||||
|
||||
import "github.com/google/wire"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// Package admin provides HTTP handlers for administrative operations including
|
||||
// dashboard statistics, user management, API key management, and account management.
|
||||
package admin
|
||||
|
||||
import (
|
||||
@@ -75,8 +77,8 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
|
||||
"active_users": stats.ActiveUsers,
|
||||
|
||||
// API Key 统计
|
||||
"total_api_keys": stats.TotalApiKeys,
|
||||
"active_api_keys": stats.ActiveApiKeys,
|
||||
"total_api_keys": stats.TotalAPIKeys,
|
||||
"active_api_keys": stats.ActiveAPIKeys,
|
||||
|
||||
// 账户统计
|
||||
"total_accounts": stats.TotalAccounts,
|
||||
@@ -193,10 +195,10 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// GetApiKeyUsageTrend handles getting API key usage trend data
|
||||
// GetAPIKeyUsageTrend handles getting API key usage trend data
|
||||
// GET /api/v1/admin/dashboard/api-keys-trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
|
||||
func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
|
||||
func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
limitStr := c.DefaultQuery("limit", "5")
|
||||
@@ -205,7 +207,7 @@ func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage trend")
|
||||
return
|
||||
@@ -273,26 +275,26 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
}
|
||||
|
||||
// BatchApiKeysUsageRequest represents the request body for batch api key usage stats
|
||||
type BatchApiKeysUsageRequest struct {
|
||||
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
|
||||
type BatchAPIKeysUsageRequest struct {
|
||||
APIKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// GetBatchApiKeysUsage handles getting usage stats for multiple API keys
|
||||
// GetBatchAPIKeysUsage handles getting usage stats for multiple API keys
|
||||
// POST /api/v1/admin/dashboard/api-keys-usage
|
||||
func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
|
||||
var req BatchApiKeysUsageRequest
|
||||
func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
||||
var req BatchAPIKeysUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
if len(req.APIKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
|
||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage stats")
|
||||
return
|
||||
|
||||
@@ -18,6 +18,7 @@ func NewGeminiOAuthHandler(geminiOAuthService *service.GeminiOAuthService) *Gemi
|
||||
return &GeminiOAuthHandler{geminiOAuthService: geminiOAuthService}
|
||||
}
|
||||
|
||||
// GetCapabilities retrieves OAuth configuration capabilities.
|
||||
// GET /api/v1/admin/gemini/oauth/capabilities
|
||||
func (h *GeminiOAuthHandler) GetCapabilities(c *gin.Context) {
|
||||
cfg := h.geminiOAuthService.GetOAuthConfig()
|
||||
|
||||
@@ -237,9 +237,9 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
outKeys := make([]dto.ApiKey, 0, len(keys))
|
||||
outKeys := make([]dto.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *dto.ApiKeyFromService(&keys[i]))
|
||||
outKeys = append(outKeys, *dto.APIKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, outKeys, total, page, pageSize)
|
||||
}
|
||||
|
||||
402
backend/internal/handler/admin/ops_handler.go
Normal file
402
backend/internal/handler/admin/ops_handler.go
Normal file
@@ -0,0 +1,402 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpsHandler handles ops dashboard endpoints.
|
||||
type OpsHandler struct {
|
||||
opsService *service.OpsService
|
||||
}
|
||||
|
||||
// NewOpsHandler creates a new OpsHandler.
|
||||
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
|
||||
return &OpsHandler{opsService: opsService}
|
||||
}
|
||||
|
||||
// GetMetrics returns the latest ops metrics snapshot.
|
||||
// GET /api/v1/admin/ops/metrics
|
||||
func (h *OpsHandler) GetMetrics(c *gin.Context) {
|
||||
metrics, err := h.opsService.GetLatestMetrics(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get ops metrics")
|
||||
return
|
||||
}
|
||||
response.Success(c, metrics)
|
||||
}
|
||||
|
||||
// ListMetricsHistory returns a time-range slice of metrics for charts.
|
||||
// GET /api/v1/admin/ops/metrics/history
|
||||
//
|
||||
// Query params:
|
||||
// - window_minutes: int (default 1)
|
||||
// - minutes: int (lookback; optional)
|
||||
// - start_time/end_time: RFC3339 timestamps (optional; overrides minutes when provided)
|
||||
// - limit: int (optional; max 100, default 300 for backward compatibility)
|
||||
func (h *OpsHandler) ListMetricsHistory(c *gin.Context) {
|
||||
windowMinutes := 1
|
||||
if v := c.Query("window_minutes"); v != "" {
|
||||
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
|
||||
windowMinutes = parsed
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid window_minutes")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
limit := 300
|
||||
limitProvided := false
|
||||
if v := c.Query("limit"); v != "" {
|
||||
parsed, err := strconv.Atoi(v)
|
||||
if err != nil || parsed <= 0 || parsed > 5000 {
|
||||
response.BadRequest(c, "Invalid limit (must be 1-5000)")
|
||||
return
|
||||
}
|
||||
limit = parsed
|
||||
limitProvided = true
|
||||
}
|
||||
|
||||
endTime := time.Now()
|
||||
startTime := time.Time{}
|
||||
|
||||
if startTimeStr := c.Query("start_time"); startTimeStr != "" {
|
||||
parsed, err := time.Parse(time.RFC3339, startTimeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_time format (RFC3339)")
|
||||
return
|
||||
}
|
||||
startTime = parsed
|
||||
}
|
||||
if endTimeStr := c.Query("end_time"); endTimeStr != "" {
|
||||
parsed, err := time.Parse(time.RFC3339, endTimeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_time format (RFC3339)")
|
||||
return
|
||||
}
|
||||
endTime = parsed
|
||||
}
|
||||
|
||||
// If explicit range not provided, use lookback minutes.
|
||||
if startTime.IsZero() {
|
||||
if v := c.Query("minutes"); v != "" {
|
||||
minutes, err := strconv.Atoi(v)
|
||||
if err != nil || minutes <= 0 {
|
||||
response.BadRequest(c, "Invalid minutes")
|
||||
return
|
||||
}
|
||||
if minutes > 60*24*7 {
|
||||
minutes = 60 * 24 * 7
|
||||
}
|
||||
startTime = endTime.Add(-time.Duration(minutes) * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
// Default time range: last 24 hours.
|
||||
if startTime.IsZero() {
|
||||
startTime = endTime.Add(-24 * time.Hour)
|
||||
if !limitProvided {
|
||||
// Metrics are collected at 1-minute cadence; 24h requires ~1440 points.
|
||||
limit = 24 * 60
|
||||
}
|
||||
}
|
||||
|
||||
if startTime.After(endTime) {
|
||||
response.BadRequest(c, "Invalid time range: start_time must be <= end_time")
|
||||
return
|
||||
}
|
||||
|
||||
items, err := h.opsService.ListMetricsHistory(c.Request.Context(), windowMinutes, startTime, endTime, limit)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to list ops metrics history")
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"items": items})
|
||||
}
|
||||
|
||||
// ListErrorLogs lists recent error logs with optional filters.
|
||||
// GET /api/v1/admin/ops/error-logs
|
||||
//
|
||||
// Query params:
|
||||
// - start_time/end_time: RFC3339 timestamps (optional)
|
||||
// - platform: string (optional)
|
||||
// - phase: string (optional)
|
||||
// - severity: string (optional)
|
||||
// - q: string (optional; fuzzy match)
|
||||
// - limit: int (optional; default 100; max 500)
|
||||
func (h *OpsHandler) ListErrorLogs(c *gin.Context) {
|
||||
var filters service.OpsErrorLogFilters
|
||||
|
||||
if startTimeStr := c.Query("start_time"); startTimeStr != "" {
|
||||
startTime, err := time.Parse(time.RFC3339, startTimeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_time format (RFC3339)")
|
||||
return
|
||||
}
|
||||
filters.StartTime = &startTime
|
||||
}
|
||||
if endTimeStr := c.Query("end_time"); endTimeStr != "" {
|
||||
endTime, err := time.Parse(time.RFC3339, endTimeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_time format (RFC3339)")
|
||||
return
|
||||
}
|
||||
filters.EndTime = &endTime
|
||||
}
|
||||
|
||||
if filters.StartTime != nil && filters.EndTime != nil && filters.StartTime.After(*filters.EndTime) {
|
||||
response.BadRequest(c, "Invalid time range: start_time must be <= end_time")
|
||||
return
|
||||
}
|
||||
|
||||
filters.Platform = c.Query("platform")
|
||||
filters.Phase = c.Query("phase")
|
||||
filters.Severity = c.Query("severity")
|
||||
filters.Query = c.Query("q")
|
||||
|
||||
filters.Limit = 100
|
||||
if limitStr := c.Query("limit"); limitStr != "" {
|
||||
limit, err := strconv.Atoi(limitStr)
|
||||
if err != nil || limit <= 0 || limit > 500 {
|
||||
response.BadRequest(c, "Invalid limit (must be 1-500)")
|
||||
return
|
||||
}
|
||||
filters.Limit = limit
|
||||
}
|
||||
|
||||
items, total, err := h.opsService.ListErrorLogs(c.Request.Context(), filters)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to list error logs")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
})
|
||||
}
|
||||
|
||||
// GetDashboardOverview returns realtime ops dashboard overview.
|
||||
// GET /api/v1/admin/ops/dashboard/overview
|
||||
//
|
||||
// Query params:
|
||||
// - time_range: string (optional; default "1h") one of: 5m, 30m, 1h, 6h, 24h
|
||||
func (h *OpsHandler) GetDashboardOverview(c *gin.Context) {
|
||||
timeRange := c.Query("time_range")
|
||||
if timeRange == "" {
|
||||
timeRange = "1h"
|
||||
}
|
||||
|
||||
switch timeRange {
|
||||
case "5m", "30m", "1h", "6h", "24h":
|
||||
default:
|
||||
response.BadRequest(c, "Invalid time_range (supported: 5m, 30m, 1h, 6h, 24h)")
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.opsService.GetDashboardOverview(c.Request.Context(), timeRange)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get dashboard overview")
|
||||
return
|
||||
}
|
||||
response.Success(c, data)
|
||||
}
|
||||
|
||||
// GetProviderHealth returns upstream provider health comparison data.
|
||||
// GET /api/v1/admin/ops/dashboard/providers
|
||||
//
|
||||
// Query params:
|
||||
// - time_range: string (optional; default "1h") one of: 5m, 30m, 1h, 6h, 24h
|
||||
func (h *OpsHandler) GetProviderHealth(c *gin.Context) {
|
||||
timeRange := c.Query("time_range")
|
||||
if timeRange == "" {
|
||||
timeRange = "1h"
|
||||
}
|
||||
|
||||
switch timeRange {
|
||||
case "5m", "30m", "1h", "6h", "24h":
|
||||
default:
|
||||
response.BadRequest(c, "Invalid time_range (supported: 5m, 30m, 1h, 6h, 24h)")
|
||||
return
|
||||
}
|
||||
|
||||
providers, err := h.opsService.GetProviderHealth(c.Request.Context(), timeRange)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get provider health")
|
||||
return
|
||||
}
|
||||
|
||||
var totalRequests int64
|
||||
var weightedSuccess float64
|
||||
var bestProvider string
|
||||
var worstProvider string
|
||||
var bestRate float64
|
||||
var worstRate float64
|
||||
hasRate := false
|
||||
|
||||
for _, p := range providers {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
totalRequests += p.RequestCount
|
||||
weightedSuccess += (p.SuccessRate / 100) * float64(p.RequestCount)
|
||||
|
||||
if p.RequestCount <= 0 {
|
||||
continue
|
||||
}
|
||||
if !hasRate {
|
||||
bestProvider = p.Name
|
||||
worstProvider = p.Name
|
||||
bestRate = p.SuccessRate
|
||||
worstRate = p.SuccessRate
|
||||
hasRate = true
|
||||
continue
|
||||
}
|
||||
|
||||
if p.SuccessRate > bestRate {
|
||||
bestProvider = p.Name
|
||||
bestRate = p.SuccessRate
|
||||
}
|
||||
if p.SuccessRate < worstRate {
|
||||
worstProvider = p.Name
|
||||
worstRate = p.SuccessRate
|
||||
}
|
||||
}
|
||||
|
||||
avgSuccessRate := 0.0
|
||||
if totalRequests > 0 {
|
||||
avgSuccessRate = (weightedSuccess / float64(totalRequests)) * 100
|
||||
avgSuccessRate = math.Round(avgSuccessRate*100) / 100
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"providers": providers,
|
||||
"summary": gin.H{
|
||||
"total_requests": totalRequests,
|
||||
"avg_success_rate": avgSuccessRate,
|
||||
"best_provider": bestProvider,
|
||||
"worst_provider": worstProvider,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// GetErrorLogs returns a paginated error log list with multi-dimensional filters.
|
||||
// GET /api/v1/admin/ops/errors
|
||||
func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
filter := &service.ErrorLogFilter{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
if startTimeStr := c.Query("start_time"); startTimeStr != "" {
|
||||
startTime, err := time.Parse(time.RFC3339, startTimeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_time format (RFC3339)")
|
||||
return
|
||||
}
|
||||
filter.StartTime = &startTime
|
||||
}
|
||||
if endTimeStr := c.Query("end_time"); endTimeStr != "" {
|
||||
endTime, err := time.Parse(time.RFC3339, endTimeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_time format (RFC3339)")
|
||||
return
|
||||
}
|
||||
filter.EndTime = &endTime
|
||||
}
|
||||
|
||||
if filter.StartTime != nil && filter.EndTime != nil && filter.StartTime.After(*filter.EndTime) {
|
||||
response.BadRequest(c, "Invalid time range: start_time must be <= end_time")
|
||||
return
|
||||
}
|
||||
|
||||
if errorCodeStr := c.Query("error_code"); errorCodeStr != "" {
|
||||
code, err := strconv.Atoi(errorCodeStr)
|
||||
if err != nil || code < 0 {
|
||||
response.BadRequest(c, "Invalid error_code")
|
||||
return
|
||||
}
|
||||
filter.ErrorCode = &code
|
||||
}
|
||||
|
||||
// Keep both parameter names for compatibility: provider (docs) and platform (legacy).
|
||||
filter.Provider = c.Query("provider")
|
||||
if filter.Provider == "" {
|
||||
filter.Provider = c.Query("platform")
|
||||
}
|
||||
|
||||
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
|
||||
accountID, err := strconv.ParseInt(accountIDStr, 10, 64)
|
||||
if err != nil || accountID <= 0 {
|
||||
response.BadRequest(c, "Invalid account_id")
|
||||
return
|
||||
}
|
||||
filter.AccountID = &accountID
|
||||
}
|
||||
|
||||
out, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get error logs")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"errors": out.Errors,
|
||||
"total": out.Total,
|
||||
"page": out.Page,
|
||||
"page_size": out.PageSize,
|
||||
})
|
||||
}
|
||||
|
||||
// GetLatencyHistogram returns the latency distribution histogram.
|
||||
// GET /api/v1/admin/ops/dashboard/latency-histogram
|
||||
func (h *OpsHandler) GetLatencyHistogram(c *gin.Context) {
|
||||
timeRange := c.Query("time_range")
|
||||
if timeRange == "" {
|
||||
timeRange = "1h"
|
||||
}
|
||||
|
||||
buckets, err := h.opsService.GetLatencyHistogram(c.Request.Context(), timeRange)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get latency histogram")
|
||||
return
|
||||
}
|
||||
|
||||
totalRequests := int64(0)
|
||||
for _, b := range buckets {
|
||||
totalRequests += b.Count
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"buckets": buckets,
|
||||
"total_requests": totalRequests,
|
||||
"slow_request_threshold": 1000,
|
||||
})
|
||||
}
|
||||
|
||||
// GetErrorDistribution returns the error distribution.
|
||||
// GET /api/v1/admin/ops/dashboard/errors/distribution
|
||||
func (h *OpsHandler) GetErrorDistribution(c *gin.Context) {
|
||||
timeRange := c.Query("time_range")
|
||||
if timeRange == "" {
|
||||
timeRange = "1h"
|
||||
}
|
||||
|
||||
items, err := h.opsService.GetErrorDistribution(c.Request.Context(), timeRange)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get error distribution")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"items": items,
|
||||
})
|
||||
}
|
||||
286
backend/internal/handler/admin/ops_ws_handler.go
Normal file
286
backend/internal/handler/admin/ops_ws_handler.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type OpsWSProxyConfig struct {
|
||||
TrustProxy bool
|
||||
TrustedProxies []netip.Prefix
|
||||
OriginPolicy string
|
||||
}
|
||||
|
||||
const (
|
||||
envOpsWSTrustProxy = "OPS_WS_TRUST_PROXY"
|
||||
envOpsWSTrustedProxies = "OPS_WS_TRUSTED_PROXIES"
|
||||
envOpsWSOriginPolicy = "OPS_WS_ORIGIN_POLICY"
|
||||
)
|
||||
|
||||
const (
|
||||
OriginPolicyStrict = "strict"
|
||||
OriginPolicyPermissive = "permissive"
|
||||
)
|
||||
|
||||
var opsWSProxyConfig = loadOpsWSProxyConfigFromEnv()
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return isAllowedOpsWSOrigin(r)
|
||||
},
|
||||
}
|
||||
|
||||
// QPSWSHandler handles realtime QPS push via WebSocket.
|
||||
// GET /api/v1/admin/ops/ws/qps
|
||||
func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("[OpsWS] upgrade failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
// Set pong handler
|
||||
if err := conn.SetReadDeadline(time.Now().Add(60 * time.Second)); err != nil {
|
||||
log.Printf("[OpsWS] set read deadline failed: %v", err)
|
||||
return
|
||||
}
|
||||
conn.SetPongHandler(func(string) error {
|
||||
return conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
})
|
||||
|
||||
// Push QPS data every 2 seconds
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Heartbeat ping every 30 seconds
|
||||
pingTicker := time.NewTicker(30 * time.Second)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// Fetch 1m window stats for current QPS
|
||||
data, err := h.opsService.GetDashboardOverview(ctx, "5m")
|
||||
if err != nil {
|
||||
log.Printf("[OpsWS] get overview failed: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
payload := gin.H{
|
||||
"type": "qps_update",
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
"data": gin.H{
|
||||
"qps": data.QPS.Current,
|
||||
"tps": data.TPS.Current,
|
||||
"request_count": data.Errors.TotalCount + int64(data.QPS.Avg1h*60), // Rough estimate
|
||||
},
|
||||
}
|
||||
|
||||
msg, _ := json.Marshal(payload)
|
||||
if err := conn.WriteMessage(websocket.TextMessage, msg); err != nil {
|
||||
log.Printf("[OpsWS] write failed: %v", err)
|
||||
return
|
||||
}
|
||||
case <-pingTicker.C:
|
||||
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
log.Printf("[OpsWS] ping failed: %v", err)
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isAllowedOpsWSOrigin(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
origin := strings.TrimSpace(r.Header.Get("Origin"))
|
||||
if origin == "" {
|
||||
switch strings.ToLower(strings.TrimSpace(opsWSProxyConfig.OriginPolicy)) {
|
||||
case OriginPolicyStrict:
|
||||
return false
|
||||
case OriginPolicyPermissive, "":
|
||||
return true
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
parsed, err := url.Parse(origin)
|
||||
if err != nil || parsed.Hostname() == "" {
|
||||
return false
|
||||
}
|
||||
originHost := strings.ToLower(parsed.Hostname())
|
||||
|
||||
trustProxyHeaders := shouldTrustOpsWSProxyHeaders(r)
|
||||
reqHost := hostWithoutPort(r.Host)
|
||||
if trustProxyHeaders {
|
||||
xfHost := strings.TrimSpace(r.Header.Get("X-Forwarded-Host"))
|
||||
if xfHost != "" {
|
||||
xfHost = strings.TrimSpace(strings.Split(xfHost, ",")[0])
|
||||
if xfHost != "" {
|
||||
reqHost = hostWithoutPort(xfHost)
|
||||
}
|
||||
}
|
||||
}
|
||||
reqHost = strings.ToLower(reqHost)
|
||||
if reqHost == "" {
|
||||
return false
|
||||
}
|
||||
return originHost == reqHost
|
||||
}
|
||||
|
||||
func shouldTrustOpsWSProxyHeaders(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
if !opsWSProxyConfig.TrustProxy {
|
||||
return false
|
||||
}
|
||||
peerIP, ok := requestPeerIP(r)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return isAddrInTrustedProxies(peerIP, opsWSProxyConfig.TrustedProxies)
|
||||
}
|
||||
|
||||
func requestPeerIP(r *http.Request) (netip.Addr, bool) {
|
||||
if r == nil {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
|
||||
if err != nil {
|
||||
host = strings.TrimSpace(r.RemoteAddr)
|
||||
}
|
||||
host = strings.TrimPrefix(host, "[")
|
||||
host = strings.TrimSuffix(host, "]")
|
||||
if host == "" {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
addr, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
return addr.Unmap(), true
|
||||
}
|
||||
|
||||
func isAddrInTrustedProxies(addr netip.Addr, trusted []netip.Prefix) bool {
|
||||
if !addr.IsValid() {
|
||||
return false
|
||||
}
|
||||
for _, p := range trusted {
|
||||
if p.Contains(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
|
||||
cfg := OpsWSProxyConfig{
|
||||
TrustProxy: true,
|
||||
TrustedProxies: defaultTrustedProxies(),
|
||||
OriginPolicy: OriginPolicyPermissive,
|
||||
}
|
||||
|
||||
if v := strings.TrimSpace(os.Getenv(envOpsWSTrustProxy)); v != "" {
|
||||
if parsed, err := strconv.ParseBool(v); err == nil {
|
||||
cfg.TrustProxy = parsed
|
||||
} else {
|
||||
log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
|
||||
}
|
||||
}
|
||||
|
||||
if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" {
|
||||
prefixes, invalid := parseTrustedProxyList(raw)
|
||||
if len(invalid) > 0 {
|
||||
log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
|
||||
}
|
||||
cfg.TrustedProxies = prefixes
|
||||
}
|
||||
|
||||
if v := strings.TrimSpace(os.Getenv(envOpsWSOriginPolicy)); v != "" {
|
||||
normalized := strings.ToLower(v)
|
||||
switch normalized {
|
||||
case OriginPolicyStrict, OriginPolicyPermissive:
|
||||
cfg.OriginPolicy = normalized
|
||||
default:
|
||||
log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func defaultTrustedProxies() []netip.Prefix {
|
||||
prefixes, _ := parseTrustedProxyList("127.0.0.0/8,::1/128")
|
||||
return prefixes
|
||||
}
|
||||
|
||||
func parseTrustedProxyList(raw string) (prefixes []netip.Prefix, invalid []string) {
|
||||
for _, token := range strings.Split(raw, ",") {
|
||||
item := strings.TrimSpace(token)
|
||||
if item == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var (
|
||||
p netip.Prefix
|
||||
err error
|
||||
)
|
||||
if strings.Contains(item, "/") {
|
||||
p, err = netip.ParsePrefix(item)
|
||||
} else {
|
||||
var addr netip.Addr
|
||||
addr, err = netip.ParseAddr(item)
|
||||
if err == nil {
|
||||
addr = addr.Unmap()
|
||||
bits := 128
|
||||
if addr.Is4() {
|
||||
bits = 32
|
||||
}
|
||||
p = netip.PrefixFrom(addr, bits)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil || !p.IsValid() {
|
||||
invalid = append(invalid, item)
|
||||
continue
|
||||
}
|
||||
|
||||
prefixes = append(prefixes, p.Masked())
|
||||
}
|
||||
return prefixes, invalid
|
||||
}
|
||||
|
||||
func hostWithoutPort(hostport string) string {
|
||||
hostport = strings.TrimSpace(hostport)
|
||||
if hostport == "" {
|
||||
return ""
|
||||
}
|
||||
if host, _, err := net.SplitHostPort(hostport); err == nil {
|
||||
return host
|
||||
}
|
||||
if strings.HasPrefix(hostport, "[") && strings.HasSuffix(hostport, "]") {
|
||||
return strings.Trim(hostport, "[]")
|
||||
}
|
||||
parts := strings.Split(hostport, ":")
|
||||
return parts[0]
|
||||
}
|
||||
123
backend/internal/handler/admin/ops_ws_handler_test.go
Normal file
123
backend/internal/handler/admin/ops_ws_handler_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsAllowedOpsWSOrigin_AllowsEmptyOrigin(t *testing.T) {
|
||||
original := opsWSProxyConfig
|
||||
t.Cleanup(func() { opsWSProxyConfig = original })
|
||||
opsWSProxyConfig = OpsWSProxyConfig{OriginPolicy: OriginPolicyPermissive}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://example.test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
|
||||
if !isAllowedOpsWSOrigin(req) {
|
||||
t.Fatalf("expected empty Origin to be allowed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAllowedOpsWSOrigin_RejectsEmptyOrigin_WhenStrict(t *testing.T) {
|
||||
original := opsWSProxyConfig
|
||||
t.Cleanup(func() { opsWSProxyConfig = original })
|
||||
opsWSProxyConfig = OpsWSProxyConfig{OriginPolicy: OriginPolicyStrict}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://example.test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
|
||||
if isAllowedOpsWSOrigin(req) {
|
||||
t.Fatalf("expected empty Origin to be rejected under strict policy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAllowedOpsWSOrigin_UsesXForwardedHostOnlyFromTrustedProxy(t *testing.T) {
|
||||
original := opsWSProxyConfig
|
||||
t.Cleanup(func() { opsWSProxyConfig = original })
|
||||
|
||||
opsWSProxyConfig = OpsWSProxyConfig{
|
||||
TrustProxy: true,
|
||||
TrustedProxies: []netip.Prefix{
|
||||
netip.MustParsePrefix("127.0.0.0/8"),
|
||||
},
|
||||
}
|
||||
|
||||
// Untrusted peer: ignore X-Forwarded-Host and compare against r.Host.
|
||||
{
|
||||
req, err := http.NewRequest(http.MethodGet, "http://internal.service.local", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
req.RemoteAddr = "192.0.2.1:12345"
|
||||
req.Host = "internal.service.local"
|
||||
req.Header.Set("Origin", "https://public.example.com")
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com")
|
||||
|
||||
if isAllowedOpsWSOrigin(req) {
|
||||
t.Fatalf("expected Origin to be rejected when peer is not a trusted proxy")
|
||||
}
|
||||
}
|
||||
|
||||
// Trusted peer: allow X-Forwarded-Host to participate in Origin validation.
|
||||
{
|
||||
req, err := http.NewRequest(http.MethodGet, "http://internal.service.local", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
req.RemoteAddr = "127.0.0.1:23456"
|
||||
req.Host = "internal.service.local"
|
||||
req.Header.Set("Origin", "https://public.example.com")
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com")
|
||||
|
||||
if !isAllowedOpsWSOrigin(req) {
|
||||
t.Fatalf("expected Origin to be accepted when peer is a trusted proxy")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadOpsWSProxyConfigFromEnv_OriginPolicy(t *testing.T) {
|
||||
t.Setenv(envOpsWSOriginPolicy, "STRICT")
|
||||
cfg := loadOpsWSProxyConfigFromEnv()
|
||||
if cfg.OriginPolicy != OriginPolicyStrict {
|
||||
t.Fatalf("OriginPolicy=%q, want %q", cfg.OriginPolicy, OriginPolicyStrict)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadOpsWSProxyConfigFromEnv_OriginPolicyInvalidUsesDefault(t *testing.T) {
|
||||
t.Setenv(envOpsWSOriginPolicy, "nope")
|
||||
cfg := loadOpsWSProxyConfigFromEnv()
|
||||
if cfg.OriginPolicy != OriginPolicyPermissive {
|
||||
t.Fatalf("OriginPolicy=%q, want %q", cfg.OriginPolicy, OriginPolicyPermissive)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTrustedProxyList(t *testing.T) {
|
||||
prefixes, invalid := parseTrustedProxyList("10.0.0.1, 10.0.0.0/8, bad, ::1/128")
|
||||
if len(prefixes) != 3 {
|
||||
t.Fatalf("prefixes=%d, want 3", len(prefixes))
|
||||
}
|
||||
if len(invalid) != 1 || invalid[0] != "bad" {
|
||||
t.Fatalf("invalid=%v, want [bad]", invalid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestPeerIP_ParsesIPv6(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "http://example.test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
req.RemoteAddr = "[::1]:1234"
|
||||
|
||||
addr, ok := requestPeerIP(req)
|
||||
if !ok {
|
||||
t.Fatalf("expected IPv6 peer IP to parse")
|
||||
}
|
||||
if addr != netip.MustParseAddr("::1") {
|
||||
t.Fatalf("addr=%s, want ::1", addr)
|
||||
}
|
||||
}
|
||||
@@ -36,22 +36,22 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
SmtpHost: settings.SmtpHost,
|
||||
SmtpPort: settings.SmtpPort,
|
||||
SmtpUsername: settings.SmtpUsername,
|
||||
SmtpPassword: settings.SmtpPassword,
|
||||
SmtpFrom: settings.SmtpFrom,
|
||||
SmtpFromName: settings.SmtpFromName,
|
||||
SmtpUseTLS: settings.SmtpUseTLS,
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
SMTPPassword: settings.SMTPPassword,
|
||||
SMTPFrom: settings.SMTPFrom,
|
||||
SMTPFromName: settings.SMTPFromName,
|
||||
SMTPUseTLS: settings.SMTPUseTLS,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKey: settings.TurnstileSecretKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
ApiBaseUrl: settings.ApiBaseUrl,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocUrl: settings.DocUrl,
|
||||
DocURL: settings.DocURL,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
})
|
||||
@@ -64,13 +64,13 @@ type UpdateSettingsRequest struct {
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
|
||||
// 邮件服务设置
|
||||
SmtpHost string `json:"smtp_host"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password"`
|
||||
SMTPFrom string `json:"smtp_from_email"`
|
||||
SMTPFromName string `json:"smtp_from_name"`
|
||||
SMTPUseTLS bool `json:"smtp_use_tls"`
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
@@ -81,9 +81,9 @@ type UpdateSettingsRequest struct {
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
DocURL string `json:"doc_url"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
@@ -106,8 +106,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
if req.DefaultBalance < 0 {
|
||||
req.DefaultBalance = 0
|
||||
}
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
}
|
||||
|
||||
// Turnstile 参数验证
|
||||
@@ -143,22 +143,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SmtpHost: req.SmtpHost,
|
||||
SmtpPort: req.SmtpPort,
|
||||
SmtpUsername: req.SmtpUsername,
|
||||
SmtpPassword: req.SmtpPassword,
|
||||
SmtpFrom: req.SmtpFrom,
|
||||
SmtpFromName: req.SmtpFromName,
|
||||
SmtpUseTLS: req.SmtpUseTLS,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
ApiBaseUrl: req.ApiBaseUrl,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocUrl: req.DocUrl,
|
||||
DocURL: req.DocURL,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
}
|
||||
@@ -178,67 +178,67 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
SmtpHost: updatedSettings.SmtpHost,
|
||||
SmtpPort: updatedSettings.SmtpPort,
|
||||
SmtpUsername: updatedSettings.SmtpUsername,
|
||||
SmtpPassword: updatedSettings.SmtpPassword,
|
||||
SmtpFrom: updatedSettings.SmtpFrom,
|
||||
SmtpFromName: updatedSettings.SmtpFromName,
|
||||
SmtpUseTLS: updatedSettings.SmtpUseTLS,
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
SMTPPassword: updatedSettings.SMTPPassword,
|
||||
SMTPFrom: updatedSettings.SMTPFrom,
|
||||
SMTPFromName: updatedSettings.SMTPFromName,
|
||||
SMTPUseTLS: updatedSettings.SMTPUseTLS,
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
ApiBaseUrl: updatedSettings.ApiBaseUrl,
|
||||
APIBaseURL: updatedSettings.APIBaseURL,
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocUrl: updatedSettings.DocUrl,
|
||||
DocURL: updatedSettings.DocURL,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
})
|
||||
}
|
||||
|
||||
// TestSmtpRequest 测试SMTP连接请求
|
||||
type TestSmtpRequest struct {
|
||||
SmtpHost string `json:"smtp_host" binding:"required"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
// TestSMTPRequest 测试SMTP连接请求
|
||||
type TestSMTPRequest struct {
|
||||
SMTPHost string `json:"smtp_host" binding:"required"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password"`
|
||||
SMTPUseTLS bool `json:"smtp_use_tls"`
|
||||
}
|
||||
|
||||
// TestSmtpConnection 测试SMTP连接
|
||||
// TestSMTPConnection 测试SMTP连接
|
||||
// POST /api/v1/admin/settings/test-smtp
|
||||
func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
|
||||
var req TestSmtpRequest
|
||||
func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
|
||||
var req TestSMTPRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
}
|
||||
|
||||
// 如果未提供密码,从数据库获取已保存的密码
|
||||
password := req.SmtpPassword
|
||||
password := req.SMTPPassword
|
||||
if password == "" {
|
||||
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
|
||||
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
|
||||
if err == nil && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
}
|
||||
}
|
||||
|
||||
config := &service.SmtpConfig{
|
||||
Host: req.SmtpHost,
|
||||
Port: req.SmtpPort,
|
||||
Username: req.SmtpUsername,
|
||||
config := &service.SMTPConfig{
|
||||
Host: req.SMTPHost,
|
||||
Port: req.SMTPPort,
|
||||
Username: req.SMTPUsername,
|
||||
Password: password,
|
||||
UseTLS: req.SmtpUseTLS,
|
||||
UseTLS: req.SMTPUseTLS,
|
||||
}
|
||||
|
||||
err := h.emailService.TestSmtpConnectionWithConfig(config)
|
||||
err := h.emailService.TestSMTPConnectionWithConfig(config)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -250,13 +250,13 @@ func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
|
||||
// SendTestEmailRequest 发送测试邮件请求
|
||||
type SendTestEmailRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
SmtpHost string `json:"smtp_host" binding:"required"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
SMTPHost string `json:"smtp_host" binding:"required"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password"`
|
||||
SMTPFrom string `json:"smtp_from_email"`
|
||||
SMTPFromName string `json:"smtp_from_name"`
|
||||
SMTPUseTLS bool `json:"smtp_use_tls"`
|
||||
}
|
||||
|
||||
// SendTestEmail 发送测试邮件
|
||||
@@ -268,27 +268,27 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
}
|
||||
|
||||
// 如果未提供密码,从数据库获取已保存的密码
|
||||
password := req.SmtpPassword
|
||||
password := req.SMTPPassword
|
||||
if password == "" {
|
||||
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
|
||||
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
|
||||
if err == nil && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
}
|
||||
}
|
||||
|
||||
config := &service.SmtpConfig{
|
||||
Host: req.SmtpHost,
|
||||
Port: req.SmtpPort,
|
||||
Username: req.SmtpUsername,
|
||||
config := &service.SMTPConfig{
|
||||
Host: req.SMTPHost,
|
||||
Port: req.SMTPPort,
|
||||
Username: req.SMTPUsername,
|
||||
Password: password,
|
||||
From: req.SmtpFrom,
|
||||
FromName: req.SmtpFromName,
|
||||
UseTLS: req.SmtpUseTLS,
|
||||
From: req.SMTPFrom,
|
||||
FromName: req.SMTPFromName,
|
||||
UseTLS: req.SMTPUseTLS,
|
||||
}
|
||||
|
||||
siteName := h.settingService.GetSiteName(c.Request.Context())
|
||||
@@ -333,10 +333,10 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "Test email sent successfully"})
|
||||
}
|
||||
|
||||
// GetAdminApiKey 获取管理员 API Key 状态
|
||||
// GetAdminAPIKey 获取管理员 API Key 状态
|
||||
// GET /api/v1/admin/settings/admin-api-key
|
||||
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
|
||||
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
|
||||
func (h *SettingHandler) GetAdminAPIKey(c *gin.Context) {
|
||||
maskedKey, exists, err := h.settingService.GetAdminAPIKeyStatus(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -348,10 +348,10 @@ func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// RegenerateAdminApiKey 生成/重新生成管理员 API Key
|
||||
// RegenerateAdminAPIKey 生成/重新生成管理员 API Key
|
||||
// POST /api/v1/admin/settings/admin-api-key/regenerate
|
||||
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
|
||||
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
|
||||
func (h *SettingHandler) RegenerateAdminAPIKey(c *gin.Context) {
|
||||
key, err := h.settingService.GenerateAdminAPIKey(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -362,10 +362,10 @@ func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteAdminApiKey 删除管理员 API Key
|
||||
// DeleteAdminAPIKey 删除管理员 API Key
|
||||
// DELETE /api/v1/admin/settings/admin-api-key
|
||||
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
|
||||
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
|
||||
func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
|
||||
if err := h.settingService.DeleteAdminAPIKey(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -17,14 +17,14 @@ import (
|
||||
// UsageHandler handles admin usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.ApiKeyService
|
||||
apiKeyService *service.APIKeyService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new admin usage handler
|
||||
func NewUsageHandler(
|
||||
usageService *service.UsageService,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
apiKeyService *service.APIKeyService,
|
||||
adminService service.AdminService,
|
||||
) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
@@ -125,7 +125,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: userID,
|
||||
ApiKeyID: apiKeyID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
@@ -207,7 +207,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
}
|
||||
|
||||
if apiKeyID > 0 {
|
||||
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
stats, err := h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -269,9 +269,9 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// SearchApiKeys handles searching API keys by user
|
||||
// SearchAPIKeys handles searching API keys by user
|
||||
// GET /api/v1/admin/usage/search-api-keys
|
||||
func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
|
||||
func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
|
||||
userIDStr := c.Query("user_id")
|
||||
keyword := c.Query("q")
|
||||
|
||||
@@ -285,22 +285,22 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
|
||||
userID = id
|
||||
}
|
||||
|
||||
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
|
||||
keys, err := h.apiKeyService.SearchAPIKeys(c.Request.Context(), userID, keyword, 30)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Return simplified API key list (only id and name)
|
||||
type SimpleApiKey struct {
|
||||
type SimpleAPIKey struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
UserID int64 `json:"user_id"`
|
||||
}
|
||||
|
||||
result := make([]SimpleApiKey, len(keys))
|
||||
result := make([]SimpleAPIKey, len(keys))
|
||||
for i, k := range keys {
|
||||
result[i] = SimpleApiKey{
|
||||
result[i] = SimpleAPIKey{
|
||||
ID: k.ID,
|
||||
Name: k.Name,
|
||||
UserID: k.UserID,
|
||||
|
||||
@@ -243,9 +243,9 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.ApiKey, 0, len(keys))
|
||||
out := make([]dto.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
out = append(out, *dto.ApiKeyFromService(&keys[i]))
|
||||
out = append(out, *dto.APIKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
@@ -14,11 +14,11 @@ import (
|
||||
|
||||
// APIKeyHandler handles API key-related requests
|
||||
type APIKeyHandler struct {
|
||||
apiKeyService *service.ApiKeyService
|
||||
apiKeyService *service.APIKeyService
|
||||
}
|
||||
|
||||
// NewAPIKeyHandler creates a new APIKeyHandler
|
||||
func NewAPIKeyHandler(apiKeyService *service.ApiKeyService) *APIKeyHandler {
|
||||
func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler {
|
||||
return &APIKeyHandler{
|
||||
apiKeyService: apiKeyService,
|
||||
}
|
||||
@@ -56,9 +56,9 @@ func (h *APIKeyHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.ApiKey, 0, len(keys))
|
||||
out := make([]dto.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
out = append(out, *dto.ApiKeyFromService(&keys[i]))
|
||||
out = append(out, *dto.APIKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
@@ -90,7 +90,7 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
response.Success(c, dto.APIKeyFromService(key))
|
||||
}
|
||||
|
||||
// Create handles creating a new API key
|
||||
@@ -108,7 +108,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.CreateApiKeyRequest{
|
||||
svcReq := service.CreateAPIKeyRequest{
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
CustomKey: req.CustomKey,
|
||||
@@ -119,7 +119,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
response.Success(c, dto.APIKeyFromService(key))
|
||||
}
|
||||
|
||||
// Update handles updating an API key
|
||||
@@ -143,7 +143,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.UpdateApiKeyRequest{}
|
||||
svcReq := service.UpdateAPIKeyRequest{}
|
||||
if req.Name != "" {
|
||||
svcReq.Name = &req.Name
|
||||
}
|
||||
@@ -158,7 +158,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
response.Success(c, dto.APIKeyFromService(key))
|
||||
}
|
||||
|
||||
// Delete handles deleting an API key
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package dto provides mapping utilities for converting between service layer and HTTP handler DTOs.
|
||||
package dto
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -26,11 +27,11 @@ func UserFromService(u *service.User) *User {
|
||||
return nil
|
||||
}
|
||||
out := UserFromServiceShallow(u)
|
||||
if len(u.ApiKeys) > 0 {
|
||||
out.ApiKeys = make([]ApiKey, 0, len(u.ApiKeys))
|
||||
for i := range u.ApiKeys {
|
||||
k := u.ApiKeys[i]
|
||||
out.ApiKeys = append(out.ApiKeys, *ApiKeyFromService(&k))
|
||||
if len(u.APIKeys) > 0 {
|
||||
out.APIKeys = make([]APIKey, 0, len(u.APIKeys))
|
||||
for i := range u.APIKeys {
|
||||
k := u.APIKeys[i]
|
||||
out.APIKeys = append(out.APIKeys, *APIKeyFromService(&k))
|
||||
}
|
||||
}
|
||||
if len(u.Subscriptions) > 0 {
|
||||
@@ -43,11 +44,11 @@ func UserFromService(u *service.User) *User {
|
||||
return out
|
||||
}
|
||||
|
||||
func ApiKeyFromService(k *service.ApiKey) *ApiKey {
|
||||
func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
return &ApiKey{
|
||||
return &APIKey{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
@@ -220,7 +221,7 @@ func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||
return &UsageLog{
|
||||
ID: l.ID,
|
||||
UserID: l.UserID,
|
||||
ApiKeyID: l.ApiKeyID,
|
||||
APIKeyID: l.APIKeyID,
|
||||
AccountID: l.AccountID,
|
||||
RequestID: l.RequestID,
|
||||
Model: l.Model,
|
||||
@@ -245,7 +246,7 @@ func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||
FirstTokenMs: l.FirstTokenMs,
|
||||
CreatedAt: l.CreatedAt,
|
||||
User: UserFromServiceShallow(l.User),
|
||||
ApiKey: ApiKeyFromService(l.ApiKey),
|
||||
APIKey: APIKeyFromService(l.APIKey),
|
||||
Account: AccountFromService(l.Account),
|
||||
Group: GroupFromServiceShallow(l.Group),
|
||||
Subscription: UserSubscriptionFromService(l.Subscription),
|
||||
|
||||
@@ -5,13 +5,13 @@ type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
|
||||
SmtpHost string `json:"smtp_host"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password,omitempty"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password,omitempty"`
|
||||
SMTPFrom string `json:"smtp_from_email"`
|
||||
SMTPFromName string `json:"smtp_from_name"`
|
||||
SMTPUseTLS bool `json:"smtp_use_tls"`
|
||||
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
@@ -20,9 +20,9 @@ type SystemSettings struct {
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
DocURL string `json:"doc_url"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
@@ -36,8 +36,8 @@ type PublicSettings struct {
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
DocURL string `json:"doc_url"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
@@ -15,11 +15,11 @@ type User struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
ApiKeys []ApiKey `json:"api_keys,omitempty"`
|
||||
APIKeys []APIKey `json:"api_keys,omitempty"`
|
||||
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
||||
}
|
||||
|
||||
type ApiKey struct {
|
||||
type APIKey struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Key string `json:"key"`
|
||||
@@ -136,7 +136,7 @@ type RedeemCode struct {
|
||||
type UsageLog struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
@@ -168,7 +168,7 @@ type UsageLog struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
ApiKey *ApiKey `json:"api_key,omitempty"`
|
||||
APIKey *APIKey `json:"api_key,omitempty"`
|
||||
Account *Account `json:"account,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
Subscription *UserSubscription `json:"subscription,omitempty"`
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// Package handler provides HTTP request handlers for the API gateway.
|
||||
// It handles authentication, request routing, concurrency control, and billing validation.
|
||||
package handler
|
||||
|
||||
import (
|
||||
@@ -27,6 +29,7 @@ type GatewayHandler struct {
|
||||
userService *service.UserService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
opsService *service.OpsService
|
||||
}
|
||||
|
||||
// NewGatewayHandler creates a new GatewayHandler
|
||||
@@ -37,6 +40,7 @@ func NewGatewayHandler(
|
||||
userService *service.UserService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
opsService *service.OpsService,
|
||||
) *GatewayHandler {
|
||||
return &GatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
@@ -45,14 +49,15 @@ func NewGatewayHandler(
|
||||
userService: userService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
|
||||
opsService: opsService,
|
||||
}
|
||||
}
|
||||
|
||||
// Messages handles Claude API compatible messages endpoint
|
||||
// POST /v1/messages
|
||||
func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 从context获取apiKey和user(ApiKeyAuth中间件已设置)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
// 从context获取apiKey和user(APIKeyAuth中间件已设置)
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
@@ -87,6 +92,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
reqModel := parsedReq.Model
|
||||
reqStream := parsedReq.Stream
|
||||
setOpsRequestContext(c, reqModel, reqStream)
|
||||
|
||||
// 验证 model 必填
|
||||
if reqModel == "" {
|
||||
@@ -258,7 +264,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
@@ -382,7 +388,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
@@ -399,7 +405,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// Returns models based on account configurations (model_mapping whitelist)
|
||||
// Falls back to default models if no whitelist is configured
|
||||
func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
apiKey, _ := middleware2.GetApiKeyFromContext(c)
|
||||
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
|
||||
|
||||
var groupID *int64
|
||||
var platform string
|
||||
@@ -448,7 +454,7 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
// Usage handles getting account balance for CC Switch integration
|
||||
// GET /v1/usage
|
||||
func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
@@ -573,6 +579,7 @@ func (h *GatewayHandler) mapUpstreamError(statusCode int) (int, string, string)
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
recordOpsError(c, h.opsService, status, errType, message, "")
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
@@ -604,6 +611,7 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
||||
|
||||
// errorResponse 返回Claude API格式的错误响应
|
||||
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
recordOpsError(c, h.opsService, status, errType, message, "")
|
||||
c.JSON(status, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
@@ -617,8 +625,8 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
|
||||
// POST /v1/messages/count_tokens
|
||||
// 特点:校验订阅/余额,但不计算并发、不记录使用量
|
||||
func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
// 从context获取apiKey和user(ApiKeyAuth中间件已设置)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
// 从context获取apiKey和user(APIKeyAuth中间件已设置)
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
// GeminiV1BetaListModels proxies:
|
||||
// GET /v1beta/models
|
||||
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware.GetAPIKeyFromContext(c)
|
||||
if !ok || apiKey == nil {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
@@ -66,7 +66,7 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
// GeminiV1BetaGetModel proxies:
|
||||
// GET /v1beta/models/{model}
|
||||
func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware.GetAPIKeyFromContext(c)
|
||||
if !ok || apiKey == nil {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
@@ -119,7 +119,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
||||
// POST /v1beta/models/{model}:generateContent
|
||||
// POST /v1beta/models/{model}:streamGenerateContent?alt=sse
|
||||
func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware.GetAPIKeyFromContext(c)
|
||||
if !ok || apiKey == nil {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
@@ -298,7 +298,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
// AdminHandlers contains all admin-related HTTP handlers
|
||||
type AdminHandlers struct {
|
||||
Dashboard *admin.DashboardHandler
|
||||
Ops *admin.OpsHandler
|
||||
User *admin.UserHandler
|
||||
Group *admin.GroupHandler
|
||||
Account *admin.AccountHandler
|
||||
|
||||
@@ -22,6 +22,7 @@ type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
opsService *service.OpsService
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
@@ -29,19 +30,21 @@ func NewOpenAIGatewayHandler(
|
||||
gatewayService *service.OpenAIGatewayService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
opsService *service.OpsService,
|
||||
) *OpenAIGatewayHandler {
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
|
||||
opsService: opsService,
|
||||
}
|
||||
}
|
||||
|
||||
// Responses handles OpenAI Responses API endpoint
|
||||
// POST /openai/v1/responses
|
||||
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
// Get apiKey and user from context (set by APIKeyAuth middleware)
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
@@ -79,6 +82,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Extract model and stream
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
setOpsRequestContext(c, reqModel, reqStream)
|
||||
|
||||
// 验证 model 必填
|
||||
if reqModel == "" {
|
||||
@@ -235,7 +239,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
@@ -278,6 +282,7 @@ func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, st
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
recordOpsError(c, h.opsService, status, errType, message, service.PlatformOpenAI)
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
@@ -297,6 +302,7 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
|
||||
|
||||
// errorResponse returns OpenAI API format error response
|
||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
recordOpsError(c, h.opsService, status, errType, message, service.PlatformOpenAI)
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
|
||||
166
backend/internal/handler/ops_error_logger.go
Normal file
166
backend/internal/handler/ops_error_logger.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
opsModelKey = "ops_model"
|
||||
opsStreamKey = "ops_stream"
|
||||
)
|
||||
|
||||
const (
|
||||
opsErrorLogWorkerCount = 10
|
||||
opsErrorLogQueueSize = 256
|
||||
opsErrorLogTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
type opsErrorLogJob struct {
|
||||
ops *service.OpsService
|
||||
entry *service.OpsErrorLog
|
||||
}
|
||||
|
||||
var (
|
||||
opsErrorLogOnce sync.Once
|
||||
opsErrorLogQueue chan opsErrorLogJob
|
||||
)
|
||||
|
||||
func startOpsErrorLogWorkers() {
|
||||
opsErrorLogQueue = make(chan opsErrorLogJob, opsErrorLogQueueSize)
|
||||
for i := 0; i < opsErrorLogWorkerCount; i++ {
|
||||
go func() {
|
||||
for job := range opsErrorLogQueue {
|
||||
if job.ops == nil || job.entry == nil {
|
||||
continue
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||
_ = job.ops.RecordError(ctx, job.entry)
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsErrorLog) {
|
||||
if ops == nil || entry == nil {
|
||||
return
|
||||
}
|
||||
|
||||
opsErrorLogOnce.Do(startOpsErrorLogWorkers)
|
||||
|
||||
select {
|
||||
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry}:
|
||||
default:
|
||||
// Queue is full; drop to avoid blocking request handling.
|
||||
}
|
||||
}
|
||||
|
||||
func setOpsRequestContext(c *gin.Context, model string, stream bool) {
|
||||
c.Set(opsModelKey, model)
|
||||
c.Set(opsStreamKey, stream)
|
||||
}
|
||||
|
||||
func recordOpsError(c *gin.Context, ops *service.OpsService, status int, errType, message, fallbackPlatform string) {
|
||||
if ops == nil || c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
model, _ := c.Get(opsModelKey)
|
||||
stream, _ := c.Get(opsStreamKey)
|
||||
|
||||
var modelName string
|
||||
if m, ok := model.(string); ok {
|
||||
modelName = m
|
||||
}
|
||||
streaming, _ := stream.(bool)
|
||||
|
||||
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
|
||||
|
||||
logEntry := &service.OpsErrorLog{
|
||||
Phase: classifyOpsPhase(errType, message),
|
||||
Type: errType,
|
||||
Severity: classifyOpsSeverity(errType, status),
|
||||
StatusCode: status,
|
||||
Platform: resolveOpsPlatform(apiKey, fallbackPlatform),
|
||||
Model: modelName,
|
||||
RequestID: c.Writer.Header().Get("x-request-id"),
|
||||
Message: message,
|
||||
ClientIP: c.ClientIP(),
|
||||
RequestPath: func() string {
|
||||
if c.Request != nil && c.Request.URL != nil {
|
||||
return c.Request.URL.Path
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
Stream: streaming,
|
||||
}
|
||||
|
||||
if apiKey != nil {
|
||||
logEntry.APIKeyID = &apiKey.ID
|
||||
if apiKey.User != nil {
|
||||
logEntry.UserID = &apiKey.User.ID
|
||||
}
|
||||
if apiKey.GroupID != nil {
|
||||
logEntry.GroupID = apiKey.GroupID
|
||||
}
|
||||
}
|
||||
|
||||
enqueueOpsErrorLog(ops, logEntry)
|
||||
}
|
||||
|
||||
func resolveOpsPlatform(apiKey *service.APIKey, fallback string) string {
|
||||
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform != "" {
|
||||
return apiKey.Group.Platform
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func classifyOpsPhase(errType, message string) string {
|
||||
msg := strings.ToLower(message)
|
||||
switch errType {
|
||||
case "authentication_error":
|
||||
return "auth"
|
||||
case "billing_error", "subscription_error":
|
||||
return "billing"
|
||||
case "rate_limit_error":
|
||||
if strings.Contains(msg, "concurrency") || strings.Contains(msg, "pending") {
|
||||
return "concurrency"
|
||||
}
|
||||
return "upstream"
|
||||
case "invalid_request_error":
|
||||
return "response"
|
||||
case "upstream_error", "overloaded_error":
|
||||
return "upstream"
|
||||
case "api_error":
|
||||
if strings.Contains(msg, "no available accounts") {
|
||||
return "scheduling"
|
||||
}
|
||||
return "internal"
|
||||
default:
|
||||
return "internal"
|
||||
}
|
||||
}
|
||||
|
||||
func classifyOpsSeverity(errType string, status int) string {
|
||||
switch errType {
|
||||
case "invalid_request_error", "authentication_error", "billing_error", "subscription_error":
|
||||
return "P3"
|
||||
}
|
||||
if status >= 500 {
|
||||
return "P1"
|
||||
}
|
||||
if status == 429 {
|
||||
return "P1"
|
||||
}
|
||||
if status >= 400 {
|
||||
return "P2"
|
||||
}
|
||||
return "P3"
|
||||
}
|
||||
@@ -39,9 +39,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
ApiBaseUrl: settings.ApiBaseUrl,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocUrl: settings.DocUrl,
|
||||
DocURL: settings.DocURL,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -18,11 +18,11 @@ import (
|
||||
// UsageHandler handles usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.ApiKeyService
|
||||
apiKeyService *service.APIKeyService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new UsageHandler
|
||||
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.ApiKeyService) *UsageHandler {
|
||||
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.APIKeyService) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
@@ -111,7 +111,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: subject.UserID, // Always filter by current user for security
|
||||
ApiKeyID: apiKeyID,
|
||||
APIKeyID: apiKeyID,
|
||||
Model: model,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
@@ -235,7 +235,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
var stats *service.UsageStats
|
||||
var err error
|
||||
if apiKeyID > 0 {
|
||||
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
stats, err = h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
} else {
|
||||
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime)
|
||||
}
|
||||
@@ -346,49 +346,49 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// BatchApiKeysUsageRequest represents the request for batch API keys usage
|
||||
type BatchApiKeysUsageRequest struct {
|
||||
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
// BatchAPIKeysUsageRequest represents the request for batch API keys usage
|
||||
type BatchAPIKeysUsageRequest struct {
|
||||
APIKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// DashboardApiKeysUsage handles getting usage stats for user's own API keys
|
||||
// DashboardAPIKeysUsage handles getting usage stats for user's own API keys
|
||||
// POST /api/v1/usage/dashboard/api-keys-usage
|
||||
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req BatchApiKeysUsageRequest
|
||||
var req BatchAPIKeysUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
if len(req.APIKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
// Limit the number of API key IDs to prevent SQL parameter overflow
|
||||
if len(req.ApiKeyIDs) > 100 {
|
||||
if len(req.APIKeyIDs) > 100 {
|
||||
response.BadRequest(c, "Too many API key IDs (maximum 100 allowed)")
|
||||
return
|
||||
}
|
||||
|
||||
validApiKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.ApiKeyIDs)
|
||||
validAPIKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.APIKeyIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(validApiKeyIDs) == 0 {
|
||||
if len(validAPIKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
|
||||
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
// ProvideAdminHandlers creates the AdminHandlers struct
|
||||
func ProvideAdminHandlers(
|
||||
dashboardHandler *admin.DashboardHandler,
|
||||
opsHandler *admin.OpsHandler,
|
||||
userHandler *admin.UserHandler,
|
||||
groupHandler *admin.GroupHandler,
|
||||
accountHandler *admin.AccountHandler,
|
||||
@@ -27,6 +28,7 @@ func ProvideAdminHandlers(
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
Ops: opsHandler,
|
||||
User: userHandler,
|
||||
Group: groupHandler,
|
||||
Account: accountHandler,
|
||||
@@ -96,6 +98,7 @@ var ProviderSet = wire.NewSet(
|
||||
|
||||
// Admin handlers
|
||||
admin.NewDashboardHandler,
|
||||
admin.NewOpsHandler,
|
||||
admin.NewUserHandler,
|
||||
admin.NewGroupHandler,
|
||||
admin.NewAccountHandler,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// Package antigravity provides a client for interacting with Google's Antigravity API,
|
||||
// handling OAuth authentication, token management, and account tier information retrieval.
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package claude provides Claude API client constants and utilities.
|
||||
package claude
|
||||
|
||||
// Claude Code 客户端相关常量
|
||||
@@ -16,13 +17,13 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
|
||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
||||
|
||||
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
|
||||
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
// APIKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
|
||||
const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
|
||||
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
|
||||
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
|
||||
const APIKeyHaikuBetaHeader = BetaInterleavedThinking
|
||||
|
||||
// Claude Code 客户端默认请求头
|
||||
// DefaultHeaders are the default request headers for Claude Code client.
|
||||
var DefaultHeaders = map[string]string{
|
||||
"User-Agent": "claude-cli/2.0.62 (external, cli)",
|
||||
"X-Stainless-Lang": "js",
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package errors provides custom error types and error handling utilities.
|
||||
// nolint:mnd
|
||||
package errors
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Package gemini provides minimal fallback model metadata for Gemini native endpoints.
|
||||
package gemini
|
||||
|
||||
// This package provides minimal fallback model metadata for Gemini native endpoints.
|
||||
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
|
||||
// This package is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
|
||||
|
||||
type Model struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// Package geminicli provides OAuth authentication and API client functionality
|
||||
// for Google's Gemini AI services, supporting both AI Studio and Code Assist endpoints.
|
||||
package geminicli
|
||||
|
||||
import "time"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package googleapi provides utilities for Google API interactions.
|
||||
package googleapi
|
||||
|
||||
import "net/http"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package oauth provides OAuth 2.0 utilities including PKCE flow, session management, and token exchange.
|
||||
package oauth
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package openai provides OpenAI API models and configuration.
|
||||
package openai
|
||||
|
||||
import _ "embed"
|
||||
|
||||
@@ -327,7 +327,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// ExtractUserInfo extracts user information from ID Token claims
|
||||
// UserInfo extracts user information from ID Token claims
|
||||
type UserInfo struct {
|
||||
Email string
|
||||
ChatGPTAccountID string
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package pagination provides utilities for handling paginated queries and results.
|
||||
package pagination
|
||||
|
||||
// PaginationParams 分页参数
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package response provides HTTP response utilities for standardized API responses and error handling.
|
||||
package response
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package sysutil provides system-level utilities for service management.
|
||||
package sysutil
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package usagestats defines types for tracking and reporting API usage statistics.
|
||||
package usagestats
|
||||
|
||||
import "time"
|
||||
@@ -10,8 +11,8 @@ type DashboardStats struct {
|
||||
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
|
||||
|
||||
// API Key 统计
|
||||
TotalApiKeys int64 `json:"total_api_keys"`
|
||||
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
|
||||
TotalAPIKeys int64 `json:"total_api_keys"`
|
||||
ActiveAPIKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
|
||||
|
||||
// 账户统计
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
@@ -82,10 +83,10 @@ type UserUsageTrendPoint struct {
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
||||
type ApiKeyUsageTrendPoint struct {
|
||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||
type APIKeyUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
KeyName string `json:"key_name"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
@@ -94,8 +95,8 @@ type ApiKeyUsageTrendPoint struct {
|
||||
// UserDashboardStats 用户仪表盘统计
|
||||
type UserDashboardStats struct {
|
||||
// API Key 统计
|
||||
TotalApiKeys int64 `json:"total_api_keys"`
|
||||
ActiveApiKeys int64 `json:"active_api_keys"`
|
||||
TotalAPIKeys int64 `json:"total_api_keys"`
|
||||
ActiveAPIKeys int64 `json:"active_api_keys"`
|
||||
|
||||
// 累计 Token 使用统计
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
@@ -128,7 +129,7 @@ type UserDashboardStats struct {
|
||||
// UsageLogFilters represents filters for usage log queries
|
||||
type UsageLogFilters struct {
|
||||
UserID int64
|
||||
ApiKeyID int64
|
||||
APIKeyID int64
|
||||
AccountID int64
|
||||
GroupID int64
|
||||
Model string
|
||||
@@ -157,9 +158,9 @@ type BatchUserUsageStats struct {
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
// BatchApiKeyUsageStats represents usage stats for a single API key
|
||||
type BatchApiKeyUsageStats struct {
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
// BatchAPIKeyUsageStats represents usage stats for a single API key
|
||||
type BatchAPIKeyUsageStats struct {
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
@@ -135,12 +135,12 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
name: "filter_by_type",
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeApiKey})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeAPIKey})
|
||||
},
|
||||
accType: service.AccountTypeApiKey,
|
||||
accType: service.AccountTypeAPIKey,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type)
|
||||
s.Require().Equal(service.AccountTypeAPIKey, accounts[0].Type)
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -80,7 +80,7 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
|
||||
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
|
||||
}
|
||||
|
||||
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
|
||||
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsAPIKeys(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
entClient := tx.Client()
|
||||
@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
|
||||
|
||||
userRepo := newUserRepositoryWithSQL(entClient, tx)
|
||||
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
|
||||
apiKeyRepo := NewApiKeyRepository(entClient)
|
||||
apiKeyRepo := NewAPIKeyRepository(entClient)
|
||||
|
||||
u := &service.User{
|
||||
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
|
||||
@@ -110,7 +110,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
|
||||
}
|
||||
require.NoError(t, userRepo.Create(ctx, u))
|
||||
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
|
||||
Name: "test key",
|
||||
|
||||
@@ -24,7 +24,7 @@ type apiKeyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
|
||||
func NewAPIKeyCache(rdb *redis.Client) service.APIKeyCache {
|
||||
return &apiKeyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,11 +13,11 @@ import (
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ApiKeyCacheSuite struct {
|
||||
type APIKeyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
}
|
||||
|
||||
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
|
||||
func (s *APIKeyCacheSuite) TestCreateAttemptCount() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
@@ -78,7 +78,7 @@ func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ApiKeyCacheSuite) TestDailyUsage() {
|
||||
func (s *APIKeyCacheSuite) TestDailyUsage() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
@@ -122,6 +122,6 @@ func (s *ApiKeyCacheSuite) TestDailyUsage() {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApiKeyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ApiKeyCacheSuite))
|
||||
func TestAPIKeyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(APIKeyCacheSuite))
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApiKeyRateLimitKey(t *testing.T) {
|
||||
func TestAPIKeyRateLimitKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
|
||||
@@ -16,17 +16,17 @@ type apiKeyRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
|
||||
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
|
||||
return &apiKeyRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
|
||||
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
|
||||
// 默认过滤已软删除记录,避免删除后仍被查询到。
|
||||
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
|
||||
return r.client.APIKey.Query().Where(apikey.DeletedAtIsNil())
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
created, err := r.client.ApiKey.Create().
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
|
||||
created, err := r.client.APIKey.Create().
|
||||
SetUserID(key.UserID).
|
||||
SetKey(key.Key).
|
||||
SetName(key.Name).
|
||||
@@ -38,10 +38,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) erro
|
||||
key.CreatedAt = created.CreatedAt
|
||||
key.UpdatedAt = created.UpdatedAt
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
|
||||
return translatePersistenceError(err, nil, service.ErrAPIKeyExists)
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
WithUser().
|
||||
@@ -49,7 +49,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
|
||||
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
|
||||
// 相比 GetByID,此方法性能更优,因为:
|
||||
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
|
||||
// - 不加载完整的 ApiKey 实体及其关联数据(User、Group 等)
|
||||
// - 不加载完整的 APIKey 实体及其关联数据(User、Group 等)
|
||||
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
|
||||
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
m, err := r.activeQuery().
|
||||
@@ -68,14 +68,14 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return 0, service.ErrApiKeyNotFound
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return m.UserID, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.KeyEQ(key)).
|
||||
WithUser().
|
||||
@@ -83,21 +83,21 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
|
||||
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
|
||||
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
|
||||
// 则会更新已删除的记录。
|
||||
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
|
||||
// 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。
|
||||
now := time.Now()
|
||||
builder := r.client.ApiKey.Update().
|
||||
builder := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
@@ -114,7 +114,7 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
|
||||
}
|
||||
if affected == 0 {
|
||||
// 更新影响行数为 0,说明记录不存在或已被软删除。
|
||||
return service.ErrApiKeyNotFound
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
// 使用同一时间戳回填,避免并发删除导致二次查询失败。
|
||||
@@ -124,18 +124,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
|
||||
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
|
||||
affected, err := r.client.ApiKey.Update().
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetDeletedAt(time.Now()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrApiKeyNotFound
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
exists, err := r.client.ApiKey.Query().
|
||||
exists, err := r.client.APIKey.Query().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Exist(mixins.SkipSoftDelete(ctx))
|
||||
if err != nil {
|
||||
@@ -144,12 +144,12 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
return service.ErrApiKeyNotFound
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
@@ -167,7 +167,7 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
@@ -180,7 +180,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
ids, err := r.client.ApiKey.Query().
|
||||
ids, err := r.client.APIKey.Query().
|
||||
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
|
||||
IDs(ctx)
|
||||
if err != nil {
|
||||
@@ -199,7 +199,7 @@ func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
@@ -217,7 +217,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
@@ -225,8 +225,8 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// SearchApiKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
q := r.activeQuery()
|
||||
if userID > 0 {
|
||||
q = q.Where(apikey.UserIDEQ(userID))
|
||||
@@ -241,7 +241,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
@@ -250,7 +250,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
|
||||
|
||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
n, err := r.client.ApiKey.Update().
|
||||
n, err := r.client.APIKey.Update().
|
||||
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx)
|
||||
@@ -263,11 +263,11 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey {
|
||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.ApiKey{
|
||||
out := &service.APIKey{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
|
||||
@@ -12,30 +12,30 @@ import (
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ApiKeyRepoSuite struct {
|
||||
type APIKeyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *apiKeyRepository
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) SetupTest() {
|
||||
func (s *APIKeyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository)
|
||||
s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository)
|
||||
}
|
||||
|
||||
func TestApiKeyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(ApiKeyRepoSuite))
|
||||
func TestAPIKeyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(APIKeyRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / GetByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCreate() {
|
||||
func (s *APIKeyRepoSuite) TestCreate() {
|
||||
user := s.mustCreateUser("create@test.com")
|
||||
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-create-test",
|
||||
Name: "Test Key",
|
||||
@@ -51,16 +51,16 @@ func (s *ApiKeyRepoSuite) TestCreate() {
|
||||
s.Require().Equal("sk-create-test", got.Key)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
|
||||
func (s *APIKeyRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey() {
|
||||
func (s *APIKeyRepoSuite) TestGetByKey() {
|
||||
user := s.mustCreateUser("getbykey@test.com")
|
||||
group := s.mustCreateGroup("g-key")
|
||||
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-getbykey",
|
||||
Name: "My Key",
|
||||
@@ -78,16 +78,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey() {
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
|
||||
func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
|
||||
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
|
||||
s.Require().Error(err, "expected error for non-existent key")
|
||||
}
|
||||
|
||||
// --- Update ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate() {
|
||||
func (s *APIKeyRepoSuite) TestUpdate() {
|
||||
user := s.mustCreateUser("update@test.com")
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update",
|
||||
Name: "Original",
|
||||
@@ -108,10 +108,10 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
|
||||
s.Require().Equal(service.StatusDisabled, got.Status)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
func (s *APIKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
user := s.mustCreateUser("cleargroup@test.com")
|
||||
group := s.mustCreateGroup("g-clear")
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-clear-group",
|
||||
Name: "Group Key",
|
||||
@@ -131,9 +131,9 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestDelete() {
|
||||
func (s *APIKeyRepoSuite) TestDelete() {
|
||||
user := s.mustCreateUser("delete@test.com")
|
||||
key := &service.ApiKey{
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-delete",
|
||||
Name: "Delete Me",
|
||||
@@ -150,10 +150,10 @@ func (s *ApiKeyRepoSuite) TestDelete() {
|
||||
|
||||
// --- ListByUserID / CountByUserID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID() {
|
||||
func (s *APIKeyRepoSuite) TestListByUserID() {
|
||||
user := s.mustCreateUser("listbyuser@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
|
||||
s.mustCreateAPIKey(user.ID, "sk-list-1", "Key 1", nil)
|
||||
s.mustCreateAPIKey(user.ID, "sk-list-2", "Key 2", nil)
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
@@ -161,10 +161,10 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
user := s.mustCreateUser("paging@test.com")
|
||||
for i := 0; i < 5; i++ {
|
||||
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
|
||||
s.mustCreateAPIKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
|
||||
}
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
|
||||
@@ -174,10 +174,10 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
s.Require().Equal(3, page.Pages)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByUserID() {
|
||||
func (s *APIKeyRepoSuite) TestCountByUserID() {
|
||||
user := s.mustCreateUser("count@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
|
||||
s.mustCreateAPIKey(user.ID, "sk-count-1", "K1", nil)
|
||||
s.mustCreateAPIKey(user.ID, "sk-count-2", "K2", nil)
|
||||
|
||||
count, err := s.repo.CountByUserID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "CountByUserID")
|
||||
@@ -186,13 +186,13 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
|
||||
|
||||
// --- ListByGroupID / CountByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByGroupID() {
|
||||
func (s *APIKeyRepoSuite) TestListByGroupID() {
|
||||
user := s.mustCreateUser("listbygroup@test.com")
|
||||
group := s.mustCreateGroup("g-list")
|
||||
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-1", "K1", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-2", "K2", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-3", "K3", nil) // no group
|
||||
s.mustCreateAPIKey(user.ID, "sk-grp-1", "K1", &group.ID)
|
||||
s.mustCreateAPIKey(user.ID, "sk-grp-2", "K2", &group.ID)
|
||||
s.mustCreateAPIKey(user.ID, "sk-grp-3", "K3", nil) // no group
|
||||
|
||||
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByGroupID")
|
||||
@@ -202,10 +202,10 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
|
||||
s.Require().NotNil(keys[0].User)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
|
||||
func (s *APIKeyRepoSuite) TestCountByGroupID() {
|
||||
user := s.mustCreateUser("countgroup@test.com")
|
||||
group := s.mustCreateGroup("g-count")
|
||||
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
|
||||
s.mustCreateAPIKey(user.ID, "sk-gc-1", "K1", &group.ID)
|
||||
|
||||
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
@@ -214,9 +214,9 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
|
||||
|
||||
// --- ExistsByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestExistsByKey() {
|
||||
func (s *APIKeyRepoSuite) TestExistsByKey() {
|
||||
user := s.mustCreateUser("exists@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
|
||||
s.mustCreateAPIKey(user.ID, "sk-exists", "K", nil)
|
||||
|
||||
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
@@ -227,47 +227,47 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- SearchApiKeys ---
|
||||
// --- SearchAPIKeys ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys() {
|
||||
user := s.mustCreateUser("search@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
|
||||
s.mustCreateAPIKey(user.ID, "sk-search-1", "Production Key", nil)
|
||||
s.mustCreateAPIKey(user.ID, "sk-search-2", "Development Key", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "prod", 10)
|
||||
s.Require().NoError(err, "SearchAPIKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Contains(found[0].Name, "Production")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoKeyword() {
|
||||
user := s.mustCreateUser("searchnokw@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
|
||||
s.mustCreateAPIKey(user.ID, "sk-nk-1", "K1", nil)
|
||||
s.mustCreateAPIKey(user.ID, "sk-nk-2", "K2", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 2)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoUserID() {
|
||||
user := s.mustCreateUser("searchnouid@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
|
||||
s.mustCreateAPIKey(user.ID, "sk-nu-1", "TestKey", nil)
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, 0, "testkey", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 1)
|
||||
}
|
||||
|
||||
// --- ClearGroupIDByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
func (s *APIKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
user := s.mustCreateUser("cleargrp@test.com")
|
||||
group := s.mustCreateGroup("g-clear-bulk")
|
||||
|
||||
k1 := s.mustCreateApiKey(user.ID, "sk-clr-1", "K1", &group.ID)
|
||||
k2 := s.mustCreateApiKey(user.ID, "sk-clr-2", "K2", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-clr-3", "K3", nil) // no group
|
||||
k1 := s.mustCreateAPIKey(user.ID, "sk-clr-1", "K1", &group.ID)
|
||||
k2 := s.mustCreateAPIKey(user.ID, "sk-clr-2", "K2", &group.ID)
|
||||
s.mustCreateAPIKey(user.ID, "sk-clr-3", "K3", nil) // no group
|
||||
|
||||
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||
@@ -284,10 +284,10 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
|
||||
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
user := s.mustCreateUser("k@example.com")
|
||||
group := s.mustCreateGroup("g-k")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
|
||||
key := s.mustCreateAPIKey(user.ID, "sk-test-1", "My Key", &group.ID)
|
||||
key.GroupID = &group.ID
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
@@ -320,13 +320,13 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
s.Require().True(exists, "expected key to exist")
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "renam", 10)
|
||||
s.Require().NoError(err, "SearchAPIKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Equal(key.ID, found[0].ID)
|
||||
|
||||
// ClearGroupIDByGroupID
|
||||
k2 := s.mustCreateApiKey(user.ID, "sk-test-2", "Group Key", &group.ID)
|
||||
k2 := s.mustCreateAPIKey(user.ID, "sk-test-2", "Group Key", &group.ID)
|
||||
k2.GroupID = &group.ID
|
||||
|
||||
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
@@ -346,7 +346,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
func (s *APIKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
s.T().Helper()
|
||||
|
||||
u, err := s.client.User.Create().
|
||||
@@ -359,7 +359,7 @@ func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
return userEntityToService(u)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
func (s *APIKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
s.T().Helper()
|
||||
|
||||
g, err := s.client.Group.Create().
|
||||
@@ -370,10 +370,10 @@ func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
return groupEntityToService(g)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.ApiKey {
|
||||
func (s *APIKeyRepoSuite) mustCreateAPIKey(userID int64, key, name string, groupID *int64) *service.APIKey {
|
||||
s.T().Helper()
|
||||
|
||||
k := &service.ApiKey{
|
||||
k := &service.APIKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: name,
|
||||
|
||||
@@ -27,8 +27,14 @@ const (
|
||||
accountSlotKeyPrefix = "concurrency:account:"
|
||||
// 格式: concurrency:user:{userID}
|
||||
userSlotKeyPrefix = "concurrency:user:"
|
||||
// 等待队列计数器格式: concurrency:wait:{userID}
|
||||
waitQueueKeyPrefix = "concurrency:wait:"
|
||||
|
||||
// Wait queue keys (global structures)
|
||||
// - total: integer total queue depth across all users
|
||||
// - updated: sorted set of userID -> lastUpdateUnixSec (for TTL cleanup)
|
||||
// - counts: hash of userID -> current wait count
|
||||
waitQueueTotalKey = "concurrency:wait:total"
|
||||
waitQueueUpdatedKey = "concurrency:wait:updated"
|
||||
waitQueueCountsKey = "concurrency:wait:counts"
|
||||
// 账号级等待队列计数器格式: wait:account:{accountID}
|
||||
accountWaitKeyPrefix = "wait:account:"
|
||||
|
||||
@@ -94,27 +100,55 @@ var (
|
||||
`)
|
||||
|
||||
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
|
||||
// KEYS[1] = wait queue key
|
||||
// ARGV[1] = maxWait
|
||||
// ARGV[2] = TTL in seconds
|
||||
// KEYS[1] = total key
|
||||
// KEYS[2] = updated zset key
|
||||
// KEYS[3] = counts hash key
|
||||
// ARGV[1] = userID
|
||||
// ARGV[2] = maxWait
|
||||
// ARGV[3] = TTL in seconds
|
||||
// ARGV[4] = cleanup limit
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
local totalKey = KEYS[1]
|
||||
local updatedKey = KEYS[2]
|
||||
local countsKey = KEYS[3]
|
||||
|
||||
local userID = ARGV[1]
|
||||
local maxWait = tonumber(ARGV[2])
|
||||
local ttl = tonumber(ARGV[3])
|
||||
local cleanupLimit = tonumber(ARGV[4])
|
||||
|
||||
redis.call('SETNX', totalKey, 0)
|
||||
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
-- Cleanup expired users (bounded)
|
||||
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
|
||||
for _, uid in ipairs(expired) do
|
||||
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
|
||||
if c > 0 then
|
||||
redis.call('DECRBY', totalKey, c)
|
||||
end
|
||||
redis.call('HDEL', countsKey, uid)
|
||||
redis.call('ZREM', updatedKey, uid)
|
||||
end
|
||||
|
||||
if current >= tonumber(ARGV[1]) then
|
||||
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
|
||||
if current >= maxWait then
|
||||
return 0
|
||||
end
|
||||
|
||||
local newVal = redis.call('INCR', KEYS[1])
|
||||
local newVal = current + 1
|
||||
redis.call('HSET', countsKey, userID, newVal)
|
||||
redis.call('ZADD', updatedKey, now, userID)
|
||||
redis.call('INCR', totalKey)
|
||||
|
||||
-- Only set TTL on first creation to avoid refreshing zombie data
|
||||
if newVal == 1 then
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
end
|
||||
-- Keep global structures from living forever in totally idle deployments.
|
||||
local ttlKeep = ttl * 2
|
||||
redis.call('EXPIRE', totalKey, ttlKeep)
|
||||
redis.call('EXPIRE', updatedKey, ttlKeep)
|
||||
redis.call('EXPIRE', countsKey, ttlKeep)
|
||||
|
||||
return 1
|
||||
`)
|
||||
@@ -144,6 +178,111 @@ var (
|
||||
|
||||
// decrementWaitScript - same as before
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local totalKey = KEYS[1]
|
||||
local updatedKey = KEYS[2]
|
||||
local countsKey = KEYS[3]
|
||||
|
||||
local userID = ARGV[1]
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local cleanupLimit = tonumber(ARGV[3])
|
||||
|
||||
redis.call('SETNX', totalKey, 0)
|
||||
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
-- Cleanup expired users (bounded)
|
||||
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
|
||||
for _, uid in ipairs(expired) do
|
||||
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
|
||||
if c > 0 then
|
||||
redis.call('DECRBY', totalKey, c)
|
||||
end
|
||||
redis.call('HDEL', countsKey, uid)
|
||||
redis.call('ZREM', updatedKey, uid)
|
||||
end
|
||||
|
||||
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
|
||||
if current <= 0 then
|
||||
return 1
|
||||
end
|
||||
|
||||
local newVal = current - 1
|
||||
if newVal <= 0 then
|
||||
redis.call('HDEL', countsKey, userID)
|
||||
redis.call('ZREM', updatedKey, userID)
|
||||
else
|
||||
redis.call('HSET', countsKey, userID, newVal)
|
||||
redis.call('ZADD', updatedKey, now, userID)
|
||||
end
|
||||
redis.call('DECR', totalKey)
|
||||
|
||||
local ttlKeep = ttl * 2
|
||||
redis.call('EXPIRE', totalKey, ttlKeep)
|
||||
redis.call('EXPIRE', updatedKey, ttlKeep)
|
||||
redis.call('EXPIRE', countsKey, ttlKeep)
|
||||
|
||||
return 1
|
||||
`)
|
||||
|
||||
// getTotalWaitScript returns the global wait depth with TTL cleanup.
|
||||
// KEYS[1] = total key
|
||||
// KEYS[2] = updated zset key
|
||||
// KEYS[3] = counts hash key
|
||||
// ARGV[1] = TTL in seconds
|
||||
// ARGV[2] = cleanup limit
|
||||
getTotalWaitScript = redis.NewScript(`
|
||||
local totalKey = KEYS[1]
|
||||
local updatedKey = KEYS[2]
|
||||
local countsKey = KEYS[3]
|
||||
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local cleanupLimit = tonumber(ARGV[2])
|
||||
|
||||
redis.call('SETNX', totalKey, 0)
|
||||
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
-- Cleanup expired users (bounded)
|
||||
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
|
||||
for _, uid in ipairs(expired) do
|
||||
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
|
||||
if c > 0 then
|
||||
redis.call('DECRBY', totalKey, c)
|
||||
end
|
||||
redis.call('HDEL', countsKey, uid)
|
||||
redis.call('ZREM', updatedKey, uid)
|
||||
end
|
||||
|
||||
-- If totalKey got lost but counts exist (e.g. Redis restart), recompute once.
|
||||
local total = redis.call('GET', totalKey)
|
||||
if total == false then
|
||||
total = 0
|
||||
local vals = redis.call('HVALS', countsKey)
|
||||
for _, v in ipairs(vals) do
|
||||
total = total + tonumber(v)
|
||||
end
|
||||
redis.call('SET', totalKey, total)
|
||||
end
|
||||
|
||||
local ttlKeep = ttl * 2
|
||||
redis.call('EXPIRE', totalKey, ttlKeep)
|
||||
redis.call('EXPIRE', updatedKey, ttlKeep)
|
||||
redis.call('EXPIRE', countsKey, ttlKeep)
|
||||
|
||||
local result = tonumber(redis.call('GET', totalKey) or '0')
|
||||
if result < 0 then
|
||||
result = 0
|
||||
redis.call('SET', totalKey, 0)
|
||||
end
|
||||
return result
|
||||
`)
|
||||
|
||||
// decrementAccountWaitScript - account-level wait queue decrement
|
||||
decrementAccountWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
@@ -244,7 +383,9 @@ func userSlotKey(userID int64) string {
|
||||
}
|
||||
|
||||
func waitQueueKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
// Historical: per-user string keys were used.
|
||||
// Now we use global structures keyed by userID string.
|
||||
return strconv.FormatInt(userID, 10)
|
||||
}
|
||||
|
||||
func accountWaitKey(accountID int64) string {
|
||||
@@ -308,8 +449,16 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
|
||||
// Wait queue operations
|
||||
|
||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
key := waitQueueKey(userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int()
|
||||
userKey := waitQueueKey(userID)
|
||||
result, err := incrementWaitScript.Run(
|
||||
ctx,
|
||||
c.rdb,
|
||||
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
|
||||
userKey,
|
||||
maxWait,
|
||||
c.waitQueueTTLSeconds,
|
||||
200, // cleanup limit per call
|
||||
).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -317,11 +466,35 @@ func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64,
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
key := waitQueueKey(userID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
userKey := waitQueueKey(userID)
|
||||
_, err := decrementWaitScript.Run(
|
||||
ctx,
|
||||
c.rdb,
|
||||
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
|
||||
userKey,
|
||||
c.waitQueueTTLSeconds,
|
||||
200, // cleanup limit per call
|
||||
).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetTotalWaitCount(ctx context.Context) (int, error) {
|
||||
if c.rdb == nil {
|
||||
return 0, nil
|
||||
}
|
||||
total, err := getTotalWaitScript.Run(
|
||||
ctx,
|
||||
c.rdb,
|
||||
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
|
||||
c.waitQueueTTLSeconds,
|
||||
500, // cleanup limit per query (rare)
|
||||
).Int64()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(total), nil
|
||||
}
|
||||
|
||||
// Account wait queue operations
|
||||
|
||||
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
@@ -335,7 +508,7 @@ func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accoun
|
||||
|
||||
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||
key := accountWaitKey(accountID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
_, err := decrementAccountWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
userID := int64(20)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
userKey := waitQueueKey(userID)
|
||||
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 1")
|
||||
@@ -172,31 +172,31 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 3")
|
||||
require.False(s.T(), ok, "expected wait increment over max to fail")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitQueueTotalKey).Result()
|
||||
require.NoError(s.T(), err, "TTL wait total key")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL*2)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int()
|
||||
require.NoError(s.T(), err, "HGET wait queue count")
|
||||
require.Equal(s.T(), 1, val, "expected wait count 1")
|
||||
|
||||
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
|
||||
require.NoError(s.T(), err, "GET wait queue total")
|
||||
require.Equal(s.T(), 1, total, "expected total wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
||||
userID := int64(300)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
userKey := waitQueueKey(userID)
|
||||
|
||||
// Test decrement on non-existent key - should not error and should not create negative value
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
|
||||
|
||||
// Verify no key was created or it's not negative
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
// Verify count remains zero / absent.
|
||||
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
|
||||
|
||||
// Set count to 1, then decrement twice
|
||||
@@ -210,12 +210,15 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
||||
// Decrement again on 0 - should not go negative
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
|
||||
|
||||
// Verify count is 0, not negative
|
||||
val, err = s.rdb.Get(s.ctx, waitKey).Int()
|
||||
// Verify per-user count is absent and total is non-negative.
|
||||
_, err = s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected count field removed on zero")
|
||||
|
||||
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey after double decrement")
|
||||
require.NoError(s.T(), err)
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
|
||||
require.GreaterOrEqual(s.T(), total, 0, "expected non-negative total wait count")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Package infrastructure 提供应用程序的基础设施层组件。
|
||||
// Package repository 提供应用程序的基础设施层组件。
|
||||
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
|
||||
package repository
|
||||
|
||||
|
||||
@@ -243,7 +243,7 @@ func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) *
|
||||
return a
|
||||
}
|
||||
|
||||
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *service.ApiKey {
|
||||
func mustCreateAPIKey(t *testing.T, client *dbent.Client, k *service.APIKey) *service.APIKey {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -257,7 +257,7 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *se
|
||||
k.Name = "default"
|
||||
}
|
||||
|
||||
create := client.ApiKey.Create().
|
||||
create := client.APIKey.Create().
|
||||
SetUserID(k.UserID).
|
||||
SetKey(k.Key).
|
||||
SetName(k.Name).
|
||||
|
||||
@@ -293,8 +293,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
|
||||
// 2. Clear group_id for api keys bound to this group.
|
||||
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
|
||||
// 与 ApiKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
|
||||
if _, err := txClient.ApiKey.Update().
|
||||
// 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
|
||||
if _, err := txClient.APIKey.Update().
|
||||
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx); err != nil {
|
||||
|
||||
190
backend/internal/repository/ops.go
Normal file
190
backend/internal/repository/ops.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// ListErrorLogs queries ops_error_logs with optional filters and pagination.
|
||||
// It returns the list items and the total count of matching rows.
|
||||
func (r *OpsRepository) ListErrorLogs(ctx context.Context, filter *service.ErrorLogFilter) ([]*service.ErrorLog, int64, error) {
|
||||
page := 1
|
||||
pageSize := 20
|
||||
if filter != nil {
|
||||
if filter.Page > 0 {
|
||||
page = filter.Page
|
||||
}
|
||||
if filter.PageSize > 0 {
|
||||
pageSize = filter.PageSize
|
||||
}
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
conditions := make([]string, 0)
|
||||
args := make([]any, 0)
|
||||
|
||||
addCondition := func(condition string, values ...any) {
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, values...)
|
||||
}
|
||||
|
||||
if filter != nil {
|
||||
// 默认查询最近 24 小时
|
||||
if filter.StartTime == nil && filter.EndTime == nil {
|
||||
defaultStart := time.Now().Add(-24 * time.Hour)
|
||||
filter.StartTime = &defaultStart
|
||||
}
|
||||
|
||||
if filter.StartTime != nil {
|
||||
addCondition(fmt.Sprintf("created_at >= $%d", len(args)+1), *filter.StartTime)
|
||||
}
|
||||
if filter.EndTime != nil {
|
||||
addCondition(fmt.Sprintf("created_at <= $%d", len(args)+1), *filter.EndTime)
|
||||
}
|
||||
if filter.ErrorCode != nil {
|
||||
addCondition(fmt.Sprintf("status_code = $%d", len(args)+1), *filter.ErrorCode)
|
||||
}
|
||||
if provider := strings.TrimSpace(filter.Provider); provider != "" {
|
||||
addCondition(fmt.Sprintf("platform = $%d", len(args)+1), provider)
|
||||
}
|
||||
if filter.AccountID != nil {
|
||||
addCondition(fmt.Sprintf("account_id = $%d", len(args)+1), *filter.AccountID)
|
||||
}
|
||||
}
|
||||
|
||||
where := ""
|
||||
if len(conditions) > 0 {
|
||||
where = "WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
countQuery := fmt.Sprintf(`SELECT COUNT(1) FROM ops_error_logs %s`, where)
|
||||
var total int64
|
||||
if err := scanSingleRow(ctx, r.sql, countQuery, args, &total); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
total = 0
|
||||
} else {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
listQuery := fmt.Sprintf(`
|
||||
SELECT
|
||||
id,
|
||||
created_at,
|
||||
severity,
|
||||
request_id,
|
||||
account_id,
|
||||
request_path,
|
||||
platform,
|
||||
model,
|
||||
status_code,
|
||||
error_message,
|
||||
duration_ms,
|
||||
retry_count,
|
||||
stream
|
||||
FROM ops_error_logs
|
||||
%s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, where, len(args)+1, len(args)+2)
|
||||
|
||||
listArgs := append(append([]any{}, args...), pageSize, offset)
|
||||
rows, err := r.sql.QueryContext(ctx, listQuery, listArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
results := make([]*service.ErrorLog, 0)
|
||||
for rows.Next() {
|
||||
var (
|
||||
id int64
|
||||
createdAt time.Time
|
||||
severity sql.NullString
|
||||
requestID sql.NullString
|
||||
accountID sql.NullInt64
|
||||
requestURI sql.NullString
|
||||
platform sql.NullString
|
||||
model sql.NullString
|
||||
statusCode sql.NullInt64
|
||||
message sql.NullString
|
||||
durationMs sql.NullInt64
|
||||
retryCount sql.NullInt64
|
||||
stream sql.NullBool
|
||||
)
|
||||
|
||||
if err := rows.Scan(
|
||||
&id,
|
||||
&createdAt,
|
||||
&severity,
|
||||
&requestID,
|
||||
&accountID,
|
||||
&requestURI,
|
||||
&platform,
|
||||
&model,
|
||||
&statusCode,
|
||||
&message,
|
||||
&durationMs,
|
||||
&retryCount,
|
||||
&stream,
|
||||
); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
entry := &service.ErrorLog{
|
||||
ID: id,
|
||||
Timestamp: createdAt,
|
||||
Level: levelFromSeverity(severity.String),
|
||||
RequestID: requestID.String,
|
||||
APIPath: requestURI.String,
|
||||
Provider: platform.String,
|
||||
Model: model.String,
|
||||
HTTPCode: int(statusCode.Int64),
|
||||
Stream: stream.Bool,
|
||||
}
|
||||
if accountID.Valid {
|
||||
entry.AccountID = strconv.FormatInt(accountID.Int64, 10)
|
||||
}
|
||||
if message.Valid {
|
||||
entry.ErrorMessage = message.String
|
||||
}
|
||||
if durationMs.Valid {
|
||||
v := int(durationMs.Int64)
|
||||
entry.DurationMs = &v
|
||||
}
|
||||
if retryCount.Valid {
|
||||
v := int(retryCount.Int64)
|
||||
entry.RetryCount = &v
|
||||
}
|
||||
|
||||
results = append(results, entry)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return results, total, nil
|
||||
}
|
||||
|
||||
func levelFromSeverity(severity string) string {
|
||||
sev := strings.ToUpper(strings.TrimSpace(severity))
|
||||
switch sev {
|
||||
case "P0", "P1":
|
||||
return "CRITICAL"
|
||||
case "P2":
|
||||
return "ERROR"
|
||||
case "P3":
|
||||
return "WARN"
|
||||
default:
|
||||
return "ERROR"
|
||||
}
|
||||
}
|
||||
127
backend/internal/repository/ops_cache.go
Normal file
127
backend/internal/repository/ops_cache.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
opsLatestMetricsKey = "ops:metrics:latest"
|
||||
|
||||
opsDashboardOverviewKeyPrefix = "ops:dashboard:overview:"
|
||||
|
||||
opsLatestMetricsTTL = 10 * time.Second
|
||||
)
|
||||
|
||||
func (r *OpsRepository) GetCachedLatestSystemMetric(ctx context.Context) (*service.OpsMetrics, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if r == nil || r.rdb == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := r.rdb.Get(ctx, opsLatestMetricsKey).Bytes()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("redis get cached latest system metric: %w", err)
|
||||
}
|
||||
|
||||
var metric service.OpsMetrics
|
||||
if err := json.Unmarshal(data, &metric); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal cached latest system metric: %w", err)
|
||||
}
|
||||
return &metric, nil
|
||||
}
|
||||
|
||||
func (r *OpsRepository) SetCachedLatestSystemMetric(ctx context.Context, metric *service.OpsMetrics) error {
|
||||
if metric == nil {
|
||||
return nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if r == nil || r.rdb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := json.Marshal(metric)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal cached latest system metric: %w", err)
|
||||
}
|
||||
return r.rdb.Set(ctx, opsLatestMetricsKey, data, opsLatestMetricsTTL).Err()
|
||||
}
|
||||
|
||||
func (r *OpsRepository) GetCachedDashboardOverview(ctx context.Context, timeRange string) (*service.DashboardOverviewData, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if r == nil || r.rdb == nil {
|
||||
return nil, nil
|
||||
}
|
||||
rangeKey := strings.TrimSpace(timeRange)
|
||||
if rangeKey == "" {
|
||||
rangeKey = "1h"
|
||||
}
|
||||
|
||||
key := opsDashboardOverviewKeyPrefix + rangeKey
|
||||
data, err := r.rdb.Get(ctx, key).Bytes()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("redis get cached dashboard overview: %w", err)
|
||||
}
|
||||
|
||||
var overview service.DashboardOverviewData
|
||||
if err := json.Unmarshal(data, &overview); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal cached dashboard overview: %w", err)
|
||||
}
|
||||
return &overview, nil
|
||||
}
|
||||
|
||||
func (r *OpsRepository) SetCachedDashboardOverview(ctx context.Context, timeRange string, data *service.DashboardOverviewData, ttl time.Duration) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = 10 * time.Second
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if r == nil || r.rdb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rangeKey := strings.TrimSpace(timeRange)
|
||||
if rangeKey == "" {
|
||||
rangeKey = "1h"
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal cached dashboard overview: %w", err)
|
||||
}
|
||||
key := opsDashboardOverviewKeyPrefix + rangeKey
|
||||
return r.rdb.Set(ctx, key, payload, ttl).Err()
|
||||
}
|
||||
|
||||
func (r *OpsRepository) PingRedis(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if r == nil || r.rdb == nil {
|
||||
return errors.New("redis client is nil")
|
||||
}
|
||||
return r.rdb.Ping(ctx).Err()
|
||||
}
|
||||
1333
backend/internal/repository/ops_repo.go
Normal file
1333
backend/internal/repository/ops_repo.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -34,15 +34,15 @@ func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, emai
|
||||
return u
|
||||
}
|
||||
|
||||
func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
|
||||
func TestEntSoftDelete_APIKey_DefaultFilterAndSkip(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// 使用全局 ent client,确保软删除验证在实际持久化数据上进行。
|
||||
client := testEntClient(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
|
||||
|
||||
repo := NewApiKeyRepository(client)
|
||||
key := &service.ApiKey{
|
||||
repo := NewAPIKeyRepository(client)
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
|
||||
Name: "soft-delete",
|
||||
@@ -53,28 +53,28 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
|
||||
|
||||
_, err := repo.GetByID(ctx, key.ID)
|
||||
require.ErrorIs(t, err, service.ErrApiKeyNotFound, "deleted rows should be hidden by default")
|
||||
require.ErrorIs(t, err, service.ErrAPIKeyNotFound, "deleted rows should be hidden by default")
|
||||
|
||||
_, err = client.ApiKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
|
||||
_, err = client.APIKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
|
||||
require.Error(t, err, "default ent query should not see soft-deleted rows")
|
||||
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
|
||||
|
||||
got, err := client.ApiKey.Query().
|
||||
got, err := client.APIKey.Query().
|
||||
Where(apikey.IDEQ(key.ID)).
|
||||
Only(mixins.SkipSoftDelete(ctx))
|
||||
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
|
||||
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
|
||||
}
|
||||
|
||||
func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
|
||||
func TestEntSoftDelete_APIKey_DeleteIdempotent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// 使用全局 ent client,避免事务回滚影响幂等性验证。
|
||||
client := testEntClient(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
|
||||
|
||||
repo := NewApiKeyRepository(client)
|
||||
key := &service.ApiKey{
|
||||
repo := NewAPIKeyRepository(client)
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
|
||||
Name: "soft-delete2",
|
||||
@@ -86,15 +86,15 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "second delete should be idempotent")
|
||||
}
|
||||
|
||||
func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
|
||||
func TestEntSoftDelete_APIKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// 使用全局 ent client,确保 SkipSoftDelete 的硬删除语义可验证。
|
||||
client := testEntClient(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
|
||||
|
||||
repo := NewApiKeyRepository(client)
|
||||
key := &service.ApiKey{
|
||||
repo := NewAPIKeyRepository(client)
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),
|
||||
Name: "soft-delete3",
|
||||
@@ -105,10 +105,10 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
|
||||
|
||||
// Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at.
|
||||
_, err := client.ApiKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
|
||||
_, err := client.APIKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
|
||||
require.NoError(t, err, "hard delete")
|
||||
|
||||
_, err = client.ApiKey.Query().
|
||||
_, err = client.APIKey.Query().
|
||||
Where(apikey.IDEQ(key.ID)).
|
||||
Only(mixins.SkipSoftDelete(ctx))
|
||||
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")
|
||||
|
||||
@@ -117,7 +117,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
|
||||
args := []any{
|
||||
log.UserID,
|
||||
log.ApiKeyID,
|
||||
log.APIKeyID,
|
||||
log.AccountID,
|
||||
log.RequestID,
|
||||
log.Model,
|
||||
@@ -183,7 +183,7 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
|
||||
return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params)
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params)
|
||||
}
|
||||
|
||||
@@ -270,8 +270,8 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
r.sql,
|
||||
apiKeyStatsQuery,
|
||||
[]any{service.StatusActive},
|
||||
&stats.TotalApiKeys,
|
||||
&stats.ActiveApiKeys,
|
||||
&stats.TotalAPIKeys,
|
||||
&stats.ActiveAPIKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -418,8 +418,8 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetApiKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
|
||||
func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
// GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
|
||||
func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
query := `
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
@@ -623,7 +623,7 @@ func resolveUsageStatsTimezone() string {
|
||||
return "UTC"
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
|
||||
return logs, nil, err
|
||||
@@ -709,11 +709,11 @@ type ModelStat = usagestats.ModelStat
|
||||
// UserUsageTrendPoint represents user usage trend data point
|
||||
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
|
||||
|
||||
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
||||
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
|
||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||
type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
|
||||
|
||||
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
|
||||
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) {
|
||||
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
|
||||
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
@@ -755,10 +755,10 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
|
||||
}
|
||||
}()
|
||||
|
||||
results = make([]ApiKeyUsageTrendPoint, 0)
|
||||
results = make([]APIKeyUsageTrendPoint, 0)
|
||||
for rows.Next() {
|
||||
var row ApiKeyUsageTrendPoint
|
||||
if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
|
||||
var row APIKeyUsageTrendPoint
|
||||
if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, row)
|
||||
@@ -844,7 +844,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
r.sql,
|
||||
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
|
||||
[]any{userID},
|
||||
&stats.TotalApiKeys,
|
||||
&stats.TotalAPIKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -853,7 +853,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
r.sql,
|
||||
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
|
||||
[]any{userID, service.StatusActive},
|
||||
&stats.ActiveApiKeys,
|
||||
&stats.ActiveAPIKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1023,9 +1023,9 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
|
||||
args = append(args, filters.UserID)
|
||||
}
|
||||
if filters.ApiKeyID > 0 {
|
||||
if filters.APIKeyID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
|
||||
args = append(args, filters.ApiKeyID)
|
||||
args = append(args, filters.APIKeyID)
|
||||
}
|
||||
if filters.AccountID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
|
||||
@@ -1145,18 +1145,18 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// BatchApiKeyUsageStats represents usage stats for a single API key
|
||||
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
|
||||
// BatchAPIKeyUsageStats represents usage stats for a single API key
|
||||
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
|
||||
|
||||
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
|
||||
func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
|
||||
result := make(map[int64]*BatchApiKeyUsageStats)
|
||||
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
|
||||
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
|
||||
result := make(map[int64]*BatchAPIKeyUsageStats)
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for _, id := range apiKeyIDs {
|
||||
result[id] = &BatchApiKeyUsageStats{ApiKeyID: id}
|
||||
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
|
||||
}
|
||||
|
||||
query := `
|
||||
@@ -1582,7 +1582,7 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
apiKeys, err := r.loadApiKeys(ctx, ids.apiKeyIDs)
|
||||
apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1603,8 +1603,8 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
|
||||
if user, ok := users[logs[i].UserID]; ok {
|
||||
logs[i].User = user
|
||||
}
|
||||
if key, ok := apiKeys[logs[i].ApiKeyID]; ok {
|
||||
logs[i].ApiKey = key
|
||||
if key, ok := apiKeys[logs[i].APIKeyID]; ok {
|
||||
logs[i].APIKey = key
|
||||
}
|
||||
if acc, ok := accounts[logs[i].AccountID]; ok {
|
||||
logs[i].Account = acc
|
||||
@@ -1642,7 +1642,7 @@ func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs {
|
||||
|
||||
for i := range logs {
|
||||
userIDs[logs[i].UserID] = struct{}{}
|
||||
apiKeyIDs[logs[i].ApiKeyID] = struct{}{}
|
||||
apiKeyIDs[logs[i].APIKeyID] = struct{}{}
|
||||
accountIDs[logs[i].AccountID] = struct{}{}
|
||||
if logs[i].GroupID != nil {
|
||||
groupIDs[*logs[i].GroupID] = struct{}{}
|
||||
@@ -1676,12 +1676,12 @@ func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[in
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) loadApiKeys(ctx context.Context, ids []int64) (map[int64]*service.ApiKey, error) {
|
||||
out := make(map[int64]*service.ApiKey)
|
||||
func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) {
|
||||
out := make(map[int64]*service.APIKey)
|
||||
if len(ids) == 0 {
|
||||
return out, nil
|
||||
}
|
||||
models, err := r.client.ApiKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
|
||||
models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1800,7 +1800,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
log := &service.UsageLog{
|
||||
ID: id,
|
||||
UserID: userID,
|
||||
ApiKeyID: apiKeyID,
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
Model: model,
|
||||
InputTokens: inputTokens,
|
||||
|
||||
@@ -35,10 +35,10 @@ func TestUsageLogRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(UsageLogRepoSuite))
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.ApiKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
|
||||
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: inputTokens,
|
||||
@@ -55,12 +55,12 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A
|
||||
|
||||
func (s *UsageLogRepoSuite) TestCreate() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "create@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-create", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-create"})
|
||||
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
@@ -76,7 +76,7 @@ func (s *UsageLogRepoSuite) TestCreate() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid"})
|
||||
|
||||
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -96,7 +96,7 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestDelete() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "delete@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-delete"})
|
||||
|
||||
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -112,7 +112,7 @@ func (s *UsageLogRepoSuite) TestDelete() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByUser() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyuser@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyuser"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -124,18 +124,18 @@ func (s *UsageLogRepoSuite) TestListByUser() {
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
// --- ListByApiKey ---
|
||||
// --- ListByAPIKey ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByApiKey() {
|
||||
func (s *UsageLogRepoSuite) TestListByAPIKey() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyapikey@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyapikey"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
|
||||
|
||||
logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByApiKey")
|
||||
logs, page, err := s.repo.ListByAPIKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByAPIKey")
|
||||
s.Require().Len(logs, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
@@ -144,7 +144,7 @@ func (s *UsageLogRepoSuite) TestListByApiKey() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByAccount() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyaccount@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyaccount"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -159,7 +159,7 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -179,7 +179,7 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filters"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -211,8 +211,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
})
|
||||
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-ul"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
|
||||
mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
|
||||
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
|
||||
mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
|
||||
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
accNormal := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-normal", Schedulable: true})
|
||||
@@ -223,7 +223,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
d1, d2, d3 := 100, 200, 300
|
||||
logToday := &service.UsageLog{
|
||||
UserID: userToday.ID,
|
||||
ApiKeyID: apiKey1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: accNormal.ID,
|
||||
Model: "claude-3",
|
||||
GroupID: &group.ID,
|
||||
@@ -240,7 +240,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
|
||||
logOld := &service.UsageLog{
|
||||
UserID: userOld.ID,
|
||||
ApiKeyID: apiKey1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: accNormal.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 5,
|
||||
@@ -254,7 +254,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
|
||||
logPerf := &service.UsageLog{
|
||||
UserID: userToday.ID,
|
||||
ApiKeyID: apiKey1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: accNormal.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 1,
|
||||
@@ -272,8 +272,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch")
|
||||
s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch")
|
||||
s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch")
|
||||
s.Require().Equal(baseStats.TotalApiKeys+2, stats.TotalApiKeys, "TotalApiKeys mismatch")
|
||||
s.Require().Equal(baseStats.ActiveApiKeys+1, stats.ActiveApiKeys, "ActiveApiKeys mismatch")
|
||||
s.Require().Equal(baseStats.TotalAPIKeys+2, stats.TotalAPIKeys, "TotalAPIKeys mismatch")
|
||||
s.Require().Equal(baseStats.ActiveAPIKeys+1, stats.ActiveAPIKeys, "ActiveAPIKeys mismatch")
|
||||
s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch")
|
||||
s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch")
|
||||
s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch")
|
||||
@@ -300,14 +300,14 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userdash@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userdash"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "GetUserDashboardStats")
|
||||
s.Require().Equal(int64(1), stats.TotalApiKeys)
|
||||
s.Require().Equal(int64(1), stats.TotalAPIKeys)
|
||||
s.Require().Equal(int64(1), stats.TotalRequests)
|
||||
}
|
||||
|
||||
@@ -315,7 +315,7 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctoday@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
@@ -331,8 +331,8 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
||||
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
||||
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
|
||||
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
|
||||
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batch"})
|
||||
|
||||
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
@@ -351,24 +351,24 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
|
||||
s.Require().Empty(stats)
|
||||
}
|
||||
|
||||
// --- GetBatchApiKeyUsageStats ---
|
||||
// --- GetBatchAPIKeyUsageStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
|
||||
func (s *UsageLogRepoSuite) TestGetBatchAPIKeyUsageStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "batchkey@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
|
||||
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
|
||||
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batchkey"})
|
||||
|
||||
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||
|
||||
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
|
||||
s.Require().NoError(err, "GetBatchApiKeyUsageStats")
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
|
||||
s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
|
||||
s.Require().Len(stats, 2)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
|
||||
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{})
|
||||
func (s *UsageLogRepoSuite) TestGetBatchAPIKeyUsageStats_Empty() {
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(stats)
|
||||
}
|
||||
@@ -377,7 +377,7 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "global@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-global", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-global"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -402,7 +402,7 @@ func maxTime(a, b time.Time) time.Time {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "timerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-timerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -417,11 +417,11 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
|
||||
s.Require().Len(logs, 2)
|
||||
}
|
||||
|
||||
// --- ListByApiKeyAndTimeRange ---
|
||||
// --- ListByAPIKeyAndTimeRange ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
|
||||
func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -431,8 +431,8 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
|
||||
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
|
||||
s.Require().NoError(err, "ListByApiKeyAndTimeRange")
|
||||
logs, _, err := s.repo.ListByAPIKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
|
||||
s.Require().NoError(err, "ListByAPIKeyAndTimeRange")
|
||||
s.Require().Len(logs, 2)
|
||||
}
|
||||
|
||||
@@ -440,7 +440,7 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-acctimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -459,7 +459,7 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modeltimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modeltimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -467,7 +467,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
// Create logs with different models
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 10,
|
||||
@@ -480,7 +480,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 15,
|
||||
@@ -493,7 +493,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
|
||||
log3 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-sonnet",
|
||||
InputTokens: 20,
|
||||
@@ -515,7 +515,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "windowstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-windowstats"})
|
||||
|
||||
now := time.Now()
|
||||
@@ -535,7 +535,7 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrend"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -552,7 +552,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrendhourly@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrendhourly"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -571,7 +571,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -579,7 +579,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
// Create logs with different models
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 100,
|
||||
@@ -592,7 +592,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-sonnet",
|
||||
InputTokens: 50,
|
||||
@@ -618,7 +618,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -646,7 +646,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters-h@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters-h"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -665,14 +665,14 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelfilters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelfilters"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 100,
|
||||
@@ -685,7 +685,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-sonnet",
|
||||
InputTokens: 50,
|
||||
@@ -719,7 +719,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "accstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-accstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
|
||||
@@ -727,7 +727,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
// Create logs on different days
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-opus",
|
||||
InputTokens: 100,
|
||||
@@ -740,7 +740,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3-sonnet",
|
||||
InputTokens: 50,
|
||||
@@ -782,8 +782,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
|
||||
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
|
||||
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrends"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -799,12 +799,12 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
|
||||
s.Require().GreaterOrEqual(len(trend), 2)
|
||||
}
|
||||
|
||||
// --- GetApiKeyUsageTrend ---
|
||||
// --- GetAPIKeyUsageTrend ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
|
||||
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
|
||||
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -815,14 +815,14 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(48 * time.Hour)
|
||||
|
||||
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
|
||||
s.Require().NoError(err, "GetApiKeyUsageTrend")
|
||||
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
|
||||
s.Require().NoError(err, "GetAPIKeyUsageTrend")
|
||||
s.Require().GreaterOrEqual(len(trend), 2)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
|
||||
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -832,21 +832,21 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(3 * time.Hour)
|
||||
|
||||
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
|
||||
s.Require().NoError(err, "GetApiKeyUsageTrend hourly")
|
||||
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
|
||||
s.Require().NoError(err, "GetAPIKeyUsageTrend hourly")
|
||||
s.Require().Len(trend, 2)
|
||||
}
|
||||
|
||||
// --- ListWithFilters (additional filter tests) ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_APIKeyFilter() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterskey@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterskey"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID}
|
||||
filters := usagestats.UsageLogFilters{APIKeyID: apiKey.ID}
|
||||
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
|
||||
s.Require().NoError(err, "ListWithFilters apiKey")
|
||||
s.Require().Len(logs, 1)
|
||||
@@ -855,7 +855,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterstime@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterstime"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -874,7 +874,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterscombined@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
|
||||
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterscombined"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
@@ -885,7 +885,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
StartTime: &startTime,
|
||||
EndTime: &endTime,
|
||||
}
|
||||
|
||||
@@ -28,12 +28,13 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
|
||||
// ProviderSet is the Wire provider set for all repositories
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewUserRepository,
|
||||
NewApiKeyRepository,
|
||||
NewAPIKeyRepository,
|
||||
NewGroupRepository,
|
||||
NewAccountRepository,
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewUsageLogRepository,
|
||||
NewOpsRepository,
|
||||
NewSettingRepository,
|
||||
NewUserSubscriptionRepository,
|
||||
NewUserAttributeDefinitionRepository,
|
||||
@@ -42,7 +43,7 @@ var ProviderSet = wire.NewSet(
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
NewBillingCache,
|
||||
NewApiKeyCache,
|
||||
NewAPIKeyCache,
|
||||
ProvideConcurrencyCache,
|
||||
NewEmailCache,
|
||||
NewIdentityCache,
|
||||
|
||||
@@ -91,7 +91,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
name: "GET /api/v1/keys (paginated)",
|
||||
setup: func(t *testing.T, deps *contractDeps) {
|
||||
t.Helper()
|
||||
deps.apiKeyRepo.MustSeed(&service.ApiKey{
|
||||
deps.apiKeyRepo.MustSeed(&service.APIKey{
|
||||
ID: 100,
|
||||
UserID: 1,
|
||||
Key: "sk_custom_1234567890",
|
||||
@@ -135,7 +135,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
{
|
||||
ID: 1,
|
||||
UserID: 1,
|
||||
ApiKeyID: 100,
|
||||
APIKeyID: 100,
|
||||
AccountID: 200,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
@@ -150,7 +150,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
{
|
||||
ID: 2,
|
||||
UserID: 1,
|
||||
ApiKeyID: 100,
|
||||
APIKeyID: 100,
|
||||
AccountID: 200,
|
||||
Model: "claude-3",
|
||||
InputTokens: 5,
|
||||
@@ -188,7 +188,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
{
|
||||
ID: 1,
|
||||
UserID: 1,
|
||||
ApiKeyID: 100,
|
||||
APIKeyID: 100,
|
||||
AccountID: 200,
|
||||
RequestID: "req_123",
|
||||
Model: "claude-3",
|
||||
@@ -259,13 +259,13 @@ func TestAPIContracts(t *testing.T) {
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyEmailVerifyEnabled: "false",
|
||||
|
||||
service.SettingKeySmtpHost: "smtp.example.com",
|
||||
service.SettingKeySmtpPort: "587",
|
||||
service.SettingKeySmtpUsername: "user",
|
||||
service.SettingKeySmtpPassword: "secret",
|
||||
service.SettingKeySmtpFrom: "no-reply@example.com",
|
||||
service.SettingKeySmtpFromName: "Sub2API",
|
||||
service.SettingKeySmtpUseTLS: "true",
|
||||
service.SettingKeySMTPHost: "smtp.example.com",
|
||||
service.SettingKeySMTPPort: "587",
|
||||
service.SettingKeySMTPUsername: "user",
|
||||
service.SettingKeySMTPPassword: "secret",
|
||||
service.SettingKeySMTPFrom: "no-reply@example.com",
|
||||
service.SettingKeySMTPFromName: "Sub2API",
|
||||
service.SettingKeySMTPUseTLS: "true",
|
||||
|
||||
service.SettingKeyTurnstileEnabled: "true",
|
||||
service.SettingKeyTurnstileSiteKey: "site-key",
|
||||
@@ -274,9 +274,9 @@ func TestAPIContracts(t *testing.T) {
|
||||
service.SettingKeySiteName: "Sub2API",
|
||||
service.SettingKeySiteLogo: "",
|
||||
service.SettingKeySiteSubtitle: "Subtitle",
|
||||
service.SettingKeyApiBaseUrl: "https://api.example.com",
|
||||
service.SettingKeyAPIBaseURL: "https://api.example.com",
|
||||
service.SettingKeyContactInfo: "support",
|
||||
service.SettingKeyDocUrl: "https://docs.example.com",
|
||||
service.SettingKeyDocURL: "https://docs.example.com",
|
||||
|
||||
service.SettingKeyDefaultConcurrency: "5",
|
||||
service.SettingKeyDefaultBalance: "1.25",
|
||||
@@ -331,7 +331,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
type contractDeps struct {
|
||||
now time.Time
|
||||
router http.Handler
|
||||
apiKeyRepo *stubApiKeyRepo
|
||||
apiKeyRepo *stubAPIKeyRepo
|
||||
usageRepo *stubUsageLogRepo
|
||||
settingRepo *stubSettingRepo
|
||||
}
|
||||
@@ -359,20 +359,20 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
},
|
||||
}
|
||||
|
||||
apiKeyRepo := newStubApiKeyRepo(now)
|
||||
apiKeyCache := stubApiKeyCache{}
|
||||
apiKeyRepo := newStubAPIKeyRepo(now)
|
||||
apiKeyCache := stubAPIKeyCache{}
|
||||
groupRepo := stubGroupRepo{}
|
||||
userSubRepo := stubUserSubscriptionRepo{}
|
||||
|
||||
cfg := &config.Config{
|
||||
Default: config.DefaultConfig{
|
||||
ApiKeyPrefix: "sk-",
|
||||
APIKeyPrefix: "sk-",
|
||||
},
|
||||
RunMode: config.RunModeStandard,
|
||||
}
|
||||
|
||||
userService := service.NewUserService(userRepo)
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
|
||||
|
||||
usageRepo := newStubUsageLogRepo()
|
||||
usageService := service.NewUsageService(usageRepo, userRepo)
|
||||
@@ -525,25 +525,25 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubApiKeyCache struct{}
|
||||
type stubAPIKeyCache struct{}
|
||||
|
||||
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
func (stubAPIKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (stubApiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
func (stubAPIKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (stubApiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
func (stubAPIKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (stubApiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||
func (stubAPIKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
func (stubAPIKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -660,24 +660,24 @@ func (stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (i
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubApiKeyRepo struct {
|
||||
type stubAPIKeyRepo struct {
|
||||
now time.Time
|
||||
|
||||
nextID int64
|
||||
byID map[int64]*service.ApiKey
|
||||
byKey map[string]*service.ApiKey
|
||||
byID map[int64]*service.APIKey
|
||||
byKey map[string]*service.APIKey
|
||||
}
|
||||
|
||||
func newStubApiKeyRepo(now time.Time) *stubApiKeyRepo {
|
||||
return &stubApiKeyRepo{
|
||||
func newStubAPIKeyRepo(now time.Time) *stubAPIKeyRepo {
|
||||
return &stubAPIKeyRepo{
|
||||
now: now,
|
||||
nextID: 100,
|
||||
byID: make(map[int64]*service.ApiKey),
|
||||
byKey: make(map[string]*service.ApiKey),
|
||||
byID: make(map[int64]*service.APIKey),
|
||||
byKey: make(map[string]*service.APIKey),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
|
||||
func (r *stubAPIKeyRepo) MustSeed(key *service.APIKey) {
|
||||
if key == nil {
|
||||
return
|
||||
}
|
||||
@@ -686,7 +686,7 @@ func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
|
||||
r.byKey[clone.Key] = &clone
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
func (r *stubAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
||||
if key == nil {
|
||||
return errors.New("nil key")
|
||||
}
|
||||
@@ -706,38 +706,38 @@ func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
func (r *stubAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
key, ok := r.byID[id]
|
||||
if !ok {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *key
|
||||
return &clone, nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
func (r *stubAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
key, ok := r.byID[id]
|
||||
if !ok {
|
||||
return 0, service.ErrApiKeyNotFound
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return key.UserID, nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
func (r *stubAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
found, ok := r.byKey[key]
|
||||
if !ok {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *found
|
||||
return &clone, nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
func (r *stubAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||
if key == nil {
|
||||
return errors.New("nil key")
|
||||
}
|
||||
if _, ok := r.byID[key.ID]; !ok {
|
||||
return service.ErrApiKeyNotFound
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
if key.UpdatedAt.IsZero() {
|
||||
key.UpdatedAt = r.now
|
||||
@@ -748,17 +748,17 @@ func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
func (r *stubAPIKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
key, ok := r.byID[id]
|
||||
if !ok {
|
||||
return service.ErrApiKeyNotFound
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
delete(r.byID, id)
|
||||
delete(r.byKey, key.Key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
ids := make([]int64, 0, len(r.byID))
|
||||
for id := range r.byID {
|
||||
if r.byID[id].UserID == userID {
|
||||
@@ -776,7 +776,7 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params
|
||||
end = len(ids)
|
||||
}
|
||||
|
||||
out := make([]service.ApiKey, 0, end-start)
|
||||
out := make([]service.APIKey, 0, end-start)
|
||||
for _, id := range ids[start:end] {
|
||||
clone := *r.byID[id]
|
||||
out = append(out, clone)
|
||||
@@ -796,7 +796,7 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
func (r *stubAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
@@ -815,7 +815,7 @@ func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiK
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
func (r *stubAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
for _, key := range r.byID {
|
||||
if key.UserID == userID {
|
||||
@@ -825,24 +825,24 @@ func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
func (r *stubAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
_, ok := r.byKey[key]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
func (r *stubAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *stubAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *stubAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -877,7 +877,7 @@ func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params
|
||||
return out, paginationResult(total, params), nil
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -890,7 +890,7 @@ func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID in
|
||||
return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -922,7 +922,7 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
|
||||
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -975,7 +975,7 @@ func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID in
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
func (r *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -995,7 +995,7 @@ func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs [
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
|
||||
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -1017,8 +1017,8 @@ func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params paginatio
|
||||
// Apply filters
|
||||
var filtered []service.UsageLog
|
||||
for _, log := range logs {
|
||||
// Apply ApiKeyID filter
|
||||
if filters.ApiKeyID > 0 && log.ApiKeyID != filters.ApiKeyID {
|
||||
// Apply APIKeyID filter
|
||||
if filters.APIKeyID > 0 && log.APIKeyID != filters.APIKeyID {
|
||||
continue
|
||||
}
|
||||
// Apply Model filter
|
||||
@@ -1151,8 +1151,8 @@ func paginationResult(total int64, params pagination.PaginationParams) *paginati
|
||||
// Ensure compile-time interface compliance.
|
||||
var (
|
||||
_ service.UserRepository = (*stubUserRepo)(nil)
|
||||
_ service.ApiKeyRepository = (*stubApiKeyRepo)(nil)
|
||||
_ service.ApiKeyCache = (*stubApiKeyCache)(nil)
|
||||
_ service.APIKeyRepository = (*stubAPIKeyRepo)(nil)
|
||||
_ service.APIKeyCache = (*stubAPIKeyCache)(nil)
|
||||
_ service.GroupRepository = (*stubGroupRepo)(nil)
|
||||
_ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil)
|
||||
_ service.UsageLogRepository = (*stubUsageLogRepo)(nil)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package server provides HTTP server setup and routing configuration.
|
||||
package server
|
||||
|
||||
import (
|
||||
@@ -25,8 +26,8 @@ func ProvideRouter(
|
||||
handlers *handler.Handlers,
|
||||
jwtAuth middleware2.JWTAuthMiddleware,
|
||||
adminAuth middleware2.AdminAuthMiddleware,
|
||||
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
) *gin.Engine {
|
||||
if cfg.Server.Mode == "release" {
|
||||
|
||||
@@ -32,7 +32,7 @@ func adminAuth(
|
||||
// 检查 x-api-key header(Admin API Key 认证)
|
||||
apiKey := c.GetHeader("x-api-key")
|
||||
if apiKey != "" {
|
||||
if !validateAdminApiKey(c, apiKey, settingService, userService) {
|
||||
if !validateAdminAPIKey(c, apiKey, settingService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
@@ -52,19 +52,48 @@ func adminAuth(
|
||||
}
|
||||
}
|
||||
|
||||
// WebSocket 请求无法设置自定义 header,允许在 query 中携带凭证
|
||||
if isWebSocketRequest(c) {
|
||||
if token := strings.TrimSpace(c.Query("token")); token != "" {
|
||||
if !validateJWTForAdmin(c, token, authService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
if apiKey := strings.TrimSpace(c.Query("api_key")); apiKey != "" {
|
||||
if !validateAdminAPIKey(c, apiKey, settingService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 无有效认证信息
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
|
||||
}
|
||||
}
|
||||
|
||||
// validateAdminApiKey 验证管理员 API Key
|
||||
func validateAdminApiKey(
|
||||
func isWebSocketRequest(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil {
|
||||
return false
|
||||
}
|
||||
if strings.EqualFold(c.GetHeader("Upgrade"), "websocket") {
|
||||
return true
|
||||
}
|
||||
conn := strings.ToLower(c.GetHeader("Connection"))
|
||||
return strings.Contains(conn, "upgrade") && strings.EqualFold(c.GetHeader("Upgrade"), "websocket")
|
||||
}
|
||||
|
||||
// validateAdminAPIKey 验证管理员 API Key
|
||||
func validateAdminAPIKey(
|
||||
c *gin.Context,
|
||||
key string,
|
||||
settingService *service.SettingService,
|
||||
userService *service.UserService,
|
||||
) bool {
|
||||
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
|
||||
storedKey, err := settingService.GetAdminAPIKey(c.Request.Context())
|
||||
if err != nil {
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
|
||||
return false
|
||||
|
||||
@@ -11,13 +11,13 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
|
||||
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware {
|
||||
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
|
||||
// NewAPIKeyAuthMiddleware 创建 API Key 认证中间件
|
||||
func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config, opsService *service.OpsService) APIKeyAuthMiddleware {
|
||||
return APIKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg, opsService))
|
||||
}
|
||||
|
||||
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
|
||||
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config, opsService *service.OpsService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 尝试从Authorization header中提取API key (Bearer scheme)
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
@@ -53,6 +53,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
|
||||
|
||||
// 如果所有header都没有API key
|
||||
if apiKeyString == "" {
|
||||
recordOpsAuthError(c, opsService, nil, 401, "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
|
||||
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
|
||||
return
|
||||
}
|
||||
@@ -60,35 +61,40 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
|
||||
// 从数据库验证API key
|
||||
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrApiKeyNotFound) {
|
||||
if errors.Is(err, service.ErrAPIKeyNotFound) {
|
||||
recordOpsAuthError(c, opsService, nil, 401, "Invalid API key")
|
||||
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
|
||||
return
|
||||
}
|
||||
recordOpsAuthError(c, opsService, nil, 500, "Failed to validate API key")
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查API key是否激活
|
||||
if !apiKey.IsActive() {
|
||||
recordOpsAuthError(c, opsService, apiKey, 401, "API key is disabled")
|
||||
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查关联的用户
|
||||
if apiKey.User == nil {
|
||||
recordOpsAuthError(c, opsService, apiKey, 401, "User associated with API key not found")
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !apiKey.User.IsActive() {
|
||||
recordOpsAuthError(c, opsService, apiKey, 401, "User account is not active")
|
||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
@@ -109,12 +115,14 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
|
||||
apiKey.Group.ID,
|
||||
)
|
||||
if err != nil {
|
||||
recordOpsAuthError(c, opsService, apiKey, 403, "No active subscription found for this group")
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证订阅状态(是否过期、暂停等)
|
||||
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
|
||||
recordOpsAuthError(c, opsService, apiKey, 403, err.Error())
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
|
||||
return
|
||||
}
|
||||
@@ -131,6 +139,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
|
||||
|
||||
// 预检查用量限制(使用0作为额外费用进行预检查)
|
||||
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
|
||||
recordOpsAuthError(c, opsService, apiKey, 429, err.Error())
|
||||
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
|
||||
return
|
||||
}
|
||||
@@ -140,13 +149,14 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
|
||||
} else {
|
||||
// 余额模式:检查用户余额
|
||||
if apiKey.User.Balance <= 0 {
|
||||
recordOpsAuthError(c, opsService, apiKey, 403, "Insufficient account balance")
|
||||
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 将API key和用户信息存入上下文
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
@@ -157,13 +167,66 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
|
||||
}
|
||||
}
|
||||
|
||||
// GetApiKeyFromContext 从上下文中获取API key
|
||||
func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) {
|
||||
value, exists := c.Get(string(ContextKeyApiKey))
|
||||
func recordOpsAuthError(c *gin.Context, opsService *service.OpsService, apiKey *service.APIKey, status int, message string) {
|
||||
if opsService == nil || c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
errType := "authentication_error"
|
||||
phase := "auth"
|
||||
severity := "P3"
|
||||
switch status {
|
||||
case 403:
|
||||
errType = "billing_error"
|
||||
phase = "billing"
|
||||
case 429:
|
||||
errType = "rate_limit_error"
|
||||
phase = "billing"
|
||||
severity = "P2"
|
||||
case 500:
|
||||
errType = "api_error"
|
||||
phase = "internal"
|
||||
severity = "P1"
|
||||
}
|
||||
|
||||
logEntry := &service.OpsErrorLog{
|
||||
Phase: phase,
|
||||
Type: errType,
|
||||
Severity: severity,
|
||||
StatusCode: status,
|
||||
Message: message,
|
||||
ClientIP: c.ClientIP(),
|
||||
RequestPath: func() string {
|
||||
if c.Request != nil && c.Request.URL != nil {
|
||||
return c.Request.URL.Path
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
}
|
||||
|
||||
if apiKey != nil {
|
||||
logEntry.APIKeyID = &apiKey.ID
|
||||
if apiKey.User != nil {
|
||||
logEntry.UserID = &apiKey.User.ID
|
||||
}
|
||||
if apiKey.GroupID != nil {
|
||||
logEntry.GroupID = apiKey.GroupID
|
||||
}
|
||||
if apiKey.Group != nil {
|
||||
logEntry.Platform = apiKey.Group.Platform
|
||||
}
|
||||
}
|
||||
|
||||
enqueueOpsAuthErrorLog(opsService, logEntry)
|
||||
}
|
||||
|
||||
// GetAPIKeyFromContext 从上下文中获取API key
|
||||
func GetAPIKeyFromContext(c *gin.Context) (*service.APIKey, bool) {
|
||||
value, exists := c.Get(string(ContextKeyAPIKey))
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
apiKey, ok := value.(*service.ApiKey)
|
||||
apiKey, ok := value.(*service.APIKey)
|
||||
return apiKey, ok
|
||||
}
|
||||
|
||||
|
||||
@@ -11,16 +11,16 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth.
|
||||
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc {
|
||||
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
|
||||
// APIKeyAuthGoogle is a Google-style error wrapper for API key auth.
|
||||
func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config) gin.HandlerFunc {
|
||||
return APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
|
||||
}
|
||||
|
||||
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
|
||||
// APIKeyAuthWithSubscriptionGoogle behaves like APIKeyAuthWithSubscription but returns Google-style errors:
|
||||
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
|
||||
//
|
||||
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
|
||||
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
apiKeyString := extractAPIKeyFromRequest(c)
|
||||
if apiKeyString == "" {
|
||||
@@ -30,7 +30,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
|
||||
|
||||
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrApiKeyNotFound) {
|
||||
if errors.Is(err, service.ErrAPIKeyNotFound) {
|
||||
abortWithGoogleError(c, 401, "Invalid API key")
|
||||
return
|
||||
}
|
||||
@@ -53,7 +53,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
|
||||
|
||||
// 简易模式:跳过余额和订阅检查
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
@@ -92,7 +92,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
|
||||
}
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
|
||||
@@ -16,53 +16,53 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeApiKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
|
||||
type fakeAPIKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
|
||||
}
|
||||
|
||||
func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if f.getByKey == nil {
|
||||
return nil, errors.New("unexpected call")
|
||||
}
|
||||
return f.getByKey(ctx, key)
|
||||
}
|
||||
func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
func (f fakeAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
func (f fakeAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (f fakeAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
func (f fakeAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -74,8 +74,8 @@ type googleErrorResponse struct {
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService {
|
||||
return service.NewApiKeyService(
|
||||
func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService {
|
||||
return service.NewAPIKeyService(
|
||||
repo,
|
||||
nil, // userRepo (unused in GetByKey)
|
||||
nil, // groupRepo
|
||||
@@ -85,16 +85,16 @@ func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService
|
||||
)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
|
||||
func TestAPIKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return nil, errors.New("should not be called")
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
@@ -109,16 +109,16 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
|
||||
func TestAPIKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
@@ -134,16 +134,16 @@ func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
|
||||
func TestAPIKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return nil, errors.New("db down")
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
@@ -159,13 +159,13 @@ func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
|
||||
require.Equal(t, "INTERNAL", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
|
||||
func TestAPIKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return &service.ApiKey{
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return &service.APIKey{
|
||||
ID: 1,
|
||||
Key: key,
|
||||
Status: service.StatusDisabled,
|
||||
@@ -176,7 +176,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
@@ -192,13 +192,13 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
|
||||
func TestAPIKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return &service.ApiKey{
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return &service.APIKey{
|
||||
ID: 1,
|
||||
Key: key,
|
||||
Status: service.StatusActive,
|
||||
@@ -210,7 +210,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
|
||||
@@ -35,7 +35,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.ApiKey{
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
@@ -45,10 +45,10 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
apiKeyRepo := &stubAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
@@ -57,7 +57,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
|
||||
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
@@ -71,7 +71,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
|
||||
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
|
||||
now := time.Now()
|
||||
sub := &service.UserSubscription{
|
||||
@@ -110,75 +110,75 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg, nil)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
return router
|
||||
}
|
||||
|
||||
type stubApiKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
|
||||
type stubAPIKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
func (r *stubAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
func (r *stubAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
func (r *stubAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
func (r *stubAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if r.getByKey != nil {
|
||||
return r.getByKey(ctx, key)
|
||||
}
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
func (r *stubAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
func (r *stubAPIKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
func (r *stubAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
func (r *stubAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
func (r *stubAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
func (r *stubAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *stubAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *stubAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -2,11 +2,14 @@ package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var sensitiveQueryParamRE = regexp.MustCompile(`(?i)([?&](?:token|api_key)=)[^&#]*`)
|
||||
|
||||
// Logger 请求日志中间件
|
||||
func Logger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
@@ -26,7 +29,7 @@ func Logger() gin.HandlerFunc {
|
||||
method := c.Request.Method
|
||||
|
||||
// 请求路径
|
||||
path := c.Request.URL.Path
|
||||
path := sensitiveQueryParamRE.ReplaceAllString(c.Request.URL.RequestURI(), "${1}***")
|
||||
|
||||
// 状态码
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// Package middleware provides HTTP middleware components for authentication,
|
||||
// authorization, logging, error recovery, and request processing.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
@@ -15,8 +17,8 @@ const (
|
||||
ContextKeyUser ContextKey = "user"
|
||||
// ContextKeyUserRole 当前用户角色(string)
|
||||
ContextKeyUserRole ContextKey = "user_role"
|
||||
// ContextKeyApiKey API密钥上下文键
|
||||
ContextKeyApiKey ContextKey = "api_key"
|
||||
// ContextKeyAPIKey API密钥上下文键
|
||||
ContextKeyAPIKey ContextKey = "api_key"
|
||||
// ContextKeySubscription 订阅上下文键
|
||||
ContextKeySubscription ContextKey = "subscription"
|
||||
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
|
||||
|
||||
55
backend/internal/server/middleware/ops_auth_error_logger.go
Normal file
55
backend/internal/server/middleware/ops_auth_error_logger.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const (
|
||||
opsAuthErrorLogWorkerCount = 10
|
||||
opsAuthErrorLogQueueSize = 256
|
||||
opsAuthErrorLogTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
type opsAuthErrorLogJob struct {
|
||||
ops *service.OpsService
|
||||
entry *service.OpsErrorLog
|
||||
}
|
||||
|
||||
var (
|
||||
opsAuthErrorLogOnce sync.Once
|
||||
opsAuthErrorLogQueue chan opsAuthErrorLogJob
|
||||
)
|
||||
|
||||
func startOpsAuthErrorLogWorkers() {
|
||||
opsAuthErrorLogQueue = make(chan opsAuthErrorLogJob, opsAuthErrorLogQueueSize)
|
||||
for i := 0; i < opsAuthErrorLogWorkerCount; i++ {
|
||||
go func() {
|
||||
for job := range opsAuthErrorLogQueue {
|
||||
if job.ops == nil || job.entry == nil {
|
||||
continue
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsAuthErrorLogTimeout)
|
||||
_ = job.ops.RecordError(ctx, job.entry)
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func enqueueOpsAuthErrorLog(ops *service.OpsService, entry *service.OpsErrorLog) {
|
||||
if ops == nil || entry == nil {
|
||||
return
|
||||
}
|
||||
|
||||
opsAuthErrorLogOnce.Do(startOpsAuthErrorLogWorkers)
|
||||
|
||||
select {
|
||||
case opsAuthErrorLogQueue <- opsAuthErrorLogJob{ops: ops, entry: entry}:
|
||||
default:
|
||||
// Queue is full; drop to avoid blocking request handling.
|
||||
}
|
||||
}
|
||||
@@ -11,12 +11,12 @@ type JWTAuthMiddleware gin.HandlerFunc
|
||||
// AdminAuthMiddleware 管理员认证中间件类型
|
||||
type AdminAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// ApiKeyAuthMiddleware API Key 认证中间件类型
|
||||
type ApiKeyAuthMiddleware gin.HandlerFunc
|
||||
// APIKeyAuthMiddleware API Key 认证中间件类型
|
||||
type APIKeyAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// ProviderSet 中间件层的依赖注入
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewJWTAuthMiddleware,
|
||||
NewAdminAuthMiddleware,
|
||||
NewApiKeyAuthMiddleware,
|
||||
NewAPIKeyAuthMiddleware,
|
||||
)
|
||||
|
||||
@@ -17,8 +17,8 @@ func SetupRouter(
|
||||
handlers *handler.Handlers,
|
||||
jwtAuth middleware2.JWTAuthMiddleware,
|
||||
adminAuth middleware2.AdminAuthMiddleware,
|
||||
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
) *gin.Engine {
|
||||
@@ -43,8 +43,8 @@ func registerRoutes(
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware2.JWTAuthMiddleware,
|
||||
adminAuth middleware2.AdminAuthMiddleware,
|
||||
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
) {
|
||||
|
||||
@@ -19,6 +19,9 @@ func RegisterAdminRoutes(
|
||||
// 仪表盘
|
||||
registerDashboardRoutes(admin, h)
|
||||
|
||||
// 运维监控
|
||||
registerOpsRoutes(admin, h)
|
||||
|
||||
// 用户管理
|
||||
registerUserManagementRoutes(admin, h)
|
||||
|
||||
@@ -67,10 +70,35 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
|
||||
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
|
||||
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
|
||||
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
|
||||
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
|
||||
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
|
||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
||||
}
|
||||
}
|
||||
|
||||
func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
ops := admin.Group("/ops")
|
||||
{
|
||||
ops.GET("/metrics", h.Admin.Ops.GetMetrics)
|
||||
ops.GET("/metrics/history", h.Admin.Ops.ListMetricsHistory)
|
||||
ops.GET("/errors", h.Admin.Ops.GetErrorLogs)
|
||||
ops.GET("/error-logs", h.Admin.Ops.ListErrorLogs)
|
||||
|
||||
// Dashboard routes
|
||||
dashboard := ops.Group("/dashboard")
|
||||
{
|
||||
dashboard.GET("/overview", h.Admin.Ops.GetDashboardOverview)
|
||||
dashboard.GET("/providers", h.Admin.Ops.GetProviderHealth)
|
||||
dashboard.GET("/latency-histogram", h.Admin.Ops.GetLatencyHistogram)
|
||||
dashboard.GET("/errors/distribution", h.Admin.Ops.GetErrorDistribution)
|
||||
}
|
||||
|
||||
// WebSocket routes
|
||||
ws := ops.Group("/ws")
|
||||
{
|
||||
ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,12 +231,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
{
|
||||
adminSettings.GET("", h.Admin.Setting.GetSettings)
|
||||
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
|
||||
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
|
||||
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSMTPConnection)
|
||||
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
|
||||
// Admin API Key 管理
|
||||
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
|
||||
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
|
||||
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
|
||||
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
|
||||
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
|
||||
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,7 +276,7 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
usage.GET("", h.Admin.Usage.List)
|
||||
usage.GET("/stats", h.Admin.Usage.Stats)
|
||||
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
|
||||
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
|
||||
usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package routes 提供 HTTP 路由注册和处理函数
|
||||
package routes
|
||||
|
||||
import (
|
||||
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
func RegisterGatewayRoutes(
|
||||
r *gin.Engine,
|
||||
h *handler.Handlers,
|
||||
apiKeyAuth middleware.ApiKeyAuthMiddleware,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
apiKeyAuth middleware.APIKeyAuthMiddleware,
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
) {
|
||||
@@ -36,7 +36,7 @@ func RegisterGatewayRoutes(
|
||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||
gemini := r.Group("/v1beta")
|
||||
gemini.Use(bodyLimit)
|
||||
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
{
|
||||
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
|
||||
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
|
||||
@@ -62,7 +62,7 @@ func RegisterGatewayRoutes(
|
||||
antigravityV1Beta := r.Group("/antigravity/v1beta")
|
||||
antigravityV1Beta.Use(bodyLimit)
|
||||
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
{
|
||||
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
|
||||
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
|
||||
|
||||
@@ -50,7 +50,7 @@ func RegisterUserRoutes(
|
||||
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
|
||||
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
|
||||
usage.GET("/dashboard/models", h.Usage.DashboardModels)
|
||||
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
|
||||
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage)
|
||||
}
|
||||
|
||||
// 卡密兑换
|
||||
|
||||
@@ -206,7 +206,7 @@ func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
}
|
||||
|
||||
func (a *Account) GetBaseURL() string {
|
||||
if a.Type != AccountTypeApiKey {
|
||||
if a.Type != AccountTypeAPIKey {
|
||||
return ""
|
||||
}
|
||||
baseURL := a.GetCredential("base_url")
|
||||
@@ -229,7 +229,7 @@ func (a *Account) GetExtraString(key string) string {
|
||||
}
|
||||
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
if a.Type != AccountTypeApiKey || a.Credentials == nil {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
|
||||
@@ -300,15 +300,15 @@ func (a *Account) IsOpenAIOAuth() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAIApiKey() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeApiKey
|
||||
func (a *Account) IsOpenAIAPIKey() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeAPIKey
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIBaseURL() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
if a.Type == AccountTypeApiKey {
|
||||
if a.Type == AccountTypeAPIKey {
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL != "" {
|
||||
return baseURL
|
||||
@@ -338,8 +338,8 @@ func (a *Account) GetOpenAIIDToken() string {
|
||||
return a.GetCredential("id_token")
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIApiKey() string {
|
||||
if !a.IsOpenAIApiKey() {
|
||||
func (a *Account) GetOpenAIAPIKey() string {
|
||||
if !a.IsOpenAIAPIKey() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("api_key")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// Package service 提供业务逻辑层服务,封装领域模型的业务规则和操作流程。
|
||||
// 服务层协调 repository 层的数据访问,实现跨实体的业务逻辑,并为上层 API 提供统一的业务接口。
|
||||
package service
|
||||
|
||||
import (
|
||||
|
||||
@@ -324,7 +324,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
chatgptAccountID = account.GetChatGPTAccountID()
|
||||
} else if account.Type == "apikey" {
|
||||
// API Key - use Platform API
|
||||
authToken = account.GetOpenAIApiKey()
|
||||
authToken = account.GetOpenAIAPIKey()
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No API key available")
|
||||
}
|
||||
@@ -402,7 +402,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
@@ -426,7 +426,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
var err error
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
|
||||
case AccountTypeOAuth:
|
||||
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
|
||||
|
||||
@@ -17,11 +17,11 @@ type UsageLogRepository interface {
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
@@ -32,10 +32,10 @@ type UsageLogRepository interface {
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
|
||||
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error)
|
||||
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
||||
|
||||
// User dashboard stats
|
||||
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
|
||||
@@ -51,7 +51,7 @@ type UsageLogRepository interface {
|
||||
|
||||
// Aggregated stats (optimized)
|
||||
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
|
||||
|
||||
@@ -19,7 +19,7 @@ type AdminService interface {
|
||||
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
|
||||
DeleteUser(ctx context.Context, id int64) error
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
|
||||
// Group management
|
||||
@@ -30,7 +30,7 @@ type AdminService interface {
|
||||
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
|
||||
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
|
||||
DeleteGroup(ctx context.Context, id int64) error
|
||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error)
|
||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||
|
||||
// Account management
|
||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
|
||||
@@ -65,7 +65,7 @@ type AdminService interface {
|
||||
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
|
||||
}
|
||||
|
||||
// Input types for admin operations
|
||||
// CreateUserInput represents the input for creating a new user
|
||||
type CreateUserInput struct {
|
||||
Email string
|
||||
Password string
|
||||
@@ -220,7 +220,7 @@ type adminServiceImpl struct {
|
||||
groupRepo GroupRepository
|
||||
accountRepo AccountRepository
|
||||
proxyRepo ProxyRepository
|
||||
apiKeyRepo ApiKeyRepository
|
||||
apiKeyRepo APIKeyRepository
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
@@ -232,7 +232,7 @@ func NewAdminService(
|
||||
groupRepo GroupRepository,
|
||||
accountRepo AccountRepository,
|
||||
proxyRepo ProxyRepository,
|
||||
apiKeyRepo ApiKeyRepository,
|
||||
apiKeyRepo APIKeyRepository,
|
||||
redeemCodeRepo RedeemCodeRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
@@ -430,7 +430,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) {
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
@@ -583,7 +583,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) {
|
||||
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,7 +2,7 @@ package service
|
||||
|
||||
import "time"
|
||||
|
||||
type ApiKey struct {
|
||||
type APIKey struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Key string
|
||||
@@ -15,6 +15,6 @@ type ApiKey struct {
|
||||
Group *Group
|
||||
}
|
||||
|
||||
func (k *ApiKey) IsActive() bool {
|
||||
func (k *APIKey) IsActive() bool {
|
||||
return k.Status == StatusActive
|
||||
}
|
||||
|
||||
@@ -14,39 +14,39 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
|
||||
ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
|
||||
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
|
||||
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
|
||||
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
|
||||
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
)
|
||||
|
||||
type ApiKeyRepository interface {
|
||||
Create(ctx context.Context, key *ApiKey) error
|
||||
GetByID(ctx context.Context, id int64) (*ApiKey, error)
|
||||
type APIKeyRepository interface {
|
||||
Create(ctx context.Context, key *APIKey) error
|
||||
GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||
// GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证
|
||||
GetOwnerID(ctx context.Context, id int64) (int64, error)
|
||||
GetByKey(ctx context.Context, key string) (*ApiKey, error)
|
||||
Update(ctx context.Context, key *ApiKey) error
|
||||
GetByKey(ctx context.Context, key string) (*APIKey, error)
|
||||
Update(ctx context.Context, key *APIKey) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||||
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||||
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
|
||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
}
|
||||
|
||||
// ApiKeyCache defines cache operations for API key service
|
||||
type ApiKeyCache interface {
|
||||
// APIKeyCache defines cache operations for API key service
|
||||
type APIKeyCache interface {
|
||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
@@ -55,40 +55,40 @@ type ApiKeyCache interface {
|
||||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||
}
|
||||
|
||||
// CreateApiKeyRequest 创建API Key请求
|
||||
type CreateApiKeyRequest struct {
|
||||
// CreateAPIKeyRequest 创建API Key请求
|
||||
type CreateAPIKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
}
|
||||
|
||||
// UpdateApiKeyRequest 更新API Key请求
|
||||
type UpdateApiKeyRequest struct {
|
||||
// UpdateAPIKeyRequest 更新API Key请求
|
||||
type UpdateAPIKeyRequest struct {
|
||||
Name *string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status *string `json:"status"`
|
||||
}
|
||||
|
||||
// ApiKeyService API Key服务
|
||||
type ApiKeyService struct {
|
||||
apiKeyRepo ApiKeyRepository
|
||||
// APIKeyService API Key服务
|
||||
type APIKeyService struct {
|
||||
apiKeyRepo APIKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache ApiKeyCache
|
||||
cache APIKeyCache
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewApiKeyService 创建API Key服务实例
|
||||
func NewApiKeyService(
|
||||
apiKeyRepo ApiKeyRepository,
|
||||
// NewAPIKeyService 创建API Key服务实例
|
||||
func NewAPIKeyService(
|
||||
apiKeyRepo APIKeyRepository,
|
||||
userRepo UserRepository,
|
||||
groupRepo GroupRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache ApiKeyCache,
|
||||
cache APIKeyCache,
|
||||
cfg *config.Config,
|
||||
) *ApiKeyService {
|
||||
return &ApiKeyService{
|
||||
) *APIKeyService {
|
||||
return &APIKeyService{
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
@@ -99,7 +99,7 @@ func NewApiKeyService(
|
||||
}
|
||||
|
||||
// GenerateKey 生成随机API Key
|
||||
func (s *ApiKeyService) GenerateKey() (string, error) {
|
||||
func (s *APIKeyService) GenerateKey() (string, error) {
|
||||
// 生成32字节随机数据
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
@@ -107,7 +107,7 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
|
||||
}
|
||||
|
||||
// 转换为十六进制字符串并添加前缀
|
||||
prefix := s.cfg.Default.ApiKeyPrefix
|
||||
prefix := s.cfg.Default.APIKeyPrefix
|
||||
if prefix == "" {
|
||||
prefix = "sk-"
|
||||
}
|
||||
@@ -117,10 +117,10 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
|
||||
}
|
||||
|
||||
// ValidateCustomKey 验证自定义API Key格式
|
||||
func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
func (s *APIKeyService) ValidateCustomKey(key string) error {
|
||||
// 检查长度
|
||||
if len(key) < 16 {
|
||||
return ErrApiKeyTooShort
|
||||
return ErrAPIKeyTooShort
|
||||
}
|
||||
|
||||
// 检查字符:只允许字母、数字、下划线、连字符
|
||||
@@ -131,14 +131,14 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
c == '_' || c == '-' {
|
||||
continue
|
||||
}
|
||||
return ErrApiKeyInvalidChars
|
||||
return ErrAPIKeyInvalidChars
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -150,14 +150,14 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
|
||||
}
|
||||
|
||||
if count >= apiKeyMaxErrorsPerHour {
|
||||
return ErrApiKeyRateLimited
|
||||
return ErrAPIKeyRateLimited
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
|
||||
// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
@@ -168,7 +168,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in
|
||||
// canUserBindGroup 检查用户是否可以绑定指定分组
|
||||
// 对于订阅类型分组:检查用户是否有有效订阅
|
||||
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
|
||||
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
|
||||
func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
|
||||
@@ -179,7 +179,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group
|
||||
}
|
||||
|
||||
// Create 创建API Key
|
||||
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
|
||||
func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
|
||||
// 验证用户存在
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -204,7 +204,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
// 判断是否使用自定义Key
|
||||
if req.CustomKey != nil && *req.CustomKey != "" {
|
||||
// 检查限流(仅对自定义key进行限流)
|
||||
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
|
||||
if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -219,9 +219,9 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
return nil, fmt.Errorf("check key exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
// Key已存在,增加错误计数
|
||||
s.incrementApiKeyErrorCount(ctx, userID)
|
||||
return nil, ErrApiKeyExists
|
||||
// Key已存在,增加错误计数
|
||||
s.incrementAPIKeyErrorCount(ctx, userID)
|
||||
return nil, ErrAPIKeyExists
|
||||
}
|
||||
|
||||
key = *req.CustomKey
|
||||
@@ -235,7 +235,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
}
|
||||
|
||||
// 创建API Key记录
|
||||
apiKey := &ApiKey{
|
||||
apiKey := &APIKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: req.Name,
|
||||
@@ -251,7 +251,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
}
|
||||
|
||||
// List 获取用户的API Key列表
|
||||
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
|
||||
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||
@@ -259,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
|
||||
return keys, pagination, nil
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
@@ -272,7 +272,7 @@ func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取API Key
|
||||
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
|
||||
func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
@@ -281,7 +281,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error)
|
||||
}
|
||||
|
||||
// GetByKey 根据Key字符串获取API Key(用于认证)
|
||||
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
|
||||
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
// 尝试从Redis缓存获取
|
||||
cacheKey := fmt.Sprintf("apikey:%s", key)
|
||||
|
||||
@@ -301,7 +301,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, erro
|
||||
}
|
||||
|
||||
// Update 更新API Key
|
||||
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
|
||||
func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
@@ -353,8 +353,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
|
||||
// Delete 删除API Key
|
||||
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
|
||||
// 避免加载完整 ApiKey 对象及其关联数据(User、Group),提升删除操作的性能
|
||||
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能
|
||||
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
|
||||
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -379,7 +379,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
||||
}
|
||||
|
||||
// ValidateKey 验证API Key是否有效(用于认证中间件)
|
||||
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
|
||||
func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
|
||||
// 获取API Key
|
||||
apiKey, err := s.GetByKey(ctx, key)
|
||||
if err != nil {
|
||||
@@ -406,7 +406,7 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *
|
||||
}
|
||||
|
||||
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
||||
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 使用Redis计数器
|
||||
if s.cache != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
||||
@@ -423,7 +423,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 返回用户可以选择的分组:
|
||||
// - 标准类型分组:公开的(非专属)或用户被明确允许的
|
||||
// - 订阅类型分组:用户有有效订阅的
|
||||
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
|
||||
func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -460,7 +460,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
|
||||
}
|
||||
|
||||
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
|
||||
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
|
||||
func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
return subscribedGroupIDs[group.ID]
|
||||
@@ -469,8 +469,8 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subsc
|
||||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
|
||||
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
|
||||
func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search api keys: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//go:build unit
|
||||
|
||||
// API Key 服务删除方法的单元测试
|
||||
// 测试 ApiKeyService.Delete 方法在各种场景下的行为,
|
||||
// 测试 APIKeyService.Delete 方法在各种场景下的行为,
|
||||
// 包括权限验证、缓存清理和错误处理
|
||||
|
||||
package service
|
||||
@@ -16,12 +16,12 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// apiKeyRepoStub 是 ApiKeyRepository 接口的测试桩实现。
|
||||
// 用于隔离测试 ApiKeyService.Delete 方法,避免依赖真实数据库。
|
||||
// apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
|
||||
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
|
||||
//
|
||||
// 设计说明:
|
||||
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
|
||||
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrApiKeyNotFound)
|
||||
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound)
|
||||
// - deleteErr: 模拟 Delete 返回的错误
|
||||
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
|
||||
type apiKeyRepoStub struct {
|
||||
@@ -33,11 +33,11 @@ type apiKeyRepoStub struct {
|
||||
|
||||
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
|
||||
|
||||
func (s *apiKeyRepoStub) Create(ctx context.Context, key *ApiKey) error {
|
||||
func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
|
||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error
|
||||
return s.ownerID, s.ownerErr
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
|
||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
panic("unexpected GetByKey call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *ApiKey) error {
|
||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
|
||||
// 以下是接口要求实现但本测试不关心的方法
|
||||
|
||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
|
||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
|
||||
@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err
|
||||
panic("unexpected ExistsByKey call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
|
||||
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByGroupID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
|
||||
panic("unexpected SearchApiKeys call")
|
||||
func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
panic("unexpected SearchAPIKeys call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
|
||||
panic("unexpected CountByGroupID call")
|
||||
}
|
||||
|
||||
// apiKeyCacheStub 是 ApiKeyCache 接口的测试桩实现。
|
||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||
//
|
||||
// 设计说明:
|
||||
@@ -132,17 +132,17 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
||||
// TestAPIKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回所有者 ID 为 1
|
||||
// - 调用者 userID 为 2(不匹配)
|
||||
// - 返回 ErrInsufficientPerms 错误
|
||||
// - Delete 方法不被调用
|
||||
// - 缓存不被清除
|
||||
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
func TestAPIKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerID: 1}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
|
||||
require.ErrorIs(t, err, ErrInsufficientPerms)
|
||||
@@ -150,17 +150,17 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
require.Empty(t, cache.invalidated) // 验证缓存未被清除
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
|
||||
// TestAPIKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回所有者 ID 为 7
|
||||
// - 调用者 userID 为 7(匹配)
|
||||
// - Delete 成功执行
|
||||
// - 缓存被正确清除(使用 ownerID)
|
||||
// - 返回 nil 错误
|
||||
func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||
func TestAPIKeyService_Delete_Success(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerID: 7}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
|
||||
require.NoError(t, err)
|
||||
@@ -168,37 +168,37 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
||||
// TestAPIKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回 ErrApiKeyNotFound 错误
|
||||
// - 返回 ErrApiKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
|
||||
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||
// - Delete 方法不被调用
|
||||
// - 缓存不被清除
|
||||
func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerErr: ErrApiKeyNotFound}
|
||||
func TestAPIKeyService_Delete_NotFound(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 99, 1)
|
||||
require.ErrorIs(t, err, ErrApiKeyNotFound)
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
require.Empty(t, cache.invalidated)
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
|
||||
// TestAPIKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回正确的所有者 ID
|
||||
// - 所有权验证通过
|
||||
// - 缓存被清除(在删除之前)
|
||||
// - Delete 被调用但返回错误
|
||||
// - 返回包含 "delete api key" 的错误信息
|
||||
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
||||
func TestAPIKeyService_Delete_DeleteFails(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
ownerID: 3,
|
||||
deleteErr: errors.New("delete failed"),
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -445,7 +445,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
|
||||
// CheckBillingEligibility 检查用户是否有资格发起请求
|
||||
// 余额模式:检查缓存余额 > 0
|
||||
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
|
||||
// 简易模式:跳过所有计费检查
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
return nil
|
||||
|
||||
@@ -32,6 +32,7 @@ type ConcurrencyCache interface {
|
||||
// 等待队列计数(只在首次创建时设置 TTL)
|
||||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||
GetTotalWaitCount(ctx context.Context) (int, error)
|
||||
|
||||
// 批量负载查询(只读)
|
||||
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
|
||||
@@ -200,6 +201,14 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
|
||||
}
|
||||
}
|
||||
|
||||
// GetTotalWaitCount returns the total wait queue depth across users.
|
||||
func (s *ConcurrencyService) GetTotalWaitCount(ctx context.Context) (int, error) {
|
||||
if s.cache == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return s.cache.GetTotalWaitCount(ctx)
|
||||
}
|
||||
|
||||
// IncrementAccountWaitCount increments the wait queue counter for an account.
|
||||
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
if s.cache == nil {
|
||||
|
||||
@@ -82,7 +82,7 @@ type crsExportResponse struct {
|
||||
OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"`
|
||||
OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"`
|
||||
GeminiOAuthAccounts []crsGeminiOAuthAccount `json:"geminiOAuthAccounts"`
|
||||
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiApiKeyAccounts"`
|
||||
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiAPIKeyAccounts"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
@@ -430,7 +430,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeApiKey,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
@@ -455,7 +455,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = PlatformAnthropic
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Type = AccountTypeAPIKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
@@ -674,7 +674,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeApiKey,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
@@ -699,7 +699,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = PlatformOpenAI
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Type = AccountTypeAPIKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
@@ -893,7 +893,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeApiKey,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
@@ -918,7 +918,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = PlatformGemini
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Type = AccountTypeAPIKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
|
||||
@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key usage trend: %w", err)
|
||||
}
|
||||
@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
|
||||
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ const (
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeApiKey = "apikey" // API Key类型账号
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
@@ -64,13 +64,13 @@ const (
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySMTPPort = "smtp_port" // SMTP端口
|
||||
SettingKeySMTPUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySMTPPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySMTPFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySMTPFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySMTPUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
||||
@@ -81,20 +81,20 @@ const (
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocUrl = "doc_url" // 文档链接
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
|
||||
// 管理员 API Key
|
||||
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
|
||||
// Gemini 配额策略(JSON)
|
||||
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
|
||||
)
|
||||
|
||||
// Admin API Key prefix (distinct from user "sk-" keys)
|
||||
const AdminApiKeyPrefix = "admin-"
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys)
|
||||
const AdminAPIKeyPrefix = "admin-"
|
||||
|
||||
@@ -40,8 +40,8 @@ const (
|
||||
maxVerifyCodeAttempts = 5
|
||||
)
|
||||
|
||||
// SmtpConfig SMTP配置
|
||||
type SmtpConfig struct {
|
||||
// SMTPConfig SMTP配置
|
||||
type SMTPConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
|
||||
}
|
||||
}
|
||||
|
||||
// GetSmtpConfig 从数据库获取SMTP配置
|
||||
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
|
||||
// GetSMTPConfig 从数据库获取SMTP配置
|
||||
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
|
||||
keys := []string{
|
||||
SettingKeySmtpHost,
|
||||
SettingKeySmtpPort,
|
||||
SettingKeySmtpUsername,
|
||||
SettingKeySmtpPassword,
|
||||
SettingKeySmtpFrom,
|
||||
SettingKeySmtpFromName,
|
||||
SettingKeySmtpUseTLS,
|
||||
SettingKeySMTPHost,
|
||||
SettingKeySMTPPort,
|
||||
SettingKeySMTPUsername,
|
||||
SettingKeySMTPPassword,
|
||||
SettingKeySMTPFrom,
|
||||
SettingKeySMTPFromName,
|
||||
SettingKeySMTPUseTLS,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -82,34 +82,34 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
|
||||
return nil, fmt.Errorf("get smtp settings: %w", err)
|
||||
}
|
||||
|
||||
host := settings[SettingKeySmtpHost]
|
||||
host := settings[SettingKeySMTPHost]
|
||||
if host == "" {
|
||||
return nil, ErrEmailNotConfigured
|
||||
}
|
||||
|
||||
port := 587 // 默认端口
|
||||
if portStr := settings[SettingKeySmtpPort]; portStr != "" {
|
||||
if portStr := settings[SettingKeySMTPPort]; portStr != "" {
|
||||
if p, err := strconv.Atoi(portStr); err == nil {
|
||||
port = p
|
||||
}
|
||||
}
|
||||
|
||||
useTLS := settings[SettingKeySmtpUseTLS] == "true"
|
||||
useTLS := settings[SettingKeySMTPUseTLS] == "true"
|
||||
|
||||
return &SmtpConfig{
|
||||
return &SMTPConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: settings[SettingKeySmtpUsername],
|
||||
Password: settings[SettingKeySmtpPassword],
|
||||
From: settings[SettingKeySmtpFrom],
|
||||
FromName: settings[SettingKeySmtpFromName],
|
||||
Username: settings[SettingKeySMTPUsername],
|
||||
Password: settings[SettingKeySMTPPassword],
|
||||
From: settings[SettingKeySMTPFrom],
|
||||
FromName: settings[SettingKeySMTPFromName],
|
||||
UseTLS: useTLS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendEmail 发送邮件(使用数据库中保存的配置)
|
||||
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
|
||||
config, err := s.GetSmtpConfig(ctx)
|
||||
config, err := s.GetSMTPConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
|
||||
}
|
||||
|
||||
// SendEmailWithConfig 使用指定配置发送邮件
|
||||
func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error {
|
||||
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
|
||||
from := config.From
|
||||
if config.FromName != "" {
|
||||
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
|
||||
@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
|
||||
`, siteName, code)
|
||||
}
|
||||
|
||||
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接
|
||||
func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
|
||||
// TestSMTPConnectionWithConfig 使用指定配置测试SMTP连接
|
||||
func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
|
||||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
|
||||
if config.UseTLS {
|
||||
|
||||
@@ -276,7 +276,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
@@ -617,7 +617,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
|
||||
@@ -905,7 +905,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
||||
case AccountTypeOAuth, AccountTypeSetupToken:
|
||||
// Both oauth and setup-token use OAuth token flow
|
||||
return s.getOAuthToken(ctx, account)
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
@@ -976,7 +976,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
originalModel := reqModel
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
// 替换请求体中的模型名
|
||||
@@ -1110,7 +1110,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages"
|
||||
}
|
||||
@@ -1178,10 +1178,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
@@ -1248,12 +1248,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func defaultApiKeyBetaHeader(body []byte) string {
|
||||
func defaultAPIKeyBetaHeader(body []byte) string {
|
||||
modelID := gjson.GetBytes(body, "model").String()
|
||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
return claude.ApiKeyHaikuBetaHeader
|
||||
return claude.APIKeyHaikuBetaHeader
|
||||
}
|
||||
return claude.ApiKeyBetaHeader
|
||||
return claude.APIKeyBetaHeader
|
||||
}
|
||||
|
||||
func truncateForLog(b []byte, maxBytes int) string {
|
||||
@@ -1630,7 +1630,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
ApiKey *ApiKey
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
@@ -1639,7 +1639,7 @@ type RecordUsageInput struct {
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||
result := input.Result
|
||||
apiKey := input.ApiKey
|
||||
apiKey := input.APIKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
@@ -1676,7 +1676,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
@@ -1762,7 +1762,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 应用模型映射(仅对 apikey 类型账号)
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if reqModel != "" {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
@@ -1848,7 +1848,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标 URL
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages/count_tokens"
|
||||
}
|
||||
@@ -1910,10 +1910,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
// OAuth 账号:处理 anthropic-beta header
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,7 +273,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
||||
return 999
|
||||
}
|
||||
switch a.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
|
||||
return 0
|
||||
}
|
||||
@@ -351,7 +351,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
|
||||
originalModel := req.Model
|
||||
mappedModel := req.Model
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(req.Model)
|
||||
}
|
||||
|
||||
@@ -374,7 +374,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
@@ -614,7 +614,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}
|
||||
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
|
||||
@@ -636,7 +636,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
var buildReq func(ctx context.Context) (*http.Request, string, error)
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
@@ -1758,7 +1758,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
}
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||
if apiKey == "" {
|
||||
return nil, errors.New("gemini api_key not configured")
|
||||
|
||||
@@ -275,7 +275,7 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPr
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
{ID: 1, Platform: PlatformGemini, Type: AccountTypeAPIKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
|
||||
@@ -251,7 +251,7 @@ func inferGoogleOneTier(storageBytes int64) string {
|
||||
return TierGoogleOneUnknown
|
||||
}
|
||||
|
||||
// fetchGoogleOneTier fetches Google One tier from Drive API
|
||||
// FetchGoogleOneTier fetches Google One tier from Drive API
|
||||
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
|
||||
driveClient := geminicli.NewDriveClient()
|
||||
|
||||
|
||||
@@ -487,8 +487,8 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
case AccountTypeApiKey:
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
case AccountTypeAPIKey:
|
||||
apiKey := account.GetOpenAIAPIKey()
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
}
|
||||
@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
case AccountTypeOAuth:
|
||||
// OAuth accounts use ChatGPT internal API
|
||||
targetURL = chatgptCodexURL
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
// API Key accounts use Platform API or custom base URL
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL != "" {
|
||||
@@ -940,7 +940,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
ApiKey *ApiKey
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
@@ -949,7 +949,7 @@ type OpenAIRecordUsageInput struct {
|
||||
// RecordUsage records usage and deducts balance
|
||||
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
|
||||
result := input.Result
|
||||
apiKey := input.ApiKey
|
||||
apiKey := input.APIKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
@@ -991,7 +991,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
|
||||
99
backend/internal/service/ops.go
Normal file
99
backend/internal/service/ops.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrorLog represents an ops error log item for list queries.
|
||||
//
|
||||
// Field naming matches docs/API-运维监控中心2.0.md (L3 根因追踪 - 错误日志列表).
|
||||
type ErrorLog struct {
|
||||
ID int64 `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
|
||||
Level string `json:"level,omitempty"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
APIPath string `json:"api_path,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
HTTPCode int `json:"http_code,omitempty"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
|
||||
DurationMs *int `json:"duration_ms,omitempty"`
|
||||
RetryCount *int `json:"retry_count,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorLogFilter describes optional filters and pagination for listing ops error logs.
|
||||
type ErrorLogFilter struct {
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
|
||||
ErrorCode *int
|
||||
Provider string
|
||||
AccountID *int64
|
||||
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
func (f *ErrorLogFilter) normalize() (page, pageSize int) {
|
||||
page = 1
|
||||
pageSize = 20
|
||||
if f == nil {
|
||||
return page, pageSize
|
||||
}
|
||||
|
||||
if f.Page > 0 {
|
||||
page = f.Page
|
||||
}
|
||||
if f.PageSize > 0 {
|
||||
pageSize = f.PageSize
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
return page, pageSize
|
||||
}
|
||||
|
||||
type ErrorLogListResponse struct {
|
||||
Errors []*ErrorLog `json:"errors"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
func (s *OpsService) GetErrorLogs(ctx context.Context, filter *ErrorLogFilter) (*ErrorLogListResponse, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return &ErrorLogListResponse{
|
||||
Errors: []*ErrorLog{},
|
||||
Total: 0,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
}, nil
|
||||
}
|
||||
|
||||
page, pageSize := filter.normalize()
|
||||
if filter == nil {
|
||||
filter = &ErrorLogFilter{}
|
||||
}
|
||||
filter.Page = page
|
||||
filter.PageSize = pageSize
|
||||
|
||||
items, total, err := s.repo.ListErrorLogs(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if items == nil {
|
||||
items = []*ErrorLog{}
|
||||
}
|
||||
|
||||
return &ErrorLogListResponse{
|
||||
Errors: items,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}, nil
|
||||
}
|
||||
834
backend/internal/service/ops_alert_service.go
Normal file
834
backend/internal/service/ops_alert_service.go
Normal file
@@ -0,0 +1,834 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OpsAlertService struct {
|
||||
opsService *OpsService
|
||||
userService *UserService
|
||||
emailService *EmailService
|
||||
httpClient *http.Client
|
||||
|
||||
interval time.Duration
|
||||
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
stopCtx context.Context
|
||||
stop context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// opsAlertEvalInterval defines how often OpsAlertService evaluates alert rules.
|
||||
//
|
||||
// Production uses opsMetricsInterval. Tests may override this variable to keep
|
||||
// integration tests fast without changing production defaults.
|
||||
var opsAlertEvalInterval = opsMetricsInterval
|
||||
|
||||
func NewOpsAlertService(opsService *OpsService, userService *UserService, emailService *EmailService) *OpsAlertService {
|
||||
return &OpsAlertService{
|
||||
opsService: opsService,
|
||||
userService: userService,
|
||||
emailService: emailService,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
interval: opsAlertEvalInterval,
|
||||
}
|
||||
}
|
||||
|
||||
// Start launches the background alert evaluation loop.
|
||||
//
|
||||
// Stop must be called during shutdown to ensure the goroutine exits.
|
||||
func (s *OpsAlertService) Start() {
|
||||
s.StartWithContext(context.Background())
|
||||
}
|
||||
|
||||
// StartWithContext is like Start but allows the caller to provide a parent context.
|
||||
// When the parent context is canceled, the service stops automatically.
|
||||
func (s *OpsAlertService) StartWithContext(ctx context.Context) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
s.startOnce.Do(func() {
|
||||
if s.interval <= 0 {
|
||||
s.interval = opsAlertEvalInterval
|
||||
}
|
||||
|
||||
s.stopCtx, s.stop = context.WithCancel(ctx)
|
||||
s.wg.Add(1)
|
||||
go s.run()
|
||||
})
|
||||
}
|
||||
|
||||
// Stop gracefully stops the background goroutine started by Start/StartWithContext.
|
||||
// It is safe to call Stop multiple times.
|
||||
func (s *OpsAlertService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.stopOnce.Do(func() {
|
||||
if s.stop != nil {
|
||||
s.stop()
|
||||
}
|
||||
})
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *OpsAlertService) run() {
|
||||
defer s.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(s.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.evaluateOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.evaluateOnce()
|
||||
case <-s.stopCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsAlertService) evaluateOnce() {
|
||||
ctx, cancel := context.WithTimeout(s.stopCtx, opsAlertEvaluateTimeout)
|
||||
defer cancel()
|
||||
|
||||
s.Evaluate(ctx, time.Now())
|
||||
}
|
||||
|
||||
func (s *OpsAlertService) Evaluate(ctx context.Context, now time.Time) {
|
||||
if s == nil || s.opsService == nil {
|
||||
return
|
||||
}
|
||||
|
||||
rules, err := s.opsService.ListAlertRules(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlert] failed to list rules: %v", err)
|
||||
return
|
||||
}
|
||||
if len(rules) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
maxSustainedByWindow := make(map[int]int)
|
||||
for _, rule := range rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
window := rule.WindowMinutes
|
||||
if window <= 0 {
|
||||
window = 1
|
||||
}
|
||||
sustained := rule.SustainedMinutes
|
||||
if sustained <= 0 {
|
||||
sustained = 1
|
||||
}
|
||||
if sustained > maxSustainedByWindow[window] {
|
||||
maxSustainedByWindow[window] = sustained
|
||||
}
|
||||
}
|
||||
|
||||
metricsByWindow := make(map[int][]OpsMetrics)
|
||||
for window, limit := range maxSustainedByWindow {
|
||||
metrics, err := s.opsService.ListRecentSystemMetrics(ctx, window, limit)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlert] failed to load metrics window=%dm: %v", window, err)
|
||||
continue
|
||||
}
|
||||
metricsByWindow[window] = metrics
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
window := rule.WindowMinutes
|
||||
if window <= 0 {
|
||||
window = 1
|
||||
}
|
||||
sustained := rule.SustainedMinutes
|
||||
if sustained <= 0 {
|
||||
sustained = 1
|
||||
}
|
||||
|
||||
metrics := metricsByWindow[window]
|
||||
selected, ok := selectContiguousMetrics(metrics, sustained, now)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
breached, latestValue, ok := evaluateRule(rule, selected)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
activeEvent, err := s.opsService.GetActiveAlertEvent(ctx, rule.ID)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlert] failed to get active event (rule=%d): %v", rule.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if breached {
|
||||
if activeEvent != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
lastEvent, err := s.opsService.GetLatestAlertEvent(ctx, rule.ID)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlert] failed to get latest event (rule=%d): %v", rule.ID, err)
|
||||
continue
|
||||
}
|
||||
if lastEvent != nil && rule.CooldownMinutes > 0 {
|
||||
cooldown := time.Duration(rule.CooldownMinutes) * time.Minute
|
||||
if now.Sub(lastEvent.FiredAt) < cooldown {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
event := &OpsAlertEvent{
|
||||
RuleID: rule.ID,
|
||||
Severity: rule.Severity,
|
||||
Status: OpsAlertStatusFiring,
|
||||
Title: fmt.Sprintf("%s: %s", rule.Severity, rule.Name),
|
||||
Description: buildAlertDescription(rule, latestValue),
|
||||
MetricValue: latestValue,
|
||||
ThresholdValue: rule.Threshold,
|
||||
FiredAt: now,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
if err := s.opsService.CreateAlertEvent(ctx, event); err != nil {
|
||||
log.Printf("[OpsAlert] failed to create event (rule=%d): %v", rule.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
emailSent, webhookSent := s.dispatchNotifications(ctx, rule, event)
|
||||
if emailSent || webhookSent {
|
||||
if err := s.opsService.UpdateAlertEventNotifications(ctx, event.ID, emailSent, webhookSent); err != nil {
|
||||
log.Printf("[OpsAlert] failed to update notification flags (event=%d): %v", event.ID, err)
|
||||
}
|
||||
}
|
||||
} else if activeEvent != nil {
|
||||
resolvedAt := now
|
||||
if err := s.opsService.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil {
|
||||
log.Printf("[OpsAlert] failed to resolve event (event=%d): %v", activeEvent.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const opsMetricsContinuityTolerance = 20 * time.Second
|
||||
|
||||
// selectContiguousMetrics picks the newest N metrics and verifies they are continuous.
|
||||
//
|
||||
// This prevents a sustained rule from triggering when metrics sampling has gaps
|
||||
// (e.g. collector downtime) and avoids evaluating "stale" data.
|
||||
//
|
||||
// Assumptions:
|
||||
// - Metrics are ordered by UpdatedAt DESC (newest first).
|
||||
// - Metrics are expected to be collected at opsMetricsInterval cadence.
|
||||
func selectContiguousMetrics(metrics []OpsMetrics, needed int, now time.Time) ([]OpsMetrics, bool) {
|
||||
if needed <= 0 {
|
||||
return nil, false
|
||||
}
|
||||
if len(metrics) < needed {
|
||||
return nil, false
|
||||
}
|
||||
newest := metrics[0].UpdatedAt
|
||||
if newest.IsZero() {
|
||||
return nil, false
|
||||
}
|
||||
if now.Sub(newest) > opsMetricsInterval+opsMetricsContinuityTolerance {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
selected := metrics[:needed]
|
||||
for i := 0; i < len(selected)-1; i++ {
|
||||
a := selected[i].UpdatedAt
|
||||
b := selected[i+1].UpdatedAt
|
||||
if a.IsZero() || b.IsZero() {
|
||||
return nil, false
|
||||
}
|
||||
gap := a.Sub(b)
|
||||
if gap < opsMetricsInterval-opsMetricsContinuityTolerance || gap > opsMetricsInterval+opsMetricsContinuityTolerance {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
return selected, true
|
||||
}
|
||||
|
||||
func evaluateRule(rule OpsAlertRule, metrics []OpsMetrics) (bool, float64, bool) {
|
||||
if len(metrics) == 0 {
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
latestValue, ok := metricValue(metrics[0], rule.MetricType)
|
||||
if !ok {
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
for _, metric := range metrics {
|
||||
value, ok := metricValue(metric, rule.MetricType)
|
||||
if !ok || !compareMetric(value, rule.Operator, rule.Threshold) {
|
||||
return false, latestValue, true
|
||||
}
|
||||
}
|
||||
|
||||
return true, latestValue, true
|
||||
}
|
||||
|
||||
func metricValue(metric OpsMetrics, metricType string) (float64, bool) {
|
||||
switch metricType {
|
||||
case OpsMetricSuccessRate:
|
||||
if metric.RequestCount == 0 {
|
||||
return 0, false
|
||||
}
|
||||
return metric.SuccessRate, true
|
||||
case OpsMetricErrorRate:
|
||||
if metric.RequestCount == 0 {
|
||||
return 0, false
|
||||
}
|
||||
return metric.ErrorRate, true
|
||||
case OpsMetricP95LatencyMs:
|
||||
return float64(metric.P95LatencyMs), true
|
||||
case OpsMetricP99LatencyMs:
|
||||
return float64(metric.P99LatencyMs), true
|
||||
case OpsMetricHTTP2Errors:
|
||||
return float64(metric.HTTP2Errors), true
|
||||
case OpsMetricCPUUsagePercent:
|
||||
return metric.CPUUsagePercent, true
|
||||
case OpsMetricMemoryUsagePercent:
|
||||
return metric.MemoryUsagePercent, true
|
||||
case OpsMetricQueueDepth:
|
||||
return float64(metric.ConcurrencyQueueDepth), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func compareMetric(value float64, operator string, threshold float64) bool {
|
||||
switch operator {
|
||||
case ">":
|
||||
return value > threshold
|
||||
case ">=":
|
||||
return value >= threshold
|
||||
case "<":
|
||||
return value < threshold
|
||||
case "<=":
|
||||
return value <= threshold
|
||||
case "==":
|
||||
return value == threshold
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func buildAlertDescription(rule OpsAlertRule, value float64) string {
|
||||
window := rule.WindowMinutes
|
||||
if window <= 0 {
|
||||
window = 1
|
||||
}
|
||||
return fmt.Sprintf("Rule %s triggered: %s %s %.2f (current %.2f) over last %dm",
|
||||
rule.Name,
|
||||
rule.MetricType,
|
||||
rule.Operator,
|
||||
rule.Threshold,
|
||||
value,
|
||||
window,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OpsAlertService) dispatchNotifications(ctx context.Context, rule OpsAlertRule, event *OpsAlertEvent) (bool, bool) {
|
||||
emailSent := false
|
||||
webhookSent := false
|
||||
|
||||
notifyCtx, cancel := s.notificationContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
if rule.NotifyEmail {
|
||||
emailSent = s.sendEmailNotification(notifyCtx, rule, event)
|
||||
}
|
||||
if rule.NotifyWebhook && rule.WebhookURL != "" {
|
||||
webhookSent = s.sendWebhookNotification(notifyCtx, rule, event)
|
||||
}
|
||||
// Fallback channel: if email is enabled but ultimately fails, try webhook even if the
|
||||
// webhook toggle is off (as long as a webhook URL is configured).
|
||||
if rule.NotifyEmail && !emailSent && !rule.NotifyWebhook && rule.WebhookURL != "" {
|
||||
log.Printf("[OpsAlert] email failed; attempting webhook fallback (rule=%d)", rule.ID)
|
||||
webhookSent = s.sendWebhookNotification(notifyCtx, rule, event)
|
||||
}
|
||||
|
||||
return emailSent, webhookSent
|
||||
}
|
||||
|
||||
const (
|
||||
opsAlertEvaluateTimeout = 45 * time.Second
|
||||
opsAlertNotificationTimeout = 30 * time.Second
|
||||
opsAlertEmailMaxRetries = 3
|
||||
)
|
||||
|
||||
var opsAlertEmailBackoff = []time.Duration{
|
||||
1 * time.Second,
|
||||
2 * time.Second,
|
||||
4 * time.Second,
|
||||
}
|
||||
|
||||
func (s *OpsAlertService) notificationContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
parent := ctx
|
||||
if s != nil && s.stopCtx != nil {
|
||||
parent = s.stopCtx
|
||||
}
|
||||
if parent == nil {
|
||||
parent = context.Background()
|
||||
}
|
||||
return context.WithTimeout(parent, opsAlertNotificationTimeout)
|
||||
}
|
||||
|
||||
var opsAlertSleep = sleepWithContext
|
||||
|
||||
func sleepWithContext(ctx context.Context, d time.Duration) error {
|
||||
if d <= 0 {
|
||||
return nil
|
||||
}
|
||||
if ctx == nil {
|
||||
time.Sleep(d)
|
||||
return nil
|
||||
}
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func retryWithBackoff(
|
||||
ctx context.Context,
|
||||
maxRetries int,
|
||||
backoff []time.Duration,
|
||||
fn func() error,
|
||||
onError func(attempt int, total int, nextDelay time.Duration, err error),
|
||||
) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if maxRetries < 0 {
|
||||
maxRetries = 0
|
||||
}
|
||||
totalAttempts := maxRetries + 1
|
||||
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= totalAttempts; attempt++ {
|
||||
if attempt > 1 {
|
||||
backoffIdx := attempt - 2
|
||||
if backoffIdx < len(backoff) {
|
||||
if err := opsAlertSleep(ctx, backoff[backoffIdx]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := fn(); err != nil {
|
||||
lastErr = err
|
||||
nextDelay := time.Duration(0)
|
||||
if attempt < totalAttempts {
|
||||
nextIdx := attempt - 1
|
||||
if nextIdx < len(backoff) {
|
||||
nextDelay = backoff[nextIdx]
|
||||
}
|
||||
}
|
||||
if onError != nil {
|
||||
onError(attempt, totalAttempts, nextDelay, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (s *OpsAlertService) sendEmailNotification(ctx context.Context, rule OpsAlertRule, event *OpsAlertEvent) bool {
|
||||
if s.emailService == nil || s.userService == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
admin, err := s.userService.GetFirstAdmin(ctx)
|
||||
if err != nil || admin == nil || admin.Email == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
subject := fmt.Sprintf("[Ops Alert][%s] %s", rule.Severity, rule.Name)
|
||||
body := fmt.Sprintf(
|
||||
"Alert triggered: %s\n\nMetric: %s\nThreshold: %.2f\nCurrent: %.2f\nWindow: %dm\nStatus: %s\nTime: %s",
|
||||
rule.Name,
|
||||
rule.MetricType,
|
||||
rule.Threshold,
|
||||
event.MetricValue,
|
||||
rule.WindowMinutes,
|
||||
event.Status,
|
||||
event.FiredAt.Format(time.RFC3339),
|
||||
)
|
||||
|
||||
config, err := s.emailService.GetSMTPConfig(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlert] email config load failed: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := retryWithBackoff(
|
||||
ctx,
|
||||
opsAlertEmailMaxRetries,
|
||||
opsAlertEmailBackoff,
|
||||
func() error {
|
||||
return s.emailService.SendEmailWithConfig(config, admin.Email, subject, body)
|
||||
},
|
||||
func(attempt int, total int, nextDelay time.Duration, err error) {
|
||||
if attempt < total {
|
||||
log.Printf("[OpsAlert] email send failed (attempt=%d/%d), retrying in %s: %v", attempt, total, nextDelay, err)
|
||||
return
|
||||
}
|
||||
log.Printf("[OpsAlert] email send failed (attempt=%d/%d), giving up: %v", attempt, total, err)
|
||||
},
|
||||
); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
log.Printf("[OpsAlert] email send canceled: %v", err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *OpsAlertService) sendWebhookNotification(ctx context.Context, rule OpsAlertRule, event *OpsAlertEvent) bool {
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
webhookTarget, err := validateWebhookURL(ctx, rule.WebhookURL)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlert] invalid webhook url (rule=%d): %v", rule.ID, err)
|
||||
return false
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"rule_id": rule.ID,
|
||||
"rule_name": rule.Name,
|
||||
"severity": rule.Severity,
|
||||
"status": event.Status,
|
||||
"metric_type": rule.MetricType,
|
||||
"metric_value": event.MetricValue,
|
||||
"threshold_value": rule.Threshold,
|
||||
"window_minutes": rule.WindowMinutes,
|
||||
"fired_at": event.FiredAt.Format(time.RFC3339),
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, webhookTarget.URL.String(), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := buildWebhookHTTPClient(s.httpClient, webhookTarget).Do(req)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlert] webhook send failed: %v", err)
|
||||
return false
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
log.Printf("[OpsAlert] webhook returned status %d", resp.StatusCode)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
const webhookHTTPClientTimeout = 10 * time.Second
|
||||
|
||||
func buildWebhookHTTPClient(base *http.Client, webhookTarget *validatedWebhookTarget) *http.Client {
|
||||
var client http.Client
|
||||
if base != nil {
|
||||
client = *base
|
||||
}
|
||||
if client.Timeout <= 0 {
|
||||
client.Timeout = webhookHTTPClientTimeout
|
||||
}
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
if webhookTarget != nil {
|
||||
client.Transport = buildWebhookTransport(client.Transport, webhookTarget)
|
||||
}
|
||||
return &client
|
||||
}
|
||||
|
||||
var disallowedWebhookIPNets = []net.IPNet{
|
||||
// "this host on this network" / unspecified.
|
||||
mustParseCIDR("0.0.0.0/8"),
|
||||
mustParseCIDR("127.0.0.0/8"), // loopback (includes 127.0.0.1)
|
||||
mustParseCIDR("10.0.0.0/8"), // RFC1918
|
||||
mustParseCIDR("192.168.0.0/16"), // RFC1918
|
||||
mustParseCIDR("172.16.0.0/12"), // RFC1918 (172.16.0.0 - 172.31.255.255)
|
||||
mustParseCIDR("100.64.0.0/10"), // RFC6598 (carrier-grade NAT)
|
||||
mustParseCIDR("169.254.0.0/16"), // IPv4 link-local (includes 169.254.169.254 metadata IP on many clouds)
|
||||
mustParseCIDR("198.18.0.0/15"), // RFC2544 benchmark testing
|
||||
mustParseCIDR("224.0.0.0/4"), // IPv4 multicast
|
||||
mustParseCIDR("240.0.0.0/4"), // IPv4 reserved
|
||||
mustParseCIDR("::/128"), // IPv6 unspecified
|
||||
mustParseCIDR("::1/128"), // IPv6 loopback
|
||||
mustParseCIDR("fc00::/7"), // IPv6 unique local
|
||||
mustParseCIDR("fe80::/10"), // IPv6 link-local
|
||||
mustParseCIDR("ff00::/8"), // IPv6 multicast
|
||||
}
|
||||
|
||||
func mustParseCIDR(cidr string) net.IPNet {
|
||||
_, block, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return *block
|
||||
}
|
||||
|
||||
var lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
return net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
}
|
||||
|
||||
type validatedWebhookTarget struct {
|
||||
URL *url.URL
|
||||
|
||||
host string
|
||||
port string
|
||||
pinnedIPs []net.IP
|
||||
}
|
||||
|
||||
var webhookBaseDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialer := net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
func buildWebhookTransport(base http.RoundTripper, webhookTarget *validatedWebhookTarget) http.RoundTripper {
|
||||
if webhookTarget == nil || webhookTarget.URL == nil {
|
||||
return base
|
||||
}
|
||||
|
||||
var transport *http.Transport
|
||||
switch typed := base.(type) {
|
||||
case *http.Transport:
|
||||
if typed != nil {
|
||||
transport = typed.Clone()
|
||||
}
|
||||
}
|
||||
if transport == nil {
|
||||
if defaultTransport, ok := http.DefaultTransport.(*http.Transport); ok && defaultTransport != nil {
|
||||
transport = defaultTransport.Clone()
|
||||
} else {
|
||||
transport = (&http.Transport{}).Clone()
|
||||
}
|
||||
}
|
||||
|
||||
webhookHost := webhookTarget.host
|
||||
webhookPort := webhookTarget.port
|
||||
pinnedIPs := append([]net.IP(nil), webhookTarget.pinnedIPs...)
|
||||
|
||||
transport.Proxy = nil
|
||||
transport.DialTLSContext = nil
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil || host == "" || port == "" {
|
||||
return nil, fmt.Errorf("webhook dial target is invalid: %q", addr)
|
||||
}
|
||||
|
||||
canonicalHost := strings.TrimSuffix(strings.ToLower(host), ".")
|
||||
if canonicalHost != webhookHost || port != webhookPort {
|
||||
return nil, fmt.Errorf("webhook dial target mismatch: %q", addr)
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, ip := range pinnedIPs {
|
||||
if isDisallowedWebhookIP(ip) {
|
||||
lastErr = fmt.Errorf("webhook target resolves to a disallowed ip")
|
||||
continue
|
||||
}
|
||||
|
||||
dialAddr := net.JoinHostPort(ip.String(), port)
|
||||
conn, err := webhookBaseDialContext(ctx, network, dialAddr)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
if lastErr == nil {
|
||||
lastErr = errors.New("webhook target has no resolved addresses")
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
return transport
|
||||
}
|
||||
|
||||
func validateWebhookURL(ctx context.Context, raw string) (*validatedWebhookTarget, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil, errors.New("webhook url is empty")
|
||||
}
|
||||
// Avoid request smuggling / header injection vectors.
|
||||
if strings.ContainsAny(raw, "\r\n") {
|
||||
return nil, errors.New("webhook url contains invalid characters")
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return nil, errors.New("webhook url format is invalid")
|
||||
}
|
||||
if !strings.EqualFold(parsed.Scheme, "https") {
|
||||
return nil, errors.New("webhook url scheme must be https")
|
||||
}
|
||||
parsed.Scheme = "https"
|
||||
if parsed.Host == "" || parsed.Hostname() == "" {
|
||||
return nil, errors.New("webhook url must include host")
|
||||
}
|
||||
if parsed.User != nil {
|
||||
return nil, errors.New("webhook url must not include userinfo")
|
||||
}
|
||||
if parsed.Port() != "" {
|
||||
port, err := strconv.Atoi(parsed.Port())
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return nil, errors.New("webhook url port is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
host := strings.TrimSuffix(strings.ToLower(parsed.Hostname()), ".")
|
||||
if host == "localhost" {
|
||||
return nil, errors.New("webhook url host must not be localhost")
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isDisallowedWebhookIP(ip) {
|
||||
return nil, errors.New("webhook url host resolves to a disallowed ip")
|
||||
}
|
||||
return &validatedWebhookTarget{
|
||||
URL: parsed,
|
||||
host: host,
|
||||
port: portForScheme(parsed),
|
||||
pinnedIPs: []net.IP{ip},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
ips, err := lookupIPAddrs(ctx, host)
|
||||
if err != nil || len(ips) == 0 {
|
||||
return nil, errors.New("webhook url host cannot be resolved")
|
||||
}
|
||||
pinned := make([]net.IP, 0, len(ips))
|
||||
for _, addr := range ips {
|
||||
if isDisallowedWebhookIP(addr.IP) {
|
||||
return nil, errors.New("webhook url host resolves to a disallowed ip")
|
||||
}
|
||||
if addr.IP != nil {
|
||||
pinned = append(pinned, addr.IP)
|
||||
}
|
||||
}
|
||||
|
||||
if len(pinned) == 0 {
|
||||
return nil, errors.New("webhook url host cannot be resolved")
|
||||
}
|
||||
|
||||
return &validatedWebhookTarget{
|
||||
URL: parsed,
|
||||
host: host,
|
||||
port: portForScheme(parsed),
|
||||
pinnedIPs: uniqueResolvedIPs(pinned),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isDisallowedWebhookIP(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
ip = ip4
|
||||
} else if ip16 := ip.To16(); ip16 != nil {
|
||||
ip = ip16
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
|
||||
// Disallow non-public addresses even if they're not explicitly covered by the CIDR list.
|
||||
// This provides defense-in-depth against SSRF targets such as link-local, multicast, and
|
||||
// unspecified addresses, and ensures any "pinned" IP is still blocked at dial time.
|
||||
if ip.IsUnspecified() ||
|
||||
ip.IsLoopback() ||
|
||||
ip.IsMulticast() ||
|
||||
ip.IsLinkLocalUnicast() ||
|
||||
ip.IsLinkLocalMulticast() ||
|
||||
ip.IsPrivate() {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, block := range disallowedWebhookIPNets {
|
||||
if block.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func portForScheme(u *url.URL) string {
|
||||
if u != nil && u.Port() != "" {
|
||||
return u.Port()
|
||||
}
|
||||
return "443"
|
||||
}
|
||||
|
||||
func uniqueResolvedIPs(ips []net.IP) []net.IP {
|
||||
seen := make(map[string]struct{}, len(ips))
|
||||
out := make([]net.IP, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
key := ip.String()
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, ip)
|
||||
}
|
||||
return out
|
||||
}
|
||||
271
backend/internal/service/ops_alert_service_integration_test.go
Normal file
271
backend/internal/service/ops_alert_service_integration_test.go
Normal file
@@ -0,0 +1,271 @@
|
||||
//go:build integration
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// This integration test protects the DI startup contract for OpsAlertService.
|
||||
//
|
||||
// Background:
|
||||
// - OpsMetricsCollector previously called alertService.Start()/Evaluate() directly.
|
||||
// - Those direct calls were removed, so OpsAlertService must now start via DI
|
||||
// (ProvideOpsAlertService in wire.go) and run its own evaluation ticker.
|
||||
//
|
||||
// What we validate here:
|
||||
// 1. When we construct via the Wire provider functions (ProvideOpsAlertService +
|
||||
// ProvideOpsMetricsCollector), OpsAlertService starts automatically.
|
||||
// 2. Its evaluation loop continues to tick even if OpsMetricsCollector is stopped,
|
||||
// proving the alert evaluator is independent.
|
||||
// 3. The evaluation path can trigger alert logic (CreateAlertEvent called).
|
||||
func TestOpsAlertService_StartedViaWireProviders_RunsIndependentTicker(t *testing.T) {
|
||||
oldInterval := opsAlertEvalInterval
|
||||
opsAlertEvalInterval = 25 * time.Millisecond
|
||||
t.Cleanup(func() { opsAlertEvalInterval = oldInterval })
|
||||
|
||||
repo := newFakeOpsRepository()
|
||||
opsService := NewOpsService(repo, nil)
|
||||
|
||||
// Start via the Wire provider function (the production DI path).
|
||||
alertService := ProvideOpsAlertService(opsService, nil, nil)
|
||||
t.Cleanup(alertService.Stop)
|
||||
|
||||
// Construct via ProvideOpsMetricsCollector (wire.go). Stop immediately to ensure
|
||||
// the alert ticker keeps running without the metrics collector.
|
||||
collector := ProvideOpsMetricsCollector(opsService, NewConcurrencyService(nil))
|
||||
collector.Stop()
|
||||
|
||||
// Wait for at least one evaluation (run() calls evaluateOnce immediately).
|
||||
require.Eventually(t, func() bool {
|
||||
return repo.listRulesCalls.Load() >= 1
|
||||
}, 1*time.Second, 5*time.Millisecond)
|
||||
|
||||
// Confirm the evaluation loop keeps ticking after the metrics collector is stopped.
|
||||
callsAfterCollectorStop := repo.listRulesCalls.Load()
|
||||
require.Eventually(t, func() bool {
|
||||
return repo.listRulesCalls.Load() >= callsAfterCollectorStop+2
|
||||
}, 1*time.Second, 5*time.Millisecond)
|
||||
|
||||
// Confirm the evaluation logic actually fires an alert event at least once.
|
||||
select {
|
||||
case <-repo.eventCreatedCh:
|
||||
// ok
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("expected OpsAlertService to create an alert event, but none was created (ListAlertRules calls=%d)", repo.listRulesCalls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func newFakeOpsRepository() *fakeOpsRepository {
|
||||
return &fakeOpsRepository{
|
||||
eventCreatedCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// fakeOpsRepository is a lightweight in-memory stub of OpsRepository for integration tests.
|
||||
// It avoids real DB/Redis usage and provides deterministic responses fast.
|
||||
type fakeOpsRepository struct {
|
||||
listRulesCalls atomic.Int64
|
||||
|
||||
mu sync.Mutex
|
||||
activeEvent *OpsAlertEvent
|
||||
latestEvent *OpsAlertEvent
|
||||
nextEventID int64
|
||||
eventCreatedCh chan struct{}
|
||||
eventOnce sync.Once
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) CreateErrorLog(ctx context.Context, log *OpsErrorLog) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) ListErrorLogsLegacy(ctx context.Context, filters OpsErrorLogFilters) ([]OpsErrorLog, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) ListErrorLogs(ctx context.Context, filter *ErrorLogFilter) ([]*ErrorLog, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetLatestSystemMetric(ctx context.Context) (*OpsMetrics, error) {
|
||||
return &OpsMetrics{WindowMinutes: 1}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) CreateSystemMetric(ctx context.Context, metric *OpsMetrics) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetWindowStats(ctx context.Context, startTime, endTime time.Time) (*OpsWindowStats, error) {
|
||||
return &OpsWindowStats{}, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetProviderStats(ctx context.Context, startTime, endTime time.Time) ([]*ProviderStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetLatencyHistogram(ctx context.Context, startTime, endTime time.Time) ([]*LatencyHistogramItem, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetErrorDistribution(ctx context.Context, startTime, endTime time.Time) ([]*ErrorDistributionItem, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) ListRecentSystemMetrics(ctx context.Context, windowMinutes, limit int) ([]OpsMetrics, error) {
|
||||
if limit <= 0 {
|
||||
limit = 1
|
||||
}
|
||||
now := time.Now()
|
||||
metrics := make([]OpsMetrics, 0, limit)
|
||||
for i := 0; i < limit; i++ {
|
||||
metrics = append(metrics, OpsMetrics{
|
||||
WindowMinutes: windowMinutes,
|
||||
CPUUsagePercent: 99,
|
||||
UpdatedAt: now.Add(-time.Duration(i) * opsMetricsInterval),
|
||||
})
|
||||
}
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) ListSystemMetricsRange(ctx context.Context, windowMinutes int, startTime, endTime time.Time, limit int) ([]OpsMetrics, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) ListAlertRules(ctx context.Context) ([]OpsAlertRule, error) {
|
||||
call := r.listRulesCalls.Add(1)
|
||||
// Delay enabling rules slightly so the test can stop OpsMetricsCollector first,
|
||||
// then observe the alert evaluator ticking independently.
|
||||
if call < 5 {
|
||||
return nil, nil
|
||||
}
|
||||
return []OpsAlertRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "cpu too high (test)",
|
||||
Enabled: true,
|
||||
MetricType: OpsMetricCPUUsagePercent,
|
||||
Operator: ">",
|
||||
Threshold: 0,
|
||||
WindowMinutes: 1,
|
||||
SustainedMinutes: 1,
|
||||
Severity: "P1",
|
||||
NotifyEmail: false,
|
||||
NotifyWebhook: false,
|
||||
CooldownMinutes: 0,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.activeEvent == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if r.activeEvent.RuleID != ruleID {
|
||||
return nil, nil
|
||||
}
|
||||
if r.activeEvent.Status != OpsAlertStatusFiring {
|
||||
return nil, nil
|
||||
}
|
||||
clone := *r.activeEvent
|
||||
return &clone, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.latestEvent == nil || r.latestEvent.RuleID != ruleID {
|
||||
return nil, nil
|
||||
}
|
||||
clone := *r.latestEvent
|
||||
return &clone, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) error {
|
||||
if event == nil {
|
||||
return nil
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.nextEventID++
|
||||
event.ID = r.nextEventID
|
||||
|
||||
clone := *event
|
||||
r.latestEvent = &clone
|
||||
if clone.Status == OpsAlertStatusFiring {
|
||||
r.activeEvent = &clone
|
||||
}
|
||||
|
||||
r.eventOnce.Do(func() { close(r.eventCreatedCh) })
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.activeEvent != nil && r.activeEvent.ID == eventID {
|
||||
r.activeEvent.Status = status
|
||||
r.activeEvent.ResolvedAt = resolvedAt
|
||||
}
|
||||
if r.latestEvent != nil && r.latestEvent.ID == eventID {
|
||||
r.latestEvent.Status = status
|
||||
r.latestEvent.ResolvedAt = resolvedAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) UpdateAlertEventNotifications(ctx context.Context, eventID int64, emailSent, webhookSent bool) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.activeEvent != nil && r.activeEvent.ID == eventID {
|
||||
r.activeEvent.EmailSent = emailSent
|
||||
r.activeEvent.WebhookSent = webhookSent
|
||||
}
|
||||
if r.latestEvent != nil && r.latestEvent.ID == eventID {
|
||||
r.latestEvent.EmailSent = emailSent
|
||||
r.latestEvent.WebhookSent = webhookSent
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) CountActiveAlerts(ctx context.Context) (int, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.activeEvent == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetOverviewStats(ctx context.Context, startTime, endTime time.Time) (*OverviewStats, error) {
|
||||
return &OverviewStats{}, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetCachedLatestSystemMetric(ctx context.Context) (*OpsMetrics, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) SetCachedLatestSystemMetric(ctx context.Context, metric *OpsMetrics) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) GetCachedDashboardOverview(ctx context.Context, timeRange string) (*DashboardOverviewData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) SetCachedDashboardOverview(ctx context.Context, timeRange string, data *DashboardOverviewData, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeOpsRepository) PingRedis(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
315
backend/internal/service/ops_alert_service_test.go
Normal file
315
backend/internal/service/ops_alert_service_test.go
Normal file
@@ -0,0 +1,315 @@
|
||||
//go:build unit || opsalert_unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSelectContiguousMetrics_Contiguous(t *testing.T) {
|
||||
now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
metrics := []OpsMetrics{
|
||||
{UpdatedAt: now},
|
||||
{UpdatedAt: now.Add(-1 * time.Minute)},
|
||||
{UpdatedAt: now.Add(-2 * time.Minute)},
|
||||
}
|
||||
|
||||
selected, ok := selectContiguousMetrics(metrics, 3, now)
|
||||
require.True(t, ok)
|
||||
require.Len(t, selected, 3)
|
||||
}
|
||||
|
||||
func TestSelectContiguousMetrics_GapFails(t *testing.T) {
|
||||
now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
metrics := []OpsMetrics{
|
||||
{UpdatedAt: now},
|
||||
// Missing the -1m sample (gap ~=2m).
|
||||
{UpdatedAt: now.Add(-2 * time.Minute)},
|
||||
{UpdatedAt: now.Add(-3 * time.Minute)},
|
||||
}
|
||||
|
||||
_, ok := selectContiguousMetrics(metrics, 3, now)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestSelectContiguousMetrics_StaleNewestFails(t *testing.T) {
|
||||
now := time.Date(2026, 1, 1, 0, 10, 0, 0, time.UTC)
|
||||
metrics := []OpsMetrics{
|
||||
{UpdatedAt: now.Add(-10 * time.Minute)},
|
||||
{UpdatedAt: now.Add(-11 * time.Minute)},
|
||||
}
|
||||
|
||||
_, ok := selectContiguousMetrics(metrics, 2, now)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestMetricValue_SuccessRate_NoTrafficIsNoData(t *testing.T) {
|
||||
metric := OpsMetrics{
|
||||
RequestCount: 0,
|
||||
SuccessRate: 0,
|
||||
}
|
||||
value, ok := metricValue(metric, OpsMetricSuccessRate)
|
||||
require.False(t, ok)
|
||||
require.Equal(t, 0.0, value)
|
||||
}
|
||||
|
||||
func TestOpsAlertService_StopWithoutStart_NoPanic(t *testing.T) {
|
||||
s := NewOpsAlertService(nil, nil, nil)
|
||||
require.NotPanics(t, func() { s.Stop() })
|
||||
}
|
||||
|
||||
func TestOpsAlertService_StartStop_Graceful(t *testing.T) {
|
||||
s := NewOpsAlertService(nil, nil, nil)
|
||||
s.interval = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.StartWithContext(ctx)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.Stop()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// ok
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("Stop did not return; background goroutine likely stuck")
|
||||
}
|
||||
|
||||
require.NotPanics(t, func() { s.Stop() })
|
||||
}
|
||||
|
||||
func TestBuildWebhookHTTPClient_DefaultTimeout(t *testing.T) {
|
||||
client := buildWebhookHTTPClient(nil, nil)
|
||||
require.Equal(t, webhookHTTPClientTimeout, client.Timeout)
|
||||
require.NotNil(t, client.CheckRedirect)
|
||||
require.ErrorIs(t, client.CheckRedirect(nil, nil), http.ErrUseLastResponse)
|
||||
|
||||
base := &http.Client{}
|
||||
client = buildWebhookHTTPClient(base, nil)
|
||||
require.Equal(t, webhookHTTPClientTimeout, client.Timeout)
|
||||
require.NotNil(t, client.CheckRedirect)
|
||||
|
||||
base = &http.Client{Timeout: 2 * time.Second}
|
||||
client = buildWebhookHTTPClient(base, nil)
|
||||
require.Equal(t, 2*time.Second, client.Timeout)
|
||||
require.NotNil(t, client.CheckRedirect)
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_RequiresHTTPS(t *testing.T) {
|
||||
oldLookup := lookupIPAddrs
|
||||
t.Cleanup(func() { lookupIPAddrs = oldLookup })
|
||||
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
|
||||
}
|
||||
|
||||
_, err := validateWebhookURL(context.Background(), "http://example.com/webhook")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_InvalidFormatRejected(t *testing.T) {
|
||||
_, err := validateWebhookURL(context.Background(), "https://[::1")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_RejectsUserinfo(t *testing.T) {
|
||||
oldLookup := lookupIPAddrs
|
||||
t.Cleanup(func() { lookupIPAddrs = oldLookup })
|
||||
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
|
||||
}
|
||||
|
||||
_, err := validateWebhookURL(context.Background(), "https://user:pass@example.com/webhook")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_RejectsLocalhost(t *testing.T) {
|
||||
_, err := validateWebhookURL(context.Background(), "https://localhost/webhook")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_RejectsPrivateIPLiteral(t *testing.T) {
|
||||
cases := []string{
|
||||
"https://0.0.0.0/webhook",
|
||||
"https://127.0.0.1/webhook",
|
||||
"https://10.0.0.1/webhook",
|
||||
"https://192.168.1.2/webhook",
|
||||
"https://172.16.0.1/webhook",
|
||||
"https://172.31.255.255/webhook",
|
||||
"https://100.64.0.1/webhook",
|
||||
"https://169.254.169.254/webhook",
|
||||
"https://198.18.0.1/webhook",
|
||||
"https://224.0.0.1/webhook",
|
||||
"https://240.0.0.1/webhook",
|
||||
"https://[::]/webhook",
|
||||
"https://[::1]/webhook",
|
||||
"https://[ff02::1]/webhook",
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc, func(t *testing.T) {
|
||||
_, err := validateWebhookURL(context.Background(), tc)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_RejectsPrivateIPViaDNS(t *testing.T) {
|
||||
oldLookup := lookupIPAddrs
|
||||
t.Cleanup(func() { lookupIPAddrs = oldLookup })
|
||||
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
require.Equal(t, "internal.example", host)
|
||||
return []net.IPAddr{{IP: net.ParseIP("10.0.0.2")}}, nil
|
||||
}
|
||||
|
||||
_, err := validateWebhookURL(context.Background(), "https://internal.example/webhook")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_RejectsLinkLocalIPViaDNS(t *testing.T) {
|
||||
oldLookup := lookupIPAddrs
|
||||
t.Cleanup(func() { lookupIPAddrs = oldLookup })
|
||||
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
require.Equal(t, "metadata.example", host)
|
||||
return []net.IPAddr{{IP: net.ParseIP("169.254.169.254")}}, nil
|
||||
}
|
||||
|
||||
_, err := validateWebhookURL(context.Background(), "https://metadata.example/webhook")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_AllowsPublicHostViaDNS(t *testing.T) {
|
||||
oldLookup := lookupIPAddrs
|
||||
t.Cleanup(func() { lookupIPAddrs = oldLookup })
|
||||
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
require.Equal(t, "example.com", host)
|
||||
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
|
||||
}
|
||||
|
||||
target, err := validateWebhookURL(context.Background(), "https://example.com:443/webhook")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https", target.URL.Scheme)
|
||||
require.Equal(t, "example.com", target.URL.Hostname())
|
||||
require.Equal(t, "443", target.URL.Port())
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL_RejectsInvalidPort(t *testing.T) {
|
||||
oldLookup := lookupIPAddrs
|
||||
t.Cleanup(func() { lookupIPAddrs = oldLookup })
|
||||
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
|
||||
}
|
||||
|
||||
_, err := validateWebhookURL(context.Background(), "https://example.com:99999/webhook")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestWebhookTransport_UsesPinnedIP_NoDNSRebinding(t *testing.T) {
|
||||
oldLookup := lookupIPAddrs
|
||||
oldDial := webhookBaseDialContext
|
||||
t.Cleanup(func() {
|
||||
lookupIPAddrs = oldLookup
|
||||
webhookBaseDialContext = oldDial
|
||||
})
|
||||
|
||||
lookupCalls := 0
|
||||
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
lookupCalls++
|
||||
require.Equal(t, "example.com", host)
|
||||
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
|
||||
}
|
||||
|
||||
target, err := validateWebhookURL(context.Background(), "https://example.com/webhook")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, lookupCalls)
|
||||
|
||||
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
lookupCalls++
|
||||
return []net.IPAddr{{IP: net.ParseIP("10.0.0.1")}}, nil
|
||||
}
|
||||
|
||||
var dialAddrs []string
|
||||
webhookBaseDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialAddrs = append(dialAddrs, addr)
|
||||
return nil, errors.New("dial blocked in test")
|
||||
}
|
||||
|
||||
client := buildWebhookHTTPClient(nil, target)
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
|
||||
_, err = transport.DialContext(context.Background(), "tcp", "example.com:443")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, []string{"93.184.216.34:443"}, dialAddrs)
|
||||
require.Equal(t, 1, lookupCalls, "dial path must not re-resolve DNS")
|
||||
}
|
||||
|
||||
func TestRetryWithBackoff_SucceedsAfterRetries(t *testing.T) {
|
||||
oldSleep := opsAlertSleep
|
||||
t.Cleanup(func() { opsAlertSleep = oldSleep })
|
||||
|
||||
var slept []time.Duration
|
||||
opsAlertSleep = func(ctx context.Context, d time.Duration) error {
|
||||
slept = append(slept, d)
|
||||
return nil
|
||||
}
|
||||
|
||||
attempts := 0
|
||||
err := retryWithBackoff(
|
||||
context.Background(),
|
||||
3,
|
||||
[]time.Duration{time.Second, 2 * time.Second, 4 * time.Second},
|
||||
func() error {
|
||||
attempts++
|
||||
if attempts <= 3 {
|
||||
return errors.New("send failed")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 4, attempts)
|
||||
require.Equal(t, []time.Duration{time.Second, 2 * time.Second, 4 * time.Second}, slept)
|
||||
}
|
||||
|
||||
func TestRetryWithBackoff_ContextCanceledStopsRetries(t *testing.T) {
|
||||
oldSleep := opsAlertSleep
|
||||
t.Cleanup(func() { opsAlertSleep = oldSleep })
|
||||
|
||||
var slept []time.Duration
|
||||
opsAlertSleep = func(ctx context.Context, d time.Duration) error {
|
||||
slept = append(slept, d)
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
attempts := 0
|
||||
err := retryWithBackoff(
|
||||
ctx,
|
||||
3,
|
||||
[]time.Duration{time.Second, 2 * time.Second, 4 * time.Second},
|
||||
func() error {
|
||||
attempts++
|
||||
return errors.New("send failed")
|
||||
},
|
||||
func(attempt int, total int, nextDelay time.Duration, err error) {
|
||||
if attempt == 1 {
|
||||
cancel()
|
||||
}
|
||||
},
|
||||
)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
require.Equal(t, 1, attempts)
|
||||
require.Equal(t, []time.Duration{time.Second}, slept)
|
||||
}
|
||||
92
backend/internal/service/ops_alerts.go
Normal file
92
backend/internal/service/ops_alerts.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
OpsAlertStatusFiring = "firing"
|
||||
OpsAlertStatusResolved = "resolved"
|
||||
)
|
||||
|
||||
const (
|
||||
OpsMetricSuccessRate = "success_rate"
|
||||
OpsMetricErrorRate = "error_rate"
|
||||
OpsMetricP95LatencyMs = "p95_latency_ms"
|
||||
OpsMetricP99LatencyMs = "p99_latency_ms"
|
||||
OpsMetricHTTP2Errors = "http2_errors"
|
||||
OpsMetricCPUUsagePercent = "cpu_usage_percent"
|
||||
OpsMetricMemoryUsagePercent = "memory_usage_percent"
|
||||
OpsMetricQueueDepth = "concurrency_queue_depth"
|
||||
)
|
||||
|
||||
type OpsAlertRule struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Enabled bool `json:"enabled"`
|
||||
MetricType string `json:"metric_type"`
|
||||
Operator string `json:"operator"`
|
||||
Threshold float64 `json:"threshold"`
|
||||
WindowMinutes int `json:"window_minutes"`
|
||||
SustainedMinutes int `json:"sustained_minutes"`
|
||||
Severity string `json:"severity"`
|
||||
NotifyEmail bool `json:"notify_email"`
|
||||
NotifyWebhook bool `json:"notify_webhook"`
|
||||
WebhookURL string `json:"webhook_url"`
|
||||
CooldownMinutes int `json:"cooldown_minutes"`
|
||||
DimensionFilters map[string]any `json:"dimension_filters,omitempty"`
|
||||
NotifyChannels []string `json:"notify_channels,omitempty"`
|
||||
NotifyConfig map[string]any `json:"notify_config,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type OpsAlertEvent struct {
|
||||
ID int64 `json:"id"`
|
||||
RuleID int64 `json:"rule_id"`
|
||||
Severity string `json:"severity"`
|
||||
Status string `json:"status"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
MetricValue float64 `json:"metric_value"`
|
||||
ThresholdValue float64 `json:"threshold_value"`
|
||||
FiredAt time.Time `json:"fired_at"`
|
||||
ResolvedAt *time.Time `json:"resolved_at"`
|
||||
EmailSent bool `json:"email_sent"`
|
||||
WebhookSent bool `json:"webhook_sent"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
func (s *OpsService) ListAlertRules(ctx context.Context) ([]OpsAlertRule, error) {
|
||||
return s.repo.ListAlertRules(ctx)
|
||||
}
|
||||
|
||||
func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||
return s.repo.GetActiveAlertEvent(ctx, ruleID)
|
||||
}
|
||||
|
||||
func (s *OpsService) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||
return s.repo.GetLatestAlertEvent(ctx, ruleID)
|
||||
}
|
||||
|
||||
func (s *OpsService) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) error {
|
||||
return s.repo.CreateAlertEvent(ctx, event)
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
|
||||
return s.repo.UpdateAlertEventStatus(ctx, eventID, status, resolvedAt)
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateAlertEventNotifications(ctx context.Context, eventID int64, emailSent, webhookSent bool) error {
|
||||
return s.repo.UpdateAlertEventNotifications(ctx, eventID, emailSent, webhookSent)
|
||||
}
|
||||
|
||||
func (s *OpsService) ListRecentSystemMetrics(ctx context.Context, windowMinutes, limit int) ([]OpsMetrics, error) {
|
||||
return s.repo.ListRecentSystemMetrics(ctx, windowMinutes, limit)
|
||||
}
|
||||
|
||||
func (s *OpsService) CountActiveAlerts(ctx context.Context) (int, error) {
|
||||
return s.repo.CountActiveAlerts(ctx)
|
||||
}
|
||||
203
backend/internal/service/ops_metrics_collector.go
Normal file
203
backend/internal/service/ops_metrics_collector.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/shirou/gopsutil/v4/cpu"
|
||||
"github.com/shirou/gopsutil/v4/mem"
|
||||
)
|
||||
|
||||
const (
|
||||
opsMetricsInterval = 1 * time.Minute
|
||||
opsMetricsCollectTimeout = 10 * time.Second
|
||||
|
||||
opsMetricsWindowShortMinutes = 1
|
||||
opsMetricsWindowLongMinutes = 5
|
||||
|
||||
bytesPerMB = 1024 * 1024
|
||||
cpuUsageSampleInterval = 0 * time.Second
|
||||
|
||||
percentScale = 100
|
||||
)
|
||||
|
||||
type OpsMetricsCollector struct {
|
||||
opsService *OpsService
|
||||
concurrencyService *ConcurrencyService
|
||||
interval time.Duration
|
||||
lastGCPauseTotal uint64
|
||||
lastGCPauseMu sync.Mutex
|
||||
stopCh chan struct{}
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
func NewOpsMetricsCollector(opsService *OpsService, concurrencyService *ConcurrencyService) *OpsMetricsCollector {
|
||||
return &OpsMetricsCollector{
|
||||
opsService: opsService,
|
||||
concurrencyService: concurrencyService,
|
||||
interval: opsMetricsInterval,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) Start() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.startOnce.Do(func() {
|
||||
if c.stopCh == nil {
|
||||
c.stopCh = make(chan struct{})
|
||||
}
|
||||
go c.run()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) Stop() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.stopOnce.Do(func() {
|
||||
if c.stopCh != nil {
|
||||
close(c.stopCh)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) run() {
|
||||
ticker := time.NewTicker(c.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
c.collectOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.collectOnce()
|
||||
case <-c.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) collectOnce() {
|
||||
if c.opsService == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsMetricsCollectTimeout)
|
||||
defer cancel()
|
||||
|
||||
now := time.Now()
|
||||
systemStats := c.collectSystemStats(ctx)
|
||||
queueDepth := c.collectQueueDepth(ctx)
|
||||
activeAlerts := c.collectActiveAlerts(ctx)
|
||||
|
||||
for _, window := range []int{opsMetricsWindowShortMinutes, opsMetricsWindowLongMinutes} {
|
||||
startTime := now.Add(-time.Duration(window) * time.Minute)
|
||||
windowStats, err := c.opsService.GetWindowStats(ctx, startTime, now)
|
||||
if err != nil {
|
||||
log.Printf("[OpsMetrics] failed to get window stats (%dm): %v", window, err)
|
||||
continue
|
||||
}
|
||||
|
||||
successRate, errorRate := computeRates(windowStats.SuccessCount, windowStats.ErrorCount)
|
||||
requestCount := windowStats.SuccessCount + windowStats.ErrorCount
|
||||
metric := &OpsMetrics{
|
||||
WindowMinutes: window,
|
||||
RequestCount: requestCount,
|
||||
SuccessCount: windowStats.SuccessCount,
|
||||
ErrorCount: windowStats.ErrorCount,
|
||||
SuccessRate: successRate,
|
||||
ErrorRate: errorRate,
|
||||
P95LatencyMs: windowStats.P95LatencyMs,
|
||||
P99LatencyMs: windowStats.P99LatencyMs,
|
||||
HTTP2Errors: windowStats.HTTP2Errors,
|
||||
ActiveAlerts: activeAlerts,
|
||||
CPUUsagePercent: systemStats.cpuUsage,
|
||||
MemoryUsedMB: systemStats.memoryUsedMB,
|
||||
MemoryTotalMB: systemStats.memoryTotalMB,
|
||||
MemoryUsagePercent: systemStats.memoryUsagePercent,
|
||||
HeapAllocMB: systemStats.heapAllocMB,
|
||||
GCPauseMs: systemStats.gcPauseMs,
|
||||
ConcurrencyQueueDepth: queueDepth,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
if err := c.opsService.RecordMetrics(ctx, metric); err != nil {
|
||||
log.Printf("[OpsMetrics] failed to record metrics (%dm): %v", window, err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func computeRates(successCount, errorCount int64) (float64, float64) {
|
||||
total := successCount + errorCount
|
||||
if total == 0 {
|
||||
// No traffic => no data. Rates are kept at 0 and request_count will be 0.
|
||||
// The UI should render this as N/A instead of "100% success".
|
||||
return 0, 0
|
||||
}
|
||||
successRate := float64(successCount) / float64(total) * percentScale
|
||||
errorRate := float64(errorCount) / float64(total) * percentScale
|
||||
return successRate, errorRate
|
||||
}
|
||||
|
||||
type opsSystemStats struct {
|
||||
cpuUsage float64
|
||||
memoryUsedMB int64
|
||||
memoryTotalMB int64
|
||||
memoryUsagePercent float64
|
||||
heapAllocMB int64
|
||||
gcPauseMs float64
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) collectSystemStats(ctx context.Context) opsSystemStats {
|
||||
stats := opsSystemStats{}
|
||||
|
||||
if percents, err := cpu.PercentWithContext(ctx, cpuUsageSampleInterval, false); err == nil && len(percents) > 0 {
|
||||
stats.cpuUsage = percents[0]
|
||||
}
|
||||
|
||||
if vm, err := mem.VirtualMemoryWithContext(ctx); err == nil {
|
||||
stats.memoryUsedMB = int64(vm.Used / bytesPerMB)
|
||||
stats.memoryTotalMB = int64(vm.Total / bytesPerMB)
|
||||
stats.memoryUsagePercent = vm.UsedPercent
|
||||
}
|
||||
|
||||
var memStats runtime.MemStats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
stats.heapAllocMB = int64(memStats.HeapAlloc / bytesPerMB)
|
||||
c.lastGCPauseMu.Lock()
|
||||
if c.lastGCPauseTotal != 0 && memStats.PauseTotalNs >= c.lastGCPauseTotal {
|
||||
stats.gcPauseMs = float64(memStats.PauseTotalNs-c.lastGCPauseTotal) / float64(time.Millisecond)
|
||||
}
|
||||
c.lastGCPauseTotal = memStats.PauseTotalNs
|
||||
c.lastGCPauseMu.Unlock()
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) collectQueueDepth(ctx context.Context) int {
|
||||
if c.concurrencyService == nil {
|
||||
return 0
|
||||
}
|
||||
depth, err := c.concurrencyService.GetTotalWaitCount(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[OpsMetrics] failed to get queue depth: %v", err)
|
||||
return 0
|
||||
}
|
||||
return depth
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) collectActiveAlerts(ctx context.Context) int {
|
||||
if c.opsService == nil {
|
||||
return 0
|
||||
}
|
||||
count, err := c.opsService.CountActiveAlerts(ctx)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return count
|
||||
}
|
||||
1020
backend/internal/service/ops_service.go
Normal file
1020
backend/internal/service/ops_service.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeySiteName,
|
||||
SettingKeySiteLogo,
|
||||
SettingKeySiteSubtitle,
|
||||
SettingKeyApiBaseUrl,
|
||||
SettingKeyAPIBaseURL,
|
||||
SettingKeyContactInfo,
|
||||
SettingKeyDocUrl,
|
||||
SettingKeyDocURL,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocUrl: settings[SettingKeyDocUrl],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
|
||||
|
||||
// 邮件服务设置(只有非空才更新密码)
|
||||
updates[SettingKeySmtpHost] = settings.SmtpHost
|
||||
updates[SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
|
||||
updates[SettingKeySmtpUsername] = settings.SmtpUsername
|
||||
if settings.SmtpPassword != "" {
|
||||
updates[SettingKeySmtpPassword] = settings.SmtpPassword
|
||||
updates[SettingKeySMTPHost] = settings.SMTPHost
|
||||
updates[SettingKeySMTPPort] = strconv.Itoa(settings.SMTPPort)
|
||||
updates[SettingKeySMTPUsername] = settings.SMTPUsername
|
||||
if settings.SMTPPassword != "" {
|
||||
updates[SettingKeySMTPPassword] = settings.SMTPPassword
|
||||
}
|
||||
updates[SettingKeySmtpFrom] = settings.SmtpFrom
|
||||
updates[SettingKeySmtpFromName] = settings.SmtpFromName
|
||||
updates[SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
|
||||
updates[SettingKeySMTPFrom] = settings.SMTPFrom
|
||||
updates[SettingKeySMTPFromName] = settings.SMTPFromName
|
||||
updates[SettingKeySMTPUseTLS] = strconv.FormatBool(settings.SMTPUseTLS)
|
||||
|
||||
// Cloudflare Turnstile 设置(只有非空才更新密钥)
|
||||
updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
|
||||
@@ -115,9 +115,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeySiteName] = settings.SiteName
|
||||
updates[SettingKeySiteLogo] = settings.SiteLogo
|
||||
updates[SettingKeySiteSubtitle] = settings.SiteSubtitle
|
||||
updates[SettingKeyApiBaseUrl] = settings.ApiBaseUrl
|
||||
updates[SettingKeyAPIBaseURL] = settings.APIBaseURL
|
||||
updates[SettingKeyContactInfo] = settings.ContactInfo
|
||||
updates[SettingKeyDocUrl] = settings.DocUrl
|
||||
updates[SettingKeyDocURL] = settings.DocURL
|
||||
|
||||
// 默认配置
|
||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
@@ -198,8 +198,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeySiteLogo: "",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeySmtpPort: "587",
|
||||
SettingKeySmtpUseTLS: "false",
|
||||
SettingKeySMTPPort: "587",
|
||||
SettingKeySMTPUseTLS: "false",
|
||||
}
|
||||
|
||||
return s.settingRepo.SetMultiple(ctx, defaults)
|
||||
@@ -210,26 +210,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result := &SystemSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
SmtpHost: settings[SettingKeySmtpHost],
|
||||
SmtpUsername: settings[SettingKeySmtpUsername],
|
||||
SmtpFrom: settings[SettingKeySmtpFrom],
|
||||
SmtpFromName: settings[SettingKeySmtpFromName],
|
||||
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
|
||||
SMTPHost: settings[SettingKeySMTPHost],
|
||||
SMTPUsername: settings[SettingKeySMTPUsername],
|
||||
SMTPFrom: settings[SettingKeySMTPFrom],
|
||||
SMTPFromName: settings[SettingKeySMTPFromName],
|
||||
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocUrl: settings[SettingKeyDocUrl],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
}
|
||||
|
||||
// 解析整数类型
|
||||
if port, err := strconv.Atoi(settings[SettingKeySmtpPort]); err == nil {
|
||||
result.SmtpPort = port
|
||||
if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil {
|
||||
result.SMTPPort = port
|
||||
} else {
|
||||
result.SmtpPort = 587
|
||||
result.SMTPPort = 587
|
||||
}
|
||||
|
||||
if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil {
|
||||
@@ -245,8 +245,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.DefaultBalance = s.cfg.Default.UserBalance
|
||||
}
|
||||
|
||||
// 敏感信息直接返回,方便测试连接时使用
|
||||
result.SmtpPassword = settings[SettingKeySmtpPassword]
|
||||
// 敏感信息直接返回,方便测试连接时使用
|
||||
result.SMTPPassword = settings[SettingKeySMTPPassword]
|
||||
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
|
||||
|
||||
return result
|
||||
@@ -278,28 +278,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
|
||||
return value
|
||||
}
|
||||
|
||||
// GenerateAdminApiKey 生成新的管理员 API Key
|
||||
func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error) {
|
||||
// GenerateAdminAPIKey 生成新的管理员 API Key
|
||||
func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) {
|
||||
// 生成 32 字节随机数 = 64 位十六进制字符
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
key := AdminApiKeyPrefix + hex.EncodeToString(bytes)
|
||||
key := AdminAPIKeyPrefix + hex.EncodeToString(bytes)
|
||||
|
||||
// 存储到 settings 表
|
||||
if err := s.settingRepo.Set(ctx, SettingKeyAdminApiKey, key); err != nil {
|
||||
if err := s.settingRepo.Set(ctx, SettingKeyAdminAPIKey, key); err != nil {
|
||||
return "", fmt.Errorf("save admin api key: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GetAdminApiKeyStatus 获取管理员 API Key 状态
|
||||
// GetAdminAPIKeyStatus 获取管理员 API Key 状态
|
||||
// 返回脱敏的 key、是否存在、错误
|
||||
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
|
||||
func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return "", false, nil
|
||||
@@ -320,10 +320,10 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
|
||||
return maskedKey, true, nil
|
||||
}
|
||||
|
||||
// GetAdminApiKey 获取完整的管理员 API Key(仅供内部验证使用)
|
||||
// GetAdminAPIKey 获取完整的管理员 API Key(仅供内部验证使用)
|
||||
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
|
||||
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
|
||||
func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return "", nil // 未配置,返回空字符串
|
||||
@@ -333,7 +333,7 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// DeleteAdminApiKey 删除管理员 API Key
|
||||
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
|
||||
return s.settingRepo.Delete(ctx, SettingKeyAdminApiKey)
|
||||
// DeleteAdminAPIKey 删除管理员 API Key
|
||||
func (s *SettingService) DeleteAdminAPIKey(ctx context.Context) error {
|
||||
return s.settingRepo.Delete(ctx, SettingKeyAdminAPIKey)
|
||||
}
|
||||
|
||||
@@ -4,13 +4,13 @@ type SystemSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
|
||||
SmtpHost string
|
||||
SmtpPort int
|
||||
SmtpUsername string
|
||||
SmtpPassword string
|
||||
SmtpFrom string
|
||||
SmtpFromName string
|
||||
SmtpUseTLS bool
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
SMTPUsername string
|
||||
SMTPPassword string
|
||||
SMTPFrom string
|
||||
SMTPFromName string
|
||||
SMTPUseTLS bool
|
||||
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
@@ -19,9 +19,9 @@ type SystemSettings struct {
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
ApiBaseUrl string
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocUrl string
|
||||
DocURL string
|
||||
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
@@ -35,8 +35,8 @@ type PublicSettings struct {
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
ApiBaseUrl string
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocUrl string
|
||||
DocURL string
|
||||
Version string
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user