feat(backend): 增强使用统计和API密钥功能
- 优化使用统计处理逻辑 - 增强API密钥仓储层功能 - 改进账户使用服务 - 完善API契约测试覆盖
This commit is contained in:
@@ -371,24 +371,16 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify ownership of all requested API keys
|
// Limit the number of API key IDs to prevent SQL parameter overflow
|
||||||
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, pagination.PaginationParams{Page: 1, PageSize: 1000})
|
if len(req.ApiKeyIDs) > 100 {
|
||||||
if err != nil {
|
response.BadRequest(c, "Too many API key IDs (maximum 100 allowed)")
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userApiKeyIDs := make(map[int64]bool)
|
validApiKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.ApiKeyIDs)
|
||||||
for _, key := range userApiKeys {
|
if err != nil {
|
||||||
userApiKeyIDs[key.ID] = true
|
response.ErrorFrom(c, err)
|
||||||
}
|
return
|
||||||
|
|
||||||
// Filter to only include user's own API keys
|
|
||||||
validApiKeyIDs := make([]int64, 0)
|
|
||||||
for _, id := range req.ApiKeyIDs {
|
|
||||||
if userApiKeyIDs[id] {
|
|
||||||
validApiKeyIDs = append(validApiKeyIDs, id)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(validApiKeyIDs) == 0 {
|
if len(validApiKeyIDs) == 0 {
|
||||||
|
|||||||
@@ -81,6 +81,22 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
|
|||||||
return outKeys, paginationResultFromTotal(total, params), nil
|
return outKeys, paginationResultFromTotal(total, params), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||||
|
if len(apiKeyIDs) == 0 {
|
||||||
|
return []int64{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := make([]int64, 0, len(apiKeyIDs))
|
||||||
|
err := r.db.WithContext(ctx).
|
||||||
|
Model(&apiKeyModel{}).
|
||||||
|
Where("user_id = ? AND id IN ?", userID, apiKeyIDs).
|
||||||
|
Pluck("id", &ids).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||||
var count int64
|
var count int64
|
||||||
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error
|
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error
|
||||||
|
|||||||
@@ -788,6 +788,25 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||||
|
if len(apiKeyIDs) == 0 {
|
||||||
|
return []int64{}, nil
|
||||||
|
}
|
||||||
|
seen := make(map[int64]struct{}, len(apiKeyIDs))
|
||||||
|
out := make([]int64, 0, len(apiKeyIDs))
|
||||||
|
for _, id := range apiKeyIDs {
|
||||||
|
if _, ok := seen[id]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
key, ok := r.byID[id]
|
||||||
|
if ok && key.UserID == userID {
|
||||||
|
out = append(out, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||||
var count int64
|
var count int64
|
||||||
for _, key := range r.byID {
|
for _, key := range r.byID {
|
||||||
|
|||||||
@@ -48,6 +48,10 @@ type UsageLogRepository interface {
|
|||||||
|
|
||||||
// Account stats
|
// Account stats
|
||||||
GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error)
|
GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error)
|
||||||
|
|
||||||
|
// Aggregated stats (optimized)
|
||||||
|
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||||
|
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// usageCache 用于缓存usage数据
|
// usageCache 用于缓存usage数据
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ type ApiKeyRepository interface {
|
|||||||
Delete(ctx context.Context, id int64) error
|
Delete(ctx context.Context, id int64) error
|
||||||
|
|
||||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||||
|
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
|
||||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||||
@@ -256,6 +257,18 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
|
|||||||
return keys, pagination, nil
|
return keys, pagination, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||||
|
if len(apiKeyIDs) == 0 {
|
||||||
|
return []int64{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
validIDs, err := s.apiKeyRepo.VerifyOwnership(ctx, userID, apiKeyIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("verify api key ownership: %w", err)
|
||||||
|
}
|
||||||
|
return validIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetByID 根据ID获取API Key
|
// GetByID 根据ID获取API Key
|
||||||
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
|
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
|
||||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||||
|
|||||||
@@ -148,22 +148,40 @@ func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, param
|
|||||||
|
|
||||||
// GetStatsByUser 获取用户的使用统计
|
// GetStatsByUser 获取用户的使用统计
|
||||||
func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||||
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime)
|
stats, err := s.usageRepo.GetUserStatsAggregated(ctx, userID, startTime, endTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
return nil, fmt.Errorf("get user stats: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.calculateStats(logs), nil
|
return &UsageStats{
|
||||||
|
TotalRequests: stats.TotalRequests,
|
||||||
|
TotalInputTokens: stats.TotalInputTokens,
|
||||||
|
TotalOutputTokens: stats.TotalOutputTokens,
|
||||||
|
TotalCacheTokens: stats.TotalCacheTokens,
|
||||||
|
TotalTokens: stats.TotalTokens,
|
||||||
|
TotalCost: stats.TotalCost,
|
||||||
|
TotalActualCost: stats.TotalActualCost,
|
||||||
|
AverageDurationMs: stats.AverageDurationMs,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStatsByApiKey 获取API Key的使用统计
|
// GetStatsByApiKey 获取API Key的使用统计
|
||||||
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||||
logs, _, err := s.usageRepo.ListByApiKeyAndTimeRange(ctx, apiKeyID, startTime, endTime)
|
stats, err := s.usageRepo.GetApiKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
return nil, fmt.Errorf("get api key stats: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.calculateStats(logs), nil
|
return &UsageStats{
|
||||||
|
TotalRequests: stats.TotalRequests,
|
||||||
|
TotalInputTokens: stats.TotalInputTokens,
|
||||||
|
TotalOutputTokens: stats.TotalOutputTokens,
|
||||||
|
TotalCacheTokens: stats.TotalCacheTokens,
|
||||||
|
TotalTokens: stats.TotalTokens,
|
||||||
|
TotalCost: stats.TotalCost,
|
||||||
|
TotalActualCost: stats.TotalActualCost,
|
||||||
|
AverageDurationMs: stats.AverageDurationMs,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStatsByAccount 获取账号的使用统计
|
// GetStatsByAccount 获取账号的使用统计
|
||||||
|
|||||||
Reference in New Issue
Block a user