Files
sub2api/backend/internal/handler/admin/dashboard_handler.go
yangjianbo 2588fa6a8f fix(audit): 第二批审计修复 — P0 生产 Bug、安全加固、性能优化、缓存一致性、代码质量
基于 backend-code-audit 审计报告,修复剩余 P0/P1/P2 共 34 项问题:

P0 生产 Bug:
- 修复 time.Since(time.Now()) 计时逻辑错误 (P0-03)
- generateRandomID 改用 crypto/rand 替代固定索引 (P0-04)
- IncrementQuotaUsed 重写为 Ent 原子操作消除 TOCTOU 竞态 (P0-05)

安全加固:
- gateway/openai handler 错误响应替换为泛化消息,防止内部信息泄露 (P1-14)
- usage_log_repo dateFormat 参数改用白名单映射,防止 SQL 注入 (P1-16)
- 默认配置安全加固:sslmode=prefer、response_headers=true、mode=release (P1-18/19, P2-15)

性能优化:
- gateway handler 循环内 defer 替换为显式 releaseWait 闭包 (P1-02)
- group_repo/promo_code_repo Count 前 Clone 查询避免状态污染 (P1-03)
- usage_log_repo 四个查询添加 LIMIT 10000 防止 OOM (P1-07)
- GetBatchUsageStats 添加时间范围参数,默认最近 30 天 (P1-10)
- ip.go CIDR 预编译为包级变量 (P1-11)
- BatchUpdateCredentials 重构为先验证后更新 (P1-13)

缓存一致性:
- billing_cache 添加 jitteredTTL 防止缓存雪崩 (P2-10)
- DeductUserBalance/UpdateSubscriptionUsage 错误传播修复 (P2-12)
- UserService.UpdateBalance 成功后异步失效 billingCache (P2-13)

代码质量:
- search 截断改为按 rune 处理,支持多字节字符 (P2-01)
- TLS Handshake 改为 HandshakeContext 支持 context 取消 (P2-07)
- CORS 预检添加 Access-Control-Max-Age: 86400 (P2-16)

测试覆盖:
- 新增 user_service_test.go(UpdateBalance 缓存失效 6 个用例)
- 新增 batch_update_credentials_test.go(fail-fast + 类型验证 7 个用例)
- 新增 response_transformer_test.go、ip_test.go、usage_log_repo_unit_test.go、search_truncate_test.go
- 集成测试:IncrementQuotaUsed 并发测试、billing_cache 错误传播测试
- config_test.go 补充 server.mode/sslmode 默认值断言

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 19:46:42 +08:00

418 lines
13 KiB
Go

package admin
import (
"errors"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// DashboardHandler handles admin dashboard statistics
type DashboardHandler struct {
dashboardService *service.DashboardService
aggregationService *service.DashboardAggregationService
startTime time.Time // Server start time for uptime calculation
}
// NewDashboardHandler creates a new admin dashboard handler
func NewDashboardHandler(dashboardService *service.DashboardService, aggregationService *service.DashboardAggregationService) *DashboardHandler {
return &DashboardHandler{
dashboardService: dashboardService,
aggregationService: aggregationService,
startTime: time.Now(),
}
}
// parseTimeRange parses start_date, end_date query parameters
// Uses user's timezone if provided, otherwise falls back to server timezone
func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
userTZ := c.Query("timezone") // Get user's timezone from request
now := timezone.NowInUserLocation(userTZ)
startDate := c.Query("start_date")
endDate := c.Query("end_date")
var startTime, endTime time.Time
if startDate != "" {
if t, err := timezone.ParseInUserLocation("2006-01-02", startDate, userTZ); err == nil {
startTime = t
} else {
startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
}
} else {
startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
}
if endDate != "" {
if t, err := timezone.ParseInUserLocation("2006-01-02", endDate, userTZ); err == nil {
endTime = t.Add(24 * time.Hour) // Include the end date
} else {
endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
}
} else {
endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
}
return startTime, endTime
}
// GetStats handles getting dashboard statistics
// GET /api/v1/admin/dashboard/stats
func (h *DashboardHandler) GetStats(c *gin.Context) {
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
if err != nil {
response.Error(c, 500, "Failed to get dashboard statistics")
return
}
// Calculate uptime in seconds
uptime := int64(time.Since(h.startTime).Seconds())
response.Success(c, gin.H{
// 用户统计
"total_users": stats.TotalUsers,
"today_new_users": stats.TodayNewUsers,
"active_users": stats.ActiveUsers,
// API Key 统计
"total_api_keys": stats.TotalAPIKeys,
"active_api_keys": stats.ActiveAPIKeys,
// 账户统计
"total_accounts": stats.TotalAccounts,
"normal_accounts": stats.NormalAccounts,
"error_accounts": stats.ErrorAccounts,
"ratelimit_accounts": stats.RateLimitAccounts,
"overload_accounts": stats.OverloadAccounts,
// 累计 Token 使用统计
"total_requests": stats.TotalRequests,
"total_input_tokens": stats.TotalInputTokens,
"total_output_tokens": stats.TotalOutputTokens,
"total_cache_creation_tokens": stats.TotalCacheCreationTokens,
"total_cache_read_tokens": stats.TotalCacheReadTokens,
"total_tokens": stats.TotalTokens,
"total_cost": stats.TotalCost, // 标准计费
"total_actual_cost": stats.TotalActualCost, // 实际扣除
// 今日 Token 使用统计
"today_requests": stats.TodayRequests,
"today_input_tokens": stats.TodayInputTokens,
"today_output_tokens": stats.TodayOutputTokens,
"today_cache_creation_tokens": stats.TodayCacheCreationTokens,
"today_cache_read_tokens": stats.TodayCacheReadTokens,
"today_tokens": stats.TodayTokens,
"today_cost": stats.TodayCost, // 今日标准计费
"today_actual_cost": stats.TodayActualCost, // 今日实际扣除
// 系统运行统计
"average_duration_ms": stats.AverageDurationMs,
"uptime": uptime,
// 性能指标
"rpm": stats.Rpm,
"tpm": stats.Tpm,
// 预聚合新鲜度
"hourly_active_users": stats.HourlyActiveUsers,
"stats_updated_at": stats.StatsUpdatedAt,
"stats_stale": stats.StatsStale,
})
}
type DashboardAggregationBackfillRequest struct {
Start string `json:"start"`
End string `json:"end"`
}
// BackfillAggregation handles triggering aggregation backfill
// POST /api/v1/admin/dashboard/aggregation/backfill
func (h *DashboardHandler) BackfillAggregation(c *gin.Context) {
if h.aggregationService == nil {
response.InternalError(c, "Aggregation service not available")
return
}
var req DashboardAggregationBackfillRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
start, err := time.Parse(time.RFC3339, req.Start)
if err != nil {
response.BadRequest(c, "Invalid start time")
return
}
end, err := time.Parse(time.RFC3339, req.End)
if err != nil {
response.BadRequest(c, "Invalid end time")
return
}
if err := h.aggregationService.TriggerBackfill(start, end); err != nil {
if errors.Is(err, service.ErrDashboardBackfillDisabled) {
response.Forbidden(c, "Backfill is disabled")
return
}
if errors.Is(err, service.ErrDashboardBackfillTooLarge) {
response.BadRequest(c, "Backfill range too large")
return
}
response.InternalError(c, "Failed to trigger backfill")
return
}
response.Success(c, gin.H{
"status": "accepted",
})
}
// GetRealtimeMetrics handles getting real-time system metrics
// GET /api/v1/admin/dashboard/realtime
func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// Return mock data for now
response.Success(c, gin.H{
"active_requests": 0,
"requests_per_minute": 0,
"average_response_time": 0,
"error_rate": 0.0,
})
}
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var model string
var stream *bool
var billingType *int8
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
}
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
apiKeyID = id
}
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
accountID = id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = id
}
}
if modelStr := c.Query("model"); modelStr != "" {
model = modelStr
}
if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
bt := int8(v)
billingType = &bt
} else {
response.BadRequest(c, "Invalid billing_type")
return
}
}
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
}
response.Success(c, gin.H{
"trend": trend,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
"granularity": granularity,
})
}
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var stream *bool
var billingType *int8
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
}
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
apiKeyID = id
}
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
accountID = id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = id
}
}
if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
bt := int8(v)
billingType = &bt
} else {
response.BadRequest(c, "Invalid billing_type")
return
}
}
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
}
response.Success(c, gin.H{
"models": stats,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
})
}
// GetAPIKeyUsageTrend handles getting API key usage trend data
// GET /api/v1/admin/dashboard/api-keys-trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
limitStr := c.DefaultQuery("limit", "5")
limit, err := strconv.Atoi(limitStr)
if err != nil || limit <= 0 {
limit = 5
}
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get API key usage trend")
return
}
response.Success(c, gin.H{
"trend": trend,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
"granularity": granularity,
})
}
// GetUserUsageTrend handles getting user usage trend data
// GET /api/v1/admin/dashboard/users-trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 12)
func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
limitStr := c.DefaultQuery("limit", "12")
limit, err := strconv.Atoi(limitStr)
if err != nil || limit <= 0 {
limit = 12
}
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get user usage trend")
return
}
response.Success(c, gin.H{
"trend": trend,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
"granularity": granularity,
})
}
// BatchUsersUsageRequest represents the request body for batch user usage stats
type BatchUsersUsageRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"`
}
// GetBatchUsersUsage handles getting usage stats for multiple users
// POST /api/v1/admin/dashboard/users-usage
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
var req BatchUsersUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.UserIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get user usage stats")
return
}
response.Success(c, gin.H{"stats": stats})
}
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
type BatchAPIKeysUsageRequest struct {
APIKeyIDs []int64 `json:"api_key_ids" binding:"required"`
}
// GetBatchAPIKeysUsage handles getting usage stats for multiple API keys
// POST /api/v1/admin/dashboard/api-keys-usage
func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
var req BatchAPIKeysUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.APIKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return
}
response.Success(c, gin.H{"stats": stats})
}