Merge branch 'mt21625457/main'
This commit is contained in:
@@ -576,18 +576,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
|
||||
return fmt.Errorf("cannot set self as fallback group")
|
||||
}
|
||||
|
||||
// 检查降级分组是否存在
|
||||
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback group not found: %w", err)
|
||||
}
|
||||
visited := map[int64]struct{}{}
|
||||
nextID := fallbackGroupID
|
||||
for {
|
||||
if _, seen := visited[nextID]; seen {
|
||||
return fmt.Errorf("fallback group cycle detected")
|
||||
}
|
||||
visited[nextID] = struct{}{}
|
||||
if currentGroupID > 0 && nextID == currentGroupID {
|
||||
return fmt.Errorf("fallback group cycle detected")
|
||||
}
|
||||
|
||||
// 降级分组不能启用 claude_code_only,否则会造成死循环
|
||||
if fallbackGroup.ClaudeCodeOnly {
|
||||
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
|
||||
}
|
||||
// 检查降级分组是否存在
|
||||
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, nextID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback group not found: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
// 降级分组不能启用 claude_code_only,否则会造成死循环
|
||||
if nextID == fallbackGroupID && fallbackGroup.ClaudeCodeOnly {
|
||||
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
|
||||
}
|
||||
|
||||
if fallbackGroup.FallbackGroupID == nil {
|
||||
return nil
|
||||
}
|
||||
nextID = *fallbackGroup.FallbackGroupID
|
||||
}
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||
|
||||
@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
|
||||
panic("unexpected GetByIDLite call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) Update(ctx context.Context, group *Group) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
@@ -45,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err
|
||||
return s.getByID, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetByIDLite(_ context.Context, _ int64) (*Group, error) {
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
return s.getByID, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
@@ -290,3 +297,84 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
|
||||
require.True(t, *repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ValidateFallbackGroup_DetectsCycle(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
fallbackID := int64(2)
|
||||
repo := &groupRepoStubForFallbackCycle{
|
||||
groups: map[int64]*Group{
|
||||
groupID: {
|
||||
ID: groupID,
|
||||
FallbackGroupID: &fallbackID,
|
||||
},
|
||||
fallbackID: {
|
||||
ID: fallbackID,
|
||||
FallbackGroupID: &groupID,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
err := svc.validateFallbackGroup(context.Background(), groupID, fallbackID)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "fallback group cycle")
|
||||
}
|
||||
|
||||
type groupRepoStubForFallbackCycle struct {
|
||||
groups map[int64]*Group
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Create(_ context.Context, _ *Group) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Update(_ context.Context, _ *Group) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
return s.GetByIDLite(ctx, id)
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||
if g, ok := s.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||
panic("unexpected DeleteCascade call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListActive(_ context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -191,6 +192,56 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockGroupRepoForGateway struct {
|
||||
groups map[int64]*Group
|
||||
getByIDCalls int
|
||||
getByIDLiteCalls int
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDLiteCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) Create(ctx context.Context, group *Group) error { return nil }
|
||||
func (m *mockGroupRepoForGateway) Update(ctx context.Context, group *Group) error { return nil }
|
||||
func (m *mockGroupRepoForGateway) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockGroupRepoForGateway) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) ListActive(ctx context.Context) ([]Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
@@ -1019,3 +1070,190 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(42)
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{groupID: group},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(42)
|
||||
ctxGroup := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, ctxGroup)
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{groupID: group},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGatewayService_GroupContext_OverwritesInvalidContextGroup(t *testing.T) {
|
||||
groupID := int64(42)
|
||||
invalidGroup := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
}
|
||||
hydratedGroup := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.Group, invalidGroup)
|
||||
svc := &GatewayService{}
|
||||
ctx = svc.withGroupContext(ctx, hydratedGroup)
|
||||
|
||||
got, ok := ctx.Value(ctxkey.Group).(*Group)
|
||||
require.True(t, ok)
|
||||
require.Same(t, hydratedGroup, got)
|
||||
}
|
||||
|
||||
func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10)
|
||||
fallbackID := int64(11)
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
ClaudeCodeOnly: true,
|
||||
FallbackGroupID: &fallbackID,
|
||||
Hydrated: true,
|
||||
}
|
||||
fallbackGroup := &Group{
|
||||
ID: fallbackID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{fallbackID: fallbackGroup},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGatewayService_ResolveGatewayGroup_DetectsFallbackCycle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10)
|
||||
fallbackID := int64(11)
|
||||
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
ClaudeCodeOnly: true,
|
||||
FallbackGroupID: &fallbackID,
|
||||
}
|
||||
fallbackGroup := &Group{
|
||||
ID: fallbackID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
ClaudeCodeOnly: true,
|
||||
FallbackGroupID: &groupID,
|
||||
}
|
||||
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{
|
||||
groupID: group,
|
||||
fallbackID: fallbackGroup,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
groupRepo: groupRepo,
|
||||
}
|
||||
|
||||
gotGroup, gotID, err := svc.resolveGatewayGroup(ctx, &groupID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, gotGroup)
|
||||
require.Nil(t, gotID)
|
||||
require.Contains(t, err.Error(), "fallback group cycle")
|
||||
}
|
||||
|
||||
@@ -361,27 +361,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
platform = forcePlatform
|
||||
} else if groupID != nil {
|
||||
// 根据分组 platform 决定查询哪种账号
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
group, resolvedGroupID, err := s.resolveGatewayGroup(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
groupID = resolvedGroupID
|
||||
ctx = s.withGroupContext(ctx, group)
|
||||
platform = group.Platform
|
||||
|
||||
// 检查 Claude Code 客户端限制
|
||||
if group.ClaudeCodeOnly {
|
||||
isClaudeCode := IsClaudeCodeClient(ctx)
|
||||
if !isClaudeCode {
|
||||
// 非 Claude Code 客户端,检查是否有降级分组
|
||||
if group.FallbackGroupID != nil {
|
||||
// 使用降级分组重新调度
|
||||
fallbackGroupID := *group.FallbackGroupID
|
||||
return s.SelectAccountForModelWithExclusions(ctx, &fallbackGroupID, sessionHash, requestedModel, excludedIDs)
|
||||
}
|
||||
// 无降级分组,拒绝访问
|
||||
return nil, ErrClaudeCodeOnly
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 无分组时只使用原生 anthropic 平台
|
||||
platform = PlatformAnthropic
|
||||
@@ -409,10 +395,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
|
||||
groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
|
||||
group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx = s.withGroupContext(ctx, group)
|
||||
|
||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
@@ -452,7 +439,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}, nil
|
||||
}
|
||||
|
||||
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID)
|
||||
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -655,51 +642,97 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) context.Context {
|
||||
if !IsGroupContextValid(group) {
|
||||
return ctx
|
||||
}
|
||||
if existing, ok := ctx.Value(ctxkey.Group).(*Group); ok && existing != nil && existing.ID == group.ID && IsGroupContextValid(existing) {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, ctxkey.Group, group)
|
||||
}
|
||||
|
||||
func (s *GatewayService) groupFromContext(ctx context.Context, groupID int64) *Group {
|
||||
if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(group) && group.ID == groupID {
|
||||
return group
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*Group, error) {
|
||||
if group := s.groupFromContext(ctx, groupID); group != nil {
|
||||
return group, nil
|
||||
}
|
||||
group, err := s.groupRepo.GetByIDLite(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) {
|
||||
if groupID == nil {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
currentID := *groupID
|
||||
visited := map[int64]struct{}{}
|
||||
for {
|
||||
if _, seen := visited[currentID]; seen {
|
||||
return nil, nil, fmt.Errorf("fallback group cycle detected")
|
||||
}
|
||||
visited[currentID] = struct{}{}
|
||||
|
||||
group, err := s.resolveGroupByID(ctx, currentID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !group.ClaudeCodeOnly || IsClaudeCodeClient(ctx) {
|
||||
return group, ¤tID, nil
|
||||
}
|
||||
|
||||
if group.FallbackGroupID == nil {
|
||||
return nil, nil, ErrClaudeCodeOnly
|
||||
}
|
||||
currentID = *group.FallbackGroupID
|
||||
}
|
||||
}
|
||||
|
||||
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
|
||||
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
|
||||
// - 有降级分组:返回降级分组的 ID
|
||||
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
|
||||
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*int64, error) {
|
||||
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*Group, *int64, error) {
|
||||
if groupID == nil {
|
||||
return groupID, nil
|
||||
return nil, groupID, nil
|
||||
}
|
||||
|
||||
// 强制平台模式不检查 Claude Code 限制
|
||||
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
|
||||
return groupID, nil
|
||||
return nil, groupID, nil
|
||||
}
|
||||
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
group, resolvedID, err := s.resolveGatewayGroup(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !group.ClaudeCodeOnly {
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
// 分组启用了 Claude Code 限制
|
||||
if IsClaudeCodeClient(ctx) {
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
// 非 Claude Code 客户端,检查降级分组
|
||||
if group.FallbackGroupID != nil {
|
||||
return group.FallbackGroupID, nil
|
||||
}
|
||||
|
||||
return nil, ErrClaudeCodeOnly
|
||||
return group, resolvedID, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
|
||||
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, group *Group) (string, bool, error) {
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
return forcePlatform, true, nil
|
||||
}
|
||||
if group != nil {
|
||||
return group.Platform, false, nil
|
||||
}
|
||||
if groupID != nil {
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
group, err := s.resolveGroupByID(ctx, *groupID)
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("get group failed: %w", err)
|
||||
return "", false, err
|
||||
}
|
||||
return group.Platform, false, nil
|
||||
}
|
||||
|
||||
@@ -86,9 +86,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
platform = forcePlatform
|
||||
} else if groupID != nil {
|
||||
// 根据分组 platform 决定查询哪种账号
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
var group *Group
|
||||
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
|
||||
group = ctxGroup
|
||||
} else {
|
||||
var err error
|
||||
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
}
|
||||
platform = group.Platform
|
||||
} else {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -152,10 +153,21 @@ var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
|
||||
|
||||
// mockGroupRepoForGemini Gemini 测试用的 group repo mock
|
||||
type mockGroupRepoForGemini struct {
|
||||
groups map[int64]*Group
|
||||
groups map[int64]*Group
|
||||
getByIDCalls int
|
||||
getByIDLiteCalls int
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDLiteCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
@@ -248,6 +260,77 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiP
|
||||
require.Equal(t, PlatformGemini, acc.Platform, "无分组时应只返回 gemini 平台账户")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_GroupResolution_ReusesContextGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(7)
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformGemini,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_GroupResolution_UsesLiteFetch(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(7)
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{
|
||||
groups: map[int64]*Group{
|
||||
groupID: {ID: groupID, Platform: PlatformGemini},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -10,6 +10,7 @@ type Group struct {
|
||||
RateMultiplier float64
|
||||
IsExclusive bool
|
||||
Status string
|
||||
Hydrated bool // indicates the group was loaded from a trusted repository source
|
||||
|
||||
SubscriptionType string
|
||||
DailyLimitUSD *float64
|
||||
@@ -72,3 +73,20 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
|
||||
return g.ImagePrice2K
|
||||
}
|
||||
}
|
||||
|
||||
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
|
||||
func IsGroupContextValid(group *Group) bool {
|
||||
if group == nil {
|
||||
return false
|
||||
}
|
||||
if group.ID <= 0 {
|
||||
return false
|
||||
}
|
||||
if !group.Hydrated {
|
||||
return false
|
||||
}
|
||||
if group.Platform == "" || group.Status == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ var (
|
||||
type GroupRepository interface {
|
||||
Create(ctx context.Context, group *Group) error
|
||||
GetByID(ctx context.Context, id int64) (*Group, error)
|
||||
GetByIDLite(ctx context.Context, id int64) (*Group, error)
|
||||
Update(ctx context.Context, group *Group) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
|
||||
|
||||
Reference in New Issue
Block a user