Merge pull request #457 from touwaeriol/pr/group-copy-accounts
feat(groups): 添加从其他分组复制账号功能
This commit is contained in:
@@ -43,6 +43,8 @@ type CreateGroupRequest struct {
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
// 从指定分组复制账号(创建后自动绑定)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
|
||||
// UpdateGroupRequest represents update group request
|
||||
@@ -66,6 +68,8 @@ type UpdateGroupRequest struct {
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
|
||||
// List handles listing all groups with pagination
|
||||
@@ -155,22 +159,23 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -196,23 +201,24 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
|
||||
@@ -425,3 +425,61 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
|
||||
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
|
||||
func (r *groupRepository) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(
|
||||
ctx,
|
||||
"SELECT DISTINCT account_id FROM account_groups WHERE group_id = ANY($1) ORDER BY account_id",
|
||||
pq.Array(groupIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var accountIDs []int64
|
||||
for rows.Next() {
|
||||
var accountID int64
|
||||
if err := rows.Scan(&accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accountIDs = append(accountIDs, accountID)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return accountIDs, nil
|
||||
}
|
||||
|
||||
// BindAccountsToGroup 将多个账号绑定到指定分组(批量插入,忽略已存在的绑定)
|
||||
func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
if len(accountIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用 INSERT ... ON CONFLICT DO NOTHING 忽略已存在的绑定
|
||||
_, err := r.sql.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO account_groups (account_id, group_id, priority, created_at)
|
||||
SELECT unnest($1::bigint[]), $2, 50, NOW()
|
||||
ON CONFLICT (account_id, group_id) DO NOTHING`,
|
||||
pq.Array(accountIDs),
|
||||
groupID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 发送调度器事件
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -880,6 +880,14 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubAccountRepo struct {
|
||||
bulkUpdateIDs []int64
|
||||
}
|
||||
|
||||
@@ -110,6 +110,8 @@ type CreateGroupInput struct {
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool // 是否启用模型路由
|
||||
// 从指定分组复制账号(创建分组后在同一事务内绑定)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
|
||||
type UpdateGroupInput struct {
|
||||
@@ -132,6 +134,8 @@ type UpdateGroupInput struct {
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
|
||||
type CreateAccountInput struct {
|
||||
@@ -572,6 +576,38 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
}
|
||||
}
|
||||
|
||||
// 如果指定了复制账号的源分组,先获取账号 ID 列表
|
||||
var accountIDsToCopy []int64
|
||||
if len(input.CopyAccountsFromGroupIDs) > 0 {
|
||||
// 去重源分组 IDs
|
||||
seen := make(map[int64]struct{})
|
||||
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
|
||||
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
|
||||
if _, exists := seen[srcGroupID]; !exists {
|
||||
seen[srcGroupID] = struct{}{}
|
||||
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
|
||||
}
|
||||
}
|
||||
|
||||
// 校验源分组的平台是否与新分组一致
|
||||
for _, srcGroupID := range uniqueSourceGroupIDs {
|
||||
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
|
||||
}
|
||||
if srcGroup.Platform != platform {
|
||||
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, platform, srcGroup.Platform)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取所有源分组的账号(去重)
|
||||
var err error
|
||||
accountIDsToCopy, err = s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
group := &Group{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
@@ -593,6 +629,15 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果有需要复制的账号,绑定到新分组
|
||||
if len(accountIDsToCopy) > 0 {
|
||||
if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil {
|
||||
return nil, fmt.Errorf("failed to bind accounts to new group: %w", err)
|
||||
}
|
||||
group.AccountCount = int64(len(accountIDsToCopy))
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
@@ -728,6 +773,54 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
|
||||
if len(input.CopyAccountsFromGroupIDs) > 0 {
|
||||
// 去重源分组 IDs
|
||||
seen := make(map[int64]struct{})
|
||||
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
|
||||
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
|
||||
// 校验:源分组不能是自身
|
||||
if srcGroupID == id {
|
||||
return nil, fmt.Errorf("cannot copy accounts from self")
|
||||
}
|
||||
// 去重
|
||||
if _, exists := seen[srcGroupID]; !exists {
|
||||
seen[srcGroupID] = struct{}{}
|
||||
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
|
||||
}
|
||||
}
|
||||
|
||||
// 校验源分组的平台是否与当前分组一致
|
||||
for _, srcGroupID := range uniqueSourceGroupIDs {
|
||||
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
|
||||
}
|
||||
if srcGroup.Platform != group.Platform {
|
||||
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, group.Platform, srcGroup.Platform)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取所有源分组的账号(去重)
|
||||
accountIDsToCopy, err := s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
|
||||
}
|
||||
|
||||
// 先清空当前分组的所有账号绑定
|
||||
if _, err := s.groupRepo.DeleteAccountGroupsByGroupID(ctx, id); err != nil {
|
||||
return nil, fmt.Errorf("failed to clear existing account bindings: %w", err)
|
||||
}
|
||||
|
||||
// 再绑定源分组的账号
|
||||
if len(accountIDsToCopy) > 0 {
|
||||
if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil {
|
||||
return nil, fmt.Errorf("failed to bind accounts to group: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||
}
|
||||
|
||||
@@ -164,6 +164,14 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
type proxyRepoStub struct {
|
||||
deleteErr error
|
||||
countErr error
|
||||
|
||||
@@ -108,6 +108,14 @@ func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context,
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
|
||||
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{}
|
||||
@@ -378,3 +386,11 @@ func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int
|
||||
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
@@ -266,6 +266,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
@@ -218,6 +218,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
|
||||
|
||||
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
|
||||
|
||||
@@ -29,6 +29,10 @@ type GroupRepository interface {
|
||||
ExistsByName(ctx context.Context, name string) (bool, error)
|
||||
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
|
||||
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
|
||||
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
|
||||
// BindAccountsToGroup 将多个账号绑定到指定分组
|
||||
BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error
|
||||
}
|
||||
|
||||
// CreateGroupRequest 创建分组请求
|
||||
|
||||
Reference in New Issue
Block a user