Merge pull request #690 from touwaeriol/pr/bulk-edit-mixed-channel-warning

feat: add mixed-channel warning for bulk account edit
This commit is contained in:
Wesley Liddick
2026-03-01 18:25:05 +08:00
committed by GitHub
8 changed files with 213 additions and 191 deletions

View File

@@ -1122,6 +1122,14 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
})
return
}
response.ErrorFrom(c, err)
return
}

View File

@@ -19,6 +19,7 @@ func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine {
router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel)
router.POST("/api/v1/admin/accounts", accountHandler.Create)
router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update)
router.POST("/api/v1/admin/accounts/bulk-update", accountHandler.BulkUpdate)
return router
}
@@ -145,3 +146,53 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}
func TestAccountHandlerBulkUpdateMixedChannelConflict(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.bulkUpdateAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1, 2, 3},
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "claude-max")
}
func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) {
adminSvc := newStubAdminService()
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1, 2},
"group_ids": []int64{27},
"confirm_mixed_channel_risk": true,
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(2), data["success"])
require.Equal(t, float64(0), data["failed"])
}

View File

@@ -10,22 +10,23 @@ import (
)
type stubAdminService struct {
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64
createAccountErr error
updateAccountErr error
checkMixedErr error
lastMixedCheck struct {
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64
createAccountErr error
updateAccountErr error
bulkUpdateAccountErr error
checkMixedErr error
lastMixedCheck struct {
accountID int64
platform string
groupIDs []int64
@@ -235,7 +236,10 @@ func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64,
}
func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) {
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
if s.bulkUpdateAccountErr != nil {
return nil, s.bulkUpdateAccountErr
}
return &service.BulkUpdateAccountsResult{Success: len(input.AccountIDs), Failed: 0, SuccessIDs: input.AccountIDs}, nil
}
func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {

View File

@@ -1539,30 +1539,31 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
// 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
// 预加载账号平台信息(混合渠道检查需要)。
platformByID := map[int64]string{}
groupAccountsByID := map[int64][]Account{}
groupNameByID := map[int64]string{}
if needMixedChannelCheck {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
if needMixedChannelCheck {
return nil, err
}
} else {
for _, account := range accounts {
if account != nil {
platformByID[account.ID] = account.Platform
}
}
}
loadedAccounts, loadedNames, err := s.preloadMixedChannelRiskData(ctx, *input.GroupIDs)
if err != nil {
return nil, err
}
groupAccountsByID = loadedAccounts
groupNameByID = loadedNames
for _, account := range accounts {
if account != nil {
platformByID[account.ID] = account.Platform
}
}
}
// 预检查混合渠道风险:在任何写操作之前,若发现风险立即返回错误。
if needMixedChannelCheck {
for _, accountID := range input.AccountIDs {
platform := platformByID[accountID]
if platform == "" {
continue
}
if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
return nil, err
}
}
}
if input.RateMultiplier != nil {
@@ -1606,34 +1607,8 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
// Handle group bindings per account (requires individual operations).
for _, accountID := range input.AccountIDs {
entry := BulkUpdateAccountResult{AccountID: accountID}
platform := ""
if input.GroupIDs != nil {
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
platform = platformByID[accountID]
if platform == "" {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.FailedIDs = append(result.FailedIDs, accountID)
result.Results = append(result.Results, entry)
continue
}
platform = account.Platform
}
if err := s.checkMixedChannelRiskWithPreloaded(accountID, platform, *input.GroupIDs, groupAccountsByID, groupNameByID); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.FailedIDs = append(result.FailedIDs, accountID)
result.Results = append(result.Results, entry)
continue
}
}
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
@@ -1642,9 +1617,6 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
result.Results = append(result.Results, entry)
continue
}
if !input.SkipMixedChannelCheck && platform != "" {
updateMixedChannelPreloadedAccounts(groupAccountsByID, *input.GroupIDs, accountID, platform)
}
}
entry.Success = true
@@ -2316,41 +2288,6 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return nil
}
func (s *adminServiceImpl) preloadMixedChannelRiskData(ctx context.Context, groupIDs []int64) (map[int64][]Account, map[int64]string, error) {
accountsByGroup := make(map[int64][]Account)
groupNameByID := make(map[int64]string)
if len(groupIDs) == 0 {
return accountsByGroup, groupNameByID, nil
}
seen := make(map[int64]struct{}, len(groupIDs))
for _, groupID := range groupIDs {
if groupID <= 0 {
continue
}
if _, ok := seen[groupID]; ok {
continue
}
seen[groupID] = struct{}{}
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return nil, nil, fmt.Errorf("get accounts in group %d: %w", groupID, err)
}
accountsByGroup[groupID] = accounts
group, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
continue
}
if group != nil {
groupNameByID[groupID] = group.Name
}
}
return accountsByGroup, groupNameByID, nil
}
func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
@@ -2380,71 +2317,6 @@ func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs [
return nil
}
func (s *adminServiceImpl) checkMixedChannelRiskWithPreloaded(currentAccountID int64, currentAccountPlatform string, groupIDs []int64, accountsByGroup map[int64][]Account, groupNameByID map[int64]string) error {
currentPlatform := getAccountPlatform(currentAccountPlatform)
if currentPlatform == "" {
return nil
}
for _, groupID := range groupIDs {
accounts := accountsByGroup[groupID]
for _, account := range accounts {
if currentAccountID > 0 && account.ID == currentAccountID {
continue
}
otherPlatform := getAccountPlatform(account.Platform)
if otherPlatform == "" {
continue
}
if currentPlatform != otherPlatform {
groupName := fmt.Sprintf("Group %d", groupID)
if name := strings.TrimSpace(groupNameByID[groupID]); name != "" {
groupName = name
}
return &MixedChannelError{
GroupID: groupID,
GroupName: groupName,
CurrentPlatform: currentPlatform,
OtherPlatform: otherPlatform,
}
}
}
}
return nil
}
func updateMixedChannelPreloadedAccounts(accountsByGroup map[int64][]Account, groupIDs []int64, accountID int64, platform string) {
if len(groupIDs) == 0 || accountID <= 0 || platform == "" {
return
}
for _, groupID := range groupIDs {
if groupID <= 0 {
continue
}
accounts := accountsByGroup[groupID]
found := false
for i := range accounts {
if accounts[i].ID != accountID {
continue
}
accounts[i].Platform = platform
found = true
break
}
if !found {
accounts = append(accounts, Account{
ID: accountID,
Platform: platform,
})
}
accountsByGroup[groupID] = accounts
}
}
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)

View File

@@ -139,34 +139,34 @@ func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T)
require.Contains(t, err.Error(), "group repository not configured")
}
func TestAdminService_BulkUpdateAccounts_MixedChannelCheckUsesUpdatedSnapshot(t *testing.T) {
// TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict verifies
// that the global pre-check detects a conflict with existing group members and returns an
// error before any DB write is performed.
func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{
getByIDsAccounts: []*Account{
{ID: 1, Platform: PlatformAnthropic},
{ID: 2, Platform: PlatformAntigravity},
{ID: 1, Platform: PlatformAntigravity},
},
// Group 10 already contains an Anthropic account.
listByGroupData: map[int64][]Account{
10: {},
10: {{ID: 99, Platform: PlatformAnthropic}},
},
}
svc := &adminServiceImpl{
accountRepo: repo,
groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "目标分组"}},
groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "target-group"}},
}
groupIDs := []int64{10}
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1, 2},
AccountIDs: []int64{1},
GroupIDs: &groupIDs,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.NoError(t, err)
require.Equal(t, 1, result.Success)
require.Equal(t, 1, result.Failed)
require.ElementsMatch(t, []int64{1}, result.SuccessIDs)
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
require.Len(t, result.Results, 2)
require.Contains(t, result.Results[1].Error, "mixed channel")
require.Equal(t, []int64{1}, repo.bindGroupsCalls)
require.Nil(t, result)
require.Error(t, err)
require.Contains(t, err.Error(), "mixed channel")
// No BindGroups should have been called since the check runs before any write.
require.Empty(t, repo.bindGroupsCalls)
}