From 3d7f8e4b3ac96c537014cd5f8af625aa6a8273b9 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Wed, 31 Dec 2025 10:17:38 +0800 Subject: [PATCH] =?UTF-8?q?fix(=E6=9C=8D=E5=8A=A1):=20=E4=BF=AE=E5=A4=8Dsy?= =?UTF-8?q?stem=E5=88=A4=E5=AE=9A=E3=80=81=E7=BB=9F=E8=AE=A1=E6=97=B6?= =?UTF-8?q?=E5=8C=BA=E4=B8=8E=E7=BC=93=E5=AD=98=E6=97=A5=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - system 字段存在即视为显式提供,避免 null 触发默认注入 - 日统计分组显式使用应用时区,缺失时从 TZ 回退到 UTC - 缓存写入队列丢弃日志节流汇总,关键任务同步回退 测试: go test ./internal/service -run TestBillingCacheServiceQueueHighLoad --- backend/internal/repository/usage_log_repo.go | 20 ++- .../internal/service/billing_cache_service.go | 120 ++++++++++++++++-- backend/internal/service/gateway_request.go | 6 +- .../internal/service/gateway_request_test.go | 4 +- 4 files changed, 132 insertions(+), 18 deletions(-) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 4e26d751..9a210bde 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "os" "strings" "time" @@ -536,9 +537,11 @@ func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelN // GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据 // 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计 func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) { + tzName := resolveUsageStatsTimezone() query := ` SELECT - TO_CHAR(created_at, 'YYYY-MM-DD') as date, + -- 使用应用时区分组,避免数据库会话时区导致日边界偏移。 + TO_CHAR(created_at AT TIME ZONE $4, 'YYYY-MM-DD') as date, COUNT(*) as total_requests, COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(output_tokens), 0) as total_output_tokens, @@ -552,7 +555,7 @@ func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID ORDER BY 1 ` - rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime) + rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime, tzName) if err != nil { return nil, err } @@ -607,6 +610,19 @@ func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID return result, nil } +// resolveUsageStatsTimezone 获取用于 SQL 分组的时区名称。 +// 优先使用应用初始化的时区,其次尝试读取 TZ 环境变量,最后回落为 UTC。 +func resolveUsageStatsTimezone() string { + tzName := timezone.Name() + if tzName != "" && tzName != "Local" { + return tzName + } + if envTZ := strings.TrimSpace(os.Getenv("TZ")); envTZ != "" { + return envTZ + } + return "UTC" +} + func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index ac320535..58ed555a 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "sync" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -49,12 +50,13 @@ const ( // 新实现使用固定大小的工作池: // 1. 预创建 10 个 worker goroutine,避免频繁创建销毁 // 2. 使用带缓冲的 channel(1000)作为任务队列,平滑写入峰值 -// 3. 非阻塞写入,队列满时丢弃任务(缓存最终一致性可接受) +// 3. 非阻塞写入,队列满时关键任务同步回退,非关键任务丢弃并告警 // 4. 统一超时控制,避免慢操作阻塞工作池 const ( - cacheWriteWorkerCount = 10 // 工作协程数量 - cacheWriteBufferSize = 1000 // 任务队列缓冲大小 - cacheWriteTimeout = 2 * time.Second // 单个写入操作超时 + cacheWriteWorkerCount = 10 // 工作协程数量 + cacheWriteBufferSize = 1000 // 任务队列缓冲大小 + cacheWriteTimeout = 2 * time.Second // 单个写入操作超时 + cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔 ) // cacheWriteTask 缓存写入任务 @@ -78,6 +80,11 @@ type BillingCacheService struct { cacheWriteChan chan cacheWriteTask cacheWriteWg sync.WaitGroup cacheWriteStopOnce sync.Once + // 丢弃日志节流计数器(减少高负载下日志噪音) + cacheWriteDropFullCount uint64 + cacheWriteDropFullLastLog int64 + cacheWriteDropClosedCount uint64 + cacheWriteDropClosedLastLog int64 } // NewBillingCacheService 创建计费缓存服务 @@ -112,16 +119,25 @@ func (s *BillingCacheService) startCacheWriteWorkers() { } } -func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) { +// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。 +func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) { if s.cacheWriteChan == nil { - return + return false } defer func() { - _ = recover() + if recovered := recover(); recovered != nil { + // 队列已关闭时可能触发 panic,记录后静默失败。 + s.logCacheWriteDrop(task, "closed") + enqueued = false + } }() select { case s.cacheWriteChan <- task: + return true default: + // 队列满时不阻塞主流程,交由调用方决定是否同步回退。 + s.logCacheWriteDrop(task, "full") + return false } } @@ -151,6 +167,62 @@ func (s *BillingCacheService) cacheWriteWorker() { } } +// cacheWriteKindName 用于日志中的任务类型标识,便于排查丢弃原因。 +func cacheWriteKindName(kind cacheWriteKind) string { + switch kind { + case cacheWriteSetBalance: + return "set_balance" + case cacheWriteSetSubscription: + return "set_subscription" + case cacheWriteUpdateSubscriptionUsage: + return "update_subscription_usage" + case cacheWriteDeductBalance: + return "deduct_balance" + default: + return "unknown" + } +} + +// logCacheWriteDrop 使用节流方式记录丢弃情况,并汇总丢弃数量。 +func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason string) { + var ( + countPtr *uint64 + lastPtr *int64 + ) + switch reason { + case "full": + countPtr = &s.cacheWriteDropFullCount + lastPtr = &s.cacheWriteDropFullLastLog + case "closed": + countPtr = &s.cacheWriteDropClosedCount + lastPtr = &s.cacheWriteDropClosedLastLog + default: + return + } + + atomic.AddUint64(countPtr, 1) + now := time.Now().UnixNano() + last := atomic.LoadInt64(lastPtr) + if now-last < int64(cacheWriteDropLogInterval) { + return + } + if !atomic.CompareAndSwapInt64(lastPtr, last, now) { + return + } + dropped := atomic.SwapUint64(countPtr, 0) + if dropped == 0 { + return + } + log.Printf("Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)", + reason, + dropped, + cacheWriteDropLogInterval, + cacheWriteKindName(task.kind), + task.userID, + task.groupID, + ) +} + // ============================================ // 余额缓存方法 // ============================================ @@ -175,7 +247,7 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) } // 异步建立缓存 - s.enqueueCacheWrite(cacheWriteTask{ + _ = s.enqueueCacheWrite(cacheWriteTask{ kind: cacheWriteSetBalance, userID: userID, balance: balance, @@ -213,11 +285,22 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int // QueueDeductBalance 异步扣减余额缓存 func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) { - s.enqueueCacheWrite(cacheWriteTask{ + if s.cache == nil { + return + } + // 队列满时同步回退,避免关键扣减被静默丢弃。 + if s.enqueueCacheWrite(cacheWriteTask{ kind: cacheWriteDeductBalance, userID: userID, amount: amount, - }) + }) { + return + } + ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) + defer cancel() + if err := s.DeductBalanceCache(ctx, userID, amount); err != nil { + log.Printf("Warning: deduct balance cache fallback failed for user %d: %v", userID, err) + } } // InvalidateUserBalance 失效用户余额缓存 @@ -255,7 +338,7 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, } // 异步建立缓存 - s.enqueueCacheWrite(cacheWriteTask{ + _ = s.enqueueCacheWrite(cacheWriteTask{ kind: cacheWriteSetSubscription, userID: userID, groupID: groupID, @@ -324,12 +407,23 @@ func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userI // QueueUpdateSubscriptionUsage 异步更新订阅用量缓存 func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) { - s.enqueueCacheWrite(cacheWriteTask{ + if s.cache == nil { + return + } + // 队列满时同步回退,确保订阅用量及时更新。 + if s.enqueueCacheWrite(cacheWriteTask{ kind: cacheWriteUpdateSubscriptionUsage, userID: userID, groupID: groupID, amount: costUSD, - }) + }) { + return + } + ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) + defer cancel() + if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil { + log.Printf("Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err) + } } // InvalidateSubscription 失效指定订阅缓存 diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 6d358c36..fbec1371 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -24,7 +24,7 @@ type ParsedRequest struct { MetadataUserID string // metadata.user_id(用于会话亲和) System any // system 字段内容 Messages []any // messages 数组 - HasSystem bool // 是否包含 system 字段 + HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) } // ParseGatewayRequest 解析网关请求体并返回结构化结果 @@ -58,7 +58,9 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { parsed.MetadataUserID = userID } } - if system, ok := req["system"]; ok && system != nil { + // system 字段只要存在就视为显式提供(即使为 null), + // 以避免客户端传 null 时被默认 system 误注入。 + if system, ok := req["system"]; ok { parsed.HasSystem = true parsed.System = system } diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index c921e0f6..5d411e2c 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -22,7 +22,9 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) { body := []byte(`{"model":"claude-3","system":null}`) parsed, err := ParseGatewayRequest(body) require.NoError(t, err) - require.False(t, parsed.HasSystem) + // 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。 + require.True(t, parsed.HasSystem) + require.Nil(t, parsed.System) } func TestParseGatewayRequest_InvalidModelType(t *testing.T) {