From 2597fe78ba2d42384f95c8f7982d0078d77e824a Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 10 Jan 2026 07:56:50 +0800 Subject: [PATCH] =?UTF-8?q?fix(=E5=88=86=E7=BB=84):=20=E9=98=B2=E6=AD=A2?= =?UTF-8?q?=E9=99=8D=E7=BA=A7=E7=8E=AF=E5=B9=B6=E6=A0=A1=E9=AA=8C=E4=B8=8A?= =?UTF-8?q?=E4=B8=8B=E6=96=87=E5=88=86=E7=BB=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 增加降级链路环检测并拦截配置 - 仅复用合法分组上下文并必要时回退查询 - 标注 GetByIDLite 轻量语义并补充测试 --- backend/internal/repository/group_repo.go | 1 + .../server/middleware/api_key_auth.go | 2 +- backend/internal/service/admin_service.go | 35 +++++--- .../service/admin_service_group_test.go | 81 +++++++++++++++++++ .../service/gateway_multiplatform_test.go | 79 ++++++++++++++++++ backend/internal/service/gateway_service.go | 10 ++- .../service/gemini_messages_compat_service.go | 2 +- backend/internal/service/group.go | 14 ++++ 8 files changed, 210 insertions(+), 14 deletions(-) diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index daff8b89..afb41a93 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -70,6 +70,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group } func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) { + // AccountCount is intentionally not loaded here; use GetByID when needed. m, err := r.client.Group.Query(). Where(group.IDEQ(id)). Only(ctx) diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 8d78e32d..bb4f549a 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -179,7 +179,7 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool } func setGroupContext(c *gin.Context, group *service.Group) { - if group == nil { + if !service.IsGroupContextValid(group) { return } if existing, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group); ok && existing != nil && existing.ID == group.ID { diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index e29bbdb4..e6e171d0 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -575,18 +575,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro return fmt.Errorf("cannot set self as fallback group") } - // 检查降级分组是否存在 - fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID) - if err != nil { - return fmt.Errorf("fallback group not found: %w", err) - } + visited := map[int64]struct{}{} + nextID := fallbackGroupID + for { + if _, seen := visited[nextID]; seen { + return fmt.Errorf("fallback group cycle detected") + } + visited[nextID] = struct{}{} + if currentGroupID > 0 && nextID == currentGroupID { + return fmt.Errorf("fallback group cycle detected") + } - // 降级分组不能启用 claude_code_only,否则会造成死循环 - if fallbackGroup.ClaudeCodeOnly { - return fmt.Errorf("fallback group cannot have claude_code_only enabled") - } + // 检查降级分组是否存在 + fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, nextID) + if err != nil { + return fmt.Errorf("fallback group not found: %w", err) + } - return nil + // 降级分组不能启用 claude_code_only,否则会造成死循环 + if nextID == fallbackGroupID && fallbackGroup.ClaudeCodeOnly { + return fmt.Errorf("fallback group cannot have claude_code_only enabled") + } + + if fallbackGroup.FallbackGroupID == nil { + return nil + } + nextID = *fallbackGroup.FallbackGroupID + } } func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index 675f4c6f..4e956dae 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -202,3 +202,84 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持 require.Nil(t, repo.updated.ImagePrice4K) } + +func TestAdminService_ValidateFallbackGroup_DetectsCycle(t *testing.T) { + groupID := int64(1) + fallbackID := int64(2) + repo := &groupRepoStubForFallbackCycle{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + FallbackGroupID: &fallbackID, + }, + fallbackID: { + ID: fallbackID, + FallbackGroupID: &groupID, + }, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + err := svc.validateFallbackGroup(context.Background(), groupID, fallbackID) + require.Error(t, err) + require.Contains(t, err.Error(), "fallback group cycle") +} + +type groupRepoStubForFallbackCycle struct { + groups map[int64]*Group +} + +func (s *groupRepoStubForFallbackCycle) Create(_ context.Context, _ *Group) error { + panic("unexpected Create call") +} + +func (s *groupRepoStubForFallbackCycle) Update(_ context.Context, _ *Group) error { + panic("unexpected Update call") +} + +func (s *groupRepoStubForFallbackCycle) GetByID(ctx context.Context, id int64) (*Group, error) { + return s.GetByIDLite(ctx, id) +} + +func (s *groupRepoStubForFallbackCycle) GetByIDLite(_ context.Context, id int64) (*Group, error) { + if g, ok := s.groups[id]; ok { + return g, nil + } + return nil, ErrGroupNotFound +} + +func (s *groupRepoStubForFallbackCycle) Delete(_ context.Context, _ int64) error { + panic("unexpected Delete call") +} + +func (s *groupRepoStubForFallbackCycle) DeleteCascade(_ context.Context, _ int64) ([]int64, error) { + panic("unexpected DeleteCascade call") +} + +func (s *groupRepoStubForFallbackCycle) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *groupRepoStubForFallbackCycle) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *groupRepoStubForFallbackCycle) ListActive(_ context.Context) ([]Group, error) { + panic("unexpected ListActive call") +} + +func (s *groupRepoStubForFallbackCycle) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) { + panic("unexpected ListActiveByPlatform call") +} + +func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string) (bool, error) { + panic("unexpected ExistsByName call") +} + +func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) { + panic("unexpected GetAccountCount call") +} + +func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) { + panic("unexpected DeleteAccountGroupsByGroupID call") +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 4f6545e2..0b5729fe 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -1102,6 +1102,47 @@ func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) { require.Equal(t, 0, groupRepo.getByIDLiteCalls) } +func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) { + ctx := context.Background() + groupID := int64(42) + ctxGroup := &Group{ + ID: groupID, + Status: StatusActive, + } + ctx = context.WithValue(ctx, ctxkey.Group, ctxGroup) + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + group := &Group{ + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + } + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{groupID: group}, + } + + svc := &GatewayService{ + accountRepo: repo, + groupRepo: groupRepo, + cfg: testConfig(), + } + + account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, 0, groupRepo.getByIDCalls) + require.Equal(t, 1, groupRepo.getByIDLiteCalls) +} + func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) { ctx := context.Background() groupID := int64(10) @@ -1146,3 +1187,41 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) { require.Equal(t, 0, groupRepo.getByIDCalls) require.Equal(t, 1, groupRepo.getByIDLiteCalls) } + +func TestGatewayService_ResolveGatewayGroup_DetectsFallbackCycle(t *testing.T) { + ctx := context.Background() + groupID := int64(10) + fallbackID := int64(11) + + group := &Group{ + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + ClaudeCodeOnly: true, + FallbackGroupID: &fallbackID, + } + fallbackGroup := &Group{ + ID: fallbackID, + Platform: PlatformAnthropic, + Status: StatusActive, + ClaudeCodeOnly: true, + FallbackGroupID: &groupID, + } + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: group, + fallbackID: fallbackGroup, + }, + } + + svc := &GatewayService{ + groupRepo: groupRepo, + } + + gotGroup, gotID, err := svc.resolveGatewayGroup(ctx, &groupID) + require.Error(t, err) + require.Nil(t, gotGroup) + require.Nil(t, gotID) + require.Contains(t, err.Error(), "fallback group cycle") +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 27353022..3ab42c3e 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -640,7 +640,7 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { } func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) context.Context { - if group == nil { + if !IsGroupContextValid(group) { return ctx } if existing, ok := ctx.Value(ctxkey.Group).(*Group); ok && existing != nil && existing.ID == group.ID { @@ -650,7 +650,7 @@ func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) con } func (s *GatewayService) groupFromContext(ctx context.Context, groupID int64) *Group { - if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && group != nil && group.ID == groupID { + if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(group) && group.ID == groupID { return group } return nil @@ -673,7 +673,13 @@ func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64 } currentID := *groupID + visited := map[int64]struct{}{} for { + if _, seen := visited[currentID]; seen { + return nil, nil, fmt.Errorf("fallback group cycle detected") + } + visited[currentID] = struct{}{} + group, err := s.resolveGroupByID(ctx, currentID) if err != nil { return nil, nil, err diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index fba50b62..db755a34 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -87,7 +87,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co } else if groupID != nil { // 根据分组 platform 决定查询哪种账号 var group *Group - if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && ctxGroup != nil && ctxGroup.ID == *groupID { + if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID { group = ctxGroup } else { var err error diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 80d89074..e148ea00 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -72,3 +72,17 @@ func (g *Group) GetImagePrice(imageSize string) *float64 { return g.ImagePrice2K } } + +// IsGroupContextValid reports whether a group from context has the fields required for routing decisions. +func IsGroupContextValid(group *Group) bool { + if group == nil { + return false + } + if group.ID <= 0 { + return false + } + if group.Platform == "" || group.Status == "" { + return false + } + return true +}