feat(sync): full code sync from release
This commit is contained in:
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
|
||||
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
|
||||
var dateFormatWhitelist = map[string]string{
|
||||
@@ -98,6 +98,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
log.RequestID = requestID
|
||||
|
||||
rateMultiplier := log.RateMultiplier
|
||||
log.SyncRequestTypeAndLegacyFields()
|
||||
requestType := int16(log.RequestType)
|
||||
|
||||
query := `
|
||||
INSERT INTO usage_logs (
|
||||
@@ -123,7 +125,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
rate_multiplier,
|
||||
account_rate_multiplier,
|
||||
billing_type,
|
||||
request_type,
|
||||
stream,
|
||||
openai_ws_mode,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
@@ -140,7 +144,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -184,7 +188,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
rateMultiplier,
|
||||
log.AccountRateMultiplier,
|
||||
log.BillingType,
|
||||
requestType,
|
||||
log.Stream,
|
||||
log.OpenAIWSMode,
|
||||
duration,
|
||||
firstToken,
|
||||
userAgent,
|
||||
@@ -492,25 +498,46 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error {
|
||||
totalStatsQuery := `
|
||||
todayEnd := todayUTC.Add(24 * time.Hour)
|
||||
combinedStatsQuery := `
|
||||
WITH scoped AS (
|
||||
SELECT
|
||||
created_at,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
COALESCE(duration_ms, 0) AS duration_ms
|
||||
FROM usage_logs
|
||||
WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
|
||||
AND created_at < GREATEST($2::timestamptz, $4::timestamptz)
|
||||
)
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
COUNT(*) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz) AS total_requests,
|
||||
COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_input_tokens,
|
||||
COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_read_tokens,
|
||||
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cost,
|
||||
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_actual_cost,
|
||||
COALESCE(SUM(duration_ms) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_duration_ms,
|
||||
COUNT(*) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz) AS today_requests,
|
||||
COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_input_tokens,
|
||||
COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_read_tokens,
|
||||
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cost,
|
||||
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost
|
||||
FROM scoped
|
||||
`
|
||||
var totalDurationMs int64
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
totalStatsQuery,
|
||||
[]any{startUTC, endUTC},
|
||||
combinedStatsQuery,
|
||||
[]any{startUTC, endUTC, todayUTC, todayEnd},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
@@ -519,32 +546,6 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&totalDurationMs,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||
if stats.TotalRequests > 0 {
|
||||
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
|
||||
}
|
||||
|
||||
todayEnd := todayUTC.Add(24 * time.Hour)
|
||||
todayStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as today_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as today_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as today_actual_cost
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
todayStatsQuery,
|
||||
[]any{todayUTC, todayEnd},
|
||||
&stats.TodayRequests,
|
||||
&stats.TodayInputTokens,
|
||||
&stats.TodayOutputTokens,
|
||||
@@ -555,25 +556,28 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||
|
||||
activeUsersQuery := `
|
||||
SELECT COUNT(DISTINCT user_id) as active_users
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`
|
||||
if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil {
|
||||
return err
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||
if stats.TotalRequests > 0 {
|
||||
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
|
||||
}
|
||||
|
||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||
|
||||
hourStart := now.UTC().Truncate(time.Hour)
|
||||
hourEnd := hourStart.Add(time.Hour)
|
||||
hourlyActiveQuery := `
|
||||
SELECT COUNT(DISTINCT user_id) as active_users
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
activeUsersQuery := `
|
||||
WITH scoped AS (
|
||||
SELECT user_id, created_at
|
||||
FROM usage_logs
|
||||
WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
|
||||
AND created_at < GREATEST($2::timestamptz, $4::timestamptz)
|
||||
)
|
||||
SELECT
|
||||
COUNT(DISTINCT CASE WHEN created_at >= $1::timestamptz AND created_at < $2::timestamptz THEN user_id END) AS active_users,
|
||||
COUNT(DISTINCT CASE WHEN created_at >= $3::timestamptz AND created_at < $4::timestamptz THEN user_id END) AS hourly_active_users
|
||||
FROM scoped
|
||||
`
|
||||
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil {
|
||||
if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd, hourStart, hourEnd}, &stats.ActiveUsers, &stats.HourlyActiveUsers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -968,6 +972,61 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetGeminiUsageTotalsBatch 批量聚合 Gemini 账号在窗口内的 Pro/Flash 请求与用量。
|
||||
// 模型分类规则与 service.geminiModelClassFromName 一致:model 包含 flash/lite 视为 flash,其余视为 pro。
|
||||
func (r *usageLogRepository) GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]service.GeminiUsageTotals, error) {
|
||||
result := make(map[int64]service.GeminiUsageTotals, len(accountIDs))
|
||||
if len(accountIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
account_id,
|
||||
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 1 ELSE 0 END), 0) AS flash_requests,
|
||||
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE 1 END), 0) AS pro_requests,
|
||||
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) ELSE 0 END), 0) AS flash_tokens,
|
||||
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) END), 0) AS pro_tokens,
|
||||
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN actual_cost ELSE 0 END), 0) AS flash_cost,
|
||||
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE actual_cost END), 0) AS pro_cost
|
||||
FROM usage_logs
|
||||
WHERE account_id = ANY($1) AND created_at >= $2 AND created_at < $3
|
||||
GROUP BY account_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
for rows.Next() {
|
||||
var accountID int64
|
||||
var totals service.GeminiUsageTotals
|
||||
if err := rows.Scan(
|
||||
&accountID,
|
||||
&totals.FlashRequests,
|
||||
&totals.ProRequests,
|
||||
&totals.FlashTokens,
|
||||
&totals.ProTokens,
|
||||
&totals.FlashCost,
|
||||
&totals.ProCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[accountID] = totals
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, accountID := range accountIDs {
|
||||
if _, ok := result[accountID]; !ok {
|
||||
result[accountID] = service.GeminiUsageTotals{}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// TrendDataPoint represents a single point in trend data
|
||||
type TrendDataPoint = usagestats.TrendDataPoint
|
||||
|
||||
@@ -1399,10 +1458,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
|
||||
args = append(args, filters.Model)
|
||||
}
|
||||
if filters.Stream != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
|
||||
args = append(args, *filters.Stream)
|
||||
}
|
||||
conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
|
||||
if filters.BillingType != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
|
||||
args = append(args, int16(*filters.BillingType))
|
||||
@@ -1598,7 +1654,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
||||
}
|
||||
|
||||
// GetUsageTrendWithFilters returns usage trend data with optional filters
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
|
||||
dateFormat := safeDateFormat(granularity)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
@@ -1636,10 +1692,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||||
args = append(args, model)
|
||||
}
|
||||
if stream != nil {
|
||||
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||||
args = append(args, *stream)
|
||||
}
|
||||
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
args = append(args, int16(*billingType))
|
||||
@@ -1667,7 +1720,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
}
|
||||
|
||||
// GetModelStatsWithFilters returns model statistics with optional filters
|
||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) {
|
||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) {
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
@@ -1704,10 +1757,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
||||
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||
args = append(args, groupID)
|
||||
}
|
||||
if stream != nil {
|
||||
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||||
args = append(args, *stream)
|
||||
}
|
||||
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
args = append(args, int16(*billingType))
|
||||
@@ -1794,10 +1844,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
||||
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
|
||||
args = append(args, filters.Model)
|
||||
}
|
||||
if filters.Stream != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
|
||||
args = append(args, *filters.Stream)
|
||||
}
|
||||
conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
|
||||
if filters.BillingType != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
|
||||
args = append(args, int16(*filters.BillingType))
|
||||
@@ -2017,7 +2064,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
}
|
||||
}
|
||||
|
||||
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil)
|
||||
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil, nil)
|
||||
if err != nil {
|
||||
models = []ModelStat{}
|
||||
}
|
||||
@@ -2267,7 +2314,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
rateMultiplier float64
|
||||
accountRateMultiplier sql.NullFloat64
|
||||
billingType int16
|
||||
requestTypeRaw int16
|
||||
stream bool
|
||||
openaiWSMode bool
|
||||
durationMs sql.NullInt64
|
||||
firstTokenMs sql.NullInt64
|
||||
userAgent sql.NullString
|
||||
@@ -2304,7 +2353,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&rateMultiplier,
|
||||
&accountRateMultiplier,
|
||||
&billingType,
|
||||
&requestTypeRaw,
|
||||
&stream,
|
||||
&openaiWSMode,
|
||||
&durationMs,
|
||||
&firstTokenMs,
|
||||
&userAgent,
|
||||
@@ -2340,11 +2391,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
RateMultiplier: rateMultiplier,
|
||||
AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
|
||||
BillingType: int8(billingType),
|
||||
Stream: stream,
|
||||
RequestType: service.RequestTypeFromInt16(requestTypeRaw),
|
||||
ImageCount: imageCount,
|
||||
CacheTTLOverridden: cacheTTLOverridden,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
// 先回填 legacy 字段,再基于 legacy + request_type 计算最终请求类型,保证历史数据兼容。
|
||||
log.Stream = stream
|
||||
log.OpenAIWSMode = openaiWSMode
|
||||
log.RequestType = log.EffectiveRequestType()
|
||||
log.Stream, log.OpenAIWSMode = service.ApplyLegacyRequestFields(log.RequestType, stream, openaiWSMode)
|
||||
|
||||
if requestID.Valid {
|
||||
log.RequestID = requestID.String
|
||||
@@ -2438,6 +2494,50 @@ func buildWhere(conditions []string) string {
|
||||
return "WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
func appendRequestTypeOrStreamWhereCondition(conditions []string, args []any, requestType *int16, stream *bool) ([]string, []any) {
|
||||
if requestType != nil {
|
||||
condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType)
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, conditionArgs...)
|
||||
return conditions, args
|
||||
}
|
||||
if stream != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
|
||||
args = append(args, *stream)
|
||||
}
|
||||
return conditions, args
|
||||
}
|
||||
|
||||
func appendRequestTypeOrStreamQueryFilter(query string, args []any, requestType *int16, stream *bool) (string, []any) {
|
||||
if requestType != nil {
|
||||
condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType)
|
||||
query += " AND " + condition
|
||||
args = append(args, conditionArgs...)
|
||||
return query, args
|
||||
}
|
||||
if stream != nil {
|
||||
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||||
args = append(args, *stream)
|
||||
}
|
||||
return query, args
|
||||
}
|
||||
|
||||
// buildRequestTypeFilterCondition 在 request_type 过滤时兼容 legacy 字段,避免历史数据漏查。
|
||||
func buildRequestTypeFilterCondition(startArgIndex int, requestType int16) (string, []any) {
|
||||
normalized := service.RequestTypeFromInt16(requestType)
|
||||
requestTypeArg := int16(normalized)
|
||||
switch normalized {
|
||||
case service.RequestTypeSync:
|
||||
return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = FALSE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
|
||||
case service.RequestTypeStream:
|
||||
return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = TRUE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
|
||||
case service.RequestTypeWSV2:
|
||||
return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND openai_ws_mode = TRUE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
|
||||
default:
|
||||
return fmt.Sprintf("request_type = $%d", startArgIndex), []any{requestTypeArg}
|
||||
}
|
||||
}
|
||||
|
||||
func nullInt64(v *int64) sql.NullInt64 {
|
||||
if v == nil {
|
||||
return sql.NullInt64{}
|
||||
|
||||
Reference in New Issue
Block a user