feat(backend): 增强使用统计和API密钥功能
- 优化使用统计处理逻辑 - 增强API密钥仓储层功能 - 改进账户使用服务 - 完善API契约测试覆盖
This commit is contained in:
@@ -371,24 +371,16 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify ownership of all requested API keys
|
||||
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, pagination.PaginationParams{Page: 1, PageSize: 1000})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
// Limit the number of API key IDs to prevent SQL parameter overflow
|
||||
if len(req.ApiKeyIDs) > 100 {
|
||||
response.BadRequest(c, "Too many API key IDs (maximum 100 allowed)")
|
||||
return
|
||||
}
|
||||
|
||||
userApiKeyIDs := make(map[int64]bool)
|
||||
for _, key := range userApiKeys {
|
||||
userApiKeyIDs[key.ID] = true
|
||||
}
|
||||
|
||||
// Filter to only include user's own API keys
|
||||
validApiKeyIDs := make([]int64, 0)
|
||||
for _, id := range req.ApiKeyIDs {
|
||||
if userApiKeyIDs[id] {
|
||||
validApiKeyIDs = append(validApiKeyIDs, id)
|
||||
}
|
||||
validApiKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.ApiKeyIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(validApiKeyIDs) == 0 {
|
||||
|
||||
@@ -81,6 +81,22 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
|
||||
return outKeys, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
ids := make([]int64, 0, len(apiKeyIDs))
|
||||
err := r.db.WithContext(ctx).
|
||||
Model(&apiKeyModel{}).
|
||||
Where("user_id = ? AND id IN ?", userID, apiKeyIDs).
|
||||
Pluck("id", &ids).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
|
||||
@@ -788,6 +788,25 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
seen := make(map[int64]struct{}, len(apiKeyIDs))
|
||||
out := make([]int64, 0, len(apiKeyIDs))
|
||||
for _, id := range apiKeyIDs {
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
key, ok := r.byID[id]
|
||||
if ok && key.UserID == userID {
|
||||
out = append(out, id)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
for _, key := range r.byID {
|
||||
|
||||
@@ -48,6 +48,10 @@ type UsageLogRepository interface {
|
||||
|
||||
// Account stats
|
||||
GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error)
|
||||
|
||||
// Aggregated stats (optimized)
|
||||
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
}
|
||||
|
||||
// usageCache 用于缓存usage数据
|
||||
|
||||
@@ -34,6 +34,7 @@ type ApiKeyRepository interface {
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||
@@ -256,6 +257,18 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
|
||||
return keys, pagination, nil
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
validIDs, err := s.apiKeyRepo.VerifyOwnership(ctx, userID, apiKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("verify api key ownership: %w", err)
|
||||
}
|
||||
return validIDs, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取API Key
|
||||
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
|
||||
@@ -148,22 +148,40 @@ func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, param
|
||||
|
||||
// GetStatsByUser 获取用户的使用统计
|
||||
func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime)
|
||||
stats, err := s.usageRepo.GetUserStatsAggregated(ctx, userID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
return nil, fmt.Errorf("get user stats: %w", err)
|
||||
}
|
||||
|
||||
return s.calculateStats(logs), nil
|
||||
return &UsageStats{
|
||||
TotalRequests: stats.TotalRequests,
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheTokens: stats.TotalCacheTokens,
|
||||
TotalTokens: stats.TotalTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
TotalActualCost: stats.TotalActualCost,
|
||||
AverageDurationMs: stats.AverageDurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetStatsByApiKey 获取API Key的使用统计
|
||||
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
logs, _, err := s.usageRepo.ListByApiKeyAndTimeRange(ctx, apiKeyID, startTime, endTime)
|
||||
stats, err := s.usageRepo.GetApiKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
return nil, fmt.Errorf("get api key stats: %w", err)
|
||||
}
|
||||
|
||||
return s.calculateStats(logs), nil
|
||||
return &UsageStats{
|
||||
TotalRequests: stats.TotalRequests,
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheTokens: stats.TotalCacheTokens,
|
||||
TotalTokens: stats.TotalTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
TotalActualCost: stats.TotalActualCost,
|
||||
AverageDurationMs: stats.AverageDurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetStatsByAccount 获取账号的使用统计
|
||||
|
||||
Reference in New Issue
Block a user