Merge branch 'test' into dev

This commit is contained in:
yangjianbo
2026-01-10 09:39:02 +08:00
18 changed files with 805 additions and 61 deletions

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
@@ -191,6 +192,56 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
return nil
}
type mockGroupRepoForGateway struct {
groups map[int64]*Group
getByIDCalls int
getByIDLiteCalls int
}
func (m *mockGroupRepoForGateway) GetByID(ctx context.Context, id int64) (*Group, error) {
m.getByIDCalls++
if g, ok := m.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (m *mockGroupRepoForGateway) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
m.getByIDLiteCalls++
if g, ok := m.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (m *mockGroupRepoForGateway) Create(ctx context.Context, group *Group) error { return nil }
func (m *mockGroupRepoForGateway) Update(ctx context.Context, group *Group) error { return nil }
func (m *mockGroupRepoForGateway) Delete(ctx context.Context, id int64) error { return nil }
func (m *mockGroupRepoForGateway) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
return nil, nil
}
func (m *mockGroupRepoForGateway) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockGroupRepoForGateway) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockGroupRepoForGateway) ListActive(ctx context.Context) ([]Group, error) {
return nil, nil
}
func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
return nil, nil
}
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func ptr[T any](v T) *T {
return &v
}
@@ -1019,3 +1070,190 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require.Contains(t, err.Error(), "no available accounts")
})
}
func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
ctx := context.Background()
groupID := int64(42)
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
}
ctx = context.WithValue(ctx, ctxkey.Group, group)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{groupID: group},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cfg: testConfig(),
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
}
func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) {
ctx := context.Background()
groupID := int64(42)
ctxGroup := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
ctx = context.WithValue(ctx, ctxkey.Group, ctxGroup)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{groupID: group},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cfg: testConfig(),
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
}
func TestGatewayService_GroupContext_OverwritesInvalidContextGroup(t *testing.T) {
groupID := int64(42)
invalidGroup := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
hydratedGroup := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
}
ctx := context.WithValue(context.Background(), ctxkey.Group, invalidGroup)
svc := &GatewayService{}
ctx = svc.withGroupContext(ctx, hydratedGroup)
got, ok := ctx.Value(ctxkey.Group).(*Group)
require.True(t, ok)
require.Same(t, hydratedGroup, got)
}
func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
ctx := context.Background()
groupID := int64(10)
fallbackID := int64(11)
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
Hydrated: true,
}
fallbackGroup := &Group{
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
ctx = context.WithValue(ctx, ctxkey.Group, group)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{fallbackID: fallbackGroup},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cfg: testConfig(),
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
}
func TestGatewayService_ResolveGatewayGroup_DetectsFallbackCycle(t *testing.T) {
ctx := context.Background()
groupID := int64(10)
fallbackID := int64(11)
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
}
fallbackGroup := &Group{
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &groupID,
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: group,
fallbackID: fallbackGroup,
},
}
svc := &GatewayService{
groupRepo: groupRepo,
}
gotGroup, gotID, err := svc.resolveGatewayGroup(ctx, &groupID)
require.Error(t, err)
require.Nil(t, gotGroup)
require.Nil(t, gotID)
require.Contains(t, err.Error(), "fallback group cycle")
}