Merge branch 'feature/ui-and-backend-improvements'
This commit is contained in:
@@ -371,24 +371,16 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify ownership of all requested API keys
|
||||
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, pagination.PaginationParams{Page: 1, PageSize: 1000})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
// Limit the number of API key IDs to prevent SQL parameter overflow
|
||||
if len(req.ApiKeyIDs) > 100 {
|
||||
response.BadRequest(c, "Too many API key IDs (maximum 100 allowed)")
|
||||
return
|
||||
}
|
||||
|
||||
userApiKeyIDs := make(map[int64]bool)
|
||||
for _, key := range userApiKeys {
|
||||
userApiKeyIDs[key.ID] = true
|
||||
}
|
||||
|
||||
// Filter to only include user's own API keys
|
||||
validApiKeyIDs := make([]int64, 0)
|
||||
for _, id := range req.ApiKeyIDs {
|
||||
if userApiKeyIDs[id] {
|
||||
validApiKeyIDs = append(validApiKeyIDs, id)
|
||||
}
|
||||
validApiKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.ApiKeyIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(validApiKeyIDs) == 0 {
|
||||
|
||||
@@ -81,6 +81,22 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
|
||||
return outKeys, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
ids := make([]int64, 0, len(apiKeyIDs))
|
||||
err := r.db.WithContext(ctx).
|
||||
Model(&apiKeyModel{}).
|
||||
Where("user_id = ? AND id IN ?", userID, apiKeyIDs).
|
||||
Pluck("id", &ids).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
|
||||
@@ -129,51 +129,67 @@ type DashboardStats = usagestats.DashboardStats
|
||||
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||
var stats DashboardStats
|
||||
today := timezone.Today()
|
||||
now := time.Now()
|
||||
|
||||
// 总用户数
|
||||
r.db.WithContext(ctx).Model(&userModel{}).Count(&stats.TotalUsers)
|
||||
// 合并用户统计查询
|
||||
var userStats struct {
|
||||
TotalUsers int64 `gorm:"column:total_users"`
|
||||
TodayNewUsers int64 `gorm:"column:today_new_users"`
|
||||
ActiveUsers int64 `gorm:"column:active_users"`
|
||||
}
|
||||
if err := r.db.WithContext(ctx).Raw(`
|
||||
SELECT
|
||||
COUNT(*) as total_users,
|
||||
COUNT(CASE WHEN created_at >= ? THEN 1 END) as today_new_users,
|
||||
(SELECT COUNT(DISTINCT user_id) FROM usage_logs WHERE created_at >= ?) as active_users
|
||||
FROM users
|
||||
`, today, today).Scan(&userStats).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalUsers = userStats.TotalUsers
|
||||
stats.TodayNewUsers = userStats.TodayNewUsers
|
||||
stats.ActiveUsers = userStats.ActiveUsers
|
||||
|
||||
// 今日新增用户数
|
||||
r.db.WithContext(ctx).Model(&userModel{}).
|
||||
Where("created_at >= ?", today).
|
||||
Count(&stats.TodayNewUsers)
|
||||
// 合并API Key统计查询
|
||||
var apiKeyStats struct {
|
||||
TotalApiKeys int64 `gorm:"column:total_api_keys"`
|
||||
ActiveApiKeys int64 `gorm:"column:active_api_keys"`
|
||||
}
|
||||
if err := r.db.WithContext(ctx).Raw(`
|
||||
SELECT
|
||||
COUNT(*) as total_api_keys,
|
||||
COUNT(CASE WHEN status = ? THEN 1 END) as active_api_keys
|
||||
FROM api_keys
|
||||
`, service.StatusActive).Scan(&apiKeyStats).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalApiKeys = apiKeyStats.TotalApiKeys
|
||||
stats.ActiveApiKeys = apiKeyStats.ActiveApiKeys
|
||||
|
||||
// 今日活跃用户数 (今日有请求的用户)
|
||||
r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Distinct("user_id").
|
||||
Where("created_at >= ?", today).
|
||||
Count(&stats.ActiveUsers)
|
||||
|
||||
// 总 API Key 数
|
||||
r.db.WithContext(ctx).Model(&apiKeyModel{}).Count(&stats.TotalApiKeys)
|
||||
|
||||
// 活跃 API Key 数
|
||||
r.db.WithContext(ctx).Model(&apiKeyModel{}).
|
||||
Where("status = ?", service.StatusActive).
|
||||
Count(&stats.ActiveApiKeys)
|
||||
|
||||
// 总账户数
|
||||
r.db.WithContext(ctx).Model(&accountModel{}).Count(&stats.TotalAccounts)
|
||||
|
||||
// 正常账户数 (schedulable=true, status=active)
|
||||
r.db.WithContext(ctx).Model(&accountModel{}).
|
||||
Where("status = ? AND schedulable = ?", service.StatusActive, true).
|
||||
Count(&stats.NormalAccounts)
|
||||
|
||||
// 异常账户数 (status=error)
|
||||
r.db.WithContext(ctx).Model(&accountModel{}).
|
||||
Where("status = ?", service.StatusError).
|
||||
Count(&stats.ErrorAccounts)
|
||||
|
||||
// 限流账户数
|
||||
r.db.WithContext(ctx).Model(&accountModel{}).
|
||||
Where("rate_limited_at IS NOT NULL AND rate_limit_reset_at > ?", time.Now()).
|
||||
Count(&stats.RateLimitAccounts)
|
||||
|
||||
// 过载账户数
|
||||
r.db.WithContext(ctx).Model(&accountModel{}).
|
||||
Where("overload_until IS NOT NULL AND overload_until > ?", time.Now()).
|
||||
Count(&stats.OverloadAccounts)
|
||||
// 合并账户统计查询
|
||||
var accountStats struct {
|
||||
TotalAccounts int64 `gorm:"column:total_accounts"`
|
||||
NormalAccounts int64 `gorm:"column:normal_accounts"`
|
||||
ErrorAccounts int64 `gorm:"column:error_accounts"`
|
||||
RateLimitAccounts int64 `gorm:"column:ratelimit_accounts"`
|
||||
OverloadAccounts int64 `gorm:"column:overload_accounts"`
|
||||
}
|
||||
if err := r.db.WithContext(ctx).Raw(`
|
||||
SELECT
|
||||
COUNT(*) as total_accounts,
|
||||
COUNT(CASE WHEN status = ? AND schedulable = true THEN 1 END) as normal_accounts,
|
||||
COUNT(CASE WHEN status = ? THEN 1 END) as error_accounts,
|
||||
COUNT(CASE WHEN rate_limited_at IS NOT NULL AND rate_limit_reset_at > ? THEN 1 END) as ratelimit_accounts,
|
||||
COUNT(CASE WHEN overload_until IS NOT NULL AND overload_until > ? THEN 1 END) as overload_accounts
|
||||
FROM accounts
|
||||
`, service.StatusActive, service.StatusError, now, now).Scan(&accountStats).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalAccounts = accountStats.TotalAccounts
|
||||
stats.NormalAccounts = accountStats.NormalAccounts
|
||||
stats.ErrorAccounts = accountStats.ErrorAccounts
|
||||
stats.RateLimitAccounts = accountStats.RateLimitAccounts
|
||||
stats.OverloadAccounts = accountStats.OverloadAccounts
|
||||
|
||||
// 累计 Token 统计
|
||||
var totalStats struct {
|
||||
@@ -273,6 +289,88 @@ func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID
|
||||
return usageLogModelsToService(logs), nil, err
|
||||
}
|
||||
|
||||
// GetUserStatsAggregated returns aggregated usage statistics for a user using database-level aggregation
|
||||
func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
var stats struct {
|
||||
TotalRequests int64 `gorm:"column:total_requests"`
|
||||
TotalInputTokens int64 `gorm:"column:total_input_tokens"`
|
||||
TotalOutputTokens int64 `gorm:"column:total_output_tokens"`
|
||||
TotalCacheTokens int64 `gorm:"column:total_cache_tokens"`
|
||||
TotalCost float64 `gorm:"column:total_cost"`
|
||||
TotalActualCost float64 `gorm:"column:total_actual_cost"`
|
||||
AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
|
||||
}
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||||
`).
|
||||
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
|
||||
Scan(&stats).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &usagestats.UsageStats{
|
||||
TotalRequests: stats.TotalRequests,
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheTokens: stats.TotalCacheTokens,
|
||||
TotalTokens: stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
TotalActualCost: stats.TotalActualCost,
|
||||
AverageDurationMs: stats.AverageDurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetApiKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
|
||||
func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
var stats struct {
|
||||
TotalRequests int64 `gorm:"column:total_requests"`
|
||||
TotalInputTokens int64 `gorm:"column:total_input_tokens"`
|
||||
TotalOutputTokens int64 `gorm:"column:total_output_tokens"`
|
||||
TotalCacheTokens int64 `gorm:"column:total_cache_tokens"`
|
||||
TotalCost float64 `gorm:"column:total_cost"`
|
||||
TotalActualCost float64 `gorm:"column:total_actual_cost"`
|
||||
AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
|
||||
}
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||||
`).
|
||||
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
|
||||
Scan(&stats).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &usagestats.UsageStats{
|
||||
TotalRequests: stats.TotalRequests,
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheTokens: stats.TotalCacheTokens,
|
||||
TotalTokens: stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
TotalActualCost: stats.TotalActualCost,
|
||||
AverageDurationMs: stats.AverageDurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []usageLogModel
|
||||
err := r.db.WithContext(ctx).
|
||||
|
||||
@@ -788,6 +788,25 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
seen := make(map[int64]struct{}, len(apiKeyIDs))
|
||||
out := make([]int64, 0, len(apiKeyIDs))
|
||||
for _, id := range apiKeyIDs {
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
key, ok := r.byID[id]
|
||||
if ok && key.UserID == userID {
|
||||
out = append(out, id)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
for _, key := range r.byID {
|
||||
@@ -903,6 +922,55 @@ func (r *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, end
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
logs := r.userLogs[userID]
|
||||
if len(logs) == 0 {
|
||||
return &usagestats.UsageStats{}, nil
|
||||
}
|
||||
|
||||
var totalRequests int64
|
||||
var totalInputTokens int64
|
||||
var totalOutputTokens int64
|
||||
var totalCacheTokens int64
|
||||
var totalCost float64
|
||||
var totalActualCost float64
|
||||
var totalDuration int64
|
||||
var durationCount int64
|
||||
|
||||
for _, log := range logs {
|
||||
totalRequests++
|
||||
totalInputTokens += int64(log.InputTokens)
|
||||
totalOutputTokens += int64(log.OutputTokens)
|
||||
totalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
|
||||
totalCost += log.TotalCost
|
||||
totalActualCost += log.ActualCost
|
||||
if log.DurationMs != nil {
|
||||
totalDuration += int64(*log.DurationMs)
|
||||
durationCount++
|
||||
}
|
||||
}
|
||||
|
||||
var avgDuration float64
|
||||
if durationCount > 0 {
|
||||
avgDuration = float64(totalDuration) / float64(durationCount)
|
||||
}
|
||||
|
||||
return &usagestats.UsageStats{
|
||||
TotalRequests: totalRequests,
|
||||
TotalInputTokens: totalInputTokens,
|
||||
TotalOutputTokens: totalOutputTokens,
|
||||
TotalCacheTokens: totalCacheTokens,
|
||||
TotalTokens: totalInputTokens + totalOutputTokens + totalCacheTokens,
|
||||
TotalCost: totalCost,
|
||||
TotalActualCost: totalActualCost,
|
||||
AverageDurationMs: avgDuration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -48,6 +48,10 @@ type UsageLogRepository interface {
|
||||
|
||||
// Account stats
|
||||
GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error)
|
||||
|
||||
// Aggregated stats (optimized)
|
||||
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
}
|
||||
|
||||
// usageCache 用于缓存usage数据
|
||||
|
||||
@@ -34,6 +34,7 @@ type ApiKeyRepository interface {
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||
@@ -256,6 +257,18 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
|
||||
return keys, pagination, nil
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
validIDs, err := s.apiKeyRepo.VerifyOwnership(ctx, userID, apiKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("verify api key ownership: %w", err)
|
||||
}
|
||||
return validIDs, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取API Key
|
||||
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
|
||||
@@ -148,22 +148,40 @@ func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, param
|
||||
|
||||
// GetStatsByUser 获取用户的使用统计
|
||||
func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime)
|
||||
stats, err := s.usageRepo.GetUserStatsAggregated(ctx, userID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
return nil, fmt.Errorf("get user stats: %w", err)
|
||||
}
|
||||
|
||||
return s.calculateStats(logs), nil
|
||||
return &UsageStats{
|
||||
TotalRequests: stats.TotalRequests,
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheTokens: stats.TotalCacheTokens,
|
||||
TotalTokens: stats.TotalTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
TotalActualCost: stats.TotalActualCost,
|
||||
AverageDurationMs: stats.AverageDurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetStatsByApiKey 获取API Key的使用统计
|
||||
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
logs, _, err := s.usageRepo.ListByApiKeyAndTimeRange(ctx, apiKeyID, startTime, endTime)
|
||||
stats, err := s.usageRepo.GetApiKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
return nil, fmt.Errorf("get api key stats: %w", err)
|
||||
}
|
||||
|
||||
return s.calculateStats(logs), nil
|
||||
return &UsageStats{
|
||||
TotalRequests: stats.TotalRequests,
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheTokens: stats.TotalCacheTokens,
|
||||
TotalTokens: stats.TotalTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
TotalActualCost: stats.TotalActualCost,
|
||||
AverageDurationMs: stats.AverageDurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetStatsByAccount 获取账号的使用统计
|
||||
|
||||
Reference in New Issue
Block a user