From a89477ddf57f53a38a8af16fbd95306c3f0a1715 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 22 Feb 2026 13:31:30 +0800 Subject: [PATCH] =?UTF-8?q?perf(gateway):=20=E4=BC=98=E5=8C=96=E7=83=AD?= =?UTF-8?q?=E7=82=B9=E8=B7=AF=E5=BE=84=E5=B9=B6=E8=A1=A5=E9=BD=90=E9=AB=98?= =?UTF-8?q?=E8=A6=86=E7=9B=96=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 - backend/internal/config/config.go | 13 + backend/internal/config/config_test.go | 10 + .../admin/account_handler_passthrough_test.go | 1 - backend/internal/handler/gateway_handler.go | 4 + backend/internal/handler/gateway_helper.go | 42 +- .../handler/gateway_helper_hotpath_test.go | 252 ++++++ .../internal/handler/gemini_v1beta_handler.go | 4 + backend/internal/handler/ops_error_logger.go | 53 +- .../internal/handler/ops_error_logger_test.go | 175 ++++ backend/internal/pkg/ctxkey/ctxkey.go | 4 + backend/internal/repository/usage_log_repo.go | 53 ++ .../gateway_hotpath_optimization_test.go | 755 ++++++++++++++++++ backend/internal/service/gateway_service.go | 386 ++++++++- backend/internal/service/ops_service.go | 23 +- .../service/ops_service_prepare_queue_test.go | 60 ++ 16 files changed, 1760 insertions(+), 76 deletions(-) create mode 100644 backend/internal/handler/gateway_helper_hotpath_test.go create mode 100644 backend/internal/handler/ops_error_logger_test.go create mode 100644 backend/internal/service/gateway_hotpath_optimization_test.go create mode 100644 backend/internal/service/ops_service_prepare_queue_test.go diff --git a/.gitignore b/.gitignore index 2062600f..f1f6fb6e 100644 --- a/.gitignore +++ b/.gitignore @@ -121,7 +121,6 @@ AGENTS.md scripts .code-review-state openspec/ -docs/ code-reviews/ AGENTS.md backend/cmd/server/server diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 777993cd..1ddd3d14 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -423,6 +423,11 @@ type GatewayConfig struct { // UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker) UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"` + + // UserGroupRateCacheTTLSeconds: 用户分组倍率热路径缓存 TTL(秒) + UserGroupRateCacheTTLSeconds int `mapstructure:"user_group_rate_cache_ttl_seconds"` + // ModelsListCacheTTLSeconds: /v1/models 模型列表短缓存 TTL(秒) + ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"` } // GatewayUsageRecordConfig 使用量记录异步队列配置 @@ -1175,6 +1180,8 @@ func setDefaults() { viper.SetDefault("gateway.usage_record.auto_scale_down_step", 16) viper.SetDefault("gateway.usage_record.auto_scale_check_interval_seconds", 3) viper.SetDefault("gateway.usage_record.auto_scale_cooldown_seconds", 10) + viper.SetDefault("gateway.user_group_rate_cache_ttl_seconds", 30) + viper.SetDefault("gateway.models_list_cache_ttl_seconds", 15) // TLS指纹伪装配置(默认关闭,需要账号级别单独启用) viper.SetDefault("gateway.tls_fingerprint.enabled", true) viper.SetDefault("concurrency.ping_interval", 10) @@ -1751,6 +1758,12 @@ func (c *Config) Validate() error { return fmt.Errorf("gateway.usage_record.auto_scale_cooldown_seconds must be non-negative") } } + if c.Gateway.UserGroupRateCacheTTLSeconds <= 0 { + return fmt.Errorf("gateway.user_group_rate_cache_ttl_seconds must be positive") + } + if c.Gateway.ModelsListCacheTTLSeconds < 10 || c.Gateway.ModelsListCacheTTLSeconds > 30 { + return fmt.Errorf("gateway.models_list_cache_ttl_seconds must be between 10-30") + } if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 { return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 2e79e5ed..1bba2f9d 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1010,6 +1010,16 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 }, wantErr: "gateway.usage_record.auto_scale_check_interval_seconds", }, + { + name: "gateway user group rate cache ttl", + mutate: func(c *Config) { c.Gateway.UserGroupRateCacheTTLSeconds = 0 }, + wantErr: "gateway.user_group_rate_cache_ttl_seconds", + }, + { + name: "gateway models list cache ttl range", + mutate: func(c *Config) { c.Gateway.ModelsListCacheTTLSeconds = 31 }, + wantErr: "gateway.models_list_cache_ttl_seconds", + }, { name: "gateway scheduling sticky waiting", mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 }, diff --git a/backend/internal/handler/admin/account_handler_passthrough_test.go b/backend/internal/handler/admin/account_handler_passthrough_test.go index b6720451..d09cccd6 100644 --- a/backend/internal/handler/admin/account_handler_passthrough_test.go +++ b/backend/internal/handler/admin/account_handler_passthrough_test.go @@ -64,4 +64,3 @@ func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testi require.NotNil(t, created.Extra) require.Equal(t, true, created.Extra["anthropic_passthrough"]) } - diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index bbe73689..9bf0fcd2 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -243,6 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { var sessionBoundAccountID int64 if sessionKey != "" { sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) + if sessionBoundAccountID > 0 { + ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID) + c.Request = c.Request.WithContext(ctx) + } } // 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号 hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index c4edf53b..6127dda7 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -6,6 +6,7 @@ import ( "fmt" "math/rand/v2" "net/http" + "strings" "sync" "time" @@ -20,14 +21,28 @@ var claudeCodeValidator = service.NewClaudeCodeValidator() // SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中 // 返回更新后的 context func SetClaudeCodeClientContext(c *gin.Context, body []byte) { - // 解析请求体为 map - var bodyMap map[string]any - if len(body) > 0 { - _ = json.Unmarshal(body, &bodyMap) + if c == nil || c.Request == nil { + return + } + // Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。 + if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) { + ctx := service.SetClaudeCodeClient(c.Request.Context(), false) + c.Request = c.Request.WithContext(ctx) + return } - // 验证是否为 Claude Code 客户端 - isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap) + isClaudeCode := false + if !strings.Contains(c.Request.URL.Path, "messages") { + // 与 Validate 行为一致:非 messages 路径 UA 命中即可视为 Claude Code 客户端。 + isClaudeCode = true + } else { + // 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。 + var bodyMap map[string]any + if len(body) > 0 { + _ = json.Unmarshal(body, &bodyMap) + } + isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap) + } // 更新 request context ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode) @@ -223,21 +238,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) defer cancel() - // Try immediate acquire first (avoid unnecessary wait) - var result *service.AcquireResult - var err error - if slotType == "user" { - result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) - } else { - result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency) - } - if err != nil { - return nil, err - } - if result.Acquired { - return result.ReleaseFunc, nil - } - // Determine if ping is needed (streaming + ping format defined) needPing := isStream && h.pingFormat != "" diff --git a/backend/internal/handler/gateway_helper_hotpath_test.go b/backend/internal/handler/gateway_helper_hotpath_test.go new file mode 100644 index 00000000..2149c130 --- /dev/null +++ b/backend/internal/handler/gateway_helper_hotpath_test.go @@ -0,0 +1,252 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type helperConcurrencyCacheStub struct { + mu sync.Mutex + + accountSeq []bool + userSeq []bool + + accountAcquireCalls int + userAcquireCalls int + accountReleaseCalls int + userReleaseCalls int +} + +func (s *helperConcurrencyCacheStub) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.accountAcquireCalls++ + if len(s.accountSeq) == 0 { + return false, nil + } + v := s.accountSeq[0] + s.accountSeq = s.accountSeq[1:] + return v, nil +} + +func (s *helperConcurrencyCacheStub) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.accountReleaseCalls++ + return nil +} + +func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} + +func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { + return true, nil +} + +func (s *helperConcurrencyCacheStub) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { + return nil +} + +func (s *helperConcurrencyCacheStub) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} + +func (s *helperConcurrencyCacheStub) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.userAcquireCalls++ + if len(s.userSeq) == 0 { + return false, nil + } + v := s.userSeq[0] + s.userSeq = s.userSeq[1:] + return v, nil +} + +func (s *helperConcurrencyCacheStub) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.userReleaseCalls++ + return nil +} + +func (s *helperConcurrencyCacheStub) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { + return 0, nil +} + +func (s *helperConcurrencyCacheStub) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { + return true, nil +} + +func (s *helperConcurrencyCacheStub) DecrementWaitCount(ctx context.Context, userID int64) error { + return nil +} + +func (s *helperConcurrencyCacheStub) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + out := make(map[int64]*service.AccountLoadInfo, len(accounts)) + for _, acc := range accounts { + out[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID} + } + return out, nil +} + +func (s *helperConcurrencyCacheStub) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { + out := make(map[int64]*service.UserLoadInfo, len(users)) + for _, user := range users { + out[user.ID] = &service.UserLoadInfo{UserID: user.ID} + } + return out, nil +} + +func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { + return nil +} + +func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(method, path, nil) + return c, rec +} + +func validClaudeCodeBodyJSON() []byte { + return []byte(`{ + "model":"claude-3-5-sonnet-20241022", + "system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}], + "metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"} + }`) +} + +func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) { + t.Run("non_cli_user_agent_sets_false", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "curl/8.6.0") + + SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON()) + require.False(t, service.IsClaudeCodeClient(c.Request.Context())) + }) + + t.Run("cli_non_messages_path_sets_true", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodGet, "/v1/models") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + + SetClaudeCodeClientContext(c, nil) + require.True(t, service.IsClaudeCodeClient(c.Request.Context())) + }) + + t.Run("cli_messages_path_valid_body_sets_true", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + c.Request.Header.Set("X-App", "claude-code") + c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24") + c.Request.Header.Set("anthropic-version", "2023-06-01") + + SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON()) + require.True(t, service.IsClaudeCodeClient(c.Request.Context())) + }) + + t.Run("cli_messages_path_invalid_body_sets_false", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + // 缺少严格校验所需 header + body 字段 + SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`)) + require.False(t, service.IsClaudeCodeClient(c.Request.Context())) + }) +} + +func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) { + cache := &helperConcurrencyCacheStub{ + accountSeq: []bool{false, true}, + userSeq: []bool{false, true}, + } + concurrency := service.NewConcurrencyService(cache) + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + + t.Run("account_slot_acquired_after_retry", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted) + require.NoError(t, err) + require.NotNil(t, release) + require.False(t, streamStarted) + release() + require.GreaterOrEqual(t, cache.accountAcquireCalls, 2) + require.GreaterOrEqual(t, cache.accountReleaseCalls, 1) + }) + + t.Run("user_slot_acquired_after_retry", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted) + require.NoError(t, err) + require.NotNil(t, release) + release() + require.GreaterOrEqual(t, cache.userAcquireCalls, 2) + require.GreaterOrEqual(t, cache.userReleaseCalls, 1) + }) +} + +func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) { + cache := &helperConcurrencyCacheStub{ + accountSeq: []bool{false, false, false}, + } + concurrency := service.NewConcurrencyService(cache) + + t.Run("timeout_returns_concurrency_error", func(t *testing.T) { + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted) + require.Nil(t, release) + var cErr *ConcurrencyError + require.ErrorAs(t, err, &cErr) + require.True(t, cErr.IsTimeout) + }) + + t.Run("stream_mode_sends_ping_before_timeout", func(t *testing.T) { + helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond) + c, rec := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted) + require.Nil(t, release) + var cErr *ConcurrencyError + require.ErrorAs(t, err, &cErr) + require.True(t, cErr.IsTimeout) + require.True(t, streamStarted) + require.Contains(t, rec.Body.String(), ":\n\n") + }) +} + +func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) { + errCache := &helperConcurrencyCacheStubWithError{ + err: errors.New("redis unavailable"), + } + concurrency := service.NewConcurrencyService(errCache) + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted) + require.Nil(t, release) + require.Error(t, err) + require.Contains(t, err.Error(), "redis unavailable") +} + +type helperConcurrencyCacheStubWithError struct { + helperConcurrencyCacheStub + err error +} + +func (s *helperConcurrencyCacheStubWithError) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return false, s.err +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 86c2e4a4..c96484a6 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -263,6 +263,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { var sessionBoundAccountID int64 if sessionKey != "" { sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) + if sessionBoundAccountID > 0 { + ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID) + c.Request = c.Request.WithContext(ctx) + } } // === Gemini 内容摘要会话 Fallback 逻辑 === diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index 2328a920..ab9a2167 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -41,9 +41,8 @@ const ( ) type opsErrorLogJob struct { - ops *service.OpsService - entry *service.OpsInsertErrorLogInput - requestBody []byte + ops *service.OpsService + entry *service.OpsInsertErrorLogInput } var ( @@ -58,6 +57,7 @@ var ( opsErrorLogEnqueued atomic.Int64 opsErrorLogDropped atomic.Int64 opsErrorLogProcessed atomic.Int64 + opsErrorLogSanitized atomic.Int64 opsErrorLogLastDropLogAt atomic.Int64 @@ -94,7 +94,7 @@ func startOpsErrorLogWorkers() { } }() ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout) - _ = job.ops.RecordError(ctx, job.entry, job.requestBody) + _ = job.ops.RecordError(ctx, job.entry, nil) cancel() opsErrorLogProcessed.Add(1) }() @@ -103,7 +103,7 @@ func startOpsErrorLogWorkers() { } } -func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput, requestBody []byte) { +func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) { if ops == nil || entry == nil { return } @@ -129,7 +129,7 @@ func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLo } select { - case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry, requestBody: requestBody}: + case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry}: opsErrorLogQueueLen.Add(1) opsErrorLogEnqueued.Add(1) default: @@ -205,6 +205,10 @@ func OpsErrorLogProcessedTotal() int64 { return opsErrorLogProcessed.Load() } +func OpsErrorLogSanitizedTotal() int64 { + return opsErrorLogSanitized.Load() +} + func maybeLogOpsErrorLogDrop() { now := time.Now().Unix() @@ -222,12 +226,13 @@ func maybeLogOpsErrorLogDrop() { queueCap := OpsErrorLogQueueCapacity() log.Printf( - "[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d)", + "[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d sanitized_total=%d)", queued, queueCap, opsErrorLogEnqueued.Load(), opsErrorLogDropped.Load(), opsErrorLogProcessed.Load(), + opsErrorLogSanitized.Load(), ) } @@ -267,6 +272,22 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody } } +func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) { + if c == nil || entry == nil { + return + } + v, ok := c.Get(opsRequestBodyKey) + if !ok { + return + } + raw, ok := v.([]byte) + if !ok || len(raw) == 0 { + return + } + entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = service.PrepareOpsRequestBodyForQueue(raw) + opsErrorLogSanitized.Add(1) +} + func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) { if c == nil || accountID <= 0 { return @@ -544,14 +565,9 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { entry.ClientIP = &clientIP } - var requestBody []byte - if v, ok := c.Get(opsRequestBodyKey); ok { - if b, ok := v.([]byte); ok && len(b) > 0 { - requestBody = b - } - } // Store request headers/body only when an upstream error occurred to keep overhead minimal. entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c) + attachOpsRequestBodyToEntry(c, entry) // Skip logging if a passthrough rule with skip_monitoring=true matched. if v, ok := c.Get(service.OpsSkipPassthroughKey); ok { @@ -560,7 +576,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { } } - enqueueOpsErrorLog(ops, entry, requestBody) + enqueueOpsErrorLog(ops, entry) return } @@ -724,17 +740,12 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { entry.ClientIP = &clientIP } - var requestBody []byte - if v, ok := c.Get(opsRequestBodyKey); ok { - if b, ok := v.([]byte); ok && len(b) > 0 { - requestBody = b - } - } // Persist only a minimal, whitelisted set of request headers to improve retry fidelity. // Do NOT store Authorization/Cookie/etc. entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c) + attachOpsRequestBodyToEntry(c, entry) - enqueueOpsErrorLog(ops, entry, requestBody) + enqueueOpsErrorLog(ops, entry) } } diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go new file mode 100644 index 00000000..a11fa1f2 --- /dev/null +++ b/backend/internal/handler/ops_error_logger_test.go @@ -0,0 +1,175 @@ +package handler + +import ( + "net/http" + "net/http/httptest" + "sync" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func resetOpsErrorLoggerStateForTest(t *testing.T) { + t.Helper() + + opsErrorLogMu.Lock() + ch := opsErrorLogQueue + opsErrorLogQueue = nil + opsErrorLogStopping = true + opsErrorLogMu.Unlock() + + if ch != nil { + close(ch) + } + opsErrorLogWorkersWg.Wait() + + opsErrorLogOnce = sync.Once{} + opsErrorLogStopOnce = sync.Once{} + opsErrorLogWorkersWg = sync.WaitGroup{} + opsErrorLogMu = sync.RWMutex{} + opsErrorLogStopping = false + + opsErrorLogQueueLen.Store(0) + opsErrorLogEnqueued.Store(0) + opsErrorLogDropped.Store(0) + opsErrorLogProcessed.Store(0) + opsErrorLogSanitized.Store(0) + opsErrorLogLastDropLogAt.Store(0) + + opsErrorLogShutdownCh = make(chan struct{}) + opsErrorLogShutdownOnce = sync.Once{} + opsErrorLogDrained.Store(false) +} + +func TestAttachOpsRequestBodyToEntry_SanitizeAndTrim(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + raw := []byte(`{"access_token":"secret-token","messages":[{"role":"user","content":"hello"}]}`) + setOpsRequestContext(c, "claude-3", false, raw) + + entry := &service.OpsInsertErrorLogInput{} + attachOpsRequestBodyToEntry(c, entry) + + require.NotNil(t, entry.RequestBodyBytes) + require.Equal(t, len(raw), *entry.RequestBodyBytes) + require.NotNil(t, entry.RequestBodyJSON) + require.NotContains(t, *entry.RequestBodyJSON, "secret-token") + require.Contains(t, *entry.RequestBodyJSON, "[REDACTED]") + require.Equal(t, int64(1), OpsErrorLogSanitizedTotal()) +} + +func TestAttachOpsRequestBodyToEntry_InvalidJSONKeepsSize(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + raw := []byte("not-json") + setOpsRequestContext(c, "claude-3", false, raw) + + entry := &service.OpsInsertErrorLogInput{} + attachOpsRequestBodyToEntry(c, entry) + + require.Nil(t, entry.RequestBodyJSON) + require.NotNil(t, entry.RequestBodyBytes) + require.Equal(t, len(raw), *entry.RequestBodyBytes) + require.False(t, entry.RequestBodyTruncated) + require.Equal(t, int64(1), OpsErrorLogSanitizedTotal()) +} + +func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + + // 禁止 enqueueOpsErrorLog 触发 workers,使用测试队列验证满队列降级。 + opsErrorLogOnce.Do(func() {}) + + opsErrorLogMu.Lock() + opsErrorLogQueue = make(chan opsErrorLogJob, 1) + opsErrorLogMu.Unlock() + + ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"} + + enqueueOpsErrorLog(ops, entry) + enqueueOpsErrorLog(ops, entry) + + require.Equal(t, int64(1), OpsErrorLogEnqueuedTotal()) + require.Equal(t, int64(1), OpsErrorLogDroppedTotal()) + require.Equal(t, int64(1), OpsErrorLogQueueLength()) +} + +func TestAttachOpsRequestBodyToEntry_EarlyReturnBranches(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + gin.SetMode(gin.TestMode) + + entry := &service.OpsInsertErrorLogInput{} + attachOpsRequestBodyToEntry(nil, entry) + attachOpsRequestBodyToEntry(&gin.Context{}, nil) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + // 无请求体 key + attachOpsRequestBodyToEntry(c, entry) + require.Nil(t, entry.RequestBodyJSON) + require.Nil(t, entry.RequestBodyBytes) + require.False(t, entry.RequestBodyTruncated) + + // 错误类型 + c.Set(opsRequestBodyKey, "not-bytes") + attachOpsRequestBodyToEntry(c, entry) + require.Nil(t, entry.RequestBodyJSON) + require.Nil(t, entry.RequestBodyBytes) + + // 空 bytes + c.Set(opsRequestBodyKey, []byte{}) + attachOpsRequestBodyToEntry(c, entry) + require.Nil(t, entry.RequestBodyJSON) + require.Nil(t, entry.RequestBodyBytes) + + require.Equal(t, int64(0), OpsErrorLogSanitizedTotal()) +} + +func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + + ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"} + + // nil 入参分支 + enqueueOpsErrorLog(nil, entry) + enqueueOpsErrorLog(ops, nil) + require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) + + // shutdown 分支 + close(opsErrorLogShutdownCh) + enqueueOpsErrorLog(ops, entry) + require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) + + // stopping 分支 + resetOpsErrorLoggerStateForTest(t) + opsErrorLogMu.Lock() + opsErrorLogStopping = true + opsErrorLogMu.Unlock() + enqueueOpsErrorLog(ops, entry) + require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) + + // queue nil 分支(防止启动 worker 干扰) + resetOpsErrorLoggerStateForTest(t) + opsErrorLogOnce.Do(func() {}) + opsErrorLogMu.Lock() + opsErrorLogQueue = nil + opsErrorLogMu.Unlock() + enqueueOpsErrorLog(ops, entry) + require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) +} diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go index 54add8a0..a320ee8c 100644 --- a/backend/internal/pkg/ctxkey/ctxkey.go +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -44,4 +44,8 @@ const ( // SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。 // 在此模式下,Service 层的模型限流预检查将等待限流过期而非直接切换账号。 SingleAccountRetry Key = "ctx_single_account_retry" + + // PrefetchedStickyAccountID 标识上游(通常 handler)预取到的 sticky session 账号 ID。 + // Service 层可复用该值,避免同请求链路重复读取 Redis。 + PrefetchedStickyAccountID Key = "ctx_prefetched_sticky_account_id" ) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 0389a008..ce67ba4d 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -915,6 +915,59 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI return stats, nil } +// GetAccountWindowStatsBatch 批量获取同一窗口起点下多个账号的统计数据。 +// 返回 map[accountID]*AccountStats,未命中的账号会返回零值统计,便于上层直接复用。 +func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) { + result := make(map[int64]*usagestats.AccountStats, len(accountIDs)) + if len(accountIDs) == 0 { + return result, nil + } + + query := ` + SELECT + account_id, + COUNT(*) as requests, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, + COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost, + COALESCE(SUM(total_cost), 0) as standard_cost, + COALESCE(SUM(actual_cost), 0) as user_cost + FROM usage_logs + WHERE account_id = ANY($1) AND created_at >= $2 + GROUP BY account_id + ` + rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var accountID int64 + stats := &usagestats.AccountStats{} + if err := rows.Scan( + &accountID, + &stats.Requests, + &stats.Tokens, + &stats.Cost, + &stats.StandardCost, + &stats.UserCost, + ); err != nil { + return nil, err + } + result[accountID] = stats + } + if err := rows.Err(); err != nil { + return nil, err + } + + for _, accountID := range accountIDs { + if _, ok := result[accountID]; !ok { + result[accountID] = &usagestats.AccountStats{} + } + } + return result, nil +} + // TrendDataPoint represents a single point in trend data type TrendDataPoint = usagestats.TrendDataPoint diff --git a/backend/internal/service/gateway_hotpath_optimization_test.go b/backend/internal/service/gateway_hotpath_optimization_test.go new file mode 100644 index 00000000..81824cb3 --- /dev/null +++ b/backend/internal/service/gateway_hotpath_optimization_test.go @@ -0,0 +1,755 @@ +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + gocache "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/require" +) + +type userGroupRateRepoHotpathStub struct { + UserGroupRateRepository + + rate *float64 + err error + wait <-chan struct{} + calls atomic.Int64 +} + +func (s *userGroupRateRepoHotpathStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { + s.calls.Add(1) + if s.wait != nil { + <-s.wait + } + if s.err != nil { + return nil, s.err + } + return s.rate, nil +} + +type usageLogWindowBatchRepoStub struct { + UsageLogRepository + + batchResult map[int64]*usagestats.AccountStats + batchErr error + batchCalls atomic.Int64 + + singleResult map[int64]*usagestats.AccountStats + singleErr error + singleCalls atomic.Int64 +} + +func (s *usageLogWindowBatchRepoStub) GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) { + s.batchCalls.Add(1) + if s.batchErr != nil { + return nil, s.batchErr + } + out := make(map[int64]*usagestats.AccountStats, len(accountIDs)) + for _, id := range accountIDs { + if stats, ok := s.batchResult[id]; ok { + out[id] = stats + } + } + return out, nil +} + +func (s *usageLogWindowBatchRepoStub) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) { + s.singleCalls.Add(1) + if s.singleErr != nil { + return nil, s.singleErr + } + if stats, ok := s.singleResult[accountID]; ok { + return stats, nil + } + return &usagestats.AccountStats{}, nil +} + +type sessionLimitCacheHotpathStub struct { + SessionLimitCache + + batchData map[int64]float64 + batchErr error + + setData map[int64]float64 + setErr error +} + +func (s *sessionLimitCacheHotpathStub) GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) { + if s.batchErr != nil { + return nil, s.batchErr + } + out := make(map[int64]float64, len(accountIDs)) + for _, id := range accountIDs { + if v, ok := s.batchData[id]; ok { + out[id] = v + } + } + return out, nil +} + +func (s *sessionLimitCacheHotpathStub) SetWindowCost(ctx context.Context, accountID int64, cost float64) error { + if s.setErr != nil { + return s.setErr + } + if s.setData == nil { + s.setData = make(map[int64]float64) + } + s.setData[accountID] = cost + return nil +} + +type modelsListAccountRepoStub struct { + AccountRepository + + byGroup map[int64][]Account + all []Account + err error + + listByGroupCalls atomic.Int64 + listAllCalls atomic.Int64 +} + +type stickyGatewayCacheHotpathStub struct { + GatewayCache + + stickyID int64 + getCalls atomic.Int64 +} + +func (s *stickyGatewayCacheHotpathStub) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { + s.getCalls.Add(1) + if s.stickyID > 0 { + return s.stickyID, nil + } + return 0, errors.New("not found") +} + +func (s *stickyGatewayCacheHotpathStub) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error { + return nil +} + +func (s *stickyGatewayCacheHotpathStub) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error { + return nil +} + +func (s *stickyGatewayCacheHotpathStub) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { + return nil +} + +func (s *modelsListAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) { + s.listByGroupCalls.Add(1) + if s.err != nil { + return nil, s.err + } + accounts, ok := s.byGroup[groupID] + if !ok { + return nil, nil + } + out := make([]Account, len(accounts)) + copy(out, accounts) + return out, nil +} + +func (s *modelsListAccountRepoStub) ListSchedulable(ctx context.Context) ([]Account, error) { + s.listAllCalls.Add(1) + if s.err != nil { + return nil, s.err + } + out := make([]Account, len(s.all)) + copy(out, s.all) + return out, nil +} + +func resetGatewayHotpathStatsForTest() { + windowCostPrefetchCacheHitTotal.Store(0) + windowCostPrefetchCacheMissTotal.Store(0) + windowCostPrefetchBatchSQLTotal.Store(0) + windowCostPrefetchFallbackTotal.Store(0) + windowCostPrefetchErrorTotal.Store(0) + + userGroupRateCacheHitTotal.Store(0) + userGroupRateCacheMissTotal.Store(0) + userGroupRateCacheLoadTotal.Store(0) + userGroupRateCacheSFSharedTotal.Store(0) + userGroupRateCacheFallbackTotal.Store(0) + + modelsListCacheHitTotal.Store(0) + modelsListCacheMissTotal.Store(0) + modelsListCacheStoreTotal.Store(0) +} + +func TestGetUserGroupRateMultiplier_UsesCacheAndSingleflight(t *testing.T) { + resetGatewayHotpathStatsForTest() + + rate := 1.7 + unblock := make(chan struct{}) + repo := &userGroupRateRepoHotpathStub{ + rate: &rate, + wait: unblock, + } + svc := &GatewayService{ + userGroupRateRepo: repo, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + UserGroupRateCacheTTLSeconds: 30, + }, + }, + } + + const concurrent = 12 + results := make([]float64, concurrent) + start := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(concurrent) + for i := 0; i < concurrent; i++ { + go func(idx int) { + defer wg.Done() + <-start + results[idx] = svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2) + }(i) + } + + close(start) + time.Sleep(20 * time.Millisecond) + close(unblock) + wg.Wait() + + for _, got := range results { + require.Equal(t, rate, got) + } + require.Equal(t, int64(1), repo.calls.Load()) + + // 再次读取应命中缓存,不再回源。 + got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2) + require.Equal(t, rate, got) + require.Equal(t, int64(1), repo.calls.Load()) + + hit, miss, load, sfShared, fallback := GatewayUserGroupRateCacheStats() + require.GreaterOrEqual(t, hit, int64(1)) + require.Equal(t, int64(12), miss) + require.Equal(t, int64(1), load) + require.GreaterOrEqual(t, sfShared, int64(1)) + require.Equal(t, int64(0), fallback) +} + +func TestGetUserGroupRateMultiplier_FallbackOnRepoError(t *testing.T) { + resetGatewayHotpathStatsForTest() + + repo := &userGroupRateRepoHotpathStub{ + err: errors.New("db down"), + } + svc := &GatewayService{ + userGroupRateRepo: repo, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + UserGroupRateCacheTTLSeconds: 30, + }, + }, + } + + got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.25) + require.Equal(t, 1.25, got) + require.Equal(t, int64(1), repo.calls.Load()) + + _, _, _, _, fallback := GatewayUserGroupRateCacheStats() + require.Equal(t, int64(1), fallback) +} + +func TestGetUserGroupRateMultiplier_CacheHitAndNilRepo(t *testing.T) { + resetGatewayHotpathStatsForTest() + + repo := &userGroupRateRepoHotpathStub{ + err: errors.New("should not be called"), + } + svc := &GatewayService{ + userGroupRateRepo: repo, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + } + key := "101:202" + svc.userGroupRateCache.Set(key, 2.3, time.Minute) + + got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.1) + require.Equal(t, 2.3, got) + + hit, miss, load, _, fallback := GatewayUserGroupRateCacheStats() + require.Equal(t, int64(1), hit) + require.Equal(t, int64(0), miss) + require.Equal(t, int64(0), load) + require.Equal(t, int64(0), fallback) + require.Equal(t, int64(0), repo.calls.Load()) + + // 无 repo 时直接返回分组默认倍率 + svc2 := &GatewayService{ + userGroupRateCache: gocache.New(time.Minute, time.Minute), + } + svc2.userGroupRateCache.Set(key, 1.9, time.Minute) + require.Equal(t, 1.9, svc2.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.4)) + require.Equal(t, 1.4, svc2.getUserGroupRateMultiplier(context.Background(), 0, 202, 1.4)) + svc2.userGroupRateCache.Delete(key) + require.Equal(t, 1.4, svc2.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.4)) +} + +func TestWithWindowCostPrefetch_BatchReadAndContextReuse(t *testing.T) { + resetGatewayHotpathStatsForTest() + + windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour) + windowEnd := windowStart.Add(5 * time.Hour) + accounts := []Account{ + { + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + { + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeSetupToken, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + { + ID: 3, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{"window_cost_limit": 100.0}, + }, + } + + cache := &sessionLimitCacheHotpathStub{ + batchData: map[int64]float64{ + 1: 11.0, + }, + } + repo := &usageLogWindowBatchRepoStub{ + batchResult: map[int64]*usagestats.AccountStats{ + 2: {StandardCost: 22.0}, + }, + } + svc := &GatewayService{ + sessionLimitCache: cache, + usageLogRepo: repo, + } + + outCtx := svc.withWindowCostPrefetch(context.Background(), accounts) + require.NotNil(t, outCtx) + + cost1, ok1 := windowCostFromPrefetchContext(outCtx, 1) + require.True(t, ok1) + require.Equal(t, 11.0, cost1) + + cost2, ok2 := windowCostFromPrefetchContext(outCtx, 2) + require.True(t, ok2) + require.Equal(t, 22.0, cost2) + + _, ok3 := windowCostFromPrefetchContext(outCtx, 3) + require.False(t, ok3) + + require.Equal(t, int64(1), repo.batchCalls.Load()) + require.Equal(t, 22.0, cache.setData[2]) + + hit, miss, batchSQL, fallback, errCount := GatewayWindowCostPrefetchStats() + require.Equal(t, int64(1), hit) + require.Equal(t, int64(1), miss) + require.Equal(t, int64(1), batchSQL) + require.Equal(t, int64(0), fallback) + require.Equal(t, int64(0), errCount) +} + +func TestWithWindowCostPrefetch_AllHitNoSQL(t *testing.T) { + resetGatewayHotpathStatsForTest() + + windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour) + windowEnd := windowStart.Add(5 * time.Hour) + accounts := []Account{ + { + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + { + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeSetupToken, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + } + + cache := &sessionLimitCacheHotpathStub{ + batchData: map[int64]float64{ + 1: 11.0, + 2: 22.0, + }, + } + repo := &usageLogWindowBatchRepoStub{} + svc := &GatewayService{ + sessionLimitCache: cache, + usageLogRepo: repo, + } + + outCtx := svc.withWindowCostPrefetch(context.Background(), accounts) + cost1, ok1 := windowCostFromPrefetchContext(outCtx, 1) + cost2, ok2 := windowCostFromPrefetchContext(outCtx, 2) + require.True(t, ok1) + require.True(t, ok2) + require.Equal(t, 11.0, cost1) + require.Equal(t, 22.0, cost2) + require.Equal(t, int64(0), repo.batchCalls.Load()) + require.Equal(t, int64(0), repo.singleCalls.Load()) + + hit, miss, batchSQL, fallback, errCount := GatewayWindowCostPrefetchStats() + require.Equal(t, int64(2), hit) + require.Equal(t, int64(0), miss) + require.Equal(t, int64(0), batchSQL) + require.Equal(t, int64(0), fallback) + require.Equal(t, int64(0), errCount) +} + +func TestWithWindowCostPrefetch_BatchErrorFallbackSingleQuery(t *testing.T) { + resetGatewayHotpathStatsForTest() + + windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour) + windowEnd := windowStart.Add(5 * time.Hour) + accounts := []Account{ + { + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeSetupToken, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + } + + cache := &sessionLimitCacheHotpathStub{} + repo := &usageLogWindowBatchRepoStub{ + batchErr: errors.New("batch failed"), + singleResult: map[int64]*usagestats.AccountStats{ + 2: {StandardCost: 33.0}, + }, + } + svc := &GatewayService{ + sessionLimitCache: cache, + usageLogRepo: repo, + } + + outCtx := svc.withWindowCostPrefetch(context.Background(), accounts) + cost, ok := windowCostFromPrefetchContext(outCtx, 2) + require.True(t, ok) + require.Equal(t, 33.0, cost) + require.Equal(t, int64(1), repo.batchCalls.Load()) + require.Equal(t, int64(1), repo.singleCalls.Load()) + + _, _, _, fallback, errCount := GatewayWindowCostPrefetchStats() + require.Equal(t, int64(1), fallback) + require.Equal(t, int64(1), errCount) +} + +func TestGetAvailableModels_UsesShortCacheAndSupportsInvalidation(t *testing.T) { + resetGatewayHotpathStatsForTest() + + groupID := int64(9) + repo := &modelsListAccountRepoStub{ + byGroup: map[int64][]Account{ + groupID: { + { + ID: 1, + Platform: PlatformAnthropic, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "claude-3-5-sonnet", + "claude-3-5-haiku": "claude-3-5-haiku", + }, + }, + }, + { + ID: 2, + Platform: PlatformGemini, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-2.5-pro": "gemini-2.5-pro", + }, + }, + }, + }, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + + models1 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic) + require.Equal(t, []string{"claude-3-5-haiku", "claude-3-5-sonnet"}, models1) + require.Equal(t, int64(1), repo.listByGroupCalls.Load()) + + // TTL 内再次请求应命中缓存,不回源。 + models2 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic) + require.Equal(t, models1, models2) + require.Equal(t, int64(1), repo.listByGroupCalls.Load()) + + // 更新仓储数据,但缓存未失效前应继续返回旧值。 + repo.byGroup[groupID] = []Account{ + { + ID: 3, + Platform: PlatformAnthropic, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-7-sonnet": "claude-3-7-sonnet", + }, + }, + }, + } + models3 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic) + require.Equal(t, []string{"claude-3-5-haiku", "claude-3-5-sonnet"}, models3) + require.Equal(t, int64(1), repo.listByGroupCalls.Load()) + + svc.InvalidateAvailableModelsCache(&groupID, PlatformAnthropic) + models4 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic) + require.Equal(t, []string{"claude-3-7-sonnet"}, models4) + require.Equal(t, int64(2), repo.listByGroupCalls.Load()) + + hit, miss, store := GatewayModelsListCacheStats() + require.Equal(t, int64(2), hit) + require.Equal(t, int64(2), miss) + require.Equal(t, int64(2), store) +} + +func TestGetAvailableModels_ErrorAndGlobalListBranches(t *testing.T) { + resetGatewayHotpathStatsForTest() + + errRepo := &modelsListAccountRepoStub{ + err: errors.New("db error"), + } + svcErr := &GatewayService{ + accountRepo: errRepo, + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + require.Nil(t, svcErr.GetAvailableModels(context.Background(), nil, "")) + + okRepo := &modelsListAccountRepoStub{ + all: []Account{ + { + ID: 1, + Platform: PlatformAnthropic, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "claude-3-5-sonnet", + }, + }, + }, + { + ID: 2, + Platform: PlatformGemini, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-2.5-pro": "gemini-2.5-pro", + }, + }, + }, + }, + } + svcOK := &GatewayService{ + accountRepo: okRepo, + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + models := svcOK.GetAvailableModels(context.Background(), nil, "") + require.Equal(t, []string{"claude-3-5-sonnet", "gemini-2.5-pro"}, models) + require.Equal(t, int64(1), okRepo.listAllCalls.Load()) +} + +func TestGatewayHotpathHelpers_CacheTTLAndStickyContext(t *testing.T) { + t.Run("resolve_user_group_rate_cache_ttl", func(t *testing.T) { + require.Equal(t, defaultUserGroupRateCacheTTL, resolveUserGroupRateCacheTTL(nil)) + + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + UserGroupRateCacheTTLSeconds: 45, + }, + } + require.Equal(t, 45*time.Second, resolveUserGroupRateCacheTTL(cfg)) + }) + + t.Run("resolve_models_list_cache_ttl", func(t *testing.T) { + require.Equal(t, defaultModelsListCacheTTL, resolveModelsListCacheTTL(nil)) + + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + ModelsListCacheTTLSeconds: 20, + }, + } + require.Equal(t, 20*time.Second, resolveModelsListCacheTTL(cfg)) + }) + + t.Run("prefetched_sticky_account_id_from_context", func(t *testing.T) { + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.TODO())) + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.Background())) + + ctx := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(123)) + require.Equal(t, int64(123), prefetchedStickyAccountIDFromContext(ctx)) + + ctx2 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, 456) + require.Equal(t, int64(456), prefetchedStickyAccountIDFromContext(ctx2)) + + ctx3 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, "invalid") + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx3)) + }) + + t.Run("window_cost_from_prefetch_context", func(t *testing.T) { + require.Equal(t, false, func() bool { + _, ok := windowCostFromPrefetchContext(context.TODO(), 0) + return ok + }()) + require.Equal(t, false, func() bool { + _, ok := windowCostFromPrefetchContext(context.Background(), 1) + return ok + }()) + + ctx := context.WithValue(context.Background(), windowCostPrefetchContextKey, map[int64]float64{ + 9: 12.34, + }) + cost, ok := windowCostFromPrefetchContext(ctx, 9) + require.True(t, ok) + require.Equal(t, 12.34, cost) + }) +} + +func TestInvalidateAvailableModelsCache_ByDimensions(t *testing.T) { + svc := &GatewayService{ + modelsListCache: gocache.New(time.Minute, time.Minute), + } + group9 := int64(9) + group10 := int64(10) + svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformAnthropic), []string{"a"}, time.Minute) + svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformGemini), []string{"b"}, time.Minute) + svc.modelsListCache.Set(modelsListCacheKey(&group10, PlatformAnthropic), []string{"c"}, time.Minute) + svc.modelsListCache.Set("invalid-key", []string{"d"}, time.Minute) + + t.Run("invalidate_group_and_platform", func(t *testing.T) { + svc.InvalidateAvailableModelsCache(&group9, PlatformAnthropic) + _, found := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic)) + require.False(t, found) + _, stillFound := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini)) + require.True(t, stillFound) + }) + + t.Run("invalidate_group_only", func(t *testing.T) { + svc.InvalidateAvailableModelsCache(&group9, "") + _, foundA := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic)) + _, foundB := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini)) + require.False(t, foundA) + require.False(t, foundB) + _, foundOtherGroup := svc.modelsListCache.Get(modelsListCacheKey(&group10, PlatformAnthropic)) + require.True(t, foundOtherGroup) + }) + + t.Run("invalidate_platform_only", func(t *testing.T) { + // 重建数据后仅按 platform 失效 + svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformAnthropic), []string{"a"}, time.Minute) + svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformGemini), []string{"b"}, time.Minute) + svc.modelsListCache.Set(modelsListCacheKey(&group10, PlatformAnthropic), []string{"c"}, time.Minute) + + svc.InvalidateAvailableModelsCache(nil, PlatformAnthropic) + _, found9Anthropic := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic)) + _, found10Anthropic := svc.modelsListCache.Get(modelsListCacheKey(&group10, PlatformAnthropic)) + _, found9Gemini := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini)) + require.False(t, found9Anthropic) + require.False(t, found10Anthropic) + require.True(t, found9Gemini) + }) +} + +func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) { + now := time.Now().Add(-time.Minute) + account := Account{ + ID: 88, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 4, + Priority: 1, + LastUsedAt: &now, + } + + repo := stubOpenAIAccountRepo{accounts: []Account{account}} + concurrency := NewConcurrencyService(stubConcurrencyCache{}) + + cfg := &config.Config{ + RunMode: config.RunModeStandard, + Gateway: config.GatewayConfig{ + Scheduling: config.GatewaySchedulingConfig{ + LoadBatchEnabled: true, + StickySessionMaxWaiting: 3, + StickySessionWaitTimeout: time.Second, + FallbackWaitTimeout: time.Second, + FallbackMaxWaiting: 10, + }, + }, + } + + baseCtx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformAnthropic) + + t.Run("without_prefetch_reads_cache_once", func(t *testing.T) { + cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID} + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: concurrency, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + + result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, account.ID, result.Account.ID) + require.Equal(t, int64(1), cache.getCalls.Load()) + }) + + t.Run("with_prefetch_skips_cache_read", func(t *testing.T) { + cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID} + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: concurrency, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + + ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID) + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, account.ID, result.Account.ID) + require.Equal(t, int64(0), cache.getCalls.Load()) + }) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index f16f685f..ae637ee3 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -24,12 +24,15 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/cespare/xxhash/v2" "github.com/google/uuid" + gocache "github.com/patrickmn/go-cache" "github.com/tidwall/gjson" "github.com/tidwall/sjson" + "golang.org/x/sync/singleflight" "github.com/gin-gonic/gin" ) @@ -44,6 +47,9 @@ const ( // separator between system blocks, we add "\n\n" at concatenation time. claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量 + + defaultUserGroupRateCacheTTL = 30 * time.Second + defaultModelsListCacheTTL = 15 * time.Second ) const ( @@ -62,6 +68,53 @@ type accountWithLoad struct { var ForceCacheBillingContextKey = forceCacheBillingKeyType{} +var ( + windowCostPrefetchCacheHitTotal atomic.Int64 + windowCostPrefetchCacheMissTotal atomic.Int64 + windowCostPrefetchBatchSQLTotal atomic.Int64 + windowCostPrefetchFallbackTotal atomic.Int64 + windowCostPrefetchErrorTotal atomic.Int64 + + userGroupRateCacheHitTotal atomic.Int64 + userGroupRateCacheMissTotal atomic.Int64 + userGroupRateCacheLoadTotal atomic.Int64 + userGroupRateCacheSFSharedTotal atomic.Int64 + userGroupRateCacheFallbackTotal atomic.Int64 + + modelsListCacheHitTotal atomic.Int64 + modelsListCacheMissTotal atomic.Int64 + modelsListCacheStoreTotal atomic.Int64 +) + +func GatewayWindowCostPrefetchStats() (cacheHit, cacheMiss, batchSQL, fallback, errCount int64) { + return windowCostPrefetchCacheHitTotal.Load(), + windowCostPrefetchCacheMissTotal.Load(), + windowCostPrefetchBatchSQLTotal.Load(), + windowCostPrefetchFallbackTotal.Load(), + windowCostPrefetchErrorTotal.Load() +} + +func GatewayUserGroupRateCacheStats() (cacheHit, cacheMiss, load, singleflightShared, fallback int64) { + return userGroupRateCacheHitTotal.Load(), + userGroupRateCacheMissTotal.Load(), + userGroupRateCacheLoadTotal.Load(), + userGroupRateCacheSFSharedTotal.Load(), + userGroupRateCacheFallbackTotal.Load() +} + +func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) { + return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load() +} + +func cloneStringSlice(src []string) []string { + if len(src) == 0 { + return nil + } + dst := make([]string, len(src)) + copy(dst, src) + return dst +} + // IsForceCacheBilling 检查是否启用强制缓存计费 func IsForceCacheBilling(ctx context.Context) bool { v, _ := ctx.Value(ForceCacheBillingContextKey).(bool) @@ -302,6 +355,42 @@ func derefGroupID(groupID *int64) int64 { return *groupID } +func resolveUserGroupRateCacheTTL(cfg *config.Config) time.Duration { + if cfg == nil || cfg.Gateway.UserGroupRateCacheTTLSeconds <= 0 { + return defaultUserGroupRateCacheTTL + } + return time.Duration(cfg.Gateway.UserGroupRateCacheTTLSeconds) * time.Second +} + +func resolveModelsListCacheTTL(cfg *config.Config) time.Duration { + if cfg == nil || cfg.Gateway.ModelsListCacheTTLSeconds <= 0 { + return defaultModelsListCacheTTL + } + return time.Duration(cfg.Gateway.ModelsListCacheTTLSeconds) * time.Second +} + +func modelsListCacheKey(groupID *int64, platform string) string { + return fmt.Sprintf("%d|%s", derefGroupID(groupID), strings.TrimSpace(platform)) +} + +func prefetchedStickyAccountIDFromContext(ctx context.Context) int64 { + if ctx == nil { + return 0 + } + v := ctx.Value(ctxkey.PrefetchedStickyAccountID) + switch t := v.(type) { + case int64: + if t > 0 { + return t + } + case int: + if t > 0 { + return int64(t) + } + } + return 0 +} + // shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。 // 当账号状态为错误、禁用、不可调度、处于临时不可调度期间, // 或请求的模型处于限流状态时,返回 true。 @@ -421,6 +510,10 @@ type GatewayService struct { concurrencyService *ConcurrencyService claudeTokenProvider *ClaudeTokenProvider sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) + userGroupRateCache *gocache.Cache + userGroupRateSF singleflight.Group + modelsListCache *gocache.Cache + modelsListCacheTTL time.Duration } // NewGatewayService creates a new GatewayService @@ -445,6 +538,9 @@ func NewGatewayService( sessionLimitCache SessionLimitCache, digestStore *DigestSessionStore, ) *GatewayService { + userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) + modelsListTTL := resolveModelsListCacheTTL(cfg) + return &GatewayService{ accountRepo: accountRepo, groupRepo: groupRepo, @@ -465,6 +561,9 @@ func NewGatewayService( deferredService: deferredService, claudeTokenProvider: claudeTokenProvider, sessionLimitCache: sessionLimitCache, + userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), + modelsListCache: gocache.New(modelsListTTL, time.Minute), + modelsListCacheTTL: modelsListTTL, } } @@ -937,7 +1036,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro cfg := s.schedulingConfig() var stickyAccountID int64 - if sessionHash != "" && s.cache != nil { + if prefetch := prefetchedStickyAccountIDFromContext(ctx); prefetch > 0 { + stickyAccountID = prefetch + } else if sessionHash != "" && s.cache != nil { if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { stickyAccountID = accountID } @@ -1035,6 +1136,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if len(accounts) == 0 { return nil, errors.New("no available accounts") } + ctx = s.withWindowCostPrefetch(ctx, accounts) isExcluded := func(accountID int64) bool { if excludedIDs == nil { @@ -1125,9 +1227,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if len(routingCandidates) > 0 { // 1.5. 在路由账号范围内检查粘性会话 - if sessionHash != "" && s.cache != nil { - stickyAccountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - if err == nil && stickyAccountID > 0 && containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { + if sessionHash != "" && stickyAccountID > 0 { + if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { // 粘性账号在路由列表中,优先使用 if stickyAccount, ok := accountByID[stickyAccountID]; ok { if stickyAccount.IsSchedulable() && @@ -1273,9 +1374,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } // ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============ - if len(routingAccountIDs) == 0 && sessionHash != "" && s.cache != nil { - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - if err == nil && accountID > 0 && !isExcluded(accountID) { + if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) { + accountID := stickyAccountID + if accountID > 0 && !isExcluded(accountID) { account, ok := accountByID[accountID] if ok { // 检查账户是否需要清理粘性会话绑定 @@ -1760,6 +1861,129 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) } +type usageLogWindowStatsBatchProvider interface { + GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) +} + +type windowCostPrefetchContextKeyType struct{} + +var windowCostPrefetchContextKey = windowCostPrefetchContextKeyType{} + +func windowCostFromPrefetchContext(ctx context.Context, accountID int64) (float64, bool) { + if ctx == nil || accountID <= 0 { + return 0, false + } + m, ok := ctx.Value(windowCostPrefetchContextKey).(map[int64]float64) + if !ok || len(m) == 0 { + return 0, false + } + v, exists := m[accountID] + return v, exists +} + +func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []Account) context.Context { + if ctx == nil || len(accounts) == 0 || s.sessionLimitCache == nil || s.usageLogRepo == nil { + return ctx + } + + accountByID := make(map[int64]*Account) + accountIDs := make([]int64, 0, len(accounts)) + for i := range accounts { + account := &accounts[i] + if account == nil || !account.IsAnthropicOAuthOrSetupToken() { + continue + } + if account.GetWindowCostLimit() <= 0 { + continue + } + accountByID[account.ID] = account + accountIDs = append(accountIDs, account.ID) + } + if len(accountIDs) == 0 { + return ctx + } + + costs := make(map[int64]float64, len(accountIDs)) + cacheValues, err := s.sessionLimitCache.GetWindowCostBatch(ctx, accountIDs) + if err == nil { + for accountID, cost := range cacheValues { + costs[accountID] = cost + } + windowCostPrefetchCacheHitTotal.Add(int64(len(cacheValues))) + } else { + windowCostPrefetchErrorTotal.Add(1) + logger.LegacyPrintf("service.gateway", "window_cost batch cache read failed: %v", err) + } + cacheMissCount := len(accountIDs) - len(costs) + if cacheMissCount < 0 { + cacheMissCount = 0 + } + windowCostPrefetchCacheMissTotal.Add(int64(cacheMissCount)) + + missingByStart := make(map[int64][]int64) + startTimes := make(map[int64]time.Time) + for _, accountID := range accountIDs { + if _, ok := costs[accountID]; ok { + continue + } + account := accountByID[accountID] + if account == nil { + continue + } + startTime := account.GetCurrentWindowStartTime() + startKey := startTime.Unix() + missingByStart[startKey] = append(missingByStart[startKey], accountID) + startTimes[startKey] = startTime + } + if len(missingByStart) == 0 { + return context.WithValue(ctx, windowCostPrefetchContextKey, costs) + } + + batchReader, hasBatch := s.usageLogRepo.(usageLogWindowStatsBatchProvider) + for startKey, ids := range missingByStart { + startTime := startTimes[startKey] + + if hasBatch { + windowCostPrefetchBatchSQLTotal.Add(1) + queryStart := time.Now() + statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, ids, startTime) + if err == nil { + slog.Debug("window_cost_batch_query_ok", + "accounts", len(ids), + "window_start", startTime.Format(time.RFC3339), + "duration_ms", time.Since(queryStart).Milliseconds()) + for _, accountID := range ids { + stats := statsByAccount[accountID] + cost := 0.0 + if stats != nil { + cost = stats.StandardCost + } + costs[accountID] = cost + _ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost) + } + continue + } + windowCostPrefetchErrorTotal.Add(1) + logger.LegacyPrintf("service.gateway", "window_cost batch db query failed: start=%s err=%v", startTime.Format(time.RFC3339), err) + } + + // 回退路径:缺少批量仓储能力或批量查询失败时,按账号单查(失败开放)。 + windowCostPrefetchFallbackTotal.Add(int64(len(ids))) + for _, accountID := range ids { + stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime) + if err != nil { + windowCostPrefetchErrorTotal.Add(1) + continue + } + cost := stats.StandardCost + costs[accountID] = cost + _ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost) + } + } + + return context.WithValue(ctx, windowCostPrefetchContextKey, costs) +} + // isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度 // 仅适用于 Anthropic OAuth/SetupToken 账号 // 返回 true 表示可调度,false 表示不可调度 @@ -1776,6 +2000,10 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, // 尝试从缓存获取窗口费用 var currentCost float64 + if cost, ok := windowCostFromPrefetchContext(ctx, account.ID); ok { + currentCost = cost + goto checkSchedulability + } if s.sessionLimitCache != nil { if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit { currentCost = cost @@ -5264,6 +5492,66 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo return body } +func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 { + if s == nil || userID <= 0 || groupID <= 0 { + return groupDefaultMultiplier + } + + key := fmt.Sprintf("%d:%d", userID, groupID) + if s.userGroupRateCache != nil { + if cached, ok := s.userGroupRateCache.Get(key); ok { + if multiplier, castOK := cached.(float64); castOK { + userGroupRateCacheHitTotal.Add(1) + return multiplier + } + } + } + if s.userGroupRateRepo == nil { + return groupDefaultMultiplier + } + userGroupRateCacheMissTotal.Add(1) + + value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) { + if s.userGroupRateCache != nil { + if cached, ok := s.userGroupRateCache.Get(key); ok { + if multiplier, castOK := cached.(float64); castOK { + userGroupRateCacheHitTotal.Add(1) + return multiplier, nil + } + } + } + + userGroupRateCacheLoadTotal.Add(1) + userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID) + if repoErr != nil { + return nil, repoErr + } + multiplier := groupDefaultMultiplier + if userRate != nil { + multiplier = *userRate + } + if s.userGroupRateCache != nil { + s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg)) + } + return multiplier, nil + }) + if shared { + userGroupRateCacheSFSharedTotal.Add(1) + } + if err != nil { + userGroupRateCacheFallbackTotal.Add(1) + logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err) + return groupDefaultMultiplier + } + + multiplier, ok := value.(float64) + if !ok { + userGroupRateCacheFallbackTotal.Add(1) + return groupDefaultMultiplier + } + return multiplier +} + // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { Result *ForwardResult @@ -5307,16 +5595,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) - multiplier := s.cfg.Default.RateMultiplier + multiplier := 1.0 + if s.cfg != nil { + multiplier = s.cfg.Default.RateMultiplier + } if apiKey.GroupID != nil && apiKey.Group != nil { - multiplier = apiKey.Group.RateMultiplier - - // 检查用户专属倍率 - if s.userGroupRateRepo != nil { - if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil { - multiplier = *userRate - } - } + groupDefault := apiKey.Group.RateMultiplier + multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) } var cost *CostBreakdown @@ -5522,16 +5807,13 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * } // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) - multiplier := s.cfg.Default.RateMultiplier + multiplier := 1.0 + if s.cfg != nil { + multiplier = s.cfg.Default.RateMultiplier + } if apiKey.GroupID != nil && apiKey.Group != nil { - multiplier = apiKey.Group.RateMultiplier - - // 检查用户专属倍率 - if s.userGroupRateRepo != nil { - if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil { - multiplier = *userRate - } - } + groupDefault := apiKey.Group.RateMultiplier + multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) } var cost *CostBreakdown @@ -6145,6 +6427,17 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { // GetAvailableModels returns the list of models available for a group // It aggregates model_mapping keys from all schedulable accounts in the group func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { + cacheKey := modelsListCacheKey(groupID, platform) + if s.modelsListCache != nil { + if cached, found := s.modelsListCache.Get(cacheKey); found { + if models, ok := cached.([]string); ok { + modelsListCacheHitTotal.Add(1) + return cloneStringSlice(models) + } + } + } + modelsListCacheMissTotal.Add(1) + var accounts []Account var err error @@ -6185,6 +6478,10 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, // If no account has model_mapping, return nil (use default) if !hasAnyMapping { + if s.modelsListCache != nil { + s.modelsListCache.Set(cacheKey, []string(nil), s.modelsListCacheTTL) + modelsListCacheStoreTotal.Add(1) + } return nil } @@ -6193,8 +6490,45 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, for model := range modelSet { models = append(models, model) } + sort.Strings(models) - return models + if s.modelsListCache != nil { + s.modelsListCache.Set(cacheKey, cloneStringSlice(models), s.modelsListCacheTTL) + modelsListCacheStoreTotal.Add(1) + } + return cloneStringSlice(models) +} + +func (s *GatewayService) InvalidateAvailableModelsCache(groupID *int64, platform string) { + if s == nil || s.modelsListCache == nil { + return + } + + normalizedPlatform := strings.TrimSpace(platform) + // 完整匹配时精准失效;否则按维度批量失效。 + if groupID != nil && normalizedPlatform != "" { + s.modelsListCache.Delete(modelsListCacheKey(groupID, normalizedPlatform)) + return + } + + targetGroup := derefGroupID(groupID) + for key := range s.modelsListCache.Items() { + parts := strings.SplitN(key, "|", 2) + if len(parts) != 2 { + continue + } + groupPart, parseErr := strconv.ParseInt(parts[0], 10, 64) + if parseErr != nil { + continue + } + if groupID != nil && groupPart != targetGroup { + continue + } + if normalizedPlatform != "" && parts[1] != normalizedPlatform { + continue + } + s.modelsListCache.Delete(key) + } } // reconcileCachedTokens 兼容 Kimi 等上游: diff --git a/backend/internal/service/ops_service.go b/backend/internal/service/ops_service.go index ed54bf6a..767d1704 100644 --- a/backend/internal/service/ops_service.go +++ b/backend/internal/service/ops_service.go @@ -20,6 +20,22 @@ const ( opsMaxStoredErrorBodyBytes = 20 * 1024 ) +// PrepareOpsRequestBodyForQueue 在入队前对请求体执行脱敏与裁剪,返回可直接写入 OpsInsertErrorLogInput 的字段。 +// 该方法用于避免异步队列持有大块原始请求体,减少错误风暴下的内存放大风险。 +func PrepareOpsRequestBodyForQueue(raw []byte) (requestBodyJSON *string, truncated bool, requestBodyBytes *int) { + if len(raw) == 0 { + return nil, false, nil + } + sanitized, truncated, bytesLen := sanitizeAndTrimRequestBody(raw, opsMaxStoredRequestBodyBytes) + if sanitized != "" { + out := sanitized + requestBodyJSON = &out + } + n := bytesLen + requestBodyBytes = &n + return requestBodyJSON, truncated, requestBodyBytes +} + // OpsService provides ingestion and query APIs for the Ops monitoring module. type OpsService struct { opsRepo OpsRepository @@ -132,12 +148,7 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn // Sanitize + trim request body (errors only). if len(rawRequestBody) > 0 { - sanitized, truncated, bytesLen := sanitizeAndTrimRequestBody(rawRequestBody, opsMaxStoredRequestBodyBytes) - if sanitized != "" { - entry.RequestBodyJSON = &sanitized - } - entry.RequestBodyTruncated = truncated - entry.RequestBodyBytes = &bytesLen + entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = PrepareOpsRequestBodyForQueue(rawRequestBody) } // Sanitize + truncate error_body to avoid storing sensitive data. diff --git a/backend/internal/service/ops_service_prepare_queue_test.go b/backend/internal/service/ops_service_prepare_queue_test.go new file mode 100644 index 00000000..d6f32c2d --- /dev/null +++ b/backend/internal/service/ops_service_prepare_queue_test.go @@ -0,0 +1,60 @@ +package service + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPrepareOpsRequestBodyForQueue_EmptyBody(t *testing.T) { + requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(nil) + require.Nil(t, requestBodyJSON) + require.False(t, truncated) + require.Nil(t, requestBodyBytes) +} + +func TestPrepareOpsRequestBodyForQueue_InvalidJSON(t *testing.T) { + raw := []byte("{invalid-json") + requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw) + require.Nil(t, requestBodyJSON) + require.False(t, truncated) + require.NotNil(t, requestBodyBytes) + require.Equal(t, len(raw), *requestBodyBytes) +} + +func TestPrepareOpsRequestBodyForQueue_RedactSensitiveFields(t *testing.T) { + raw := []byte(`{ + "model":"claude-3-5-sonnet-20241022", + "api_key":"sk-test-123", + "headers":{"authorization":"Bearer secret-token"}, + "messages":[{"role":"user","content":"hello"}] + }`) + + requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw) + require.NotNil(t, requestBodyJSON) + require.NotNil(t, requestBodyBytes) + require.False(t, truncated) + require.Equal(t, len(raw), *requestBodyBytes) + + var body map[string]any + require.NoError(t, json.Unmarshal([]byte(*requestBodyJSON), &body)) + require.Equal(t, "[REDACTED]", body["api_key"]) + headers, ok := body["headers"].(map[string]any) + require.True(t, ok) + require.Equal(t, "[REDACTED]", headers["authorization"]) +} + +func TestPrepareOpsRequestBodyForQueue_LargeBodyTruncated(t *testing.T) { + largeMsg := strings.Repeat("x", opsMaxStoredRequestBodyBytes*2) + raw := []byte(`{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":"` + largeMsg + `"}]}`) + + requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw) + require.NotNil(t, requestBodyJSON) + require.NotNil(t, requestBodyBytes) + require.True(t, truncated) + require.Equal(t, len(raw), *requestBodyBytes) + require.LessOrEqual(t, len(*requestBodyJSON), opsMaxStoredRequestBodyBytes) + require.Contains(t, *requestBodyJSON, "request_body_truncated") +}