From 227d506c53965eda7827a022acba9d3685454f92 Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Sat, 27 Dec 2025 16:03:57 +0800 Subject: [PATCH] =?UTF-8?q?feat(backend):=20=E5=A2=9E=E5=BC=BA=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E7=BB=9F=E8=AE=A1=E5=92=8CAPI=E5=AF=86=E9=92=A5?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 优化使用统计处理逻辑 - 增强API密钥仓储层功能 - 改进账户使用服务 - 完善API契约测试覆盖 --- backend/internal/handler/usage_handler.go | 22 +++++--------- backend/internal/repository/api_key_repo.go | 16 ++++++++++ backend/internal/server/api_contract_test.go | 19 ++++++++++++ .../internal/service/account_usage_service.go | 4 +++ backend/internal/service/api_key_service.go | 13 ++++++++ backend/internal/service/usage_service.go | 30 +++++++++++++++---- 6 files changed, 83 insertions(+), 21 deletions(-) diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go index 15b30bbb..a0cf9f2c 100644 --- a/backend/internal/handler/usage_handler.go +++ b/backend/internal/handler/usage_handler.go @@ -371,24 +371,16 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) { return } - // Verify ownership of all requested API keys - userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, pagination.PaginationParams{Page: 1, PageSize: 1000}) - if err != nil { - response.ErrorFrom(c, err) + // Limit the number of API key IDs to prevent SQL parameter overflow + if len(req.ApiKeyIDs) > 100 { + response.BadRequest(c, "Too many API key IDs (maximum 100 allowed)") return } - userApiKeyIDs := make(map[int64]bool) - for _, key := range userApiKeys { - userApiKeyIDs[key.ID] = true - } - - // Filter to only include user's own API keys - validApiKeyIDs := make([]int64, 0) - for _, id := range req.ApiKeyIDs { - if userApiKeyIDs[id] { - validApiKeyIDs = append(validApiKeyIDs, id) - } + validApiKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.ApiKeyIDs) + if err != nil { + response.ErrorFrom(c, err) + return } if len(validApiKeyIDs) == 0 { diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 718bef33..a6001ecc 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -81,6 +81,22 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param return outKeys, paginationResultFromTotal(total, params), nil } +func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { + if len(apiKeyIDs) == 0 { + return []int64{}, nil + } + + ids := make([]int64, 0, len(apiKeyIDs)) + err := r.db.WithContext(ctx). + Model(&apiKeyModel{}). + Where("user_id = ? AND id IN ?", userID, apiKeyIDs). + Pluck("id", &ids).Error + if err != nil { + return nil, err + } + return ids, nil +} + func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) { var count int64 err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 1aeedf8d..55f83afa 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -788,6 +788,25 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params }, nil } +func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { + if len(apiKeyIDs) == 0 { + return []int64{}, nil + } + seen := make(map[int64]struct{}, len(apiKeyIDs)) + out := make([]int64, 0, len(apiKeyIDs)) + for _, id := range apiKeyIDs { + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + key, ok := r.byID[id] + if ok && key.UserID == userID { + out = append(out, id) + } + } + return out, nil +} + func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { var count int64 for _, key := range r.byID { diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 642c8e09..575e72b1 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -48,6 +48,10 @@ type UsageLogRepository interface { // Account stats GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) + + // Aggregated stats (optimized) + GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) + GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) } // usageCache 用于缓存usage数据 diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index 4ab50fb5..e6234382 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -34,6 +34,7 @@ type ApiKeyRepository interface { Delete(ctx context.Context, id int64) error ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) + VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) CountByUserID(ctx context.Context, userID int64) (int64, error) ExistsByKey(ctx context.Context, key string) (bool, error) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) @@ -256,6 +257,18 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio return keys, pagination, nil } +func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { + if len(apiKeyIDs) == 0 { + return []int64{}, nil + } + + validIDs, err := s.apiKeyRepo.VerifyOwnership(ctx, userID, apiKeyIDs) + if err != nil { + return nil, fmt.Errorf("verify api key ownership: %w", err) + } + return validIDs, nil +} + // GetByID 根据ID获取API Key func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) { apiKey, err := s.apiKeyRepo.GetByID(ctx, id) diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index 2ccad4ff..0df8a0de 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -148,22 +148,40 @@ func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, param // GetStatsByUser 获取用户的使用统计 func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTime, endTime time.Time) (*UsageStats, error) { - logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime) + stats, err := s.usageRepo.GetUserStatsAggregated(ctx, userID, startTime, endTime) if err != nil { - return nil, fmt.Errorf("list usage logs: %w", err) + return nil, fmt.Errorf("get user stats: %w", err) } - return s.calculateStats(logs), nil + return &UsageStats{ + TotalRequests: stats.TotalRequests, + TotalInputTokens: stats.TotalInputTokens, + TotalOutputTokens: stats.TotalOutputTokens, + TotalCacheTokens: stats.TotalCacheTokens, + TotalTokens: stats.TotalTokens, + TotalCost: stats.TotalCost, + TotalActualCost: stats.TotalActualCost, + AverageDurationMs: stats.AverageDurationMs, + }, nil } // GetStatsByApiKey 获取API Key的使用统计 func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) { - logs, _, err := s.usageRepo.ListByApiKeyAndTimeRange(ctx, apiKeyID, startTime, endTime) + stats, err := s.usageRepo.GetApiKeyStatsAggregated(ctx, apiKeyID, startTime, endTime) if err != nil { - return nil, fmt.Errorf("list usage logs: %w", err) + return nil, fmt.Errorf("get api key stats: %w", err) } - return s.calculateStats(logs), nil + return &UsageStats{ + TotalRequests: stats.TotalRequests, + TotalInputTokens: stats.TotalInputTokens, + TotalOutputTokens: stats.TotalOutputTokens, + TotalCacheTokens: stats.TotalCacheTokens, + TotalTokens: stats.TotalTokens, + TotalCost: stats.TotalCost, + TotalActualCost: stats.TotalActualCost, + AverageDurationMs: stats.AverageDurationMs, + }, nil } // GetStatsByAccount 获取账号的使用统计