diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index ca454606..e4da825b 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -28,50 +28,64 @@ import ( gocache "github.com/patrickmn/go-cache" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" +// usageLogInsertArgTypes must stay in the same order as: +// 1. prepareUsageLogInsert().args +// 2. every INSERT/CTE VALUES column list in this file +// 3. execUsageLogInsertNoResult placeholder positions +// 4. scanUsageLog selected column order (via usageLogSelectColumns) +// +// When adding a usage_logs column, update all of those call sites together. var usageLogInsertArgTypes = [...]string{ - "bigint", - "bigint", - "bigint", - "text", - "text", - "text", - "bigint", - "bigint", - "integer", - "integer", - "integer", - "integer", - "integer", - "integer", - "numeric", - "numeric", - "numeric", - "numeric", - "numeric", - "numeric", - "numeric", - "numeric", - "smallint", - "smallint", - "boolean", - "boolean", - "integer", - "integer", - "text", - "text", - "integer", - "text", - "text", - "text", - "text", - "text", - "text", - "boolean", - "timestamptz", + "bigint", // user_id + "bigint", // api_key_id + "bigint", // account_id + "text", // request_id + "text", // model + "text", // requested_model + "text", // upstream_model + "bigint", // group_id + "bigint", // subscription_id + "integer", // input_tokens + "integer", // output_tokens + "integer", // cache_creation_tokens + "integer", // cache_read_tokens + "integer", // cache_creation_5m_tokens + "integer", // cache_creation_1h_tokens + "numeric", // input_cost + "numeric", // output_cost + "numeric", // cache_creation_cost + "numeric", // cache_read_cost + "numeric", // total_cost + "numeric", // actual_cost + "numeric", // rate_multiplier + "numeric", // account_rate_multiplier + "smallint", // billing_type + "smallint", // request_type + "boolean", // stream + "boolean", // openai_ws_mode + "integer", // duration_ms + "integer", // first_token_ms + "text", // user_agent + "text", // ip_address + "integer", // image_count + "text", // image_size + "text", // media_type + "text", // service_tier + "text", // reasoning_effort + "text", // inbound_endpoint + "text", // upstream_endpoint + "boolean", // cache_ttl_overridden + "timestamptz", // created_at } +const rawUsageLogModelColumn = "model" + +// rawUsageLogModelColumn preserves the exact stored usage_logs.model semantics for direct filters. +// Historical rows may contain upstream/billing model values, while newer rows store requested_model. +// Requested/upstream/mapping analytics must use resolveModelDimensionExpression instead. + // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL var dateFormatWhitelist = map[string]string{ "hour": "YYYY-MM-DD HH24:00", @@ -88,6 +102,30 @@ func safeDateFormat(granularity string) string { return "YYYY-MM-DD" } +// appendRawUsageLogModelWhereCondition keeps direct model filters on the raw model column for backward +// compatibility with historical rows. Requested/upstream analytics must use +// resolveModelDimensionExpression instead. +func appendRawUsageLogModelWhereCondition(conditions []string, args []any, model string) ([]string, []any) { + if strings.TrimSpace(model) == "" { + return conditions, args + } + conditions = append(conditions, fmt.Sprintf("%s = $%d", rawUsageLogModelColumn, len(args)+1)) + args = append(args, model) + return conditions, args +} + +// appendRawUsageLogModelQueryFilter keeps direct model filters on the raw model column for backward +// compatibility with historical rows. Requested/upstream analytics must use +// resolveModelDimensionExpression instead. +func appendRawUsageLogModelQueryFilter(query string, args []any, model string) (string, []any) { + if strings.TrimSpace(model) == "" { + return query, args + } + query += fmt.Sprintf(" AND %s = $%d", rawUsageLogModelColumn, len(args)+1) + args = append(args, model) + return query, args +} + type usageLogRepository struct { client *dbent.Client sql sqlExecutor @@ -278,6 +316,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, account_id, request_id, model, + requested_model, upstream_model, group_id, subscription_id, @@ -313,12 +352,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, cache_ttl_overridden, created_at ) VALUES ( - $1, $2, $3, $4, $5, $6, - $7, $8, - $9, $10, $11, $12, - $13, $14, - $15, $16, $17, $18, $19, $20, - $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39 + $1, $2, $3, $4, $5, $6, $7, + $8, $9, + $10, $11, $12, $13, + $14, $15, + $16, $17, $18, $19, $20, $21, + $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -709,6 +748,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + requested_model, upstream_model, group_id, subscription_id, @@ -779,6 +819,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + requested_model, upstream_model, group_id, subscription_id, @@ -820,6 +861,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + requested_model, upstream_model, group_id, subscription_id, @@ -901,6 +943,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + requested_model, upstream_model, group_id, subscription_id, @@ -937,7 +980,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( created_at ) AS (VALUES `) - args := make([]any, 0, len(preparedList)*39) + args := make([]any, 0, len(preparedList)*40) argPos := 1 for idx, prepared := range preparedList { if idx > 0 { @@ -968,6 +1011,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + requested_model, upstream_model, group_id, subscription_id, @@ -1009,6 +1053,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + requested_model, upstream_model, group_id, subscription_id, @@ -1058,6 +1103,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared account_id, request_id, model, + requested_model, upstream_model, group_id, subscription_id, @@ -1093,12 +1139,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared cache_ttl_overridden, created_at ) VALUES ( - $1, $2, $3, $4, $5, $6, - $7, $8, - $9, $10, $11, $12, - $13, $14, - $15, $16, $17, $18, $19, $20, - $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39 + $1, $2, $3, $4, $5, $6, $7, + $8, $9, + $10, $11, $12, $13, + $14, $15, + $16, $17, $18, $19, $20, $21, + $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 ) ON CONFLICT (request_id, api_key_id) DO NOTHING `, prepared.args...) @@ -1130,6 +1176,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { reasoningEffort := nullString(log.ReasoningEffort) inboundEndpoint := nullString(log.InboundEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint) + requestedModel := strings.TrimSpace(log.RequestedModel) + if requestedModel == "" { + requestedModel = strings.TrimSpace(log.Model) + } upstreamModel := nullString(log.UpstreamModel) var requestIDArg any @@ -1148,6 +1198,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { log.AccountID, requestIDArg, log.Model, + nullString(&requestedModel), upstreamModel, groupID, subscriptionID, @@ -1702,7 +1753,7 @@ func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, acco // GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据 // 性能优化:数据库层聚合计算,避免应用层循环统计 func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) { - query := ` + query := fmt.Sprintf(` SELECT COUNT(*) as total_requests, COALESCE(SUM(input_tokens), 0) as total_input_tokens, @@ -1712,8 +1763,8 @@ func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelN COALESCE(SUM(actual_cost), 0) as total_actual_cost, COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms FROM usage_logs - WHERE model = $1 AND created_at >= $2 AND created_at < $3 - ` + WHERE %s = $1 AND created_at >= $2 AND created_at < $3 + `, rawUsageLogModelColumn) var stats usagestats.UsageStats if err := scanSingleRow( @@ -1837,7 +1888,7 @@ func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco } func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" + query := fmt.Sprintf("SELECT %s FROM usage_logs WHERE %s = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000", usageLogSelectColumns, rawUsageLogModelColumn) logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime) return logs, nil, err } @@ -2532,10 +2583,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1)) args = append(args, filters.GroupID) } - if filters.Model != "" { - conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) - args = append(args, filters.Model) - } + conditions, args = appendRawUsageLogModelWhereCondition(conditions, args, filters.Model) conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) if filters.BillingType != nil { conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) @@ -2768,10 +2816,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) args = append(args, groupID) } - if model != "" { - query += fmt.Sprintf(" AND model = $%d", len(args)+1) - args = append(args, model) - } + query, args = appendRawUsageLogModelQueryFilter(query, args, model) query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) @@ -3126,13 +3171,14 @@ func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayS // resolveModelDimensionExpression maps model source type to a safe SQL expression. func resolveModelDimensionExpression(modelType string) string { + requestedExpr := "COALESCE(NULLIF(TRIM(requested_model), ''), model)" switch usagestats.NormalizeModelSource(modelType) { case usagestats.ModelSourceUpstream: - return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)" + return fmt.Sprintf("COALESCE(NULLIF(TRIM(upstream_model), ''), %s)", requestedExpr) case usagestats.ModelSourceMapping: - return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))" + return fmt.Sprintf("(%s || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), %s))", requestedExpr, requestedExpr) default: - return "model" + return requestedExpr } } @@ -3204,10 +3250,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1)) args = append(args, filters.GroupID) } - if filters.Model != "" { - conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) - args = append(args, filters.Model) - } + conditions, args = appendRawUsageLogModelWhereCondition(conditions, args, filters.Model) conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) if filters.BillingType != nil { conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) @@ -3336,10 +3379,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) args = append(args, groupID) } - if model != "" { - query += fmt.Sprintf(" AND model = $%d", len(args)+1) - args = append(args, model) - } + query, args = appendRawUsageLogModelQueryFilter(query, args, model) query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) @@ -3410,10 +3450,7 @@ func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) args = append(args, groupID) } - if model != "" { - query += fmt.Sprintf(" AND model = $%d", len(args)+1) - args = append(args, model) - } + query, args = appendRawUsageLogModelQueryFilter(query, args, model) query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) @@ -3888,6 +3925,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e accountID int64 requestID sql.NullString model string + requestedModel sql.NullString upstreamModel sql.NullString groupID sql.NullInt64 subscriptionID sql.NullInt64 @@ -3931,6 +3969,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &accountID, &requestID, &model, + &requestedModel, &upstreamModel, &groupID, &subscriptionID, @@ -3975,6 +4014,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e APIKeyID: apiKeyID, AccountID: accountID, Model: model, + RequestedModel: coalesceTrimmedString(requestedModel, model), InputTokens: inputTokens, OutputTokens: outputTokens, CacheCreationTokens: cacheCreationTokens, @@ -4181,6 +4221,13 @@ func nullString(v *string) sql.NullString { return sql.NullString{String: *v, Valid: true} } +func coalesceTrimmedString(v sql.NullString, fallback string) string { + if v.Valid && strings.TrimSpace(v.String) != "" { + return v.String + } + return fallback +} + func setToSlice(set map[int64]struct{}) []int64 { out := make([]int64, 0, len(set)) for id := range set {