feat(sync): full code sync from release

This commit is contained in:
yangjianbo
2026-02-28 15:01:20 +08:00
parent bfc7b339f7
commit bb664d9bbf
338 changed files with 54513 additions and 2011 deletions

View File

@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{
@@ -98,6 +98,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.RequestID = requestID
rateMultiplier := log.RateMultiplier
log.SyncRequestTypeAndLegacyFields()
requestType := int16(log.RequestType)
query := `
INSERT INTO usage_logs (
@@ -123,7 +125,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
@@ -140,7 +144,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -184,7 +188,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
rateMultiplier,
log.AccountRateMultiplier,
log.BillingType,
requestType,
log.Stream,
log.OpenAIWSMode,
duration,
firstToken,
userAgent,
@@ -492,25 +498,46 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
}
func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error {
totalStatsQuery := `
todayEnd := todayUTC.Add(24 * time.Hour)
combinedStatsQuery := `
WITH scoped AS (
SELECT
created_at,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
total_cost,
actual_cost,
COALESCE(duration_ms, 0) AS duration_ms
FROM usage_logs
WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
AND created_at < GREATEST($2::timestamptz, $4::timestamptz)
)
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
COUNT(*) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz) AS total_requests,
COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_input_tokens,
COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_output_tokens,
COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_read_tokens,
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_actual_cost,
COALESCE(SUM(duration_ms) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_duration_ms,
COUNT(*) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz) AS today_requests,
COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_input_tokens,
COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_output_tokens,
COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_read_tokens,
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost
FROM scoped
`
var totalDurationMs int64
if err := scanSingleRow(
ctx,
r.sql,
totalStatsQuery,
[]any{startUTC, endUTC},
combinedStatsQuery,
[]any{startUTC, endUTC, todayUTC, todayEnd},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
@@ -519,32 +546,6 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
&stats.TotalCost,
&stats.TotalActualCost,
&totalDurationMs,
); err != nil {
return err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
if stats.TotalRequests > 0 {
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
}
todayEnd := todayUTC.Add(24 * time.Hour)
todayStatsQuery := `
SELECT
COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as today_cost,
COALESCE(SUM(actual_cost), 0) as today_actual_cost
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
if err := scanSingleRow(
ctx,
r.sql,
todayStatsQuery,
[]any{todayUTC, todayEnd},
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
@@ -555,25 +556,28 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
); err != nil {
return err
}
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
activeUsersQuery := `
SELECT COUNT(DISTINCT user_id) as active_users
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil {
return err
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
if stats.TotalRequests > 0 {
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
}
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
hourStart := now.UTC().Truncate(time.Hour)
hourEnd := hourStart.Add(time.Hour)
hourlyActiveQuery := `
SELECT COUNT(DISTINCT user_id) as active_users
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
activeUsersQuery := `
WITH scoped AS (
SELECT user_id, created_at
FROM usage_logs
WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
AND created_at < GREATEST($2::timestamptz, $4::timestamptz)
)
SELECT
COUNT(DISTINCT CASE WHEN created_at >= $1::timestamptz AND created_at < $2::timestamptz THEN user_id END) AS active_users,
COUNT(DISTINCT CASE WHEN created_at >= $3::timestamptz AND created_at < $4::timestamptz THEN user_id END) AS hourly_active_users
FROM scoped
`
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil {
if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd, hourStart, hourEnd}, &stats.ActiveUsers, &stats.HourlyActiveUsers); err != nil {
return err
}
@@ -968,6 +972,61 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
return result, nil
}
// GetGeminiUsageTotalsBatch 批量聚合 Gemini 账号在窗口内的 Pro/Flash 请求与用量。
// 模型分类规则与 service.geminiModelClassFromName 一致model 包含 flash/lite 视为 flash其余视为 pro。
func (r *usageLogRepository) GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]service.GeminiUsageTotals, error) {
result := make(map[int64]service.GeminiUsageTotals, len(accountIDs))
if len(accountIDs) == 0 {
return result, nil
}
query := `
SELECT
account_id,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 1 ELSE 0 END), 0) AS flash_requests,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE 1 END), 0) AS pro_requests,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) ELSE 0 END), 0) AS flash_tokens,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) END), 0) AS pro_tokens,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN actual_cost ELSE 0 END), 0) AS flash_cost,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE actual_cost END), 0) AS pro_cost
FROM usage_logs
WHERE account_id = ANY($1) AND created_at >= $2 AND created_at < $3
GROUP BY account_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime, endTime)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var accountID int64
var totals service.GeminiUsageTotals
if err := rows.Scan(
&accountID,
&totals.FlashRequests,
&totals.ProRequests,
&totals.FlashTokens,
&totals.ProTokens,
&totals.FlashCost,
&totals.ProCost,
); err != nil {
return nil, err
}
result[accountID] = totals
}
if err := rows.Err(); err != nil {
return nil, err
}
for _, accountID := range accountIDs {
if _, ok := result[accountID]; !ok {
result[accountID] = service.GeminiUsageTotals{}
}
}
return result, nil
}
// TrendDataPoint represents a single point in trend data
type TrendDataPoint = usagestats.TrendDataPoint
@@ -1399,10 +1458,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
args = append(args, *filters.Stream)
}
conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
@@ -1598,7 +1654,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
}
// GetUsageTrendWithFilters returns usage trend data with optional filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(`
@@ -1636,10 +1692,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
if stream != nil {
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
@@ -1667,7 +1720,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
}
// GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) {
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier
if accountID > 0 && userID == 0 && apiKeyID == 0 {
@@ -1704,10 +1757,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if stream != nil {
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
@@ -1794,10 +1844,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
args = append(args, *filters.Stream)
}
conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
@@ -2017,7 +2064,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
}
}
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil)
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil, nil)
if err != nil {
models = []ModelStat{}
}
@@ -2267,7 +2314,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
rateMultiplier float64
accountRateMultiplier sql.NullFloat64
billingType int16
requestTypeRaw int16
stream bool
openaiWSMode bool
durationMs sql.NullInt64
firstTokenMs sql.NullInt64
userAgent sql.NullString
@@ -2304,7 +2353,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&rateMultiplier,
&accountRateMultiplier,
&billingType,
&requestTypeRaw,
&stream,
&openaiWSMode,
&durationMs,
&firstTokenMs,
&userAgent,
@@ -2340,11 +2391,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
RateMultiplier: rateMultiplier,
AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
BillingType: int8(billingType),
Stream: stream,
RequestType: service.RequestTypeFromInt16(requestTypeRaw),
ImageCount: imageCount,
CacheTTLOverridden: cacheTTLOverridden,
CreatedAt: createdAt,
}
// 先回填 legacy 字段,再基于 legacy + request_type 计算最终请求类型,保证历史数据兼容。
log.Stream = stream
log.OpenAIWSMode = openaiWSMode
log.RequestType = log.EffectiveRequestType()
log.Stream, log.OpenAIWSMode = service.ApplyLegacyRequestFields(log.RequestType, stream, openaiWSMode)
if requestID.Valid {
log.RequestID = requestID.String
@@ -2438,6 +2494,50 @@ func buildWhere(conditions []string) string {
return "WHERE " + strings.Join(conditions, " AND ")
}
func appendRequestTypeOrStreamWhereCondition(conditions []string, args []any, requestType *int16, stream *bool) ([]string, []any) {
if requestType != nil {
condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType)
conditions = append(conditions, condition)
args = append(args, conditionArgs...)
return conditions, args
}
if stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
args = append(args, *stream)
}
return conditions, args
}
func appendRequestTypeOrStreamQueryFilter(query string, args []any, requestType *int16, stream *bool) (string, []any) {
if requestType != nil {
condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType)
query += " AND " + condition
args = append(args, conditionArgs...)
return query, args
}
if stream != nil {
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
return query, args
}
// buildRequestTypeFilterCondition 在 request_type 过滤时兼容 legacy 字段,避免历史数据漏查。
func buildRequestTypeFilterCondition(startArgIndex int, requestType int16) (string, []any) {
normalized := service.RequestTypeFromInt16(requestType)
requestTypeArg := int16(normalized)
switch normalized {
case service.RequestTypeSync:
return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = FALSE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
case service.RequestTypeStream:
return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = TRUE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
case service.RequestTypeWSV2:
return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND openai_ws_mode = TRUE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
default:
return fmt.Sprintf("request_type = $%d", startArgIndex), []any{requestTypeArg}
}
}
func nullInt64(v *int64) sql.NullInt64 {
if v == nil {
return sql.NullInt64{}