perf(gateway): 优化热点路径并补齐高覆盖测试
This commit is contained in:
@@ -64,4 +64,3 @@ func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testi
|
||||
require.NotNil(t, created.Extra)
|
||||
require.Equal(t, true, created.Extra["anthropic_passthrough"])
|
||||
}
|
||||
|
||||
|
||||
@@ -243,6 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
var sessionBoundAccountID int64
|
||||
if sessionKey != "" {
|
||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||
if sessionBoundAccountID > 0 {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
|
||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -20,14 +21,28 @@ var claudeCodeValidator = service.NewClaudeCodeValidator()
|
||||
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
|
||||
// 返回更新后的 context
|
||||
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
||||
// 解析请求体为 map
|
||||
var bodyMap map[string]any
|
||||
if len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &bodyMap)
|
||||
if c == nil || c.Request == nil {
|
||||
return
|
||||
}
|
||||
// Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。
|
||||
if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) {
|
||||
ctx := service.SetClaudeCodeClient(c.Request.Context(), false)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证是否为 Claude Code 客户端
|
||||
isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap)
|
||||
isClaudeCode := false
|
||||
if !strings.Contains(c.Request.URL.Path, "messages") {
|
||||
// 与 Validate 行为一致:非 messages 路径 UA 命中即可视为 Claude Code 客户端。
|
||||
isClaudeCode = true
|
||||
} else {
|
||||
// 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。
|
||||
var bodyMap map[string]any
|
||||
if len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &bodyMap)
|
||||
}
|
||||
isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap)
|
||||
}
|
||||
|
||||
// 更新 request context
|
||||
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
|
||||
@@ -223,21 +238,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// Try immediate acquire first (avoid unnecessary wait)
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Determine if ping is needed (streaming + ping format defined)
|
||||
needPing := isStream && h.pingFormat != ""
|
||||
|
||||
|
||||
252
backend/internal/handler/gateway_helper_hotpath_test.go
Normal file
252
backend/internal/handler/gateway_helper_hotpath_test.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type helperConcurrencyCacheStub struct {
|
||||
mu sync.Mutex
|
||||
|
||||
accountSeq []bool
|
||||
userSeq []bool
|
||||
|
||||
accountAcquireCalls int
|
||||
userAcquireCalls int
|
||||
accountReleaseCalls int
|
||||
userReleaseCalls int
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.accountAcquireCalls++
|
||||
if len(s.accountSeq) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
v := s.accountSeq[0]
|
||||
s.accountSeq = s.accountSeq[1:]
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.accountReleaseCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.userAcquireCalls++
|
||||
if len(s.userSeq) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
v := s.userSeq[0]
|
||||
s.userSeq = s.userSeq[1:]
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.userReleaseCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||
out := make(map[int64]*service.AccountLoadInfo, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
out[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||
out := make(map[int64]*service.UserLoadInfo, len(users))
|
||||
for _, user := range users {
|
||||
out[user.ID] = &service.UserLoadInfo{UserID: user.ID}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(method, path, nil)
|
||||
return c, rec
|
||||
}
|
||||
|
||||
func validClaudeCodeBodyJSON() []byte {
|
||||
return []byte(`{
|
||||
"model":"claude-3-5-sonnet-20241022",
|
||||
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
|
||||
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
|
||||
}`)
|
||||
}
|
||||
|
||||
func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
||||
t.Run("non_cli_user_agent_sets_false", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "curl/8.6.0")
|
||||
|
||||
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
|
||||
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("cli_non_messages_path_sets_true", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodGet, "/v1/models")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
|
||||
SetClaudeCodeClientContext(c, nil)
|
||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("cli_messages_path_valid_body_sets_true", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
c.Request.Header.Set("X-App", "claude-code")
|
||||
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
||||
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
|
||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("cli_messages_path_invalid_body_sets_false", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
// 缺少严格校验所需 header + body 字段
|
||||
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`))
|
||||
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
|
||||
cache := &helperConcurrencyCacheStub{
|
||||
accountSeq: []bool{false, true},
|
||||
userSeq: []bool{false, true},
|
||||
}
|
||||
concurrency := service.NewConcurrencyService(cache)
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||
|
||||
t.Run("account_slot_acquired_after_retry", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, release)
|
||||
require.False(t, streamStarted)
|
||||
release()
|
||||
require.GreaterOrEqual(t, cache.accountAcquireCalls, 2)
|
||||
require.GreaterOrEqual(t, cache.accountReleaseCalls, 1)
|
||||
})
|
||||
|
||||
t.Run("user_slot_acquired_after_retry", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, release)
|
||||
release()
|
||||
require.GreaterOrEqual(t, cache.userAcquireCalls, 2)
|
||||
require.GreaterOrEqual(t, cache.userReleaseCalls, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
|
||||
cache := &helperConcurrencyCacheStub{
|
||||
accountSeq: []bool{false, false, false},
|
||||
}
|
||||
concurrency := service.NewConcurrencyService(cache)
|
||||
|
||||
t.Run("timeout_returns_concurrency_error", func(t *testing.T) {
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted)
|
||||
require.Nil(t, release)
|
||||
var cErr *ConcurrencyError
|
||||
require.ErrorAs(t, err, &cErr)
|
||||
require.True(t, cErr.IsTimeout)
|
||||
})
|
||||
|
||||
t.Run("stream_mode_sends_ping_before_timeout", func(t *testing.T) {
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond)
|
||||
c, rec := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted)
|
||||
require.Nil(t, release)
|
||||
var cErr *ConcurrencyError
|
||||
require.ErrorAs(t, err, &cErr)
|
||||
require.True(t, cErr.IsTimeout)
|
||||
require.True(t, streamStarted)
|
||||
require.Contains(t, rec.Body.String(), ":\n\n")
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) {
|
||||
errCache := &helperConcurrencyCacheStubWithError{
|
||||
err: errors.New("redis unavailable"),
|
||||
}
|
||||
concurrency := service.NewConcurrencyService(errCache)
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted)
|
||||
require.Nil(t, release)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "redis unavailable")
|
||||
}
|
||||
|
||||
type helperConcurrencyCacheStubWithError struct {
|
||||
helperConcurrencyCacheStub
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStubWithError) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return false, s.err
|
||||
}
|
||||
@@ -263,6 +263,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
var sessionBoundAccountID int64
|
||||
if sessionKey != "" {
|
||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||
if sessionBoundAccountID > 0 {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// === Gemini 内容摘要会话 Fallback 逻辑 ===
|
||||
|
||||
@@ -41,9 +41,8 @@ const (
|
||||
)
|
||||
|
||||
type opsErrorLogJob struct {
|
||||
ops *service.OpsService
|
||||
entry *service.OpsInsertErrorLogInput
|
||||
requestBody []byte
|
||||
ops *service.OpsService
|
||||
entry *service.OpsInsertErrorLogInput
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -58,6 +57,7 @@ var (
|
||||
opsErrorLogEnqueued atomic.Int64
|
||||
opsErrorLogDropped atomic.Int64
|
||||
opsErrorLogProcessed atomic.Int64
|
||||
opsErrorLogSanitized atomic.Int64
|
||||
|
||||
opsErrorLogLastDropLogAt atomic.Int64
|
||||
|
||||
@@ -94,7 +94,7 @@ func startOpsErrorLogWorkers() {
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||
_ = job.ops.RecordError(ctx, job.entry, job.requestBody)
|
||||
_ = job.ops.RecordError(ctx, job.entry, nil)
|
||||
cancel()
|
||||
opsErrorLogProcessed.Add(1)
|
||||
}()
|
||||
@@ -103,7 +103,7 @@ func startOpsErrorLogWorkers() {
|
||||
}
|
||||
}
|
||||
|
||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput, requestBody []byte) {
|
||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
|
||||
if ops == nil || entry == nil {
|
||||
return
|
||||
}
|
||||
@@ -129,7 +129,7 @@ func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLo
|
||||
}
|
||||
|
||||
select {
|
||||
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry, requestBody: requestBody}:
|
||||
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry}:
|
||||
opsErrorLogQueueLen.Add(1)
|
||||
opsErrorLogEnqueued.Add(1)
|
||||
default:
|
||||
@@ -205,6 +205,10 @@ func OpsErrorLogProcessedTotal() int64 {
|
||||
return opsErrorLogProcessed.Load()
|
||||
}
|
||||
|
||||
func OpsErrorLogSanitizedTotal() int64 {
|
||||
return opsErrorLogSanitized.Load()
|
||||
}
|
||||
|
||||
func maybeLogOpsErrorLogDrop() {
|
||||
now := time.Now().Unix()
|
||||
|
||||
@@ -222,12 +226,13 @@ func maybeLogOpsErrorLogDrop() {
|
||||
queueCap := OpsErrorLogQueueCapacity()
|
||||
|
||||
log.Printf(
|
||||
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d)",
|
||||
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d sanitized_total=%d)",
|
||||
queued,
|
||||
queueCap,
|
||||
opsErrorLogEnqueued.Load(),
|
||||
opsErrorLogDropped.Load(),
|
||||
opsErrorLogProcessed.Load(),
|
||||
opsErrorLogSanitized.Load(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -267,6 +272,22 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
|
||||
}
|
||||
}
|
||||
|
||||
func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
|
||||
if c == nil || entry == nil {
|
||||
return
|
||||
}
|
||||
v, ok := c.Get(opsRequestBodyKey)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
raw, ok := v.([]byte)
|
||||
if !ok || len(raw) == 0 {
|
||||
return
|
||||
}
|
||||
entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = service.PrepareOpsRequestBodyForQueue(raw)
|
||||
opsErrorLogSanitized.Add(1)
|
||||
}
|
||||
|
||||
func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) {
|
||||
if c == nil || accountID <= 0 {
|
||||
return
|
||||
@@ -544,14 +565,9 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
entry.ClientIP = &clientIP
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if v, ok := c.Get(opsRequestBodyKey); ok {
|
||||
if b, ok := v.([]byte); ok && len(b) > 0 {
|
||||
requestBody = b
|
||||
}
|
||||
}
|
||||
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
|
||||
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
|
||||
// Skip logging if a passthrough rule with skip_monitoring=true matched.
|
||||
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
|
||||
@@ -560,7 +576,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
enqueueOpsErrorLog(ops, entry, requestBody)
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -724,17 +740,12 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
entry.ClientIP = &clientIP
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if v, ok := c.Get(opsRequestBodyKey); ok {
|
||||
if b, ok := v.([]byte); ok && len(b) > 0 {
|
||||
requestBody = b
|
||||
}
|
||||
}
|
||||
// Persist only a minimal, whitelisted set of request headers to improve retry fidelity.
|
||||
// Do NOT store Authorization/Cookie/etc.
|
||||
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
|
||||
enqueueOpsErrorLog(ops, entry, requestBody)
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
175
backend/internal/handler/ops_error_logger_test.go
Normal file
175
backend/internal/handler/ops_error_logger_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func resetOpsErrorLoggerStateForTest(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
opsErrorLogMu.Lock()
|
||||
ch := opsErrorLogQueue
|
||||
opsErrorLogQueue = nil
|
||||
opsErrorLogStopping = true
|
||||
opsErrorLogMu.Unlock()
|
||||
|
||||
if ch != nil {
|
||||
close(ch)
|
||||
}
|
||||
opsErrorLogWorkersWg.Wait()
|
||||
|
||||
opsErrorLogOnce = sync.Once{}
|
||||
opsErrorLogStopOnce = sync.Once{}
|
||||
opsErrorLogWorkersWg = sync.WaitGroup{}
|
||||
opsErrorLogMu = sync.RWMutex{}
|
||||
opsErrorLogStopping = false
|
||||
|
||||
opsErrorLogQueueLen.Store(0)
|
||||
opsErrorLogEnqueued.Store(0)
|
||||
opsErrorLogDropped.Store(0)
|
||||
opsErrorLogProcessed.Store(0)
|
||||
opsErrorLogSanitized.Store(0)
|
||||
opsErrorLogLastDropLogAt.Store(0)
|
||||
|
||||
opsErrorLogShutdownCh = make(chan struct{})
|
||||
opsErrorLogShutdownOnce = sync.Once{}
|
||||
opsErrorLogDrained.Store(false)
|
||||
}
|
||||
|
||||
func TestAttachOpsRequestBodyToEntry_SanitizeAndTrim(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
raw := []byte(`{"access_token":"secret-token","messages":[{"role":"user","content":"hello"}]}`)
|
||||
setOpsRequestContext(c, "claude-3", false, raw)
|
||||
|
||||
entry := &service.OpsInsertErrorLogInput{}
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
|
||||
require.NotNil(t, entry.RequestBodyBytes)
|
||||
require.Equal(t, len(raw), *entry.RequestBodyBytes)
|
||||
require.NotNil(t, entry.RequestBodyJSON)
|
||||
require.NotContains(t, *entry.RequestBodyJSON, "secret-token")
|
||||
require.Contains(t, *entry.RequestBodyJSON, "[REDACTED]")
|
||||
require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
|
||||
}
|
||||
|
||||
func TestAttachOpsRequestBodyToEntry_InvalidJSONKeepsSize(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
raw := []byte("not-json")
|
||||
setOpsRequestContext(c, "claude-3", false, raw)
|
||||
|
||||
entry := &service.OpsInsertErrorLogInput{}
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
|
||||
require.Nil(t, entry.RequestBodyJSON)
|
||||
require.NotNil(t, entry.RequestBodyBytes)
|
||||
require.Equal(t, len(raw), *entry.RequestBodyBytes)
|
||||
require.False(t, entry.RequestBodyTruncated)
|
||||
require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
|
||||
}
|
||||
|
||||
func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
|
||||
// 禁止 enqueueOpsErrorLog 触发 workers,使用测试队列验证满队列降级。
|
||||
opsErrorLogOnce.Do(func() {})
|
||||
|
||||
opsErrorLogMu.Lock()
|
||||
opsErrorLogQueue = make(chan opsErrorLogJob, 1)
|
||||
opsErrorLogMu.Unlock()
|
||||
|
||||
ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"}
|
||||
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
|
||||
require.Equal(t, int64(1), OpsErrorLogEnqueuedTotal())
|
||||
require.Equal(t, int64(1), OpsErrorLogDroppedTotal())
|
||||
require.Equal(t, int64(1), OpsErrorLogQueueLength())
|
||||
}
|
||||
|
||||
func TestAttachOpsRequestBodyToEntry_EarlyReturnBranches(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
entry := &service.OpsInsertErrorLogInput{}
|
||||
attachOpsRequestBodyToEntry(nil, entry)
|
||||
attachOpsRequestBodyToEntry(&gin.Context{}, nil)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 无请求体 key
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
require.Nil(t, entry.RequestBodyJSON)
|
||||
require.Nil(t, entry.RequestBodyBytes)
|
||||
require.False(t, entry.RequestBodyTruncated)
|
||||
|
||||
// 错误类型
|
||||
c.Set(opsRequestBodyKey, "not-bytes")
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
require.Nil(t, entry.RequestBodyJSON)
|
||||
require.Nil(t, entry.RequestBodyBytes)
|
||||
|
||||
// 空 bytes
|
||||
c.Set(opsRequestBodyKey, []byte{})
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
require.Nil(t, entry.RequestBodyJSON)
|
||||
require.Nil(t, entry.RequestBodyBytes)
|
||||
|
||||
require.Equal(t, int64(0), OpsErrorLogSanitizedTotal())
|
||||
}
|
||||
|
||||
func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
|
||||
ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"}
|
||||
|
||||
// nil 入参分支
|
||||
enqueueOpsErrorLog(nil, entry)
|
||||
enqueueOpsErrorLog(ops, nil)
|
||||
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||
|
||||
// shutdown 分支
|
||||
close(opsErrorLogShutdownCh)
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||
|
||||
// stopping 分支
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
opsErrorLogMu.Lock()
|
||||
opsErrorLogStopping = true
|
||||
opsErrorLogMu.Unlock()
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||
|
||||
// queue nil 分支(防止启动 worker 干扰)
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
opsErrorLogOnce.Do(func() {})
|
||||
opsErrorLogMu.Lock()
|
||||
opsErrorLogQueue = nil
|
||||
opsErrorLogMu.Unlock()
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||
}
|
||||
Reference in New Issue
Block a user