package admin import ( "strconv" "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) // UsageHandler handles admin usage-related requests type UsageHandler struct { usageService *service.UsageService apiKeyService *service.APIKeyService adminService service.AdminService } // NewUsageHandler creates a new admin usage handler func NewUsageHandler( usageService *service.UsageService, apiKeyService *service.APIKeyService, adminService service.AdminService, ) *UsageHandler { return &UsageHandler{ usageService: usageService, apiKeyService: apiKeyService, adminService: adminService, } } // List handles listing all usage records with filters // GET /api/v1/admin/usage func (h *UsageHandler) List(c *gin.Context) { page, pageSize := response.ParsePagination(c) // Parse filters var userID, apiKeyID, accountID, groupID int64 if userIDStr := c.Query("user_id"); userIDStr != "" { id, err := strconv.ParseInt(userIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid user_id") return } userID = id } if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" { id, err := strconv.ParseInt(apiKeyIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid api_key_id") return } apiKeyID = id } if accountIDStr := c.Query("account_id"); accountIDStr != "" { id, err := strconv.ParseInt(accountIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid account_id") return } accountID = id } if groupIDStr := c.Query("group_id"); groupIDStr != "" { id, err := strconv.ParseInt(groupIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid group_id") return } groupID = id } model := c.Query("model") var stream *bool if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") return } stream = &val } var billingType *int8 if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { val, err := strconv.ParseInt(billingTypeStr, 10, 8) if err != nil { response.BadRequest(c, "Invalid billing_type") return } bt := int8(val) billingType = &bt } // Parse date range var startTime, endTime *time.Time userTZ := c.Query("timezone") // Get user's timezone from request if startDateStr := c.Query("start_date"); startDateStr != "" { t, err := timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ) if err != nil { response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") return } startTime = &t } if endDateStr := c.Query("end_date"); endDateStr != "" { t, err := timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ) if err != nil { response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") return } // Set end time to end of day t = t.Add(24*time.Hour - time.Nanosecond) endTime = &t } params := pagination.PaginationParams{Page: page, PageSize: pageSize} filters := usagestats.UsageLogFilters{ UserID: userID, APIKeyID: apiKeyID, AccountID: accountID, GroupID: groupID, Model: model, Stream: stream, BillingType: billingType, StartTime: startTime, EndTime: endTime, } records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters) if err != nil { response.ErrorFrom(c, err) return } out := make([]dto.UsageLog, 0, len(records)) for i := range records { out = append(out, *dto.UsageLogFromService(&records[i])) } response.Paginated(c, out, result.Total, page, pageSize) } // Stats handles getting usage statistics with filters // GET /api/v1/admin/usage/stats func (h *UsageHandler) Stats(c *gin.Context) { // Parse filters - same as List endpoint var userID, apiKeyID, accountID, groupID int64 if userIDStr := c.Query("user_id"); userIDStr != "" { id, err := strconv.ParseInt(userIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid user_id") return } userID = id } if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" { id, err := strconv.ParseInt(apiKeyIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid api_key_id") return } apiKeyID = id } if accountIDStr := c.Query("account_id"); accountIDStr != "" { id, err := strconv.ParseInt(accountIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid account_id") return } accountID = id } if groupIDStr := c.Query("group_id"); groupIDStr != "" { id, err := strconv.ParseInt(groupIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid group_id") return } groupID = id } model := c.Query("model") var stream *bool if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") return } stream = &val } var billingType *int8 if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { val, err := strconv.ParseInt(billingTypeStr, 10, 8) if err != nil { response.BadRequest(c, "Invalid billing_type") return } bt := int8(val) billingType = &bt } // Parse date range userTZ := c.Query("timezone") now := timezone.NowInUserLocation(userTZ) var startTime, endTime time.Time startDateStr := c.Query("start_date") endDateStr := c.Query("end_date") if startDateStr != "" && endDateStr != "" { var err error startTime, err = timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ) if err != nil { response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") return } endTime, err = timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ) if err != nil { response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") return } endTime = endTime.Add(24*time.Hour - time.Nanosecond) } else { period := c.DefaultQuery("period", "today") switch period { case "today": startTime = timezone.StartOfDayInUserLocation(now, userTZ) case "week": startTime = now.AddDate(0, 0, -7) case "month": startTime = now.AddDate(0, -1, 0) default: startTime = timezone.StartOfDayInUserLocation(now, userTZ) } endTime = now } // Build filters and call GetStatsWithFilters filters := usagestats.UsageLogFilters{ UserID: userID, APIKeyID: apiKeyID, AccountID: accountID, GroupID: groupID, Model: model, Stream: stream, BillingType: billingType, StartTime: &startTime, EndTime: &endTime, } stats, err := h.usageService.GetStatsWithFilters(c.Request.Context(), filters) if err != nil { response.ErrorFrom(c, err) return } response.Success(c, stats) } // SearchUsers handles searching users by email keyword // GET /api/v1/admin/usage/search-users func (h *UsageHandler) SearchUsers(c *gin.Context) { keyword := c.Query("q") if keyword == "" { response.Success(c, []any{}) return } // Limit to 30 results users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword}) if err != nil { response.ErrorFrom(c, err) return } // Return simplified user list (only id and email) type SimpleUser struct { ID int64 `json:"id"` Email string `json:"email"` } result := make([]SimpleUser, len(users)) for i, u := range users { result[i] = SimpleUser{ ID: u.ID, Email: u.Email, } } response.Success(c, result) } // SearchAPIKeys handles searching API keys by user // GET /api/v1/admin/usage/search-api-keys func (h *UsageHandler) SearchAPIKeys(c *gin.Context) { userIDStr := c.Query("user_id") keyword := c.Query("q") var userID int64 if userIDStr != "" { id, err := strconv.ParseInt(userIDStr, 10, 64) if err != nil { response.BadRequest(c, "Invalid user_id") return } userID = id } keys, err := h.apiKeyService.SearchAPIKeys(c.Request.Context(), userID, keyword, 30) if err != nil { response.ErrorFrom(c, err) return } // Return simplified API key list (only id and name) type SimpleAPIKey struct { ID int64 `json:"id"` Name string `json:"name"` UserID int64 `json:"user_id"` } result := make([]SimpleAPIKey, len(keys)) for i, k := range keys { result[i] = SimpleAPIKey{ ID: k.ID, Name: k.Name, UserID: k.UserID, } } response.Success(c, result) }