fix(服务): 修复system判定、统计时区与缓存日志

- system 字段存在即视为显式提供,避免 null 触发默认注入
- 日统计分组显式使用应用时区,缺失时从 TZ 回退到 UTC
- 缓存写入队列丢弃日志节流汇总,关键任务同步回退

测试: go test ./internal/service -run TestBillingCacheServiceQueueHighLoad
This commit is contained in:
yangjianbo
2025-12-31 10:17:38 +08:00
parent 7efa8b54c4
commit 3d7f8e4b3a
4 changed files with 132 additions and 18 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"os"
"strings"
"time"
@@ -536,9 +537,11 @@ func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelN
// GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据
// 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计
func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) {
tzName := resolveUsageStatsTimezone()
query := `
SELECT
TO_CHAR(created_at, 'YYYY-MM-DD') as date,
-- 使用应用时区分组,避免数据库会话时区导致日边界偏移。
TO_CHAR(created_at AT TIME ZONE $4, 'YYYY-MM-DD') as date,
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
@@ -552,7 +555,7 @@ func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID
ORDER BY 1
`
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime)
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime, tzName)
if err != nil {
return nil, err
}
@@ -607,6 +610,19 @@ func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID
return result, nil
}
// resolveUsageStatsTimezone 获取用于 SQL 分组的时区名称。
// 优先使用应用初始化的时区,其次尝试读取 TZ 环境变量,最后回落为 UTC。
func resolveUsageStatsTimezone() string {
tzName := timezone.Name()
if tzName != "" && tzName != "Local" {
return tzName
}
if envTZ := strings.TrimSpace(os.Getenv("TZ")); envTZ != "" {
return envTZ
}
return "UTC"
}
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"log"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -49,12 +50,13 @@ const (
// 新实现使用固定大小的工作池:
// 1. 预创建 10 个 worker goroutine避免频繁创建销毁
// 2. 使用带缓冲的 channel1000作为任务队列平滑写入峰值
// 3. 非阻塞写入,队列满时丢弃任务(缓存最终一致性可接受)
// 3. 非阻塞写入,队列满时关键任务同步回退,非关键任务丢弃并告警
// 4. 统一超时控制,避免慢操作阻塞工作池
const (
cacheWriteWorkerCount = 10 // 工作协程数量
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
cacheWriteWorkerCount = 10 // 工作协程数量
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
)
// cacheWriteTask 缓存写入任务
@@ -78,6 +80,11 @@ type BillingCacheService struct {
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
cacheWriteStopOnce sync.Once
// 丢弃日志节流计数器(减少高负载下日志噪音)
cacheWriteDropFullCount uint64
cacheWriteDropFullLastLog int64
cacheWriteDropClosedCount uint64
cacheWriteDropClosedLastLog int64
}
// NewBillingCacheService 创建计费缓存服务
@@ -112,16 +119,25 @@ func (s *BillingCacheService) startCacheWriteWorkers() {
}
}
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) {
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false并记录告警
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
if s.cacheWriteChan == nil {
return
return false
}
defer func() {
_ = recover()
if recovered := recover(); recovered != nil {
// 队列已关闭时可能触发 panic记录后静默失败。
s.logCacheWriteDrop(task, "closed")
enqueued = false
}
}()
select {
case s.cacheWriteChan <- task:
return true
default:
// 队列满时不阻塞主流程,交由调用方决定是否同步回退。
s.logCacheWriteDrop(task, "full")
return false
}
}
@@ -151,6 +167,62 @@ func (s *BillingCacheService) cacheWriteWorker() {
}
}
// cacheWriteKindName 用于日志中的任务类型标识,便于排查丢弃原因。
func cacheWriteKindName(kind cacheWriteKind) string {
switch kind {
case cacheWriteSetBalance:
return "set_balance"
case cacheWriteSetSubscription:
return "set_subscription"
case cacheWriteUpdateSubscriptionUsage:
return "update_subscription_usage"
case cacheWriteDeductBalance:
return "deduct_balance"
default:
return "unknown"
}
}
// logCacheWriteDrop 使用节流方式记录丢弃情况,并汇总丢弃数量。
func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason string) {
var (
countPtr *uint64
lastPtr *int64
)
switch reason {
case "full":
countPtr = &s.cacheWriteDropFullCount
lastPtr = &s.cacheWriteDropFullLastLog
case "closed":
countPtr = &s.cacheWriteDropClosedCount
lastPtr = &s.cacheWriteDropClosedLastLog
default:
return
}
atomic.AddUint64(countPtr, 1)
now := time.Now().UnixNano()
last := atomic.LoadInt64(lastPtr)
if now-last < int64(cacheWriteDropLogInterval) {
return
}
if !atomic.CompareAndSwapInt64(lastPtr, last, now) {
return
}
dropped := atomic.SwapUint64(countPtr, 0)
if dropped == 0 {
return
}
log.Printf("Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)",
reason,
dropped,
cacheWriteDropLogInterval,
cacheWriteKindName(task.kind),
task.userID,
task.groupID,
)
}
// ============================================
// 余额缓存方法
// ============================================
@@ -175,7 +247,7 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
}
// 异步建立缓存
s.enqueueCacheWrite(cacheWriteTask{
_ = s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteSetBalance,
userID: userID,
balance: balance,
@@ -213,11 +285,22 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int
// QueueDeductBalance 异步扣减余额缓存
func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
s.enqueueCacheWrite(cacheWriteTask{
if s.cache == nil {
return
}
// 队列满时同步回退,避免关键扣减被静默丢弃。
if s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteDeductBalance,
userID: userID,
amount: amount,
})
}) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if err := s.DeductBalanceCache(ctx, userID, amount); err != nil {
log.Printf("Warning: deduct balance cache fallback failed for user %d: %v", userID, err)
}
}
// InvalidateUserBalance 失效用户余额缓存
@@ -255,7 +338,7 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
}
// 异步建立缓存
s.enqueueCacheWrite(cacheWriteTask{
_ = s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteSetSubscription,
userID: userID,
groupID: groupID,
@@ -324,12 +407,23 @@ func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userI
// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存
func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) {
s.enqueueCacheWrite(cacheWriteTask{
if s.cache == nil {
return
}
// 队列满时同步回退,确保订阅用量及时更新。
if s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteUpdateSubscriptionUsage,
userID: userID,
groupID: groupID,
amount: costUSD,
})
}) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil {
log.Printf("Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err)
}
}
// InvalidateSubscription 失效指定订阅缓存

View File

@@ -24,7 +24,7 @@ type ParsedRequest struct {
MetadataUserID string // metadata.user_id用于会话亲和
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果
@@ -58,7 +58,9 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
parsed.MetadataUserID = userID
}
}
if system, ok := req["system"]; ok && system != nil {
// system 字段只要存在就视为显式提供(即使为 null
// 以避免客户端传 null 时被默认 system 误注入。
if system, ok := req["system"]; ok {
parsed.HasSystem = true
parsed.System = system
}

View File

@@ -22,7 +22,9 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) {
body := []byte(`{"model":"claude-3","system":null}`)
parsed, err := ParseGatewayRequest(body)
require.NoError(t, err)
require.False(t, parsed.HasSystem)
// 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。
require.True(t, parsed.HasSystem)
require.Nil(t, parsed.System)
}
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {