fix(服务): 修复system判定、统计时区与缓存日志
- system 字段存在即视为显式提供,避免 null 触发默认注入 - 日统计分组显式使用应用时区,缺失时从 TZ 回退到 UTC - 缓存写入队列丢弃日志节流汇总,关键任务同步回退 测试: go test ./internal/service -run TestBillingCacheServiceQueueHighLoad
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -536,9 +537,11 @@ func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelN
|
||||
// GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据
|
||||
// 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计
|
||||
func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) {
|
||||
tzName := resolveUsageStatsTimezone()
|
||||
query := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM-DD') as date,
|
||||
-- 使用应用时区分组,避免数据库会话时区导致日边界偏移。
|
||||
TO_CHAR(created_at AT TIME ZONE $4, 'YYYY-MM-DD') as date,
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
@@ -552,7 +555,7 @@ func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID
|
||||
ORDER BY 1
|
||||
`
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime)
|
||||
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime, tzName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -607,6 +610,19 @@ func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// resolveUsageStatsTimezone 获取用于 SQL 分组的时区名称。
|
||||
// 优先使用应用初始化的时区,其次尝试读取 TZ 环境变量,最后回落为 UTC。
|
||||
func resolveUsageStatsTimezone() string {
|
||||
tzName := timezone.Name()
|
||||
if tzName != "" && tzName != "Local" {
|
||||
return tzName
|
||||
}
|
||||
if envTZ := strings.TrimSpace(os.Getenv("TZ")); envTZ != "" {
|
||||
return envTZ
|
||||
}
|
||||
return "UTC"
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -49,12 +50,13 @@ const (
|
||||
// 新实现使用固定大小的工作池:
|
||||
// 1. 预创建 10 个 worker goroutine,避免频繁创建销毁
|
||||
// 2. 使用带缓冲的 channel(1000)作为任务队列,平滑写入峰值
|
||||
// 3. 非阻塞写入,队列满时丢弃任务(缓存最终一致性可接受)
|
||||
// 3. 非阻塞写入,队列满时关键任务同步回退,非关键任务丢弃并告警
|
||||
// 4. 统一超时控制,避免慢操作阻塞工作池
|
||||
const (
|
||||
cacheWriteWorkerCount = 10 // 工作协程数量
|
||||
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
|
||||
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
|
||||
cacheWriteWorkerCount = 10 // 工作协程数量
|
||||
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
|
||||
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
|
||||
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
|
||||
)
|
||||
|
||||
// cacheWriteTask 缓存写入任务
|
||||
@@ -78,6 +80,11 @@ type BillingCacheService struct {
|
||||
cacheWriteChan chan cacheWriteTask
|
||||
cacheWriteWg sync.WaitGroup
|
||||
cacheWriteStopOnce sync.Once
|
||||
// 丢弃日志节流计数器(减少高负载下日志噪音)
|
||||
cacheWriteDropFullCount uint64
|
||||
cacheWriteDropFullLastLog int64
|
||||
cacheWriteDropClosedCount uint64
|
||||
cacheWriteDropClosedLastLog int64
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
@@ -112,16 +119,25 @@ func (s *BillingCacheService) startCacheWriteWorkers() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) {
|
||||
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。
|
||||
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
|
||||
if s.cacheWriteChan == nil {
|
||||
return
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
_ = recover()
|
||||
if recovered := recover(); recovered != nil {
|
||||
// 队列已关闭时可能触发 panic,记录后静默失败。
|
||||
s.logCacheWriteDrop(task, "closed")
|
||||
enqueued = false
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case s.cacheWriteChan <- task:
|
||||
return true
|
||||
default:
|
||||
// 队列满时不阻塞主流程,交由调用方决定是否同步回退。
|
||||
s.logCacheWriteDrop(task, "full")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,6 +167,62 @@ func (s *BillingCacheService) cacheWriteWorker() {
|
||||
}
|
||||
}
|
||||
|
||||
// cacheWriteKindName 用于日志中的任务类型标识,便于排查丢弃原因。
|
||||
func cacheWriteKindName(kind cacheWriteKind) string {
|
||||
switch kind {
|
||||
case cacheWriteSetBalance:
|
||||
return "set_balance"
|
||||
case cacheWriteSetSubscription:
|
||||
return "set_subscription"
|
||||
case cacheWriteUpdateSubscriptionUsage:
|
||||
return "update_subscription_usage"
|
||||
case cacheWriteDeductBalance:
|
||||
return "deduct_balance"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// logCacheWriteDrop 使用节流方式记录丢弃情况,并汇总丢弃数量。
|
||||
func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason string) {
|
||||
var (
|
||||
countPtr *uint64
|
||||
lastPtr *int64
|
||||
)
|
||||
switch reason {
|
||||
case "full":
|
||||
countPtr = &s.cacheWriteDropFullCount
|
||||
lastPtr = &s.cacheWriteDropFullLastLog
|
||||
case "closed":
|
||||
countPtr = &s.cacheWriteDropClosedCount
|
||||
lastPtr = &s.cacheWriteDropClosedLastLog
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
atomic.AddUint64(countPtr, 1)
|
||||
now := time.Now().UnixNano()
|
||||
last := atomic.LoadInt64(lastPtr)
|
||||
if now-last < int64(cacheWriteDropLogInterval) {
|
||||
return
|
||||
}
|
||||
if !atomic.CompareAndSwapInt64(lastPtr, last, now) {
|
||||
return
|
||||
}
|
||||
dropped := atomic.SwapUint64(countPtr, 0)
|
||||
if dropped == 0 {
|
||||
return
|
||||
}
|
||||
log.Printf("Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)",
|
||||
reason,
|
||||
dropped,
|
||||
cacheWriteDropLogInterval,
|
||||
cacheWriteKindName(task.kind),
|
||||
task.userID,
|
||||
task.groupID,
|
||||
)
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 余额缓存方法
|
||||
// ============================================
|
||||
@@ -175,7 +247,7 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
s.enqueueCacheWrite(cacheWriteTask{
|
||||
_ = s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteSetBalance,
|
||||
userID: userID,
|
||||
balance: balance,
|
||||
@@ -213,11 +285,22 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int
|
||||
|
||||
// QueueDeductBalance 异步扣减余额缓存
|
||||
func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
|
||||
s.enqueueCacheWrite(cacheWriteTask{
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
// 队列满时同步回退,避免关键扣减被静默丢弃。
|
||||
if s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteDeductBalance,
|
||||
userID: userID,
|
||||
amount: amount,
|
||||
})
|
||||
}) {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
defer cancel()
|
||||
if err := s.DeductBalanceCache(ctx, userID, amount); err != nil {
|
||||
log.Printf("Warning: deduct balance cache fallback failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateUserBalance 失效用户余额缓存
|
||||
@@ -255,7 +338,7 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
s.enqueueCacheWrite(cacheWriteTask{
|
||||
_ = s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteSetSubscription,
|
||||
userID: userID,
|
||||
groupID: groupID,
|
||||
@@ -324,12 +407,23 @@ func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userI
|
||||
|
||||
// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存
|
||||
func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) {
|
||||
s.enqueueCacheWrite(cacheWriteTask{
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
// 队列满时同步回退,确保订阅用量及时更新。
|
||||
if s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteUpdateSubscriptionUsage,
|
||||
userID: userID,
|
||||
groupID: groupID,
|
||||
amount: costUSD,
|
||||
})
|
||||
}) {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
defer cancel()
|
||||
if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil {
|
||||
log.Printf("Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateSubscription 失效指定订阅缓存
|
||||
|
||||
@@ -24,7 +24,7 @@ type ParsedRequest struct {
|
||||
MetadataUserID string // metadata.user_id(用于会话亲和)
|
||||
System any // system 字段内容
|
||||
Messages []any // messages 数组
|
||||
HasSystem bool // 是否包含 system 字段
|
||||
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
|
||||
}
|
||||
|
||||
// ParseGatewayRequest 解析网关请求体并返回结构化结果
|
||||
@@ -58,7 +58,9 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
||||
parsed.MetadataUserID = userID
|
||||
}
|
||||
}
|
||||
if system, ok := req["system"]; ok && system != nil {
|
||||
// system 字段只要存在就视为显式提供(即使为 null),
|
||||
// 以避免客户端传 null 时被默认 system 误注入。
|
||||
if system, ok := req["system"]; ok {
|
||||
parsed.HasSystem = true
|
||||
parsed.System = system
|
||||
}
|
||||
|
||||
@@ -22,7 +22,9 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3","system":null}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.False(t, parsed.HasSystem)
|
||||
// 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。
|
||||
require.True(t, parsed.HasSystem)
|
||||
require.Nil(t, parsed.System)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user