package admin import ( "log" "net/http" "strconv" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) // UsageHandler handles admin usage-related requests type UsageHandler struct { usageService *service.UsageService apiKeyService *service.APIKeyService adminService service.AdminService cleanupService *service.UsageCleanupService } // NewUsageHandler creates a new admin usage handler func NewUsageHandler( usageService *service.UsageService, apiKeyService *service.APIKeyService, adminService service.AdminService, cleanupService *service.UsageCleanupService, ) *UsageHandler { return &UsageHandler{ usageService: usageService, apiKeyService: apiKeyService, adminService: adminService, cleanupService: cleanupService, } } // CreateUsageCleanupTaskRequest represents cleanup task creation request type CreateUsageCleanupTaskRequest struct { StartDate string `json:"start_date"` EndDate string `json:"end_date"` UserID *int64 `json:"user_id"` APIKeyID *int64 `json:"api_key_id"` AccountID *int64 `json:"account_id"` GroupID *int64 `json:"group_id"` Model *string `json:"model"` Stream *bool `json:"stream"` BillingType *int8 `json:"billing_type"` Timezone string `json:"timezone"` } // List handles listing all usage records with filters // GET /api/v1/admin/usage func (h *UsageHandler) List(c *gin.Context) { page, pageSize := response.ParsePagination(c) // Parse filters var userID, apiKeyID, accountID, groupID int64 if userIDStr := c.Query("user_id"); userIDStr != "" { id, err := strconv.ParseInt(userIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid user_id") return } userID = id } if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" { id, err := strconv.ParseInt(apiKeyIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid api_key_id") return } apiKeyID = id } if accountIDStr := c.Query("account_id"); accountIDStr != "" { id, err := strconv.ParseInt(accountIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid account_id") return } accountID = id } if groupIDStr := c.Query("group_id"); groupIDStr != "" { id, err := strconv.ParseInt(groupIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid group_id") return } groupID = id } model := c.Query("model") var stream *bool if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") return } stream = &val } var billingType *int8 if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { val, err := strconv.ParseInt(billingTypeStr, 10, 8) if err != nil { response.BadRequest(c, "Invalid billing_type") return } bt := int8(val) billingType = &bt } // Parse date range var startTime, endTime *time.Time userTZ := c.Query("timezone") // Get user's timezone from request if startDateStr := c.Query("start_date"); startDateStr != "" { t, err := timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ) if err != nil { response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") return } startTime = &t } if endDateStr := c.Query("end_date"); endDateStr != "" { t, err := timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ) if err != nil { response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") return } // Set end time to end of day t = t.Add(24*time.Hour - time.Nanosecond) endTime = &t } params := pagination.PaginationParams{Page: page, PageSize: pageSize} filters := usagestats.UsageLogFilters{ UserID: userID, APIKeyID: apiKeyID, AccountID: accountID, GroupID: groupID, Model: model, Stream: stream, BillingType: billingType, StartTime: startTime, EndTime: endTime, } records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters) if err != nil { response.ErrorFrom(c, err) return } out := make([]dto.UsageLog, 0, len(records)) for i := range records { out = append(out, *dto.UsageLogFromServiceAdmin(&records[i])) } response.Paginated(c, out, result.Total, page, pageSize) } // Stats handles getting usage statistics with filters // GET /api/v1/admin/usage/stats func (h *UsageHandler) Stats(c *gin.Context) { // Parse filters - same as List endpoint var userID, apiKeyID, accountID, groupID int64 if userIDStr := c.Query("user_id"); userIDStr != "" { id, err := strconv.ParseInt(userIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid user_id") return } userID = id } if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" { id, err := strconv.ParseInt(apiKeyIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid api_key_id") return } apiKeyID = id } if accountIDStr := c.Query("account_id"); accountIDStr != "" { id, err := strconv.ParseInt(accountIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid account_id") return } accountID = id } if groupIDStr := c.Query("group_id"); groupIDStr != "" { id, err := strconv.ParseInt(groupIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid group_id") return } groupID = id } model := c.Query("model") var stream *bool if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") return } stream = &val } var billingType *int8 if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { val, err := strconv.ParseInt(billingTypeStr, 10, 8) if err != nil { response.BadRequest(c, "Invalid billing_type") return } bt := int8(val) billingType = &bt } // Parse date range userTZ := c.Query("timezone") now := timezone.NowInUserLocation(userTZ) var startTime, endTime time.Time startDateStr := c.Query("start_date") endDateStr := c.Query("end_date") if startDateStr != "" && endDateStr != "" { var err error startTime, err = timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ) if err != nil { response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") return } endTime, err = timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ) if err != nil { response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") return } endTime = endTime.Add(24*time.Hour - time.Nanosecond) } else { period := c.DefaultQuery("period", "today") switch period { case "today": startTime = timezone.StartOfDayInUserLocation(now, userTZ) case "week": startTime = now.AddDate(0, 0, -7) case "month": startTime = now.AddDate(0, -1, 0) default: startTime = timezone.StartOfDayInUserLocation(now, userTZ) } endTime = now } // Build filters and call GetStatsWithFilters filters := usagestats.UsageLogFilters{ UserID: userID, APIKeyID: apiKeyID, AccountID: accountID, GroupID: groupID, Model: model, Stream: stream, BillingType: billingType, StartTime: &startTime, EndTime: &endTime, } stats, err := h.usageService.GetStatsWithFilters(c.Request.Context(), filters) if err != nil { response.ErrorFrom(c, err) return } response.Success(c, stats) } // SearchUsers handles searching users by email keyword // GET /api/v1/admin/usage/search-users func (h *UsageHandler) SearchUsers(c *gin.Context) { keyword := c.Query("q") if keyword == "" { response.Success(c, []any{}) return } // Limit to 30 results users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword}) if err != nil { response.ErrorFrom(c, err) return } // Return simplified user list (only id and email) type SimpleUser struct { ID int64 `json:"id"` Email string `json:"email"` } result := make([]SimpleUser, len(users)) for i, u := range users { result[i] = SimpleUser{ ID: u.ID, Email: u.Email, } } response.Success(c, result) } // SearchAPIKeys handles searching API keys by user // GET /api/v1/admin/usage/search-api-keys func (h *UsageHandler) SearchAPIKeys(c *gin.Context) { userIDStr := c.Query("user_id") keyword := c.Query("q") var userID int64 if userIDStr != "" { id, err := strconv.ParseInt(userIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid user_id") return } userID = id } keys, err := h.apiKeyService.SearchAPIKeys(c.Request.Context(), userID, keyword, 30) if err != nil { response.ErrorFrom(c, err) return } // Return simplified API key list (only id and name) type SimpleAPIKey struct { ID int64 `json:"id"` Name string `json:"name"` UserID int64 `json:"user_id"` } result := make([]SimpleAPIKey, len(keys)) for i, k := range keys { result[i] = SimpleAPIKey{ ID: k.ID, Name: k.Name, UserID: k.UserID, } } response.Success(c, result) } // ListCleanupTasks handles listing usage cleanup tasks // GET /api/v1/admin/usage/cleanup-tasks func (h *UsageHandler) ListCleanupTasks(c *gin.Context) { if h.cleanupService == nil { response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable") return } operator := int64(0) if subject, ok := middleware.GetAuthSubjectFromContext(c); ok { operator = subject.UserID } page, pageSize := response.ParsePagination(c) log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize) params := pagination.PaginationParams{Page: page, PageSize: pageSize} tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params) if err != nil { log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err) response.ErrorFrom(c, err) return } out := make([]dto.UsageCleanupTask, 0, len(tasks)) for i := range tasks { out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i])) } log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize) response.Paginated(c, out, result.Total, page, pageSize) } // CreateCleanupTask handles creating a usage cleanup task // POST /api/v1/admin/usage/cleanup-tasks func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { if h.cleanupService == nil { response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable") return } subject, ok := middleware.GetAuthSubjectFromContext(c) if !ok || subject.UserID <= 0 { response.Unauthorized(c, "Unauthorized") return } var req CreateUsageCleanupTaskRequest if err := c.ShouldBindJSON(&req); err != nil { response.BadRequest(c, "Invalid request: "+err.Error()) return } req.StartDate = strings.TrimSpace(req.StartDate) req.EndDate = strings.TrimSpace(req.EndDate) if req.StartDate == "" || req.EndDate == "" { response.BadRequest(c, "start_date and end_date are required") return } startTime, err := timezone.ParseInUserLocation("2006-01-02", req.StartDate, req.Timezone) if err != nil { response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") return } endTime, err := timezone.ParseInUserLocation("2006-01-02", req.EndDate, req.Timezone) if err != nil { response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") return } endTime = endTime.Add(24*time.Hour - time.Nanosecond) filters := service.UsageCleanupFilters{ StartTime: startTime, EndTime: endTime, UserID: req.UserID, APIKeyID: req.APIKeyID, AccountID: req.AccountID, GroupID: req.GroupID, Model: req.Model, Stream: req.Stream, BillingType: req.BillingType, } var userID any if filters.UserID != nil { userID = *filters.UserID } var apiKeyID any if filters.APIKeyID != nil { apiKeyID = *filters.APIKeyID } var accountID any if filters.AccountID != nil { accountID = *filters.AccountID } var groupID any if filters.GroupID != nil { groupID = *filters.GroupID } var model any if filters.Model != nil { model = *filters.Model } var stream any if filters.Stream != nil { stream = *filters.Stream } var billingType any if filters.BillingType != nil { billingType = *filters.BillingType } log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q", subject.UserID, filters.StartTime.Format(time.RFC3339), filters.EndTime.Format(time.RFC3339), userID, apiKeyID, accountID, groupID, model, stream, billingType, req.Timezone, ) task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID) if err != nil { log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err) response.ErrorFrom(c, err) return } log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status) response.Success(c, dto.UsageCleanupTaskFromService(task)) } // CancelCleanupTask handles canceling a usage cleanup task // POST /api/v1/admin/usage/cleanup-tasks/:id/cancel func (h *UsageHandler) CancelCleanupTask(c *gin.Context) { if h.cleanupService == nil { response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable") return } subject, ok := middleware.GetAuthSubjectFromContext(c) if !ok || subject.UserID <= 0 { response.Unauthorized(c, "Unauthorized") return } idStr := strings.TrimSpace(c.Param("id")) taskID, err := strconv.ParseInt(idStr, 10, 64) if err != nil || taskID <= 0 { response.BadRequest(c, "Invalid task id") return } log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID) if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil { log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err) response.ErrorFrom(c, err) return } log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID) response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled}) }