fix(test): 修复测试和添加数据库迁移
测试修复: - 修复集成测试中的重复键冲突问题 - 移除 JSON 中多余的尾随逗号 - 新增 inprocess_transport_test.go - 更新 haiku 模型映射测试用例 数据库迁移: - 026: 运营指标聚合表 - 027: 使用量与计费一致性约束
This commit is contained in:
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -70,6 +71,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
createdAt = time.Now()
|
||||
}
|
||||
|
||||
requestID := strings.TrimSpace(log.RequestID)
|
||||
log.RequestID = requestID
|
||||
|
||||
rateMultiplier := log.RateMultiplier
|
||||
|
||||
query := `
|
||||
@@ -107,6 +111,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
`
|
||||
|
||||
@@ -115,11 +120,16 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
duration := nullInt(log.DurationMs)
|
||||
firstToken := nullInt(log.FirstTokenMs)
|
||||
|
||||
var requestIDArg any
|
||||
if requestID != "" {
|
||||
requestIDArg = requestID
|
||||
}
|
||||
|
||||
args := []any{
|
||||
log.UserID,
|
||||
log.APIKeyID,
|
||||
log.ApiKeyID,
|
||||
log.AccountID,
|
||||
log.RequestID,
|
||||
requestIDArg,
|
||||
log.Model,
|
||||
groupID,
|
||||
subscriptionID,
|
||||
@@ -143,7 +153,14 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
createdAt,
|
||||
}
|
||||
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||
return err
|
||||
if errors.Is(err, sql.ErrNoRows) && requestID != "" {
|
||||
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
|
||||
if err := scanSingleRow(ctx, r.sql, selectQuery, []any{requestID, log.ApiKeyID}, &log.ID, &log.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
log.RateMultiplier = rateMultiplier
|
||||
return nil
|
||||
@@ -183,7 +200,7 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
|
||||
return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params)
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params)
|
||||
}
|
||||
|
||||
@@ -270,8 +287,8 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
r.sql,
|
||||
apiKeyStatsQuery,
|
||||
[]any{service.StatusActive},
|
||||
&stats.TotalAPIKeys,
|
||||
&stats.ActiveAPIKeys,
|
||||
&stats.TotalApiKeys,
|
||||
&stats.ActiveApiKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -418,8 +435,8 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
|
||||
func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
// GetApiKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
|
||||
func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
query := `
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
@@ -623,7 +640,7 @@ func resolveUsageStatsTimezone() string {
|
||||
return "UTC"
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
|
||||
return logs, nil, err
|
||||
@@ -709,11 +726,11 @@ type ModelStat = usagestats.ModelStat
|
||||
// UserUsageTrendPoint represents user usage trend data point
|
||||
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
|
||||
|
||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||
type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
|
||||
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
||||
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
|
||||
|
||||
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
|
||||
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
|
||||
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
|
||||
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
@@ -755,10 +772,10 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime,
|
||||
}
|
||||
}()
|
||||
|
||||
results = make([]APIKeyUsageTrendPoint, 0)
|
||||
results = make([]ApiKeyUsageTrendPoint, 0)
|
||||
for rows.Next() {
|
||||
var row APIKeyUsageTrendPoint
|
||||
if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
|
||||
var row ApiKeyUsageTrendPoint
|
||||
if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, row)
|
||||
@@ -844,7 +861,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
r.sql,
|
||||
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
|
||||
[]any{userID},
|
||||
&stats.TotalAPIKeys,
|
||||
&stats.TotalApiKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -853,7 +870,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
r.sql,
|
||||
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
|
||||
[]any{userID, service.StatusActive},
|
||||
&stats.ActiveAPIKeys,
|
||||
&stats.ActiveApiKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1023,9 +1040,9 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
|
||||
args = append(args, filters.UserID)
|
||||
}
|
||||
if filters.APIKeyID > 0 {
|
||||
if filters.ApiKeyID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
|
||||
args = append(args, filters.APIKeyID)
|
||||
args = append(args, filters.ApiKeyID)
|
||||
}
|
||||
if filters.AccountID > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
|
||||
@@ -1145,18 +1162,18 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// BatchAPIKeyUsageStats represents usage stats for a single API key
|
||||
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
|
||||
// BatchApiKeyUsageStats represents usage stats for a single API key
|
||||
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
|
||||
|
||||
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
|
||||
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
|
||||
result := make(map[int64]*BatchAPIKeyUsageStats)
|
||||
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
|
||||
func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
|
||||
result := make(map[int64]*BatchApiKeyUsageStats)
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for _, id := range apiKeyIDs {
|
||||
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
|
||||
result[id] = &BatchApiKeyUsageStats{ApiKeyID: id}
|
||||
}
|
||||
|
||||
query := `
|
||||
@@ -1582,7 +1599,7 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs)
|
||||
apiKeys, err := r.loadApiKeys(ctx, ids.apiKeyIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1603,8 +1620,8 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
|
||||
if user, ok := users[logs[i].UserID]; ok {
|
||||
logs[i].User = user
|
||||
}
|
||||
if key, ok := apiKeys[logs[i].APIKeyID]; ok {
|
||||
logs[i].APIKey = key
|
||||
if key, ok := apiKeys[logs[i].ApiKeyID]; ok {
|
||||
logs[i].ApiKey = key
|
||||
}
|
||||
if acc, ok := accounts[logs[i].AccountID]; ok {
|
||||
logs[i].Account = acc
|
||||
@@ -1642,7 +1659,7 @@ func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs {
|
||||
|
||||
for i := range logs {
|
||||
userIDs[logs[i].UserID] = struct{}{}
|
||||
apiKeyIDs[logs[i].APIKeyID] = struct{}{}
|
||||
apiKeyIDs[logs[i].ApiKeyID] = struct{}{}
|
||||
accountIDs[logs[i].AccountID] = struct{}{}
|
||||
if logs[i].GroupID != nil {
|
||||
groupIDs[*logs[i].GroupID] = struct{}{}
|
||||
@@ -1676,12 +1693,12 @@ func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[in
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) {
|
||||
out := make(map[int64]*service.APIKey)
|
||||
func (r *usageLogRepository) loadApiKeys(ctx context.Context, ids []int64) (map[int64]*service.ApiKey, error) {
|
||||
out := make(map[int64]*service.ApiKey)
|
||||
if len(ids) == 0 {
|
||||
return out, nil
|
||||
}
|
||||
models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
|
||||
models, err := r.client.ApiKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1800,7 +1817,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
log := &service.UsageLog{
|
||||
ID: id,
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
ApiKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
Model: model,
|
||||
InputTokens: inputTokens,
|
||||
|
||||
Reference in New Issue
Block a user