fix(audit): 第二批审计修复 — P0 生产 Bug、安全加固、性能优化、缓存一致性、代码质量

基于 backend-code-audit 审计报告,修复剩余 P0/P1/P2 共 34 项问题:

P0 生产 Bug:
- 修复 time.Since(time.Now()) 计时逻辑错误 (P0-03)
- generateRandomID 改用 crypto/rand 替代固定索引 (P0-04)
- IncrementQuotaUsed 重写为 Ent 原子操作消除 TOCTOU 竞态 (P0-05)

安全加固:
- gateway/openai handler 错误响应替换为泛化消息,防止内部信息泄露 (P1-14)
- usage_log_repo dateFormat 参数改用白名单映射,防止 SQL 注入 (P1-16)
- 默认配置安全加固:sslmode=prefer、response_headers=true、mode=release (P1-18/19, P2-15)

性能优化:
- gateway handler 循环内 defer 替换为显式 releaseWait 闭包 (P1-02)
- group_repo/promo_code_repo Count 前 Clone 查询避免状态污染 (P1-03)
- usage_log_repo 四个查询添加 LIMIT 10000 防止 OOM (P1-07)
- GetBatchUsageStats 添加时间范围参数,默认最近 30 天 (P1-10)
- ip.go CIDR 预编译为包级变量 (P1-11)
- BatchUpdateCredentials 重构为先验证后更新 (P1-13)

缓存一致性:
- billing_cache 添加 jitteredTTL 防止缓存雪崩 (P2-10)
- DeductUserBalance/UpdateSubscriptionUsage 错误传播修复 (P2-12)
- UserService.UpdateBalance 成功后异步失效 billingCache (P2-13)

代码质量:
- search 截断改为按 rune 处理,支持多字节字符 (P2-01)
- TLS Handshake 改为 HandshakeContext 支持 context 取消 (P2-07)
- CORS 预检添加 Access-Control-Max-Age: 86400 (P2-16)

测试覆盖:
- 新增 user_service_test.go(UpdateBalance 缓存失效 6 个用例)
- 新增 batch_update_credentials_test.go(fail-fast + 类型验证 7 个用例)
- 新增 response_transformer_test.go、ip_test.go、usage_log_repo_unit_test.go、search_truncate_test.go
- 集成测试:IncrementQuotaUsed 并发测试、billing_cache 错误传播测试
- config_test.go 补充 server.mode/sslmode 默认值断言

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-07 19:46:42 +08:00
parent f6ca701917
commit 2588fa6a8f
36 changed files with 1037 additions and 178 deletions

View File

@@ -24,6 +24,22 @@ import (
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at"
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{
"hour": "YYYY-MM-DD HH24:00",
"day": "YYYY-MM-DD",
"week": "IYYY-IW",
"month": "YYYY-MM",
}
// safeDateFormat 根据白名单获取 dateFormat未匹配时返回默认值
func safeDateFormat(granularity string) string {
if f, ok := dateFormatWhitelist[granularity]; ok {
return f
}
return "YYYY-MM-DD"
}
type usageLogRepository struct {
client *dbent.Client
sql sqlExecutor
@@ -564,7 +580,7 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64,
}
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime)
return logs, nil, err
}
@@ -810,19 +826,19 @@ func resolveUsageStatsTimezone() string {
}
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
return logs, nil, err
}
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime)
return logs, nil, err
}
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime)
return logs, nil, err
}
@@ -908,10 +924,7 @@ type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(`
WITH top_keys AS (
@@ -966,10 +979,7 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime,
// GetUserUsageTrend returns usage trend data grouped by user and date
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(`
WITH top_users AS (
@@ -1228,10 +1238,7 @@ func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKey
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(`
SELECT
@@ -1369,13 +1376,22 @@ type UsageStats = usagestats.UsageStats
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats = usagestats.BatchUserUsageStats
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
// If startTime is zero, defaults to 30 days ago.
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) {
result := make(map[int64]*BatchUserUsageStats)
if len(userIDs) == 0 {
return result, nil
}
// 默认最近 30 天
if startTime.IsZero() {
startTime = time.Now().AddDate(0, 0, -30)
}
if endTime.IsZero() {
endTime = time.Now()
}
for _, id := range userIDs {
result[id] = &BatchUserUsageStats{UserID: id}
}
@@ -1383,10 +1399,10 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
query := `
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
FROM usage_logs
WHERE user_id = ANY($1)
WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3
GROUP BY user_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime)
if err != nil {
return nil, err
}
@@ -1443,13 +1459,22 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
// BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range.
// If startTime is zero, defaults to 30 days ago.
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) {
result := make(map[int64]*BatchAPIKeyUsageStats)
if len(apiKeyIDs) == 0 {
return result, nil
}
// 默认最近 30 天
if startTime.IsZero() {
startTime = time.Now().AddDate(0, 0, -30)
}
if endTime.IsZero() {
endTime = time.Now()
}
for _, id := range apiKeyIDs {
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
}
@@ -1457,10 +1482,10 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
query := `
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
FROM usage_logs
WHERE api_key_id = ANY($1)
WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3
GROUP BY api_key_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs))
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime)
if err != nil {
return nil, err
}
@@ -1516,10 +1541,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
// GetUsageTrendWithFilters returns usage trend data with optional filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(`
SELECT