diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index f6082e09..98ead284 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -1122,6 +1122,14 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { SkipMixedChannelCheck: skipCheck, }) if err != nil { + var mixedErr *service.MixedChannelError + if errors.As(err, &mixedErr) { + c.JSON(409, gin.H{ + "error": "mixed_channel_warning", + "message": mixedErr.Error(), + }) + return + } response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/admin/account_handler_mixed_channel_test.go b/backend/internal/handler/admin/account_handler_mixed_channel_test.go index 61b99e03..24ec5bcf 100644 --- a/backend/internal/handler/admin/account_handler_mixed_channel_test.go +++ b/backend/internal/handler/admin/account_handler_mixed_channel_test.go @@ -19,6 +19,7 @@ func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine { router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel) router.POST("/api/v1/admin/accounts", accountHandler.Create) router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update) + router.POST("/api/v1/admin/accounts/bulk-update", accountHandler.BulkUpdate) return router } @@ -145,3 +146,53 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T require.False(t, hasDetails) require.False(t, hasRequireConfirmation) } + +func TestAccountHandlerBulkUpdateMixedChannelConflict(t *testing.T) { + adminSvc := newStubAdminService() + adminSvc.bulkUpdateAccountErr = &service.MixedChannelError{ + GroupID: 27, + GroupName: "claude-max", + CurrentPlatform: "Antigravity", + OtherPlatform: "Anthropic", + } + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1, 2, 3}, + "group_ids": []int64{27}, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusConflict, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "mixed_channel_warning", resp["error"]) + require.Contains(t, resp["message"], "claude-max") +} + +func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) { + adminSvc := newStubAdminService() + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1, 2}, + "group_ids": []int64{27}, + "confirm_mixed_channel_risk": true, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, float64(0), resp["code"]) + data, ok := resp["data"].(map[string]any) + require.True(t, ok) + require.Equal(t, float64(2), data["success"]) + require.Equal(t, float64(0), data["failed"]) +} diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 1d469bd7..f3b99ddb 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -10,22 +10,23 @@ import ( ) type stubAdminService struct { - users []service.User - apiKeys []service.APIKey - groups []service.Group - accounts []service.Account - proxies []service.Proxy - proxyCounts []service.ProxyWithAccountCount - redeems []service.RedeemCode - createdAccounts []*service.CreateAccountInput - createdProxies []*service.CreateProxyInput - updatedProxyIDs []int64 - updatedProxies []*service.UpdateProxyInput - testedProxyIDs []int64 - createAccountErr error - updateAccountErr error - checkMixedErr error - lastMixedCheck struct { + users []service.User + apiKeys []service.APIKey + groups []service.Group + accounts []service.Account + proxies []service.Proxy + proxyCounts []service.ProxyWithAccountCount + redeems []service.RedeemCode + createdAccounts []*service.CreateAccountInput + createdProxies []*service.CreateProxyInput + updatedProxyIDs []int64 + updatedProxies []*service.UpdateProxyInput + testedProxyIDs []int64 + createAccountErr error + updateAccountErr error + bulkUpdateAccountErr error + checkMixedErr error + lastMixedCheck struct { accountID int64 platform string groupIDs []int64 @@ -235,7 +236,10 @@ func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, } func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) { - return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil + if s.bulkUpdateAccountErr != nil { + return nil, s.bulkUpdateAccountErr + } + return &service.BulkUpdateAccountsResult{Success: len(input.AccountIDs), Failed: 0, SuccessIDs: input.AccountIDs}, nil } func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index ee78b6d3..f9995d04 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -1539,30 +1539,31 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck - // 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。 + // 预加载账号平台信息(混合渠道检查需要)。 platformByID := map[int64]string{} - groupAccountsByID := map[int64][]Account{} - groupNameByID := map[int64]string{} if needMixedChannelCheck { accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) - if err != nil { - if needMixedChannelCheck { - return nil, err - } - } else { - for _, account := range accounts { - if account != nil { - platformByID[account.ID] = account.Platform - } - } - } - - loadedAccounts, loadedNames, err := s.preloadMixedChannelRiskData(ctx, *input.GroupIDs) if err != nil { return nil, err } - groupAccountsByID = loadedAccounts - groupNameByID = loadedNames + for _, account := range accounts { + if account != nil { + platformByID[account.ID] = account.Platform + } + } + } + + // 预检查混合渠道风险:在任何写操作之前,若发现风险立即返回错误。 + if needMixedChannelCheck { + for _, accountID := range input.AccountIDs { + platform := platformByID[accountID] + if platform == "" { + continue + } + if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil { + return nil, err + } + } } if input.RateMultiplier != nil { @@ -1606,34 +1607,8 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp // Handle group bindings per account (requires individual operations). for _, accountID := range input.AccountIDs { entry := BulkUpdateAccountResult{AccountID: accountID} - platform := "" if input.GroupIDs != nil { - // 检查混合渠道风险(除非用户已确认) - if !input.SkipMixedChannelCheck { - platform = platformByID[accountID] - if platform == "" { - account, err := s.accountRepo.GetByID(ctx, accountID) - if err != nil { - entry.Success = false - entry.Error = err.Error() - result.Failed++ - result.FailedIDs = append(result.FailedIDs, accountID) - result.Results = append(result.Results, entry) - continue - } - platform = account.Platform - } - if err := s.checkMixedChannelRiskWithPreloaded(accountID, platform, *input.GroupIDs, groupAccountsByID, groupNameByID); err != nil { - entry.Success = false - entry.Error = err.Error() - result.Failed++ - result.FailedIDs = append(result.FailedIDs, accountID) - result.Results = append(result.Results, entry) - continue - } - } - if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil { entry.Success = false entry.Error = err.Error() @@ -1642,9 +1617,6 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp result.Results = append(result.Results, entry) continue } - if !input.SkipMixedChannelCheck && platform != "" { - updateMixedChannelPreloadedAccounts(groupAccountsByID, *input.GroupIDs, accountID, platform) - } } entry.Success = true @@ -2316,41 +2288,6 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc return nil } -func (s *adminServiceImpl) preloadMixedChannelRiskData(ctx context.Context, groupIDs []int64) (map[int64][]Account, map[int64]string, error) { - accountsByGroup := make(map[int64][]Account) - groupNameByID := make(map[int64]string) - if len(groupIDs) == 0 { - return accountsByGroup, groupNameByID, nil - } - - seen := make(map[int64]struct{}, len(groupIDs)) - for _, groupID := range groupIDs { - if groupID <= 0 { - continue - } - if _, ok := seen[groupID]; ok { - continue - } - seen[groupID] = struct{}{} - - accounts, err := s.accountRepo.ListByGroup(ctx, groupID) - if err != nil { - return nil, nil, fmt.Errorf("get accounts in group %d: %w", groupID, err) - } - accountsByGroup[groupID] = accounts - - group, err := s.groupRepo.GetByID(ctx, groupID) - if err != nil { - continue - } - if group != nil { - groupNameByID[groupID] = group.Name - } - } - - return accountsByGroup, groupNameByID, nil -} - func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error { if len(groupIDs) == 0 { return nil @@ -2380,71 +2317,6 @@ func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs [ return nil } -func (s *adminServiceImpl) checkMixedChannelRiskWithPreloaded(currentAccountID int64, currentAccountPlatform string, groupIDs []int64, accountsByGroup map[int64][]Account, groupNameByID map[int64]string) error { - currentPlatform := getAccountPlatform(currentAccountPlatform) - if currentPlatform == "" { - return nil - } - - for _, groupID := range groupIDs { - accounts := accountsByGroup[groupID] - for _, account := range accounts { - if currentAccountID > 0 && account.ID == currentAccountID { - continue - } - - otherPlatform := getAccountPlatform(account.Platform) - if otherPlatform == "" { - continue - } - - if currentPlatform != otherPlatform { - groupName := fmt.Sprintf("Group %d", groupID) - if name := strings.TrimSpace(groupNameByID[groupID]); name != "" { - groupName = name - } - - return &MixedChannelError{ - GroupID: groupID, - GroupName: groupName, - CurrentPlatform: currentPlatform, - OtherPlatform: otherPlatform, - } - } - } - } - - return nil -} - -func updateMixedChannelPreloadedAccounts(accountsByGroup map[int64][]Account, groupIDs []int64, accountID int64, platform string) { - if len(groupIDs) == 0 || accountID <= 0 || platform == "" { - return - } - for _, groupID := range groupIDs { - if groupID <= 0 { - continue - } - accounts := accountsByGroup[groupID] - found := false - for i := range accounts { - if accounts[i].ID != accountID { - continue - } - accounts[i].Platform = platform - found = true - break - } - if !found { - accounts = append(accounts, Account{ - ID: accountID, - Platform: platform, - }) - } - accountsByGroup[groupID] = accounts - } -} - // CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform. func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs) diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go index 647a84a9..4845d87c 100644 --- a/backend/internal/service/admin_service_bulk_update_test.go +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -139,34 +139,34 @@ func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T) require.Contains(t, err.Error(), "group repository not configured") } -func TestAdminService_BulkUpdateAccounts_MixedChannelCheckUsesUpdatedSnapshot(t *testing.T) { +// TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict verifies +// that the global pre-check detects a conflict with existing group members and returns an +// error before any DB write is performed. +func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict(t *testing.T) { repo := &accountRepoStubForBulkUpdate{ getByIDsAccounts: []*Account{ - {ID: 1, Platform: PlatformAnthropic}, - {ID: 2, Platform: PlatformAntigravity}, + {ID: 1, Platform: PlatformAntigravity}, }, + // Group 10 already contains an Anthropic account. listByGroupData: map[int64][]Account{ - 10: {}, + 10: {{ID: 99, Platform: PlatformAnthropic}}, }, } svc := &adminServiceImpl{ accountRepo: repo, - groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "目标分组"}}, + groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "target-group"}}, } groupIDs := []int64{10} input := &BulkUpdateAccountsInput{ - AccountIDs: []int64{1, 2}, + AccountIDs: []int64{1}, GroupIDs: &groupIDs, } result, err := svc.BulkUpdateAccounts(context.Background(), input) - require.NoError(t, err) - require.Equal(t, 1, result.Success) - require.Equal(t, 1, result.Failed) - require.ElementsMatch(t, []int64{1}, result.SuccessIDs) - require.ElementsMatch(t, []int64{2}, result.FailedIDs) - require.Len(t, result.Results, 2) - require.Contains(t, result.Results[1].Error, "mixed channel") - require.Equal(t, []int64{1}, repo.bindGroupsCalls) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "mixed channel") + // No BindGroups should have been called since the check runs before any write. + require.Empty(t, repo.bindGroupsCalls) } diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 22db5a44..95f9ff31 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -267,6 +267,7 @@ apiClient.interceptors.response.use( return Promise.reject({ status, code: apiData.code, + error: apiData.error, message: apiData.message || apiData.detail || error.message }) } diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index ae16ff1a..30c3d739 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -756,6 +756,17 @@ + +