fix(test): 修复测试和添加数据库迁移

测试修复:
- 修复集成测试中的重复键冲突问题
- 移除 JSON 中多余的尾随逗号
- 新增 inprocess_transport_test.go
- 更新 haiku 模型映射测试用例

数据库迁移:
- 026: 运营指标聚合表
- 027: 使用量与计费一致性约束
This commit is contained in:
ianshaw
2026-01-03 06:36:35 -08:00
parent ff3f514f6b
commit b1702de522
16 changed files with 495 additions and 244 deletions

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"database/sql"
"errors"
"fmt"
"os"
"strings"
@@ -70,6 +71,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
createdAt = time.Now()
}
requestID := strings.TrimSpace(log.RequestID)
log.RequestID = requestID
rateMultiplier := log.RateMultiplier
query := `
@@ -107,6 +111,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
`
@@ -115,11 +120,16 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs)
var requestIDArg any
if requestID != "" {
requestIDArg = requestID
}
args := []any{
log.UserID,
log.APIKeyID,
log.ApiKeyID,
log.AccountID,
log.RequestID,
requestIDArg,
log.Model,
groupID,
subscriptionID,
@@ -143,7 +153,14 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
createdAt,
}
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
return err
if errors.Is(err, sql.ErrNoRows) && requestID != "" {
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
if err := scanSingleRow(ctx, r.sql, selectQuery, []any{requestID, log.ApiKeyID}, &log.ID, &log.CreatedAt); err != nil {
return err
}
} else {
return err
}
}
log.RateMultiplier = rateMultiplier
return nil
@@ -183,7 +200,7 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params)
}
func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params)
}
@@ -270,8 +287,8 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
r.sql,
apiKeyStatsQuery,
[]any{service.StatusActive},
&stats.TotalAPIKeys,
&stats.ActiveAPIKeys,
&stats.TotalApiKeys,
&stats.ActiveApiKeys,
); err != nil {
return nil, err
}
@@ -418,8 +435,8 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
return &stats, nil
}
// GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
// GetApiKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
@@ -623,7 +640,7 @@ func resolveUsageStatsTimezone() string {
return "UTC"
}
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
return logs, nil, err
@@ -709,11 +726,11 @@ type ModelStat = usagestats.ModelStat
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
// APIKeyUsageTrendPoint represents API key usage trend data point
type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
@@ -755,10 +772,10 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime,
}
}()
results = make([]APIKeyUsageTrendPoint, 0)
results = make([]ApiKeyUsageTrendPoint, 0)
for rows.Next() {
var row APIKeyUsageTrendPoint
if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
var row ApiKeyUsageTrendPoint
if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
return nil, err
}
results = append(results, row)
@@ -844,7 +861,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
[]any{userID},
&stats.TotalAPIKeys,
&stats.TotalApiKeys,
); err != nil {
return nil, err
}
@@ -853,7 +870,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
[]any{userID, service.StatusActive},
&stats.ActiveAPIKeys,
&stats.ActiveApiKeys,
); err != nil {
return nil, err
}
@@ -1023,9 +1040,9 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
args = append(args, filters.UserID)
}
if filters.APIKeyID > 0 {
if filters.ApiKeyID > 0 {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
args = append(args, filters.APIKeyID)
args = append(args, filters.ApiKeyID)
}
if filters.AccountID > 0 {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
@@ -1145,18 +1162,18 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
return result, nil
}
// BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
result := make(map[int64]*BatchAPIKeyUsageStats)
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
result := make(map[int64]*BatchApiKeyUsageStats)
if len(apiKeyIDs) == 0 {
return result, nil
}
for _, id := range apiKeyIDs {
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
result[id] = &BatchApiKeyUsageStats{ApiKeyID: id}
}
query := `
@@ -1582,7 +1599,7 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
if err != nil {
return err
}
apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs)
apiKeys, err := r.loadApiKeys(ctx, ids.apiKeyIDs)
if err != nil {
return err
}
@@ -1603,8 +1620,8 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
if user, ok := users[logs[i].UserID]; ok {
logs[i].User = user
}
if key, ok := apiKeys[logs[i].APIKeyID]; ok {
logs[i].APIKey = key
if key, ok := apiKeys[logs[i].ApiKeyID]; ok {
logs[i].ApiKey = key
}
if acc, ok := accounts[logs[i].AccountID]; ok {
logs[i].Account = acc
@@ -1642,7 +1659,7 @@ func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs {
for i := range logs {
userIDs[logs[i].UserID] = struct{}{}
apiKeyIDs[logs[i].APIKeyID] = struct{}{}
apiKeyIDs[logs[i].ApiKeyID] = struct{}{}
accountIDs[logs[i].AccountID] = struct{}{}
if logs[i].GroupID != nil {
groupIDs[*logs[i].GroupID] = struct{}{}
@@ -1676,12 +1693,12 @@ func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[in
return out, nil
}
func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) {
out := make(map[int64]*service.APIKey)
func (r *usageLogRepository) loadApiKeys(ctx context.Context, ids []int64) (map[int64]*service.ApiKey, error) {
out := make(map[int64]*service.ApiKey)
if len(ids) == 0 {
return out, nil
}
models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
models, err := r.client.ApiKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
if err != nil {
return nil, err
}
@@ -1800,7 +1817,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
log := &service.UsageLog{
ID: id,
UserID: userID,
APIKeyID: apiKeyID,
ApiKeyID: apiKeyID,
AccountID: accountID,
Model: model,
InputTokens: inputTokens,