package handler import ( "context" "errors" "net/http" "net/http/httptest" "sync" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) type helperConcurrencyCacheStub struct { mu sync.Mutex accountSeq []bool userSeq []bool accountAcquireCalls int userAcquireCalls int accountReleaseCalls int userReleaseCalls int } func (s *helperConcurrencyCacheStub) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { s.mu.Lock() defer s.mu.Unlock() s.accountAcquireCalls++ if len(s.accountSeq) == 0 { return false, nil } v := s.accountSeq[0] s.accountSeq = s.accountSeq[1:] return v, nil } func (s *helperConcurrencyCacheStub) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { s.mu.Lock() defer s.mu.Unlock() s.accountReleaseCalls++ return nil } func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { return 0, nil } func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { return true, nil } func (s *helperConcurrencyCacheStub) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { return nil } func (s *helperConcurrencyCacheStub) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { return 0, nil } func (s *helperConcurrencyCacheStub) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { s.mu.Lock() defer s.mu.Unlock() s.userAcquireCalls++ if len(s.userSeq) == 0 { return false, nil } v := s.userSeq[0] s.userSeq = s.userSeq[1:] return v, nil } func (s *helperConcurrencyCacheStub) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { s.mu.Lock() defer s.mu.Unlock() s.userReleaseCalls++ return nil } func (s *helperConcurrencyCacheStub) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { return 0, nil } func (s *helperConcurrencyCacheStub) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { return true, nil } func (s *helperConcurrencyCacheStub) DecrementWaitCount(ctx context.Context, userID int64) error { return nil } func (s *helperConcurrencyCacheStub) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { out := make(map[int64]*service.AccountLoadInfo, len(accounts)) for _, acc := range accounts { out[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID} } return out, nil } func (s *helperConcurrencyCacheStub) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { out := make(map[int64]*service.UserLoadInfo, len(users)) for _, user := range users { out[user.ID] = &service.UserLoadInfo{UserID: user.ID} } return out, nil } func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { return nil } func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(method, path, nil) return c, rec } func validClaudeCodeBodyJSON() []byte { return []byte(`{ "model":"claude-3-5-sonnet-20241022", "system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}], "metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"} }`) } func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) { t.Run("non_cli_user_agent_sets_false", func(t *testing.T) { c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") c.Request.Header.Set("User-Agent", "curl/8.6.0") SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON()) require.False(t, service.IsClaudeCodeClient(c.Request.Context())) }) t.Run("cli_non_messages_path_sets_true", func(t *testing.T) { c, _ := newHelperTestContext(http.MethodGet, "/v1/models") c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") SetClaudeCodeClientContext(c, nil) require.True(t, service.IsClaudeCodeClient(c.Request.Context())) }) t.Run("cli_messages_path_valid_body_sets_true", func(t *testing.T) { c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") c.Request.Header.Set("X-App", "claude-code") c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24") c.Request.Header.Set("anthropic-version", "2023-06-01") SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON()) require.True(t, service.IsClaudeCodeClient(c.Request.Context())) }) t.Run("cli_messages_path_invalid_body_sets_false", func(t *testing.T) { c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") // 缺少严格校验所需 header + body 字段 SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`)) require.False(t, service.IsClaudeCodeClient(c.Request.Context())) }) } func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) { cache := &helperConcurrencyCacheStub{ accountSeq: []bool{false, true}, userSeq: []bool{false, true}, } concurrency := service.NewConcurrencyService(cache) helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) t.Run("account_slot_acquired_after_retry", func(t *testing.T) { c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") streamStarted := false release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true) require.NoError(t, err) require.NotNil(t, release) require.False(t, streamStarted) release() require.GreaterOrEqual(t, cache.accountAcquireCalls, 2) require.GreaterOrEqual(t, cache.accountReleaseCalls, 1) }) t.Run("user_slot_acquired_after_retry", func(t *testing.T) { c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") streamStarted := false release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted, true) require.NoError(t, err) require.NotNil(t, release) release() require.GreaterOrEqual(t, cache.userAcquireCalls, 2) require.GreaterOrEqual(t, cache.userReleaseCalls, 1) }) } func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) { cache := &helperConcurrencyCacheStub{ accountSeq: []bool{false, false, false}, } concurrency := service.NewConcurrencyService(cache) t.Run("timeout_returns_concurrency_error", func(t *testing.T) { helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") streamStarted := false release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted, true) require.Nil(t, release) var cErr *ConcurrencyError require.ErrorAs(t, err, &cErr) require.True(t, cErr.IsTimeout) }) t.Run("stream_mode_sends_ping_before_timeout", func(t *testing.T) { helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond) c, rec := newHelperTestContext(http.MethodPost, "/v1/messages") streamStarted := false release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted, true) require.Nil(t, release) var cErr *ConcurrencyError require.ErrorAs(t, err, &cErr) require.True(t, cErr.IsTimeout) require.True(t, streamStarted) require.Contains(t, rec.Body.String(), ":\n\n") }) } func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) { errCache := &helperConcurrencyCacheStubWithError{ err: errors.New("redis unavailable"), } concurrency := service.NewConcurrencyService(errCache) helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") streamStarted := false release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted, true) require.Nil(t, release) require.Error(t, err) require.Contains(t, err.Error(), "redis unavailable") } func TestAcquireAccountSlotWithWaitTimeout_ImmediateAttemptBeforeBackoff(t *testing.T) { cache := &helperConcurrencyCacheStub{ accountSeq: []bool{false}, } concurrency := service.NewConcurrencyService(cache) helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") streamStarted := false release, err := helper.AcquireAccountSlotWithWaitTimeout(c, 301, 1, 30*time.Millisecond, false, &streamStarted) require.Nil(t, release) var cErr *ConcurrencyError require.ErrorAs(t, err, &cErr) require.True(t, cErr.IsTimeout) require.GreaterOrEqual(t, cache.accountAcquireCalls, 1) } type helperConcurrencyCacheStubWithError struct { helperConcurrencyCacheStub err error } func (s *helperConcurrencyCacheStubWithError) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { return false, s.err }