270 lines
9.2 KiB
Go
270 lines
9.2 KiB
Go
package handler
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type helperConcurrencyCacheStub struct {
|
|
mu sync.Mutex
|
|
|
|
accountSeq []bool
|
|
userSeq []bool
|
|
|
|
accountAcquireCalls int
|
|
userAcquireCalls int
|
|
accountReleaseCalls int
|
|
userReleaseCalls int
|
|
}
|
|
|
|
func (s *helperConcurrencyCacheStub) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.accountAcquireCalls++
|
|
if len(s.accountSeq) == 0 {
|
|
return false, nil
|
|
}
|
|
v := s.accountSeq[0]
|
|
s.accountSeq = s.accountSeq[1:]
|
|
return v, nil
|
|
}
|
|
|
|
func (s *helperConcurrencyCacheStub) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.accountReleaseCalls++
|
|
return nil
|
|
}
|
|
|
|
func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
|
return true, nil
|
|
}
|
|
|
|
func (s *helperConcurrencyCacheStub) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
|
return nil
|
|
}
|
|
|
|
func (s *helperConcurrencyCacheStub) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (s *helperConcurrencyCacheStub) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.userAcquireCalls++
|
|
if len(s.userSeq) == 0 {
|
|
return false, nil
|
|
}
|
|
v := s.userSeq[0]
|
|
s.userSeq = s.userSeq[1:]
|
|
return v, nil
|
|
}
|
|
|
|
func (s *helperConcurrencyCacheStub) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.userReleaseCalls++
|
|
return nil
|
|
}
|
|
|
|
func (s *helperConcurrencyCacheStub) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (s *helperConcurrencyCacheStub) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
|
return true, nil
|
|
}
|
|
|
|
func (s *helperConcurrencyCacheStub) DecrementWaitCount(ctx context.Context, userID int64) error {
|
|
return nil
|
|
}
|
|
|
|
func (s *helperConcurrencyCacheStub) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
|
out := make(map[int64]*service.AccountLoadInfo, len(accounts))
|
|
for _, acc := range accounts {
|
|
out[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID}
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *helperConcurrencyCacheStub) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
|
out := make(map[int64]*service.UserLoadInfo, len(users))
|
|
for _, user := range users {
|
|
out[user.ID] = &service.UserLoadInfo{UserID: user.ID}
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
|
return nil
|
|
}
|
|
|
|
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
|
|
gin.SetMode(gin.TestMode)
|
|
rec := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(rec)
|
|
c.Request = httptest.NewRequest(method, path, nil)
|
|
return c, rec
|
|
}
|
|
|
|
func validClaudeCodeBodyJSON() []byte {
|
|
return []byte(`{
|
|
"model":"claude-3-5-sonnet-20241022",
|
|
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
|
|
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
|
|
}`)
|
|
}
|
|
|
|
func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
|
t.Run("non_cli_user_agent_sets_false", func(t *testing.T) {
|
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
|
c.Request.Header.Set("User-Agent", "curl/8.6.0")
|
|
|
|
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
|
|
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
|
})
|
|
|
|
t.Run("cli_non_messages_path_sets_true", func(t *testing.T) {
|
|
c, _ := newHelperTestContext(http.MethodGet, "/v1/models")
|
|
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
|
|
|
SetClaudeCodeClientContext(c, nil)
|
|
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
|
})
|
|
|
|
t.Run("cli_messages_path_valid_body_sets_true", func(t *testing.T) {
|
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
|
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
|
c.Request.Header.Set("X-App", "claude-code")
|
|
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
|
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
|
|
|
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
|
|
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
|
})
|
|
|
|
t.Run("cli_messages_path_invalid_body_sets_false", func(t *testing.T) {
|
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
|
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
|
// 缺少严格校验所需 header + body 字段
|
|
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`))
|
|
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
|
})
|
|
}
|
|
|
|
func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
|
|
cache := &helperConcurrencyCacheStub{
|
|
accountSeq: []bool{false, true},
|
|
userSeq: []bool{false, true},
|
|
}
|
|
concurrency := service.NewConcurrencyService(cache)
|
|
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
|
|
|
t.Run("account_slot_acquired_after_retry", func(t *testing.T) {
|
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
|
streamStarted := false
|
|
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, release)
|
|
require.False(t, streamStarted)
|
|
release()
|
|
require.GreaterOrEqual(t, cache.accountAcquireCalls, 2)
|
|
require.GreaterOrEqual(t, cache.accountReleaseCalls, 1)
|
|
})
|
|
|
|
t.Run("user_slot_acquired_after_retry", func(t *testing.T) {
|
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
|
streamStarted := false
|
|
release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted, true)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, release)
|
|
release()
|
|
require.GreaterOrEqual(t, cache.userAcquireCalls, 2)
|
|
require.GreaterOrEqual(t, cache.userReleaseCalls, 1)
|
|
})
|
|
}
|
|
|
|
func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
|
|
cache := &helperConcurrencyCacheStub{
|
|
accountSeq: []bool{false, false, false},
|
|
}
|
|
concurrency := service.NewConcurrencyService(cache)
|
|
|
|
t.Run("timeout_returns_concurrency_error", func(t *testing.T) {
|
|
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
|
streamStarted := false
|
|
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted, true)
|
|
require.Nil(t, release)
|
|
var cErr *ConcurrencyError
|
|
require.ErrorAs(t, err, &cErr)
|
|
require.True(t, cErr.IsTimeout)
|
|
})
|
|
|
|
t.Run("stream_mode_sends_ping_before_timeout", func(t *testing.T) {
|
|
helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond)
|
|
c, rec := newHelperTestContext(http.MethodPost, "/v1/messages")
|
|
streamStarted := false
|
|
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted, true)
|
|
require.Nil(t, release)
|
|
var cErr *ConcurrencyError
|
|
require.ErrorAs(t, err, &cErr)
|
|
require.True(t, cErr.IsTimeout)
|
|
require.True(t, streamStarted)
|
|
require.Contains(t, rec.Body.String(), ":\n\n")
|
|
})
|
|
}
|
|
|
|
func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) {
|
|
errCache := &helperConcurrencyCacheStubWithError{
|
|
err: errors.New("redis unavailable"),
|
|
}
|
|
concurrency := service.NewConcurrencyService(errCache)
|
|
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
|
streamStarted := false
|
|
release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted, true)
|
|
require.Nil(t, release)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "redis unavailable")
|
|
}
|
|
|
|
func TestAcquireAccountSlotWithWaitTimeout_ImmediateAttemptBeforeBackoff(t *testing.T) {
|
|
cache := &helperConcurrencyCacheStub{
|
|
accountSeq: []bool{false},
|
|
}
|
|
concurrency := service.NewConcurrencyService(cache)
|
|
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
|
streamStarted := false
|
|
|
|
release, err := helper.AcquireAccountSlotWithWaitTimeout(c, 301, 1, 30*time.Millisecond, false, &streamStarted)
|
|
require.Nil(t, release)
|
|
var cErr *ConcurrencyError
|
|
require.ErrorAs(t, err, &cErr)
|
|
require.True(t, cErr.IsTimeout)
|
|
require.GreaterOrEqual(t, cache.accountAcquireCalls, 1)
|
|
}
|
|
|
|
type helperConcurrencyCacheStubWithError struct {
|
|
helperConcurrencyCacheStub
|
|
err error
|
|
}
|
|
|
|
func (s *helperConcurrencyCacheStubWithError) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
|
return false, s.err
|
|
}
|