527 lines
14 KiB
Go
527 lines
14 KiB
Go
package admin
|
|
|
|
import (
|
|
"log"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// UsageHandler handles admin usage-related requests
|
|
type UsageHandler struct {
|
|
usageService *service.UsageService
|
|
apiKeyService *service.APIKeyService
|
|
adminService service.AdminService
|
|
cleanupService *service.UsageCleanupService
|
|
}
|
|
|
|
// NewUsageHandler creates a new admin usage handler
|
|
func NewUsageHandler(
|
|
usageService *service.UsageService,
|
|
apiKeyService *service.APIKeyService,
|
|
adminService service.AdminService,
|
|
cleanupService *service.UsageCleanupService,
|
|
) *UsageHandler {
|
|
return &UsageHandler{
|
|
usageService: usageService,
|
|
apiKeyService: apiKeyService,
|
|
adminService: adminService,
|
|
cleanupService: cleanupService,
|
|
}
|
|
}
|
|
|
|
// CreateUsageCleanupTaskRequest represents cleanup task creation request
|
|
type CreateUsageCleanupTaskRequest struct {
|
|
StartDate string `json:"start_date"`
|
|
EndDate string `json:"end_date"`
|
|
UserID *int64 `json:"user_id"`
|
|
APIKeyID *int64 `json:"api_key_id"`
|
|
AccountID *int64 `json:"account_id"`
|
|
GroupID *int64 `json:"group_id"`
|
|
Model *string `json:"model"`
|
|
Stream *bool `json:"stream"`
|
|
BillingType *int8 `json:"billing_type"`
|
|
Timezone string `json:"timezone"`
|
|
}
|
|
|
|
// List handles listing all usage records with filters
|
|
// GET /api/v1/admin/usage
|
|
func (h *UsageHandler) List(c *gin.Context) {
|
|
page, pageSize := response.ParsePagination(c)
|
|
|
|
// Parse filters
|
|
var userID, apiKeyID, accountID, groupID int64
|
|
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
|
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid user_id")
|
|
return
|
|
}
|
|
userID = id
|
|
}
|
|
|
|
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
|
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid api_key_id")
|
|
return
|
|
}
|
|
apiKeyID = id
|
|
}
|
|
|
|
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
|
|
id, err := strconv.ParseInt(accountIDStr, 10, 64)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid account_id")
|
|
return
|
|
}
|
|
accountID = id
|
|
}
|
|
|
|
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
|
id, err := strconv.ParseInt(groupIDStr, 10, 64)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid group_id")
|
|
return
|
|
}
|
|
groupID = id
|
|
}
|
|
|
|
model := c.Query("model")
|
|
|
|
var stream *bool
|
|
if streamStr := c.Query("stream"); streamStr != "" {
|
|
val, err := strconv.ParseBool(streamStr)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid stream value, use true or false")
|
|
return
|
|
}
|
|
stream = &val
|
|
}
|
|
|
|
var billingType *int8
|
|
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
|
val, err := strconv.ParseInt(billingTypeStr, 10, 8)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid billing_type")
|
|
return
|
|
}
|
|
bt := int8(val)
|
|
billingType = &bt
|
|
}
|
|
|
|
// Parse date range
|
|
var startTime, endTime *time.Time
|
|
userTZ := c.Query("timezone") // Get user's timezone from request
|
|
if startDateStr := c.Query("start_date"); startDateStr != "" {
|
|
t, err := timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
|
return
|
|
}
|
|
startTime = &t
|
|
}
|
|
|
|
if endDateStr := c.Query("end_date"); endDateStr != "" {
|
|
t, err := timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
|
return
|
|
}
|
|
// Set end time to end of day
|
|
t = t.Add(24*time.Hour - time.Nanosecond)
|
|
endTime = &t
|
|
}
|
|
|
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
|
filters := usagestats.UsageLogFilters{
|
|
UserID: userID,
|
|
APIKeyID: apiKeyID,
|
|
AccountID: accountID,
|
|
GroupID: groupID,
|
|
Model: model,
|
|
Stream: stream,
|
|
BillingType: billingType,
|
|
StartTime: startTime,
|
|
EndTime: endTime,
|
|
}
|
|
|
|
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
|
|
if err != nil {
|
|
response.ErrorFrom(c, err)
|
|
return
|
|
}
|
|
|
|
out := make([]dto.AdminUsageLog, 0, len(records))
|
|
for i := range records {
|
|
out = append(out, *dto.UsageLogFromServiceAdmin(&records[i]))
|
|
}
|
|
response.Paginated(c, out, result.Total, page, pageSize)
|
|
}
|
|
|
|
// Stats handles getting usage statistics with filters
|
|
// GET /api/v1/admin/usage/stats
|
|
func (h *UsageHandler) Stats(c *gin.Context) {
|
|
// Parse filters - same as List endpoint
|
|
var userID, apiKeyID, accountID, groupID int64
|
|
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
|
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid user_id")
|
|
return
|
|
}
|
|
userID = id
|
|
}
|
|
|
|
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
|
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid api_key_id")
|
|
return
|
|
}
|
|
apiKeyID = id
|
|
}
|
|
|
|
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
|
|
id, err := strconv.ParseInt(accountIDStr, 10, 64)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid account_id")
|
|
return
|
|
}
|
|
accountID = id
|
|
}
|
|
|
|
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
|
id, err := strconv.ParseInt(groupIDStr, 10, 64)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid group_id")
|
|
return
|
|
}
|
|
groupID = id
|
|
}
|
|
|
|
model := c.Query("model")
|
|
|
|
var stream *bool
|
|
if streamStr := c.Query("stream"); streamStr != "" {
|
|
val, err := strconv.ParseBool(streamStr)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid stream value, use true or false")
|
|
return
|
|
}
|
|
stream = &val
|
|
}
|
|
|
|
var billingType *int8
|
|
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
|
val, err := strconv.ParseInt(billingTypeStr, 10, 8)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid billing_type")
|
|
return
|
|
}
|
|
bt := int8(val)
|
|
billingType = &bt
|
|
}
|
|
|
|
// Parse date range
|
|
userTZ := c.Query("timezone")
|
|
now := timezone.NowInUserLocation(userTZ)
|
|
var startTime, endTime time.Time
|
|
|
|
startDateStr := c.Query("start_date")
|
|
endDateStr := c.Query("end_date")
|
|
|
|
if startDateStr != "" && endDateStr != "" {
|
|
var err error
|
|
startTime, err = timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
|
return
|
|
}
|
|
endTime, err = timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
|
return
|
|
}
|
|
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
|
} else {
|
|
period := c.DefaultQuery("period", "today")
|
|
switch period {
|
|
case "today":
|
|
startTime = timezone.StartOfDayInUserLocation(now, userTZ)
|
|
case "week":
|
|
startTime = now.AddDate(0, 0, -7)
|
|
case "month":
|
|
startTime = now.AddDate(0, -1, 0)
|
|
default:
|
|
startTime = timezone.StartOfDayInUserLocation(now, userTZ)
|
|
}
|
|
endTime = now
|
|
}
|
|
|
|
// Build filters and call GetStatsWithFilters
|
|
filters := usagestats.UsageLogFilters{
|
|
UserID: userID,
|
|
APIKeyID: apiKeyID,
|
|
AccountID: accountID,
|
|
GroupID: groupID,
|
|
Model: model,
|
|
Stream: stream,
|
|
BillingType: billingType,
|
|
StartTime: &startTime,
|
|
EndTime: &endTime,
|
|
}
|
|
|
|
stats, err := h.usageService.GetStatsWithFilters(c.Request.Context(), filters)
|
|
if err != nil {
|
|
response.ErrorFrom(c, err)
|
|
return
|
|
}
|
|
|
|
response.Success(c, stats)
|
|
}
|
|
|
|
// SearchUsers handles searching users by email keyword
|
|
// GET /api/v1/admin/usage/search-users
|
|
func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
|
keyword := c.Query("q")
|
|
if keyword == "" {
|
|
response.Success(c, []any{})
|
|
return
|
|
}
|
|
|
|
// Limit to 30 results
|
|
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword})
|
|
if err != nil {
|
|
response.ErrorFrom(c, err)
|
|
return
|
|
}
|
|
|
|
// Return simplified user list (only id and email)
|
|
type SimpleUser struct {
|
|
ID int64 `json:"id"`
|
|
Email string `json:"email"`
|
|
}
|
|
|
|
result := make([]SimpleUser, len(users))
|
|
for i, u := range users {
|
|
result[i] = SimpleUser{
|
|
ID: u.ID,
|
|
Email: u.Email,
|
|
}
|
|
}
|
|
|
|
response.Success(c, result)
|
|
}
|
|
|
|
// SearchAPIKeys handles searching API keys by user
|
|
// GET /api/v1/admin/usage/search-api-keys
|
|
func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
|
|
userIDStr := c.Query("user_id")
|
|
keyword := c.Query("q")
|
|
|
|
var userID int64
|
|
if userIDStr != "" {
|
|
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid user_id")
|
|
return
|
|
}
|
|
userID = id
|
|
}
|
|
|
|
keys, err := h.apiKeyService.SearchAPIKeys(c.Request.Context(), userID, keyword, 30)
|
|
if err != nil {
|
|
response.ErrorFrom(c, err)
|
|
return
|
|
}
|
|
|
|
// Return simplified API key list (only id and name)
|
|
type SimpleAPIKey struct {
|
|
ID int64 `json:"id"`
|
|
Name string `json:"name"`
|
|
UserID int64 `json:"user_id"`
|
|
}
|
|
|
|
result := make([]SimpleAPIKey, len(keys))
|
|
for i, k := range keys {
|
|
result[i] = SimpleAPIKey{
|
|
ID: k.ID,
|
|
Name: k.Name,
|
|
UserID: k.UserID,
|
|
}
|
|
}
|
|
|
|
response.Success(c, result)
|
|
}
|
|
|
|
// ListCleanupTasks handles listing usage cleanup tasks
|
|
// GET /api/v1/admin/usage/cleanup-tasks
|
|
func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
|
|
if h.cleanupService == nil {
|
|
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
|
|
return
|
|
}
|
|
operator := int64(0)
|
|
if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
|
|
operator = subject.UserID
|
|
}
|
|
page, pageSize := response.ParsePagination(c)
|
|
log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
|
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
|
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
|
|
if err != nil {
|
|
log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
|
|
response.ErrorFrom(c, err)
|
|
return
|
|
}
|
|
out := make([]dto.UsageCleanupTask, 0, len(tasks))
|
|
for i := range tasks {
|
|
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
|
|
}
|
|
log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
|
|
response.Paginated(c, out, result.Total, page, pageSize)
|
|
}
|
|
|
|
// CreateCleanupTask handles creating a usage cleanup task
|
|
// POST /api/v1/admin/usage/cleanup-tasks
|
|
func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
|
if h.cleanupService == nil {
|
|
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
|
|
return
|
|
}
|
|
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
|
if !ok || subject.UserID <= 0 {
|
|
response.Unauthorized(c, "Unauthorized")
|
|
return
|
|
}
|
|
|
|
var req CreateUsageCleanupTaskRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
|
return
|
|
}
|
|
req.StartDate = strings.TrimSpace(req.StartDate)
|
|
req.EndDate = strings.TrimSpace(req.EndDate)
|
|
if req.StartDate == "" || req.EndDate == "" {
|
|
response.BadRequest(c, "start_date and end_date are required")
|
|
return
|
|
}
|
|
|
|
startTime, err := timezone.ParseInUserLocation("2006-01-02", req.StartDate, req.Timezone)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
|
return
|
|
}
|
|
endTime, err := timezone.ParseInUserLocation("2006-01-02", req.EndDate, req.Timezone)
|
|
if err != nil {
|
|
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
|
return
|
|
}
|
|
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
|
|
|
filters := service.UsageCleanupFilters{
|
|
StartTime: startTime,
|
|
EndTime: endTime,
|
|
UserID: req.UserID,
|
|
APIKeyID: req.APIKeyID,
|
|
AccountID: req.AccountID,
|
|
GroupID: req.GroupID,
|
|
Model: req.Model,
|
|
Stream: req.Stream,
|
|
BillingType: req.BillingType,
|
|
}
|
|
|
|
var userID any
|
|
if filters.UserID != nil {
|
|
userID = *filters.UserID
|
|
}
|
|
var apiKeyID any
|
|
if filters.APIKeyID != nil {
|
|
apiKeyID = *filters.APIKeyID
|
|
}
|
|
var accountID any
|
|
if filters.AccountID != nil {
|
|
accountID = *filters.AccountID
|
|
}
|
|
var groupID any
|
|
if filters.GroupID != nil {
|
|
groupID = *filters.GroupID
|
|
}
|
|
var model any
|
|
if filters.Model != nil {
|
|
model = *filters.Model
|
|
}
|
|
var stream any
|
|
if filters.Stream != nil {
|
|
stream = *filters.Stream
|
|
}
|
|
var billingType any
|
|
if filters.BillingType != nil {
|
|
billingType = *filters.BillingType
|
|
}
|
|
|
|
log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
|
|
subject.UserID,
|
|
filters.StartTime.Format(time.RFC3339),
|
|
filters.EndTime.Format(time.RFC3339),
|
|
userID,
|
|
apiKeyID,
|
|
accountID,
|
|
groupID,
|
|
model,
|
|
stream,
|
|
billingType,
|
|
req.Timezone,
|
|
)
|
|
|
|
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
|
|
if err != nil {
|
|
log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
|
|
response.ErrorFrom(c, err)
|
|
return
|
|
}
|
|
|
|
log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
|
|
response.Success(c, dto.UsageCleanupTaskFromService(task))
|
|
}
|
|
|
|
// CancelCleanupTask handles canceling a usage cleanup task
|
|
// POST /api/v1/admin/usage/cleanup-tasks/:id/cancel
|
|
func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
|
|
if h.cleanupService == nil {
|
|
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
|
|
return
|
|
}
|
|
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
|
if !ok || subject.UserID <= 0 {
|
|
response.Unauthorized(c, "Unauthorized")
|
|
return
|
|
}
|
|
idStr := strings.TrimSpace(c.Param("id"))
|
|
taskID, err := strconv.ParseInt(idStr, 10, 64)
|
|
if err != nil || taskID <= 0 {
|
|
response.BadRequest(c, "Invalid task id")
|
|
return
|
|
}
|
|
log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
|
|
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
|
|
log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
|
|
response.ErrorFrom(c, err)
|
|
return
|
|
}
|
|
log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
|
|
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
|
|
}
|