fix(分组): 防止降级环并校验上下文分组
- 增加降级链路环检测并拦截配置 - 仅复用合法分组上下文并必要时回退查询 - 标注 GetByIDLite 轻量语义并补充测试
This commit is contained in:
@@ -70,6 +70,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
||||||
|
// AccountCount is intentionally not loaded here; use GetByID when needed.
|
||||||
m, err := r.client.Group.Query().
|
m, err := r.client.Group.Query().
|
||||||
Where(group.IDEQ(id)).
|
Where(group.IDEQ(id)).
|
||||||
Only(ctx)
|
Only(ctx)
|
||||||
|
|||||||
@@ -179,7 +179,7 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
|
|||||||
}
|
}
|
||||||
|
|
||||||
func setGroupContext(c *gin.Context, group *service.Group) {
|
func setGroupContext(c *gin.Context, group *service.Group) {
|
||||||
if group == nil {
|
if !service.IsGroupContextValid(group) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if existing, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group); ok && existing != nil && existing.ID == group.ID {
|
if existing, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group); ok && existing != nil && existing.ID == group.ID {
|
||||||
|
|||||||
@@ -575,18 +575,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
|
|||||||
return fmt.Errorf("cannot set self as fallback group")
|
return fmt.Errorf("cannot set self as fallback group")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查降级分组是否存在
|
visited := map[int64]struct{}{}
|
||||||
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID)
|
nextID := fallbackGroupID
|
||||||
if err != nil {
|
for {
|
||||||
return fmt.Errorf("fallback group not found: %w", err)
|
if _, seen := visited[nextID]; seen {
|
||||||
}
|
return fmt.Errorf("fallback group cycle detected")
|
||||||
|
}
|
||||||
|
visited[nextID] = struct{}{}
|
||||||
|
if currentGroupID > 0 && nextID == currentGroupID {
|
||||||
|
return fmt.Errorf("fallback group cycle detected")
|
||||||
|
}
|
||||||
|
|
||||||
// 降级分组不能启用 claude_code_only,否则会造成死循环
|
// 检查降级分组是否存在
|
||||||
if fallbackGroup.ClaudeCodeOnly {
|
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, nextID)
|
||||||
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
|
if err != nil {
|
||||||
}
|
return fmt.Errorf("fallback group not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
// 降级分组不能启用 claude_code_only,否则会造成死循环
|
||||||
|
if nextID == fallbackGroupID && fallbackGroup.ClaudeCodeOnly {
|
||||||
|
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
if fallbackGroup.FallbackGroupID == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
nextID = *fallbackGroup.FallbackGroupID
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||||
|
|||||||
@@ -202,3 +202,84 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
|
|||||||
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
|
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
|
||||||
require.Nil(t, repo.updated.ImagePrice4K)
|
require.Nil(t, repo.updated.ImagePrice4K)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAdminService_ValidateFallbackGroup_DetectsCycle(t *testing.T) {
|
||||||
|
groupID := int64(1)
|
||||||
|
fallbackID := int64(2)
|
||||||
|
repo := &groupRepoStubForFallbackCycle{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
groupID: {
|
||||||
|
ID: groupID,
|
||||||
|
FallbackGroupID: &fallbackID,
|
||||||
|
},
|
||||||
|
fallbackID: {
|
||||||
|
ID: fallbackID,
|
||||||
|
FallbackGroupID: &groupID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
err := svc.validateFallbackGroup(context.Background(), groupID, fallbackID)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "fallback group cycle")
|
||||||
|
}
|
||||||
|
|
||||||
|
type groupRepoStubForFallbackCycle struct {
|
||||||
|
groups map[int64]*Group
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForFallbackCycle) Create(_ context.Context, _ *Group) error {
|
||||||
|
panic("unexpected Create call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForFallbackCycle) Update(_ context.Context, _ *Group) error {
|
||||||
|
panic("unexpected Update call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForFallbackCycle) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||||
|
return s.GetByIDLite(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForFallbackCycle) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||||
|
if g, ok := s.groups[id]; ok {
|
||||||
|
return g, nil
|
||||||
|
}
|
||||||
|
return nil, ErrGroupNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForFallbackCycle) Delete(_ context.Context, _ int64) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForFallbackCycle) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||||
|
panic("unexpected DeleteCascade call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForFallbackCycle) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected List call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForFallbackCycle) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListWithFilters call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForFallbackCycle) ListActive(_ context.Context) ([]Group, error) {
|
||||||
|
panic("unexpected ListActive call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForFallbackCycle) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||||
|
panic("unexpected ListActiveByPlatform call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||||
|
panic("unexpected ExistsByName call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||||
|
panic("unexpected GetAccountCount call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||||
|
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||||
|
}
|
||||||
|
|||||||
@@ -1102,6 +1102,47 @@ func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
|
|||||||
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
|
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
groupID := int64(42)
|
||||||
|
ctxGroup := &Group{
|
||||||
|
ID: groupID,
|
||||||
|
Status: StatusActive,
|
||||||
|
}
|
||||||
|
ctx = context.WithValue(ctx, ctxkey.Group, ctxGroup)
|
||||||
|
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
group := &Group{
|
||||||
|
ID: groupID,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Status: StatusActive,
|
||||||
|
}
|
||||||
|
groupRepo := &mockGroupRepoForGateway{
|
||||||
|
groups: map[int64]*Group{groupID: group},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
cfg: testConfig(),
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, account)
|
||||||
|
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||||
|
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||||
|
}
|
||||||
|
|
||||||
func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
|
func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
groupID := int64(10)
|
groupID := int64(10)
|
||||||
@@ -1146,3 +1187,41 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
|
|||||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||||
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_ResolveGatewayGroup_DetectsFallbackCycle(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
groupID := int64(10)
|
||||||
|
fallbackID := int64(11)
|
||||||
|
|
||||||
|
group := &Group{
|
||||||
|
ID: groupID,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Status: StatusActive,
|
||||||
|
ClaudeCodeOnly: true,
|
||||||
|
FallbackGroupID: &fallbackID,
|
||||||
|
}
|
||||||
|
fallbackGroup := &Group{
|
||||||
|
ID: fallbackID,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Status: StatusActive,
|
||||||
|
ClaudeCodeOnly: true,
|
||||||
|
FallbackGroupID: &groupID,
|
||||||
|
}
|
||||||
|
|
||||||
|
groupRepo := &mockGroupRepoForGateway{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
groupID: group,
|
||||||
|
fallbackID: fallbackGroup,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
}
|
||||||
|
|
||||||
|
gotGroup, gotID, err := svc.resolveGatewayGroup(ctx, &groupID)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, gotGroup)
|
||||||
|
require.Nil(t, gotID)
|
||||||
|
require.Contains(t, err.Error(), "fallback group cycle")
|
||||||
|
}
|
||||||
|
|||||||
@@ -640,7 +640,7 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) context.Context {
|
func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) context.Context {
|
||||||
if group == nil {
|
if !IsGroupContextValid(group) {
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
if existing, ok := ctx.Value(ctxkey.Group).(*Group); ok && existing != nil && existing.ID == group.ID {
|
if existing, ok := ctx.Value(ctxkey.Group).(*Group); ok && existing != nil && existing.ID == group.ID {
|
||||||
@@ -650,7 +650,7 @@ func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) con
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) groupFromContext(ctx context.Context, groupID int64) *Group {
|
func (s *GatewayService) groupFromContext(ctx context.Context, groupID int64) *Group {
|
||||||
if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && group != nil && group.ID == groupID {
|
if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(group) && group.ID == groupID {
|
||||||
return group
|
return group
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -673,7 +673,13 @@ func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64
|
|||||||
}
|
}
|
||||||
|
|
||||||
currentID := *groupID
|
currentID := *groupID
|
||||||
|
visited := map[int64]struct{}{}
|
||||||
for {
|
for {
|
||||||
|
if _, seen := visited[currentID]; seen {
|
||||||
|
return nil, nil, fmt.Errorf("fallback group cycle detected")
|
||||||
|
}
|
||||||
|
visited[currentID] = struct{}{}
|
||||||
|
|
||||||
group, err := s.resolveGroupByID(ctx, currentID)
|
group, err := s.resolveGroupByID(ctx, currentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
} else if groupID != nil {
|
} else if groupID != nil {
|
||||||
// 根据分组 platform 决定查询哪种账号
|
// 根据分组 platform 决定查询哪种账号
|
||||||
var group *Group
|
var group *Group
|
||||||
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && ctxGroup != nil && ctxGroup.ID == *groupID {
|
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
|
||||||
group = ctxGroup
|
group = ctxGroup
|
||||||
} else {
|
} else {
|
||||||
var err error
|
var err error
|
||||||
|
|||||||
@@ -72,3 +72,17 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
|
|||||||
return g.ImagePrice2K
|
return g.ImagePrice2K
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
|
||||||
|
func IsGroupContextValid(group *Group) bool {
|
||||||
|
if group == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if group.ID <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if group.Platform == "" || group.Status == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user