package handler import ( "strconv" "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) // UsageHandler handles usage-related requests type UsageHandler struct { usageService *service.UsageService apiKeyService *service.ApiKeyService } // NewUsageHandler creates a new UsageHandler func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.ApiKeyService) *UsageHandler { return &UsageHandler{ usageService: usageService, apiKeyService: apiKeyService, } } // List handles listing usage records with pagination // GET /api/v1/usage func (h *UsageHandler) List(c *gin.Context) { subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { response.Unauthorized(c, "User not authenticated") return } page, pageSize := response.ParsePagination(c) var apiKeyID int64 if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" { id, err := strconv.ParseInt(apiKeyIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid api_key_id") return } // [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id) if err != nil { response.ErrorFrom(c, err) return } if apiKey.UserID != subject.UserID { response.Forbidden(c, "Not authorized to access this API key's usage records") return } apiKeyID = id } // Parse additional filters model := c.Query("model") var stream *bool if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") return } stream = &val } var billingType *int8 if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { val, err := strconv.ParseInt(billingTypeStr, 10, 8) if err != nil { response.BadRequest(c, "Invalid billing_type") return } bt := int8(val) billingType = &bt } // Parse date range var startTime, endTime *time.Time if startDateStr := c.Query("start_date"); startDateStr != "" { t, err := timezone.ParseInLocation("2006-01-02", startDateStr) if err != nil { response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") return } startTime = &t } if endDateStr := c.Query("end_date"); endDateStr != "" { t, err := timezone.ParseInLocation("2006-01-02", endDateStr) if err != nil { response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") return } // Set end time to end of day t = t.Add(24*time.Hour - time.Nanosecond) endTime = &t } params := pagination.PaginationParams{Page: page, PageSize: pageSize} filters := usagestats.UsageLogFilters{ UserID: subject.UserID, // Always filter by current user for security ApiKeyID: apiKeyID, Model: model, Stream: stream, BillingType: billingType, StartTime: startTime, EndTime: endTime, } records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters) if err != nil { response.ErrorFrom(c, err) return } out := make([]dto.UsageLog, 0, len(records)) for i := range records { out = append(out, *dto.UsageLogFromService(&records[i])) } response.Paginated(c, out, result.Total, page, pageSize) } // GetByID handles getting a single usage record // GET /api/v1/usage/:id func (h *UsageHandler) GetByID(c *gin.Context) { subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { response.Unauthorized(c, "User not authenticated") return } usageID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { response.BadRequest(c, "Invalid usage ID") return } record, err := h.usageService.GetByID(c.Request.Context(), usageID) if err != nil { response.ErrorFrom(c, err) return } // 验证所有权 if record.UserID != subject.UserID { response.Forbidden(c, "Not authorized to access this record") return } response.Success(c, dto.UsageLogFromService(record)) } // Stats handles getting usage statistics // GET /api/v1/usage/stats func (h *UsageHandler) Stats(c *gin.Context) { subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { response.Unauthorized(c, "User not authenticated") return } var apiKeyID int64 if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" { id, err := strconv.ParseInt(apiKeyIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid api_key_id") return } // [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id) if err != nil { response.NotFound(c, "API key not found") return } if apiKey.UserID != subject.UserID { response.Forbidden(c, "Not authorized to access this API key's statistics") return } apiKeyID = id } // 获取时间范围参数 now := timezone.Now() var startTime, endTime time.Time // 优先使用 start_date 和 end_date 参数 startDateStr := c.Query("start_date") endDateStr := c.Query("end_date") if startDateStr != "" && endDateStr != "" { // 使用自定义日期范围 var err error startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr) if err != nil { response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") return } endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr) if err != nil { response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") return } // 设置结束时间为当天结束 endTime = endTime.Add(24*time.Hour - time.Nanosecond) } else { // 使用 period 参数 period := c.DefaultQuery("period", "today") switch period { case "today": startTime = timezone.StartOfDay(now) case "week": startTime = now.AddDate(0, 0, -7) case "month": startTime = now.AddDate(0, -1, 0) default: startTime = timezone.StartOfDay(now) } endTime = now } var stats *service.UsageStats var err error if apiKeyID > 0 { stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime) } else { stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime) } if err != nil { response.ErrorFrom(c, err) return } response.Success(c, stats) } // parseUserTimeRange parses start_date, end_date query parameters for user dashboard func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) { now := timezone.Now() startDate := c.Query("start_date") endDate := c.Query("end_date") var startTime, endTime time.Time if startDate != "" { if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil { startTime = t } else { startTime = timezone.StartOfDay(now.AddDate(0, 0, -7)) } } else { startTime = timezone.StartOfDay(now.AddDate(0, 0, -7)) } if endDate != "" { if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil { endTime = t.Add(24 * time.Hour) // Include the end date } else { endTime = timezone.StartOfDay(now.AddDate(0, 0, 1)) } } else { endTime = timezone.StartOfDay(now.AddDate(0, 0, 1)) } return startTime, endTime } // DashboardStats handles getting user dashboard statistics // GET /api/v1/usage/dashboard/stats func (h *UsageHandler) DashboardStats(c *gin.Context) { subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { response.Unauthorized(c, "User not authenticated") return } stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID) if err != nil { response.ErrorFrom(c, err) return } response.Success(c, stats) } // DashboardTrend handles getting user usage trend data // GET /api/v1/usage/dashboard/trend func (h *UsageHandler) DashboardTrend(c *gin.Context) { subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { response.Unauthorized(c, "User not authenticated") return } startTime, endTime := parseUserTimeRange(c) granularity := c.DefaultQuery("granularity", "day") trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), subject.UserID, startTime, endTime, granularity) if err != nil { response.ErrorFrom(c, err) return } response.Success(c, gin.H{ "trend": trend, "start_date": startTime.Format("2006-01-02"), "end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"), "granularity": granularity, }) } // DashboardModels handles getting user model usage statistics // GET /api/v1/usage/dashboard/models func (h *UsageHandler) DashboardModels(c *gin.Context) { subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { response.Unauthorized(c, "User not authenticated") return } startTime, endTime := parseUserTimeRange(c) stats, err := h.usageService.GetUserModelStats(c.Request.Context(), subject.UserID, startTime, endTime) if err != nil { response.ErrorFrom(c, err) return } response.Success(c, gin.H{ "models": stats, "start_date": startTime.Format("2006-01-02"), "end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"), }) } // BatchApiKeysUsageRequest represents the request for batch API keys usage type BatchApiKeysUsageRequest struct { ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"` } // DashboardApiKeysUsage handles getting usage stats for user's own API keys // POST /api/v1/usage/dashboard/api-keys-usage func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) { subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { response.Unauthorized(c, "User not authenticated") return } var req BatchApiKeysUsageRequest if err := c.ShouldBindJSON(&req); err != nil { response.BadRequest(c, "Invalid request: "+err.Error()) return } if len(req.ApiKeyIDs) == 0 { response.Success(c, gin.H{"stats": map[string]any{}}) return } // Verify ownership of all requested API keys userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, pagination.PaginationParams{Page: 1, PageSize: 1000}) if err != nil { response.ErrorFrom(c, err) return } userApiKeyIDs := make(map[int64]bool) for _, key := range userApiKeys { userApiKeyIDs[key.ID] = true } // Filter to only include user's own API keys validApiKeyIDs := make([]int64, 0) for _, id := range req.ApiKeyIDs { if userApiKeyIDs[id] { validApiKeyIDs = append(validApiKeyIDs, id) } } if len(validApiKeyIDs) == 0 { response.Success(c, gin.H{"stats": map[string]any{}}) return } stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs) if err != nil { response.ErrorFrom(c, err) return } response.Success(c, gin.H{"stats": stats}) }