fix(服务): 修复system判定、统计时区与缓存日志

- system 字段存在即视为显式提供,避免 null 触发默认注入
- 日统计分组显式使用应用时区,缺失时从 TZ 回退到 UTC
- 缓存写入队列丢弃日志节流汇总,关键任务同步回退

测试: go test ./internal/service -run TestBillingCacheServiceQueueHighLoad
This commit is contained in:
yangjianbo
2025-12-31 10:17:38 +08:00
parent 7efa8b54c4
commit 3d7f8e4b3a
4 changed files with 132 additions and 18 deletions

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"os"
"strings" "strings"
"time" "time"
@@ -536,9 +537,11 @@ func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelN
// GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据 // GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据
// 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计 // 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计
func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) { func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) {
tzName := resolveUsageStatsTimezone()
query := ` query := `
SELECT SELECT
TO_CHAR(created_at, 'YYYY-MM-DD') as date, -- 使用应用时区分组,避免数据库会话时区导致日边界偏移。
TO_CHAR(created_at AT TIME ZONE $4, 'YYYY-MM-DD') as date,
COUNT(*) as total_requests, COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens, COALESCE(SUM(output_tokens), 0) as total_output_tokens,
@@ -552,7 +555,7 @@ func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID
ORDER BY 1 ORDER BY 1
` `
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime) rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime, tzName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -607,6 +610,19 @@ func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID
return result, nil return result, nil
} }
// resolveUsageStatsTimezone 获取用于 SQL 分组的时区名称。
// 优先使用应用初始化的时区,其次尝试读取 TZ 环境变量,最后回落为 UTC。
func resolveUsageStatsTimezone() string {
tzName := timezone.Name()
if tzName != "" && tzName != "Local" {
return tzName
}
if envTZ := strings.TrimSpace(os.Getenv("TZ")); envTZ != "" {
return envTZ
}
return "UTC"
}
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"log" "log"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
@@ -49,12 +50,13 @@ const (
// 新实现使用固定大小的工作池: // 新实现使用固定大小的工作池:
// 1. 预创建 10 个 worker goroutine避免频繁创建销毁 // 1. 预创建 10 个 worker goroutine避免频繁创建销毁
// 2. 使用带缓冲的 channel1000作为任务队列平滑写入峰值 // 2. 使用带缓冲的 channel1000作为任务队列平滑写入峰值
// 3. 非阻塞写入,队列满时丢弃任务(缓存最终一致性可接受) // 3. 非阻塞写入,队列满时关键任务同步回退,非关键任务丢弃并告警
// 4. 统一超时控制,避免慢操作阻塞工作池 // 4. 统一超时控制,避免慢操作阻塞工作池
const ( const (
cacheWriteWorkerCount = 10 // 工作协程数量 cacheWriteWorkerCount = 10 // 工作协程数量
cacheWriteBufferSize = 1000 // 任务队列缓冲大小 cacheWriteBufferSize = 1000 // 任务队列缓冲大小
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时 cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
) )
// cacheWriteTask 缓存写入任务 // cacheWriteTask 缓存写入任务
@@ -78,6 +80,11 @@ type BillingCacheService struct {
cacheWriteChan chan cacheWriteTask cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup cacheWriteWg sync.WaitGroup
cacheWriteStopOnce sync.Once cacheWriteStopOnce sync.Once
// 丢弃日志节流计数器(减少高负载下日志噪音)
cacheWriteDropFullCount uint64
cacheWriteDropFullLastLog int64
cacheWriteDropClosedCount uint64
cacheWriteDropClosedLastLog int64
} }
// NewBillingCacheService 创建计费缓存服务 // NewBillingCacheService 创建计费缓存服务
@@ -112,16 +119,25 @@ func (s *BillingCacheService) startCacheWriteWorkers() {
} }
} }
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) { // enqueueCacheWrite 尝试将任务入队,队列满时返回 false并记录告警
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
if s.cacheWriteChan == nil { if s.cacheWriteChan == nil {
return return false
} }
defer func() { defer func() {
_ = recover() if recovered := recover(); recovered != nil {
// 队列已关闭时可能触发 panic记录后静默失败。
s.logCacheWriteDrop(task, "closed")
enqueued = false
}
}() }()
select { select {
case s.cacheWriteChan <- task: case s.cacheWriteChan <- task:
return true
default: default:
// 队列满时不阻塞主流程,交由调用方决定是否同步回退。
s.logCacheWriteDrop(task, "full")
return false
} }
} }
@@ -151,6 +167,62 @@ func (s *BillingCacheService) cacheWriteWorker() {
} }
} }
// cacheWriteKindName 用于日志中的任务类型标识,便于排查丢弃原因。
func cacheWriteKindName(kind cacheWriteKind) string {
switch kind {
case cacheWriteSetBalance:
return "set_balance"
case cacheWriteSetSubscription:
return "set_subscription"
case cacheWriteUpdateSubscriptionUsage:
return "update_subscription_usage"
case cacheWriteDeductBalance:
return "deduct_balance"
default:
return "unknown"
}
}
// logCacheWriteDrop 使用节流方式记录丢弃情况,并汇总丢弃数量。
func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason string) {
var (
countPtr *uint64
lastPtr *int64
)
switch reason {
case "full":
countPtr = &s.cacheWriteDropFullCount
lastPtr = &s.cacheWriteDropFullLastLog
case "closed":
countPtr = &s.cacheWriteDropClosedCount
lastPtr = &s.cacheWriteDropClosedLastLog
default:
return
}
atomic.AddUint64(countPtr, 1)
now := time.Now().UnixNano()
last := atomic.LoadInt64(lastPtr)
if now-last < int64(cacheWriteDropLogInterval) {
return
}
if !atomic.CompareAndSwapInt64(lastPtr, last, now) {
return
}
dropped := atomic.SwapUint64(countPtr, 0)
if dropped == 0 {
return
}
log.Printf("Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)",
reason,
dropped,
cacheWriteDropLogInterval,
cacheWriteKindName(task.kind),
task.userID,
task.groupID,
)
}
// ============================================ // ============================================
// 余额缓存方法 // 余额缓存方法
// ============================================ // ============================================
@@ -175,7 +247,7 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
} }
// 异步建立缓存 // 异步建立缓存
s.enqueueCacheWrite(cacheWriteTask{ _ = s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteSetBalance, kind: cacheWriteSetBalance,
userID: userID, userID: userID,
balance: balance, balance: balance,
@@ -213,11 +285,22 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int
// QueueDeductBalance 异步扣减余额缓存 // QueueDeductBalance 异步扣减余额缓存
func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) { func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
s.enqueueCacheWrite(cacheWriteTask{ if s.cache == nil {
return
}
// 队列满时同步回退,避免关键扣减被静默丢弃。
if s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteDeductBalance, kind: cacheWriteDeductBalance,
userID: userID, userID: userID,
amount: amount, amount: amount,
}) }) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if err := s.DeductBalanceCache(ctx, userID, amount); err != nil {
log.Printf("Warning: deduct balance cache fallback failed for user %d: %v", userID, err)
}
} }
// InvalidateUserBalance 失效用户余额缓存 // InvalidateUserBalance 失效用户余额缓存
@@ -255,7 +338,7 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
} }
// 异步建立缓存 // 异步建立缓存
s.enqueueCacheWrite(cacheWriteTask{ _ = s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteSetSubscription, kind: cacheWriteSetSubscription,
userID: userID, userID: userID,
groupID: groupID, groupID: groupID,
@@ -324,12 +407,23 @@ func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userI
// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存 // QueueUpdateSubscriptionUsage 异步更新订阅用量缓存
func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) { func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) {
s.enqueueCacheWrite(cacheWriteTask{ if s.cache == nil {
return
}
// 队列满时同步回退,确保订阅用量及时更新。
if s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteUpdateSubscriptionUsage, kind: cacheWriteUpdateSubscriptionUsage,
userID: userID, userID: userID,
groupID: groupID, groupID: groupID,
amount: costUSD, amount: costUSD,
}) }) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil {
log.Printf("Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err)
}
} }
// InvalidateSubscription 失效指定订阅缓存 // InvalidateSubscription 失效指定订阅缓存

View File

@@ -24,7 +24,7 @@ type ParsedRequest struct {
MetadataUserID string // metadata.user_id用于会话亲和 MetadataUserID string // metadata.user_id用于会话亲和
System any // system 字段内容 System any // system 字段内容
Messages []any // messages 数组 Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段 HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
} }
// ParseGatewayRequest 解析网关请求体并返回结构化结果 // ParseGatewayRequest 解析网关请求体并返回结构化结果
@@ -58,7 +58,9 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
parsed.MetadataUserID = userID parsed.MetadataUserID = userID
} }
} }
if system, ok := req["system"]; ok && system != nil { // system 字段只要存在就视为显式提供(即使为 null
// 以避免客户端传 null 时被默认 system 误注入。
if system, ok := req["system"]; ok {
parsed.HasSystem = true parsed.HasSystem = true
parsed.System = system parsed.System = system
} }

View File

@@ -22,7 +22,9 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) {
body := []byte(`{"model":"claude-3","system":null}`) body := []byte(`{"model":"claude-3","system":null}`)
parsed, err := ParseGatewayRequest(body) parsed, err := ParseGatewayRequest(body)
require.NoError(t, err) require.NoError(t, err)
require.False(t, parsed.HasSystem) // 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。
require.True(t, parsed.HasSystem)
require.Nil(t, parsed.System)
} }
func TestParseGatewayRequest_InvalidModelType(t *testing.T) { func TestParseGatewayRequest_InvalidModelType(t *testing.T) {