test: add warmup request interception unit tests
Add comprehensive tests for warmup request interception behavior covering Antigravity accounts with various credential configurations.
This commit is contained in:
@@ -0,0 +1,340 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
middleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 目标:严格验证“antigravity 账号通过 /v1/messages 提供 Claude 服务时”,
|
||||
// 当账号 credentials.intercept_warmup_requests=true 且请求为 Warmup 时,
|
||||
// 后端会在转发上游前直接拦截并返回 mock 响应(不依赖上游)。
|
||||
|
||||
type fakeSchedulerCache struct {
|
||||
accounts []*service.Account
|
||||
}
|
||||
|
||||
func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||
return f.accounts, true, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error {
|
||||
return nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil }
|
||||
func (f *fakeSchedulerCache) DeleteAccount(_ context.Context, _ int64) error { return nil }
|
||||
func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) GetOutboxWatermark(_ context.Context) (int64, error) { return 0, nil }
|
||||
func (f *fakeSchedulerCache) SetOutboxWatermark(_ context.Context, _ int64) error { return nil }
|
||||
|
||||
type fakeGroupRepo struct {
|
||||
group *service.Group
|
||||
}
|
||||
|
||||
func (f *fakeGroupRepo) Create(context.Context, *service.Group) error { return nil }
|
||||
func (f *fakeGroupRepo) GetByID(context.Context, int64) (*service.Group, error) {
|
||||
return f.group, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) GetByIDLite(context.Context, int64) (*service.Group, error) {
|
||||
return f.group, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) Update(context.Context, *service.Group) error { return nil }
|
||||
func (f *fakeGroupRepo) Delete(context.Context, int64) error { return nil }
|
||||
func (f *fakeGroupRepo) DeleteCascade(context.Context, int64) ([]int64, error) { return nil, nil }
|
||||
func (f *fakeGroupRepo) List(context.Context, pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) ListActive(context.Context) ([]service.Group, error) { return nil, nil }
|
||||
func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service.Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
||||
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
|
||||
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) BindAccountsToGroup(context.Context, int64, []int64) error { return nil }
|
||||
func (f *fakeGroupRepo) UpdateSortOrders(context.Context, []service.GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeConcurrencyCache struct{}
|
||||
|
||||
func (f *fakeConcurrencyCache) AcquireAccountSlot(context.Context, int64, int, string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) ReleaseAccountSlot(context.Context, int64, string) error { return nil }
|
||||
func (f *fakeConcurrencyCache) GetAccountConcurrency(context.Context, int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) IncrementAccountWaitCount(context.Context, int64, int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) DecrementAccountWaitCount(context.Context, int64) error { return nil }
|
||||
func (f *fakeConcurrencyCache) GetAccountWaitingCount(context.Context, int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) AcquireUserSlot(context.Context, int64, int, string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) ReleaseUserSlot(context.Context, int64, string) error { return nil }
|
||||
func (f *fakeConcurrencyCache) GetUserConcurrency(context.Context, int64) (int, error) { return 0, nil }
|
||||
func (f *fakeConcurrencyCache) IncrementWaitCount(context.Context, int64, int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) DecrementWaitCount(context.Context, int64) error { return nil }
|
||||
func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||
return map[int64]*service.AccountLoadInfo{}, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||
return map[int64]*service.UserLoadInfo{}, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
||||
|
||||
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
||||
t.Helper()
|
||||
|
||||
schedulerCache := &fakeSchedulerCache{accounts: accounts}
|
||||
schedulerSnapshot := service.NewSchedulerSnapshotService(schedulerCache, nil, nil, nil, nil)
|
||||
|
||||
gwSvc := service.NewGatewayService(
|
||||
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||
&fakeGroupRepo{group: group},
|
||||
nil, // usageLogRepo
|
||||
nil, // userRepo
|
||||
nil, // userSubRepo
|
||||
nil, // userGroupRateRepo
|
||||
nil, // cache (disable sticky)
|
||||
nil, // cfg
|
||||
schedulerSnapshot,
|
||||
nil, // concurrencyService (disable load-aware; tryAcquire always acquired)
|
||||
nil, // billingService
|
||||
nil, // rateLimitService
|
||||
nil, // billingCacheService
|
||||
nil, // identityService
|
||||
nil, // httpUpstream
|
||||
nil, // deferredService
|
||||
nil, // claudeTokenProvider
|
||||
nil, // sessionLimitCache
|
||||
nil, // digestStore
|
||||
)
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg)
|
||||
|
||||
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
|
||||
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
|
||||
|
||||
h := &GatewayHandler{
|
||||
gatewayService: gwSvc,
|
||||
billingCacheService: billingCacheSvc,
|
||||
concurrencyHelper: concurrencyHelper,
|
||||
// 这些字段对本测试不敏感,保持较小即可
|
||||
maxAccountSwitches: 1,
|
||||
maxAccountSwitchesGemini: 1,
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
billingCacheSvc.Stop()
|
||||
}
|
||||
return h, cleanup
|
||||
}
|
||||
|
||||
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_MixedSchedulingV1(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(2001)
|
||||
accountID := int64(1001)
|
||||
|
||||
group := &service.Group{
|
||||
ID: groupID,
|
||||
Hydrated: true,
|
||||
Platform: service.PlatformAnthropic, // /v1/messages(Claude兼容)入口
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
account := &service.Account{
|
||||
ID: accountID,
|
||||
Name: "ag-1",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "tok_xxx",
|
||||
"intercept_warmup_requests": true,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"mixed_scheduling": true, // 关键:允许被 anthropic 分组混合调度选中
|
||||
},
|
||||
Concurrency: 1,
|
||||
Priority: 1,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
|
||||
}
|
||||
|
||||
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
|
||||
defer cleanup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
body := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"max_tokens": 256,
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
|
||||
}`)
|
||||
req := httptest.NewRequest("POST", "/v1/messages", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, group))
|
||||
c.Request = req
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 3001,
|
||||
UserID: 4001,
|
||||
GroupID: &groupID,
|
||||
Status: service.StatusActive,
|
||||
User: &service.User{
|
||||
ID: 4001,
|
||||
Concurrency: 10,
|
||||
Balance: 100,
|
||||
},
|
||||
Group: group,
|
||||
}
|
||||
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
|
||||
|
||||
h.Messages(c)
|
||||
|
||||
require.Equal(t, 200, rec.Code)
|
||||
|
||||
// 断言:确实选中了 antigravity 账号(不是纯函数测试,而是从 Handler 里验证调度结果)
|
||||
selected, ok := c.Get(opsAccountIDKey)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, accountID, selected)
|
||||
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "msg_mock_warmup", resp["id"])
|
||||
require.Equal(t, "claude-sonnet-4-5", resp["model"])
|
||||
|
||||
content, ok := resp["content"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 1)
|
||||
first, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "New Conversation", first["text"])
|
||||
}
|
||||
|
||||
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_ForcePlatform(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(2002)
|
||||
accountID := int64(1002)
|
||||
|
||||
group := &service.Group{
|
||||
ID: groupID,
|
||||
Hydrated: true,
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
account := &service.Account{
|
||||
ID: accountID,
|
||||
Name: "ag-2",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "tok_xxx",
|
||||
"intercept_warmup_requests": true,
|
||||
},
|
||||
Concurrency: 1,
|
||||
Priority: 1,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
|
||||
}
|
||||
|
||||
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
|
||||
defer cleanup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
body := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"max_tokens": 256,
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
|
||||
}`)
|
||||
req := httptest.NewRequest("POST", "/antigravity/v1/messages", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 模拟 routes/gateway.go 里的 ForcePlatform 中间件效果:
|
||||
// - 写入 request.Context(Service读取)
|
||||
// - 写入 gin.Context(Handler快速读取)
|
||||
ctx := context.WithValue(req.Context(), ctxkey.Group, group)
|
||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformAntigravity)
|
||||
req = req.WithContext(ctx)
|
||||
c.Request = req
|
||||
c.Set(string(middleware.ContextKeyForcePlatform), service.PlatformAntigravity)
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 3002,
|
||||
UserID: 4002,
|
||||
GroupID: &groupID,
|
||||
Status: service.StatusActive,
|
||||
User: &service.User{
|
||||
ID: 4002,
|
||||
Concurrency: 10,
|
||||
Balance: 100,
|
||||
},
|
||||
Group: group,
|
||||
}
|
||||
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
|
||||
|
||||
h.Messages(c)
|
||||
|
||||
require.Equal(t, 200, rec.Code)
|
||||
|
||||
selected, ok := c.Get(opsAccountIDKey)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, accountID, selected)
|
||||
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "msg_mock_warmup", resp["id"])
|
||||
require.Equal(t, "claude-sonnet-4-5", resp["model"])
|
||||
}
|
||||
Reference in New Issue
Block a user