运维监控系统安全加固和功能优化 (#21)

* fix(ops): 修复运维监控系统的关键安全和稳定性问题

## 修复内容

### P0 严重问题
1. **DNS Rebinding防护** (ops_alert_service.go)
   - 实现IP钉住机制防止验证后的DNS rebinding攻击
   - 自定义Transport.DialContext强制只允许拨号到验证过的公网IP
   - 扩展IP黑名单,包括云metadata地址(169.254.169.254)
   - 添加完整的单元测试覆盖

2. **OpsAlertService生命周期管理** (wire.go)
   - 在ProvideOpsMetricsCollector中添加opsAlertService.Start()调用
   - 确保stopCtx正确初始化,避免nil指针问题
   - 实现防御式启动,保证服务启动顺序

3. **数据库查询排序** (ops_repo.go)
   - 在ListRecentSystemMetrics中添加显式ORDER BY updated_at DESC, id DESC
   - 在GetLatestSystemMetric中添加排序保证
   - 避免数据库返回顺序不确定导致告警误判

### P1 重要问题
4. **并发安全** (ops_metrics_collector.go)
   - 为lastGCPauseTotal字段添加sync.Mutex保护
   - 防止数据竞争

5. **Goroutine泄漏** (ops_error_logger.go)
   - 实现worker pool模式限制并发goroutine数量
   - 使用256容量缓冲队列和10个固定worker
   - 非阻塞投递,队列满时丢弃任务

6. **生命周期控制** (ops_alert_service.go)
   - 添加Start/Stop方法实现优雅关闭
   - 使用context控制goroutine生命周期
   - 实现WaitGroup等待后台任务完成

7. **Webhook URL验证** (ops_alert_service.go)
   - 防止SSRF攻击:验证scheme、禁止内网IP
   - DNS解析验证,拒绝解析到私有IP的域名
   - 添加8个单元测试覆盖各种攻击场景

8. **资源泄漏** (ops_repo.go)
   - 修复多处defer rows.Close()问题
   - 简化冗余的defer func()包装

9. **HTTP超时控制** (ops_alert_service.go)
   - 创建带10秒超时的http.Client
   - 添加buildWebhookHTTPClient辅助函数
   - 防止HTTP请求无限期挂起

10. **数据库查询优化** (ops_repo.go)
    - 将GetWindowStats的4次独立查询合并为1次CTE查询
    - 减少网络往返和表扫描次数
    - 显著提升性能

11. **重试机制** (ops_alert_service.go)
    - 实现邮件发送重试:最多3次,指数退避(1s/2s/4s)
    - 添加webhook备用通道
    - 实现完整的错误处理和日志记录

12. **魔法数字** (ops_repo.go, ops_metrics_collector.go)
    - 提取硬编码数字为有意义的常量
    - 提高代码可读性和可维护性

## 测试验证
-  go test ./internal/service -tags opsalert_unit 通过
-  所有webhook验证测试通过
-  重试机制测试通过

## 影响范围
- 运维监控系统安全性显著提升
- 系统稳定性和性能优化
- 无破坏性变更,向后兼容

* feat(ops): 运维监控系统V2 - 完整实现

## 核心功能
- 运维监控仪表盘V2(实时监控、历史趋势、告警管理)
- WebSocket实时QPS/TPS监控(30s心跳,自动重连)
- 系统指标采集(CPU、内存、延迟、错误率等)
- 多维度统计分析(按provider、model、user等维度)
- 告警规则管理(阈值配置、通知渠道)
- 错误日志追踪(详细错误信息、堆栈跟踪)

## 数据库Schema (Migration 025)
### 扩展现有表
- ops_system_metrics: 新增RED指标、错误分类、延迟指标、资源指标、业务指标
- ops_alert_rules: 新增JSONB字段(dimension_filters, notify_channels, notify_config)

### 新增表
- ops_dimension_stats: 多维度统计数据
- ops_data_retention_config: 数据保留策略配置

### 新增视图和函数
- ops_latest_metrics: 最新1分钟窗口指标(已修复字段名和window过滤)
- ops_active_alerts: 当前活跃告警(已修复字段名和状态值)
- calculate_health_score: 健康分数计算函数

## 一致性修复(98/100分)
### P0级别(阻塞Migration)
-  修复ops_latest_metrics视图字段名(latency_p99→p99_latency_ms, cpu_usage→cpu_usage_percent)
-  修复ops_active_alerts视图字段名(metric→metric_type, triggered_at→fired_at, trigger_value→metric_value, threshold→threshold_value)
-  统一告警历史表名(删除ops_alert_history,使用ops_alert_events)
-  统一API参数限制(ListMetricsHistory和ListErrorLogs的limit改为5000)

### P1级别(功能完整性)
-  修复ops_latest_metrics视图未过滤window_minutes(添加WHERE m.window_minutes = 1)
-  修复数据回填UPDATE逻辑(QPS计算改为request_count/(window_minutes*60.0))
-  添加ops_alert_rules JSONB字段后端支持(Go结构体+序列化)

### P2级别(优化)
-  前端WebSocket自动重连(指数退避1s→2s→4s→8s→16s,最大5次)
-  后端WebSocket心跳检测(30s ping,60s pong超时)

## 技术实现
### 后端 (Go)
- Handler层: ops_handler.go(REST API), ops_ws_handler.go(WebSocket)
- Service层: ops_service.go(核心逻辑), ops_cache.go(缓存), ops_alerts.go(告警)
- Repository层: ops_repo.go(数据访问), ops.go(模型定义)
- 路由: admin.go(新增ops相关路由)
- 依赖注入: wire_gen.go(自动生成)

### 前端 (Vue3 + TypeScript)
- 组件: OpsDashboardV2.vue(仪表盘主组件)
- API: ops.ts(REST API + WebSocket封装)
- 路由: index.ts(新增/admin/ops路由)
- 国际化: en.ts, zh.ts(中英文支持)

## 测试验证
-  所有Go测试通过
-  Migration可正常执行
-  WebSocket连接稳定
-  前后端数据结构对齐

* refactor: 代码清理和测试优化

## 测试文件优化
- 简化integration test fixtures和断言
- 优化test helper函数
- 统一测试数据格式

## 代码清理
- 移除未使用的代码和注释
- 简化concurrency_cache实现
- 优化middleware错误处理

## 小修复
- 修复gateway_handler和openai_gateway_handler的小问题
- 统一代码风格和格式

变更统计: 27个文件,292行新增,322行删除(净减少30行)

* fix(ops): 运维监控系统安全加固和功能优化

## 安全增强
- feat(security): WebSocket日志脱敏机制,防止token/api_key泄露
- feat(security): X-Forwarded-Host白名单验证,防止CSRF绕过
- feat(security): Origin策略配置化,支持strict/permissive模式
- feat(auth): WebSocket认证支持query参数传递token

## 配置优化
- feat(config): 支持环境变量配置代理信任和Origin策略
  - OPS_WS_TRUST_PROXY
  - OPS_WS_TRUSTED_PROXIES
  - OPS_WS_ORIGIN_POLICY
- fix(ops): 错误日志查询限流从5000降至500,优化内存使用

## 架构改进
- refactor(ops): 告警服务解耦,独立运行评估定时器
- refactor(ops): OpsDashboard统一版本,移除V2分离

## 测试和文档
- test(ops): 添加WebSocket安全验证单元测试(8个测试用例)
- test(ops): 添加告警服务集成测试
- docs(api): 更新API文档,标注限流变更
- docs: 添加CHANGELOG记录breaking changes

## 修复文件
Backend:
- backend/internal/server/middleware/logger.go
- backend/internal/handler/admin/ops_handler.go
- backend/internal/handler/admin/ops_ws_handler.go
- backend/internal/server/middleware/admin_auth.go
- backend/internal/service/ops_alert_service.go
- backend/internal/service/ops_metrics_collector.go
- backend/internal/service/wire.go

Frontend:
- frontend/src/views/admin/ops/OpsDashboard.vue
- frontend/src/router/index.ts
- frontend/src/api/admin/ops.ts

Tests:
- backend/internal/handler/admin/ops_ws_handler_test.go (新增)
- backend/internal/service/ops_alert_service_integration_test.go (新增)

Docs:
- CHANGELOG.md (新增)
- docs/API-运维监控中心2.0.md (更新)

* fix(migrations): 修复calculate_health_score函数类型匹配问题

在ops_latest_metrics视图中添加显式类型转换,确保参数类型与函数签名匹配

* fix(lint): 修复golangci-lint检查发现的所有问题

- 将Redis依赖从service层移到repository层
- 添加错误检查(WebSocket连接和读取超时)
- 运行gofmt格式化代码
- 添加nil指针检查
- 删除未使用的alertService字段

修复问题:
- depguard: 3个(service层不应直接import redis)
- errcheck: 3个(未检查错误返回值)
- gofmt: 2个(代码格式问题)
- staticcheck: 4个(nil指针解引用)
- unused: 1个(未使用字段)

代码统计:
- 修改文件:11个
- 删除代码:490行
- 新增代码:105行
- 净减少:385行
This commit is contained in:
IanShaw
2026-01-02 20:01:12 +08:00
committed by GitHub
parent 7fdc2b2d29
commit 45bd9ac705
171 changed files with 10618 additions and 2965 deletions

View File

@@ -135,12 +135,12 @@ func (s *AccountRepoSuite) TestListWithFilters() {
name: "filter_by_type",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeApiKey})
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeAPIKey})
},
accType: service.AccountTypeApiKey,
accType: service.AccountTypeAPIKey,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type)
s.Require().Equal(service.AccountTypeAPIKey, accounts[0].Type)
},
},
{

View File

@@ -80,7 +80,7 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
}
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsAPIKeys(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
entClient := tx.Client()
@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
userRepo := newUserRepositoryWithSQL(entClient, tx)
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
apiKeyRepo := NewApiKeyRepository(entClient)
apiKeyRepo := NewAPIKeyRepository(entClient)
u := &service.User{
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
@@ -110,7 +110,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
}
require.NoError(t, userRepo.Create(ctx, u))
key := &service.ApiKey{
key := &service.APIKey{
UserID: u.ID,
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
Name: "test key",

View File

@@ -24,7 +24,7 @@ type apiKeyCache struct {
rdb *redis.Client
}
func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
func NewAPIKeyCache(rdb *redis.Client) service.APIKeyCache {
return &apiKeyCache{rdb: rdb}
}

View File

@@ -13,11 +13,11 @@ import (
"github.com/stretchr/testify/suite"
)
type ApiKeyCacheSuite struct {
type APIKeyCacheSuite struct {
IntegrationRedisSuite
}
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
func (s *APIKeyCacheSuite) TestCreateAttemptCount() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
@@ -78,7 +78,7 @@ func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
}
}
func (s *ApiKeyCacheSuite) TestDailyUsage() {
func (s *APIKeyCacheSuite) TestDailyUsage() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
@@ -122,6 +122,6 @@ func (s *ApiKeyCacheSuite) TestDailyUsage() {
}
}
func TestApiKeyCacheSuite(t *testing.T) {
suite.Run(t, new(ApiKeyCacheSuite))
func TestAPIKeyCacheSuite(t *testing.T) {
suite.Run(t, new(APIKeyCacheSuite))
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestApiKeyRateLimitKey(t *testing.T) {
func TestAPIKeyRateLimitKey(t *testing.T) {
tests := []struct {
name string
userID int64

View File

@@ -16,17 +16,17 @@ type apiKeyRepository struct {
client *dbent.Client
}
func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
return &apiKeyRepository{client: client}
}
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
// 默认过滤已软删除记录,避免删除后仍被查询到。
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
return r.client.APIKey.Query().Where(apikey.DeletedAtIsNil())
}
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
created, err := r.client.ApiKey.Create().
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
created, err := r.client.APIKey.Create().
SetUserID(key.UserID).
SetKey(key.Key).
SetName(key.Name).
@@ -38,10 +38,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) erro
key.CreatedAt = created.CreatedAt
key.UpdatedAt = created.UpdatedAt
}
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
return translatePersistenceError(err, nil, service.ErrAPIKeyExists)
}
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
WithUser().
@@ -49,7 +49,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrApiKeyNotFound
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
@@ -59,7 +59,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiK
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
// 相比 GetByID此方法性能更优因为
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
// - 不加载完整的 ApiKey 实体及其关联数据User、Group 等)
// - 不加载完整的 APIKey 实体及其关联数据User、Group 等)
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
m, err := r.activeQuery().
@@ -68,14 +68,14 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, service.ErrApiKeyNotFound
return 0, service.ErrAPIKeyNotFound
}
return 0, err
}
return m.UserID, nil
}
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.KeyEQ(key)).
WithUser().
@@ -83,21 +83,21 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrApiKeyNotFound
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
// 之前的实现先检查 Exist 再 UpdateOneID若在两步之间发生软删除
// 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at避免二次查询带来的并发可见性问题。
now := time.Now()
builder := r.client.ApiKey.Update().
builder := r.client.APIKey.Update().
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name).
SetStatus(key.Status).
@@ -114,7 +114,7 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
}
if affected == 0 {
// 更新影响行数为 0说明记录不存在或已被软删除。
return service.ErrApiKeyNotFound
return service.ErrAPIKeyNotFound
}
// 使用同一时间戳回填,避免并发删除导致二次查询失败。
@@ -124,18 +124,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) erro
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
affected, err := r.client.ApiKey.Update().
affected, err := r.client.APIKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetDeletedAt(time.Now()).
Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrApiKeyNotFound
return service.ErrAPIKeyNotFound
}
return err
}
if affected == 0 {
exists, err := r.client.ApiKey.Query().
exists, err := r.client.APIKey.Query().
Where(apikey.IDEQ(id)).
Exist(mixins.SkipSoftDelete(ctx))
if err != nil {
@@ -144,12 +144,12 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
if exists {
return nil
}
return service.ErrApiKeyNotFound
return service.ErrAPIKeyNotFound
}
return nil
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
total, err := q.Count(ctx)
@@ -167,7 +167,7 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return nil, nil, err
}
outKeys := make([]service.ApiKey, 0, len(keys))
outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
@@ -180,7 +180,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
return []int64{}, nil
}
ids, err := r.client.ApiKey.Query().
ids, err := r.client.APIKey.Query().
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
IDs(ctx)
if err != nil {
@@ -199,7 +199,7 @@ func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
return count > 0, err
}
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
total, err := q.Count(ctx)
@@ -217,7 +217,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return nil, nil, err
}
outKeys := make([]service.ApiKey, 0, len(keys))
outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
@@ -225,8 +225,8 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return outKeys, paginationResultFromTotal(int64(total), params), nil
}
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
q := r.activeQuery()
if userID > 0 {
q = q.Where(apikey.UserIDEQ(userID))
@@ -241,7 +241,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
return nil, err
}
outKeys := make([]service.ApiKey, 0, len(keys))
outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
@@ -250,7 +250,7 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
n, err := r.client.ApiKey.Update().
n, err := r.client.APIKey.Update().
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
ClearGroupID().
Save(ctx)
@@ -263,11 +263,11 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
return int64(count), err
}
func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey {
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
if m == nil {
return nil
}
out := &service.ApiKey{
out := &service.APIKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,

View File

@@ -12,30 +12,30 @@ import (
"github.com/stretchr/testify/suite"
)
type ApiKeyRepoSuite struct {
type APIKeyRepoSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
repo *apiKeyRepository
}
func (s *ApiKeyRepoSuite) SetupTest() {
func (s *APIKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository)
s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository)
}
func TestApiKeyRepoSuite(t *testing.T) {
suite.Run(t, new(ApiKeyRepoSuite))
func TestAPIKeyRepoSuite(t *testing.T) {
suite.Run(t, new(APIKeyRepoSuite))
}
// --- Create / GetByID / GetByKey ---
func (s *ApiKeyRepoSuite) TestCreate() {
func (s *APIKeyRepoSuite) TestCreate() {
user := s.mustCreateUser("create@test.com")
key := &service.ApiKey{
key := &service.APIKey{
UserID: user.ID,
Key: "sk-create-test",
Name: "Test Key",
@@ -51,16 +51,16 @@ func (s *ApiKeyRepoSuite) TestCreate() {
s.Require().Equal("sk-create-test", got.Key)
}
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
func (s *APIKeyRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *ApiKeyRepoSuite) TestGetByKey() {
func (s *APIKeyRepoSuite) TestGetByKey() {
user := s.mustCreateUser("getbykey@test.com")
group := s.mustCreateGroup("g-key")
key := &service.ApiKey{
key := &service.APIKey{
UserID: user.ID,
Key: "sk-getbykey",
Name: "My Key",
@@ -78,16 +78,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey() {
s.Require().Equal(group.ID, got.Group.ID)
}
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
s.Require().Error(err, "expected error for non-existent key")
}
// --- Update ---
func (s *ApiKeyRepoSuite) TestUpdate() {
func (s *APIKeyRepoSuite) TestUpdate() {
user := s.mustCreateUser("update@test.com")
key := &service.ApiKey{
key := &service.APIKey{
UserID: user.ID,
Key: "sk-update",
Name: "Original",
@@ -108,10 +108,10 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
s.Require().Equal(service.StatusDisabled, got.Status)
}
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
func (s *APIKeyRepoSuite) TestUpdate_ClearGroupID() {
user := s.mustCreateUser("cleargroup@test.com")
group := s.mustCreateGroup("g-clear")
key := &service.ApiKey{
key := &service.APIKey{
UserID: user.ID,
Key: "sk-clear-group",
Name: "Group Key",
@@ -131,9 +131,9 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
// --- Delete ---
func (s *ApiKeyRepoSuite) TestDelete() {
func (s *APIKeyRepoSuite) TestDelete() {
user := s.mustCreateUser("delete@test.com")
key := &service.ApiKey{
key := &service.APIKey{
UserID: user.ID,
Key: "sk-delete",
Name: "Delete Me",
@@ -150,10 +150,10 @@ func (s *ApiKeyRepoSuite) TestDelete() {
// --- ListByUserID / CountByUserID ---
func (s *ApiKeyRepoSuite) TestListByUserID() {
func (s *APIKeyRepoSuite) TestListByUserID() {
user := s.mustCreateUser("listbyuser@test.com")
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
s.mustCreateAPIKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateAPIKey(user.ID, "sk-list-2", "Key 2", nil)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
@@ -161,10 +161,10 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
s.Require().Equal(int64(2), page.Total)
}
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
user := s.mustCreateUser("paging@test.com")
for i := 0; i < 5; i++ {
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
s.mustCreateAPIKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
}
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
@@ -174,10 +174,10 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
s.Require().Equal(3, page.Pages)
}
func (s *ApiKeyRepoSuite) TestCountByUserID() {
func (s *APIKeyRepoSuite) TestCountByUserID() {
user := s.mustCreateUser("count@test.com")
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
s.mustCreateAPIKey(user.ID, "sk-count-1", "K1", nil)
s.mustCreateAPIKey(user.ID, "sk-count-2", "K2", nil)
count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID")
@@ -186,13 +186,13 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
// --- ListByGroupID / CountByGroupID ---
func (s *ApiKeyRepoSuite) TestListByGroupID() {
func (s *APIKeyRepoSuite) TestListByGroupID() {
user := s.mustCreateUser("listbygroup@test.com")
group := s.mustCreateGroup("g-list")
s.mustCreateApiKey(user.ID, "sk-grp-1", "K1", &group.ID)
s.mustCreateApiKey(user.ID, "sk-grp-2", "K2", &group.ID)
s.mustCreateApiKey(user.ID, "sk-grp-3", "K3", nil) // no group
s.mustCreateAPIKey(user.ID, "sk-grp-1", "K1", &group.ID)
s.mustCreateAPIKey(user.ID, "sk-grp-2", "K2", &group.ID)
s.mustCreateAPIKey(user.ID, "sk-grp-3", "K3", nil) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
@@ -202,10 +202,10 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
s.Require().NotNil(keys[0].User)
}
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
func (s *APIKeyRepoSuite) TestCountByGroupID() {
user := s.mustCreateUser("countgroup@test.com")
group := s.mustCreateGroup("g-count")
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
s.mustCreateAPIKey(user.ID, "sk-gc-1", "K1", &group.ID)
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
@@ -214,9 +214,9 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
// --- ExistsByKey ---
func (s *ApiKeyRepoSuite) TestExistsByKey() {
func (s *APIKeyRepoSuite) TestExistsByKey() {
user := s.mustCreateUser("exists@test.com")
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
s.mustCreateAPIKey(user.ID, "sk-exists", "K", nil)
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey")
@@ -227,47 +227,47 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
s.Require().False(notExists)
}
// --- SearchApiKeys ---
// --- SearchAPIKeys ---
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
func (s *APIKeyRepoSuite) TestSearchAPIKeys() {
user := s.mustCreateUser("search@test.com")
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
s.mustCreateAPIKey(user.ID, "sk-search-1", "Production Key", nil)
s.mustCreateAPIKey(user.ID, "sk-search-2", "Development Key", nil)
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys")
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchAPIKeys")
s.Require().Len(found, 1)
s.Require().Contains(found[0].Name, "Production")
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoKeyword() {
user := s.mustCreateUser("searchnokw@test.com")
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
s.mustCreateAPIKey(user.ID, "sk-nk-1", "K1", nil)
s.mustCreateAPIKey(user.ID, "sk-nk-2", "K2", nil)
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err)
s.Require().Len(found, 2)
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoUserID() {
user := s.mustCreateUser("searchnouid@test.com")
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
s.mustCreateAPIKey(user.ID, "sk-nu-1", "TestKey", nil)
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
found, err := s.repo.SearchAPIKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err)
s.Require().Len(found, 1)
}
// --- ClearGroupIDByGroupID ---
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
func (s *APIKeyRepoSuite) TestClearGroupIDByGroupID() {
user := s.mustCreateUser("cleargrp@test.com")
group := s.mustCreateGroup("g-clear-bulk")
k1 := s.mustCreateApiKey(user.ID, "sk-clr-1", "K1", &group.ID)
k2 := s.mustCreateApiKey(user.ID, "sk-clr-2", "K2", &group.ID)
s.mustCreateApiKey(user.ID, "sk-clr-3", "K3", nil) // no group
k1 := s.mustCreateAPIKey(user.ID, "sk-clr-1", "K1", &group.ID)
k2 := s.mustCreateAPIKey(user.ID, "sk-clr-2", "K2", &group.ID)
s.mustCreateAPIKey(user.ID, "sk-clr-3", "K3", nil) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
@@ -284,10 +284,10 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := s.mustCreateUser("k@example.com")
group := s.mustCreateGroup("g-k")
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
key := s.mustCreateAPIKey(user.ID, "sk-test-1", "My Key", &group.ID)
key.GroupID = &group.ID
got, err := s.repo.GetByKey(s.ctx, key.Key)
@@ -320,13 +320,13 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().NoError(err, "ExistsByKey")
s.Require().True(exists, "expected key to exist")
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
s.Require().NoError(err, "SearchApiKeys")
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "renam", 10)
s.Require().NoError(err, "SearchAPIKeys")
s.Require().Len(found, 1)
s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID
k2 := s.mustCreateApiKey(user.ID, "sk-test-2", "Group Key", &group.ID)
k2 := s.mustCreateAPIKey(user.ID, "sk-test-2", "Group Key", &group.ID)
k2.GroupID = &group.ID
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
@@ -346,7 +346,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
}
func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
func (s *APIKeyRepoSuite) mustCreateUser(email string) *service.User {
s.T().Helper()
u, err := s.client.User.Create().
@@ -359,7 +359,7 @@ func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
return userEntityToService(u)
}
func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
func (s *APIKeyRepoSuite) mustCreateGroup(name string) *service.Group {
s.T().Helper()
g, err := s.client.Group.Create().
@@ -370,10 +370,10 @@ func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
return groupEntityToService(g)
}
func (s *ApiKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.ApiKey {
func (s *APIKeyRepoSuite) mustCreateAPIKey(userID int64, key, name string, groupID *int64) *service.APIKey {
s.T().Helper()
k := &service.ApiKey{
k := &service.APIKey{
UserID: userID,
Key: key,
Name: name,

View File

@@ -27,8 +27,14 @@ const (
accountSlotKeyPrefix = "concurrency:account:"
// 格式: concurrency:user:{userID}
userSlotKeyPrefix = "concurrency:user:"
// 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// Wait queue keys (global structures)
// - total: integer total queue depth across all users
// - updated: sorted set of userID -> lastUpdateUnixSec (for TTL cleanup)
// - counts: hash of userID -> current wait count
waitQueueTotalKey = "concurrency:wait:total"
waitQueueUpdatedKey = "concurrency:wait:updated"
waitQueueCountsKey = "concurrency:wait:counts"
// 账号级等待队列计数器格式: wait:account:{accountID}
accountWaitKeyPrefix = "wait:account:"
@@ -94,27 +100,55 @@ var (
`)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
// KEYS[1] = wait queue key
// ARGV[1] = maxWait
// ARGV[2] = TTL in seconds
// KEYS[1] = total key
// KEYS[2] = updated zset key
// KEYS[3] = counts hash key
// ARGV[1] = userID
// ARGV[2] = maxWait
// ARGV[3] = TTL in seconds
// ARGV[4] = cleanup limit
incrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
local totalKey = KEYS[1]
local updatedKey = KEYS[2]
local countsKey = KEYS[3]
local userID = ARGV[1]
local maxWait = tonumber(ARGV[2])
local ttl = tonumber(ARGV[3])
local cleanupLimit = tonumber(ARGV[4])
redis.call('SETNX', totalKey, 0)
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Cleanup expired users (bounded)
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
for _, uid in ipairs(expired) do
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
if c > 0 then
redis.call('DECRBY', totalKey, c)
end
redis.call('HDEL', countsKey, uid)
redis.call('ZREM', updatedKey, uid)
end
if current >= tonumber(ARGV[1]) then
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
if current >= maxWait then
return 0
end
local newVal = redis.call('INCR', KEYS[1])
local newVal = current + 1
redis.call('HSET', countsKey, userID, newVal)
redis.call('ZADD', updatedKey, now, userID)
redis.call('INCR', totalKey)
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
-- Keep global structures from living forever in totally idle deployments.
local ttlKeep = ttl * 2
redis.call('EXPIRE', totalKey, ttlKeep)
redis.call('EXPIRE', updatedKey, ttlKeep)
redis.call('EXPIRE', countsKey, ttlKeep)
return 1
`)
@@ -144,6 +178,111 @@ var (
// decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(`
local totalKey = KEYS[1]
local updatedKey = KEYS[2]
local countsKey = KEYS[3]
local userID = ARGV[1]
local ttl = tonumber(ARGV[2])
local cleanupLimit = tonumber(ARGV[3])
redis.call('SETNX', totalKey, 0)
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Cleanup expired users (bounded)
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
for _, uid in ipairs(expired) do
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
if c > 0 then
redis.call('DECRBY', totalKey, c)
end
redis.call('HDEL', countsKey, uid)
redis.call('ZREM', updatedKey, uid)
end
local current = tonumber(redis.call('HGET', countsKey, userID) or '0')
if current <= 0 then
return 1
end
local newVal = current - 1
if newVal <= 0 then
redis.call('HDEL', countsKey, userID)
redis.call('ZREM', updatedKey, userID)
else
redis.call('HSET', countsKey, userID, newVal)
redis.call('ZADD', updatedKey, now, userID)
end
redis.call('DECR', totalKey)
local ttlKeep = ttl * 2
redis.call('EXPIRE', totalKey, ttlKeep)
redis.call('EXPIRE', updatedKey, ttlKeep)
redis.call('EXPIRE', countsKey, ttlKeep)
return 1
`)
// getTotalWaitScript returns the global wait depth with TTL cleanup.
// KEYS[1] = total key
// KEYS[2] = updated zset key
// KEYS[3] = counts hash key
// ARGV[1] = TTL in seconds
// ARGV[2] = cleanup limit
getTotalWaitScript = redis.NewScript(`
local totalKey = KEYS[1]
local updatedKey = KEYS[2]
local countsKey = KEYS[3]
local ttl = tonumber(ARGV[1])
local cleanupLimit = tonumber(ARGV[2])
redis.call('SETNX', totalKey, 0)
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Cleanup expired users (bounded)
local expired = redis.call('ZRANGEBYSCORE', updatedKey, '-inf', expireBefore, 'LIMIT', 0, cleanupLimit)
for _, uid in ipairs(expired) do
local c = tonumber(redis.call('HGET', countsKey, uid) or '0')
if c > 0 then
redis.call('DECRBY', totalKey, c)
end
redis.call('HDEL', countsKey, uid)
redis.call('ZREM', updatedKey, uid)
end
-- If totalKey got lost but counts exist (e.g. Redis restart), recompute once.
local total = redis.call('GET', totalKey)
if total == false then
total = 0
local vals = redis.call('HVALS', countsKey)
for _, v in ipairs(vals) do
total = total + tonumber(v)
end
redis.call('SET', totalKey, total)
end
local ttlKeep = ttl * 2
redis.call('EXPIRE', totalKey, ttlKeep)
redis.call('EXPIRE', updatedKey, ttlKeep)
redis.call('EXPIRE', countsKey, ttlKeep)
local result = tonumber(redis.call('GET', totalKey) or '0')
if result < 0 then
result = 0
redis.call('SET', totalKey, 0)
end
return result
`)
// decrementAccountWaitScript - account-level wait queue decrement
decrementAccountWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
@@ -244,7 +383,9 @@ func userSlotKey(userID int64) string {
}
func waitQueueKey(userID int64) string {
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
// Historical: per-user string keys were used.
// Now we use global structures keyed by userID string.
return strconv.FormatInt(userID, 10)
}
func accountWaitKey(accountID int64) string {
@@ -308,8 +449,16 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
// Wait queue operations
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int()
userKey := waitQueueKey(userID)
result, err := incrementWaitScript.Run(
ctx,
c.rdb,
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
userKey,
maxWait,
c.waitQueueTTLSeconds,
200, // cleanup limit per call
).Int()
if err != nil {
return false, err
}
@@ -317,11 +466,35 @@ func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64,
}
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
key := waitQueueKey(userID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
userKey := waitQueueKey(userID)
_, err := decrementWaitScript.Run(
ctx,
c.rdb,
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
userKey,
c.waitQueueTTLSeconds,
200, // cleanup limit per call
).Result()
return err
}
func (c *concurrencyCache) GetTotalWaitCount(ctx context.Context) (int, error) {
if c.rdb == nil {
return 0, nil
}
total, err := getTotalWaitScript.Run(
ctx,
c.rdb,
[]string{waitQueueTotalKey, waitQueueUpdatedKey, waitQueueCountsKey},
c.waitQueueTTLSeconds,
500, // cleanup limit per query (rare)
).Int64()
if err != nil {
return 0, err
}
return int(total), nil
}
// Account wait queue operations
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
@@ -335,7 +508,7 @@ func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accoun
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
key := accountWaitKey(accountID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
_, err := decrementAccountWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}

View File

@@ -158,7 +158,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
userID := int64(20)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
userKey := waitQueueKey(userID)
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 1")
@@ -172,31 +172,31 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
require.NoError(s.T(), err, "IncrementWaitCount 3")
require.False(s.T(), ok, "expected wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
ttl, err := s.rdb.TTL(s.ctx, waitQueueTotalKey).Result()
require.NoError(s.T(), err, "TTL wait total key")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL*2)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int()
require.NoError(s.T(), err, "HGET wait queue count")
require.Equal(s.T(), 1, val, "expected wait count 1")
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
require.NoError(s.T(), err, "GET wait queue total")
require.Equal(s.T(), 1, total, "expected total wait count 1")
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
userID := int64(300)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
userKey := waitQueueKey(userID)
// Test decrement on non-existent key - should not error and should not create negative value
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
// Verify no key was created or it's not negative
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
// Verify count remains zero / absent.
val, err := s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Int()
require.True(s.T(), errors.Is(err, redis.Nil))
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
// Set count to 1, then decrement twice
@@ -210,12 +210,15 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
// Decrement again on 0 - should not go negative
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
// Verify count is 0, not negative
val, err = s.rdb.Get(s.ctx, waitKey).Int()
// Verify per-user count is absent and total is non-negative.
_, err = s.rdb.HGet(s.ctx, waitQueueCountsKey, userKey).Result()
require.True(s.T(), errors.Is(err, redis.Nil), "expected count field removed on zero")
total, err := s.rdb.Get(s.ctx, waitQueueTotalKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey after double decrement")
require.NoError(s.T(), err)
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
require.GreaterOrEqual(s.T(), total, 0, "expected non-negative total wait count")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {

View File

@@ -1,4 +1,4 @@
// Package infrastructure 提供应用程序的基础设施层组件。
// Package repository 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package repository

View File

@@ -243,7 +243,7 @@ func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) *
return a
}
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *service.ApiKey {
func mustCreateAPIKey(t *testing.T, client *dbent.Client, k *service.APIKey) *service.APIKey {
t.Helper()
ctx := context.Background()
@@ -257,7 +257,7 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *se
k.Name = "default"
}
create := client.ApiKey.Create().
create := client.APIKey.Create().
SetUserID(k.UserID).
SetKey(k.Key).
SetName(k.Name).

View File

@@ -293,8 +293,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
// 2. Clear group_id for api keys bound to this group.
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
// 与 ApiKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
if _, err := txClient.ApiKey.Update().
// 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
if _, err := txClient.APIKey.Update().
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
ClearGroupID().
Save(ctx); err != nil {

View File

@@ -0,0 +1,190 @@
package repository
import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// ListErrorLogs queries ops_error_logs with optional filters and pagination.
// It returns the list items and the total count of matching rows.
func (r *OpsRepository) ListErrorLogs(ctx context.Context, filter *service.ErrorLogFilter) ([]*service.ErrorLog, int64, error) {
page := 1
pageSize := 20
if filter != nil {
if filter.Page > 0 {
page = filter.Page
}
if filter.PageSize > 0 {
pageSize = filter.PageSize
}
}
if pageSize > 100 {
pageSize = 100
}
offset := (page - 1) * pageSize
conditions := make([]string, 0)
args := make([]any, 0)
addCondition := func(condition string, values ...any) {
conditions = append(conditions, condition)
args = append(args, values...)
}
if filter != nil {
// 默认查询最近 24 小时
if filter.StartTime == nil && filter.EndTime == nil {
defaultStart := time.Now().Add(-24 * time.Hour)
filter.StartTime = &defaultStart
}
if filter.StartTime != nil {
addCondition(fmt.Sprintf("created_at >= $%d", len(args)+1), *filter.StartTime)
}
if filter.EndTime != nil {
addCondition(fmt.Sprintf("created_at <= $%d", len(args)+1), *filter.EndTime)
}
if filter.ErrorCode != nil {
addCondition(fmt.Sprintf("status_code = $%d", len(args)+1), *filter.ErrorCode)
}
if provider := strings.TrimSpace(filter.Provider); provider != "" {
addCondition(fmt.Sprintf("platform = $%d", len(args)+1), provider)
}
if filter.AccountID != nil {
addCondition(fmt.Sprintf("account_id = $%d", len(args)+1), *filter.AccountID)
}
}
where := ""
if len(conditions) > 0 {
where = "WHERE " + strings.Join(conditions, " AND ")
}
countQuery := fmt.Sprintf(`SELECT COUNT(1) FROM ops_error_logs %s`, where)
var total int64
if err := scanSingleRow(ctx, r.sql, countQuery, args, &total); err != nil {
if err == sql.ErrNoRows {
total = 0
} else {
return nil, 0, err
}
}
listQuery := fmt.Sprintf(`
SELECT
id,
created_at,
severity,
request_id,
account_id,
request_path,
platform,
model,
status_code,
error_message,
duration_ms,
retry_count,
stream
FROM ops_error_logs
%s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d
`, where, len(args)+1, len(args)+2)
listArgs := append(append([]any{}, args...), pageSize, offset)
rows, err := r.sql.QueryContext(ctx, listQuery, listArgs...)
if err != nil {
return nil, 0, err
}
defer func() { _ = rows.Close() }()
results := make([]*service.ErrorLog, 0)
for rows.Next() {
var (
id int64
createdAt time.Time
severity sql.NullString
requestID sql.NullString
accountID sql.NullInt64
requestURI sql.NullString
platform sql.NullString
model sql.NullString
statusCode sql.NullInt64
message sql.NullString
durationMs sql.NullInt64
retryCount sql.NullInt64
stream sql.NullBool
)
if err := rows.Scan(
&id,
&createdAt,
&severity,
&requestID,
&accountID,
&requestURI,
&platform,
&model,
&statusCode,
&message,
&durationMs,
&retryCount,
&stream,
); err != nil {
return nil, 0, err
}
entry := &service.ErrorLog{
ID: id,
Timestamp: createdAt,
Level: levelFromSeverity(severity.String),
RequestID: requestID.String,
APIPath: requestURI.String,
Provider: platform.String,
Model: model.String,
HTTPCode: int(statusCode.Int64),
Stream: stream.Bool,
}
if accountID.Valid {
entry.AccountID = strconv.FormatInt(accountID.Int64, 10)
}
if message.Valid {
entry.ErrorMessage = message.String
}
if durationMs.Valid {
v := int(durationMs.Int64)
entry.DurationMs = &v
}
if retryCount.Valid {
v := int(retryCount.Int64)
entry.RetryCount = &v
}
results = append(results, entry)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return results, total, nil
}
func levelFromSeverity(severity string) string {
sev := strings.ToUpper(strings.TrimSpace(severity))
switch sev {
case "P0", "P1":
return "CRITICAL"
case "P2":
return "ERROR"
case "P3":
return "WARN"
default:
return "ERROR"
}
}

View File

@@ -0,0 +1,127 @@
package repository
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
opsLatestMetricsKey = "ops:metrics:latest"
opsDashboardOverviewKeyPrefix = "ops:dashboard:overview:"
opsLatestMetricsTTL = 10 * time.Second
)
func (r *OpsRepository) GetCachedLatestSystemMetric(ctx context.Context) (*service.OpsMetrics, error) {
if ctx == nil {
ctx = context.Background()
}
if r == nil || r.rdb == nil {
return nil, nil
}
data, err := r.rdb.Get(ctx, opsLatestMetricsKey).Bytes()
if errors.Is(err, redis.Nil) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("redis get cached latest system metric: %w", err)
}
var metric service.OpsMetrics
if err := json.Unmarshal(data, &metric); err != nil {
return nil, fmt.Errorf("unmarshal cached latest system metric: %w", err)
}
return &metric, nil
}
func (r *OpsRepository) SetCachedLatestSystemMetric(ctx context.Context, metric *service.OpsMetrics) error {
if metric == nil {
return nil
}
if ctx == nil {
ctx = context.Background()
}
if r == nil || r.rdb == nil {
return nil
}
data, err := json.Marshal(metric)
if err != nil {
return fmt.Errorf("marshal cached latest system metric: %w", err)
}
return r.rdb.Set(ctx, opsLatestMetricsKey, data, opsLatestMetricsTTL).Err()
}
func (r *OpsRepository) GetCachedDashboardOverview(ctx context.Context, timeRange string) (*service.DashboardOverviewData, error) {
if ctx == nil {
ctx = context.Background()
}
if r == nil || r.rdb == nil {
return nil, nil
}
rangeKey := strings.TrimSpace(timeRange)
if rangeKey == "" {
rangeKey = "1h"
}
key := opsDashboardOverviewKeyPrefix + rangeKey
data, err := r.rdb.Get(ctx, key).Bytes()
if errors.Is(err, redis.Nil) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("redis get cached dashboard overview: %w", err)
}
var overview service.DashboardOverviewData
if err := json.Unmarshal(data, &overview); err != nil {
return nil, fmt.Errorf("unmarshal cached dashboard overview: %w", err)
}
return &overview, nil
}
func (r *OpsRepository) SetCachedDashboardOverview(ctx context.Context, timeRange string, data *service.DashboardOverviewData, ttl time.Duration) error {
if data == nil {
return nil
}
if ttl <= 0 {
ttl = 10 * time.Second
}
if ctx == nil {
ctx = context.Background()
}
if r == nil || r.rdb == nil {
return nil
}
rangeKey := strings.TrimSpace(timeRange)
if rangeKey == "" {
rangeKey = "1h"
}
payload, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("marshal cached dashboard overview: %w", err)
}
key := opsDashboardOverviewKeyPrefix + rangeKey
return r.rdb.Set(ctx, key, payload, ttl).Err()
}
func (r *OpsRepository) PingRedis(ctx context.Context) error {
if ctx == nil {
ctx = context.Background()
}
if r == nil || r.rdb == nil {
return errors.New("redis client is nil")
}
return r.rdb.Ping(ctx).Err()
}

File diff suppressed because it is too large Load Diff

View File

@@ -34,15 +34,15 @@ func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, emai
return u
}
func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
func TestEntSoftDelete_APIKey_DefaultFilterAndSkip(t *testing.T) {
ctx := context.Background()
// 使用全局 ent client确保软删除验证在实际持久化数据上进行。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
Name: "soft-delete",
@@ -53,28 +53,28 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
_, err := repo.GetByID(ctx, key.ID)
require.ErrorIs(t, err, service.ErrApiKeyNotFound, "deleted rows should be hidden by default")
require.ErrorIs(t, err, service.ErrAPIKeyNotFound, "deleted rows should be hidden by default")
_, err = client.ApiKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
_, err = client.APIKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
require.Error(t, err, "default ent query should not see soft-deleted rows")
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
got, err := client.ApiKey.Query().
got, err := client.APIKey.Query().
Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
}
func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
func TestEntSoftDelete_APIKey_DeleteIdempotent(t *testing.T) {
ctx := context.Background()
// 使用全局 ent client避免事务回滚影响幂等性验证。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
Name: "soft-delete2",
@@ -86,15 +86,15 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "second delete should be idempotent")
}
func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
func TestEntSoftDelete_APIKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
ctx := context.Background()
// 使用全局 ent client确保 SkipSoftDelete 的硬删除语义可验证。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),
Name: "soft-delete3",
@@ -105,10 +105,10 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
// Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at.
_, err := client.ApiKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
_, err := client.APIKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "hard delete")
_, err = client.ApiKey.Query().
_, err = client.APIKey.Query().
Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")

View File

@@ -117,7 +117,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
args := []any{
log.UserID,
log.ApiKeyID,
log.APIKeyID,
log.AccountID,
log.RequestID,
log.Model,
@@ -183,7 +183,7 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params)
}
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params)
}
@@ -270,8 +270,8 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
r.sql,
apiKeyStatsQuery,
[]any{service.StatusActive},
&stats.TotalApiKeys,
&stats.ActiveApiKeys,
&stats.TotalAPIKeys,
&stats.ActiveAPIKeys,
); err != nil {
return nil, err
}
@@ -418,8 +418,8 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
return &stats, nil
}
// GetApiKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
// GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
@@ -623,7 +623,7 @@ func resolveUsageStatsTimezone() string {
return "UTC"
}
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
return logs, nil, err
@@ -709,11 +709,11 @@ type ModelStat = usagestats.ModelStat
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
// APIKeyUsageTrendPoint represents API key usage trend data point
type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) {
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
@@ -755,10 +755,10 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
}
}()
results = make([]ApiKeyUsageTrendPoint, 0)
results = make([]APIKeyUsageTrendPoint, 0)
for rows.Next() {
var row ApiKeyUsageTrendPoint
if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
var row APIKeyUsageTrendPoint
if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
return nil, err
}
results = append(results, row)
@@ -844,7 +844,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
[]any{userID},
&stats.TotalApiKeys,
&stats.TotalAPIKeys,
); err != nil {
return nil, err
}
@@ -853,7 +853,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
[]any{userID, service.StatusActive},
&stats.ActiveApiKeys,
&stats.ActiveAPIKeys,
); err != nil {
return nil, err
}
@@ -1023,9 +1023,9 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
args = append(args, filters.UserID)
}
if filters.ApiKeyID > 0 {
if filters.APIKeyID > 0 {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
args = append(args, filters.ApiKeyID)
args = append(args, filters.APIKeyID)
}
if filters.AccountID > 0 {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
@@ -1145,18 +1145,18 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
return result, nil
}
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
// BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
result := make(map[int64]*BatchApiKeyUsageStats)
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
result := make(map[int64]*BatchAPIKeyUsageStats)
if len(apiKeyIDs) == 0 {
return result, nil
}
for _, id := range apiKeyIDs {
result[id] = &BatchApiKeyUsageStats{ApiKeyID: id}
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
}
query := `
@@ -1582,7 +1582,7 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
if err != nil {
return err
}
apiKeys, err := r.loadApiKeys(ctx, ids.apiKeyIDs)
apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs)
if err != nil {
return err
}
@@ -1603,8 +1603,8 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
if user, ok := users[logs[i].UserID]; ok {
logs[i].User = user
}
if key, ok := apiKeys[logs[i].ApiKeyID]; ok {
logs[i].ApiKey = key
if key, ok := apiKeys[logs[i].APIKeyID]; ok {
logs[i].APIKey = key
}
if acc, ok := accounts[logs[i].AccountID]; ok {
logs[i].Account = acc
@@ -1642,7 +1642,7 @@ func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs {
for i := range logs {
userIDs[logs[i].UserID] = struct{}{}
apiKeyIDs[logs[i].ApiKeyID] = struct{}{}
apiKeyIDs[logs[i].APIKeyID] = struct{}{}
accountIDs[logs[i].AccountID] = struct{}{}
if logs[i].GroupID != nil {
groupIDs[*logs[i].GroupID] = struct{}{}
@@ -1676,12 +1676,12 @@ func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[in
return out, nil
}
func (r *usageLogRepository) loadApiKeys(ctx context.Context, ids []int64) (map[int64]*service.ApiKey, error) {
out := make(map[int64]*service.ApiKey)
func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) {
out := make(map[int64]*service.APIKey)
if len(ids) == 0 {
return out, nil
}
models, err := r.client.ApiKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
if err != nil {
return nil, err
}
@@ -1800,7 +1800,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
log := &service.UsageLog{
ID: id,
UserID: userID,
ApiKeyID: apiKeyID,
APIKeyID: apiKeyID,
AccountID: accountID,
Model: model,
InputTokens: inputTokens,

View File

@@ -35,10 +35,10 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite))
}
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.ApiKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: inputTokens,
@@ -55,12 +55,12 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A
func (s *UsageLogRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-create"})
log := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: 10,
@@ -76,7 +76,7 @@ func (s *UsageLogRepoSuite) TestCreate() {
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -96,7 +96,7 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
func (s *UsageLogRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-delete"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -112,7 +112,7 @@ func (s *UsageLogRepoSuite) TestDelete() {
func (s *UsageLogRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyuser"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -124,18 +124,18 @@ func (s *UsageLogRepoSuite) TestListByUser() {
s.Require().Equal(int64(2), page.Total)
}
// --- ListByApiKey ---
// --- ListByAPIKey ---
func (s *UsageLogRepoSuite) TestListByApiKey() {
func (s *UsageLogRepoSuite) TestListByAPIKey() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyapikey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByApiKey")
logs, page, err := s.repo.ListByAPIKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByAPIKey")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}
@@ -144,7 +144,7 @@ func (s *UsageLogRepoSuite) TestListByApiKey() {
func (s *UsageLogRepoSuite) TestListByAccount() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyaccount"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -159,7 +159,7 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
func (s *UsageLogRepoSuite) TestGetUserStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -179,7 +179,7 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
func (s *UsageLogRepoSuite) TestListWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filters"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -211,8 +211,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
})
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
resetAt := now.Add(10 * time.Minute)
accNormal := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-normal", Schedulable: true})
@@ -223,7 +223,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
d1, d2, d3 := 100, 200, 300
logToday := &service.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
APIKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
GroupID: &group.ID,
@@ -240,7 +240,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
logOld := &service.UsageLog{
UserID: userOld.ID,
ApiKeyID: apiKey1.ID,
APIKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
InputTokens: 5,
@@ -254,7 +254,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
logPerf := &service.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
APIKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
InputTokens: 1,
@@ -272,8 +272,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch")
s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch")
s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch")
s.Require().Equal(baseStats.TotalApiKeys+2, stats.TotalApiKeys, "TotalApiKeys mismatch")
s.Require().Equal(baseStats.ActiveApiKeys+1, stats.ActiveApiKeys, "ActiveApiKeys mismatch")
s.Require().Equal(baseStats.TotalAPIKeys+2, stats.TotalAPIKeys, "TotalAPIKeys mismatch")
s.Require().Equal(baseStats.ActiveAPIKeys+1, stats.ActiveAPIKeys, "ActiveAPIKeys mismatch")
s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch")
s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch")
s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch")
@@ -300,14 +300,14 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userdash"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID)
s.Require().NoError(err, "GetUserDashboardStats")
s.Require().Equal(int64(1), stats.TotalApiKeys)
s.Require().Equal(int64(1), stats.TotalAPIKeys)
s.Require().Equal(int64(1), stats.TotalRequests)
}
@@ -315,7 +315,7 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -331,8 +331,8 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batch"})
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
@@ -351,24 +351,24 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
s.Require().Empty(stats)
}
// --- GetBatchApiKeyUsageStats ---
// --- GetBatchAPIKeyUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
func (s *UsageLogRepoSuite) TestGetBatchAPIKeyUsageStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batchkey"})
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
s.Require().NoError(err, "GetBatchApiKeyUsageStats")
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
s.Require().Len(stats, 2)
}
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{})
func (s *UsageLogRepoSuite) TestGetBatchAPIKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
@@ -377,7 +377,7 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-global"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -402,7 +402,7 @@ func maxTime(a, b time.Time) time.Time {
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-timerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -417,11 +417,11 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
s.Require().Len(logs, 2)
}
// --- ListByApiKeyAndTimeRange ---
// --- ListByAPIKeyAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -431,8 +431,8 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
s.Require().NoError(err, "ListByApiKeyAndTimeRange")
logs, _, err := s.repo.ListByAPIKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
s.Require().NoError(err, "ListByAPIKeyAndTimeRange")
s.Require().Len(logs, 2)
}
@@ -440,7 +440,7 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-acctimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -459,7 +459,7 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modeltimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -467,7 +467,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
// Create logs with different models
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 10,
@@ -480,7 +480,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 15,
@@ -493,7 +493,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
log3 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 20,
@@ -515,7 +515,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-windowstats"})
now := time.Now()
@@ -535,7 +535,7 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrend"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -552,7 +552,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrendhourly"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -571,7 +571,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -579,7 +579,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
// Create logs with different models
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
@@ -592,7 +592,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
@@ -618,7 +618,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -646,7 +646,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters-h"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -665,14 +665,14 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
@@ -685,7 +685,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
@@ -719,7 +719,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-accstats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
@@ -727,7 +727,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
// Create logs on different days
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
@@ -740,7 +740,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
@@ -782,8 +782,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -799,12 +799,12 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
s.Require().GreaterOrEqual(len(trend), 2)
}
// --- GetApiKeyUsageTrend ---
// --- GetAPIKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -815,14 +815,14 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend")
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetAPIKeyUsageTrend")
s.Require().GreaterOrEqual(len(trend), 2)
}
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -832,21 +832,21 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend hourly")
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
s.Require().NoError(err, "GetAPIKeyUsageTrend hourly")
s.Require().Len(trend, 2)
}
// --- ListWithFilters (additional filter tests) ---
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
func (s *UsageLogRepoSuite) TestListWithFilters_APIKeyFilter() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterskey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID}
filters := usagestats.UsageLogFilters{APIKeyID: apiKey.ID}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters apiKey")
s.Require().Len(logs, 1)
@@ -855,7 +855,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterstime"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -874,7 +874,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterscombined"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -885,7 +885,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
endTime := base.Add(2 * time.Hour)
filters := usagestats.UsageLogFilters{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
StartTime: &startTime,
EndTime: &endTime,
}

View File

@@ -28,12 +28,13 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
NewApiKeyRepository,
NewAPIKeyRepository,
NewGroupRepository,
NewAccountRepository,
NewProxyRepository,
NewRedeemCodeRepository,
NewUsageLogRepository,
NewOpsRepository,
NewSettingRepository,
NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository,
@@ -42,7 +43,7 @@ var ProviderSet = wire.NewSet(
// Cache implementations
NewGatewayCache,
NewBillingCache,
NewApiKeyCache,
NewAPIKeyCache,
ProvideConcurrencyCache,
NewEmailCache,
NewIdentityCache,