Merge pull request #1406 from DaydreamCoding/feat/group-account-filter

feat(group-filter): 分组账号过滤控制 — require_oauth_only + require_privacy_set
This commit is contained in:
Wesley Liddick
2026-03-31 14:01:05 +08:00
committed by GitHub
26 changed files with 708 additions and 6 deletions

View File

@@ -141,6 +141,21 @@ func (a *Account) IsOAuth() bool {
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
}
// IsPrivacySet 检查账号的 privacy 是否已成功设置。
// OpenAI: privacy_mode == "training_off"
// Antigravity: privacy_mode == "privacy_set"
// 其他平台: 无 privacy 概念,始终返回 true
func (a *Account) IsPrivacySet() bool {
switch a.Platform {
case PlatformOpenAI:
return a.getExtraString("privacy_mode") == PrivacyModeTrainingOff
case PlatformAntigravity:
return a.getExtraString("privacy_mode") == AntigravityPrivacySet
default:
return true
}
}
func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini
}

View File

@@ -174,6 +174,19 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
return nil, fmt.Errorf("create account: %w", err)
}
// require_oauth_only 检查apikey 类型账号不可加入限制分组
if account.Type == AccountTypeAPIKey && len(req.GroupIDs) > 0 {
for _, gid := range req.GroupIDs {
g, err := s.groupRepo.GetByID(ctx, gid)
if err != nil {
return nil, err
}
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) {
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号apikey 类型账号无法加入", g.Name)
}
}
}
// 绑定分组
if len(req.GroupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, req.GroupIDs); err != nil {
@@ -277,6 +290,19 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
return nil, fmt.Errorf("update account: %w", err)
}
// require_oauth_only 检查
if account.Type == AccountTypeAPIKey && req.GroupIDs != nil {
for _, gid := range *req.GroupIDs {
g, err := s.groupRepo.GetByID(ctx, gid)
if err != nil {
return nil, err
}
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) {
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号apikey 类型账号无法加入", g.Name)
}
}
}
// 绑定分组
if req.GroupIDs != nil {
if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil {

View File

@@ -163,6 +163,8 @@ type CreateGroupInput struct {
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool
DefaultMappedModel string
RequireOAuthOnly bool
RequirePrivacySet bool
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
@@ -202,6 +204,8 @@ type UpdateGroupInput struct {
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool
DefaultMappedModel *string
RequireOAuthOnly *bool
RequirePrivacySet *bool
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
@@ -942,12 +946,35 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
SupportedModelScopes: input.SupportedModelScopes,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
AllowMessagesDispatch: input.AllowMessagesDispatch,
RequireOAuthOnly: input.RequireOAuthOnly,
RequirePrivacySet: input.RequirePrivacySet,
DefaultMappedModel: input.DefaultMappedModel,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
}
// require_oauth_only: 过滤掉 apikey 类型账号
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 {
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
if err != nil {
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
}
oauthIDs := make(map[int64]struct{}, len(accounts))
for _, acc := range accounts {
if acc.Type != AccountTypeAPIKey {
oauthIDs[acc.ID] = struct{}{}
}
}
var filtered []int64
for _, aid := range accountIDsToCopy {
if _, ok := oauthIDs[aid]; ok {
filtered = append(filtered, aid)
}
}
accountIDsToCopy = filtered
}
// 如果有需要复制的账号,绑定到新分组
if len(accountIDsToCopy) > 0 {
if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil {
@@ -1155,6 +1182,12 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.AllowMessagesDispatch != nil {
group.AllowMessagesDispatch = *input.AllowMessagesDispatch
}
if input.RequireOAuthOnly != nil {
group.RequireOAuthOnly = *input.RequireOAuthOnly
}
if input.RequirePrivacySet != nil {
group.RequirePrivacySet = *input.RequirePrivacySet
}
if input.DefaultMappedModel != nil {
group.DefaultMappedModel = *input.DefaultMappedModel
}
@@ -1202,6 +1235,27 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
return nil, fmt.Errorf("failed to clear existing account bindings: %w", err)
}
// require_oauth_only: 过滤掉 apikey 类型账号
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 {
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
if err != nil {
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
}
oauthIDs := make(map[int64]struct{}, len(accounts))
for _, acc := range accounts {
if acc.Type != AccountTypeAPIKey {
oauthIDs[acc.ID] = struct{}{}
}
}
var filtered []int64
for _, aid := range accountIDsToCopy {
if _, ok := oauthIDs[aid]; ok {
filtered = append(filtered, aid)
}
}
accountIDsToCopy = filtered
}
// 再绑定源分组的账号
if len(accountIDsToCopy) > 0 {
if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil {

View File

@@ -3139,7 +3139,7 @@ func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 1, groupRepo.getByIDCalls) // +1 for require_privacy_set check
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
}
@@ -3182,7 +3182,7 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T)
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 1, groupRepo.getByIDCalls) // +1 for require_privacy_set check
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
}
@@ -3252,7 +3252,7 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 1, groupRepo.getByIDCalls) // +1 for require_privacy_set check
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
}

View File

@@ -2744,6 +2744,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
// require_privacy_set: 获取分组信息
var schedGroup *Group
if groupID != nil && s.groupRepo != nil {
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
}
var accounts []Account
accountsLoaded := false
@@ -2815,6 +2821,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForSelection(acc) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
@@ -2917,6 +2929,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForSelection(acc) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
@@ -2980,6 +2998,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
preferOAuth := nativePlatform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
// require_privacy_set: 获取分组信息
var schedGroup *Group
if groupID != nil && s.groupRepo != nil {
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
}
var accounts []Account
accountsLoaded := false
@@ -3047,6 +3071,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForSelection(acc) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
// 过滤原生平台直接通过antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
@@ -3151,6 +3181,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForSelection(acc) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
// 过滤原生平台直接通过antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue

View File

@@ -59,6 +59,8 @@ type Group struct {
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool
RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联OpenAI/Antigravity/Anthropic/Gemini
RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号OpenAI/Antigravity/Anthropic/Gemini
DefaultMappedModel string
CreatedAt time.Time

View File

@@ -4,6 +4,7 @@ import (
"container/heap"
"context"
"errors"
"fmt"
"hash/fnv"
"math"
"sort"
@@ -575,6 +576,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
}
// require_privacy_set: 获取分组信息
var schedGroup *Group
if req.GroupID != nil && s.service.schedulerSnapshot != nil {
schedGroup, _ = s.service.schedulerSnapshot.GetGroupByID(ctx, *req.GroupID)
}
filtered := make([]*Account, 0, len(accounts))
loadReq := make([]AccountWithConcurrency, 0, len(accounts))
for i := range accounts {
@@ -587,6 +594,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
if !account.IsSchedulable() || !account.IsOpenAI() {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !account.IsPrivacySet() {
_ = s.service.accountRepo.SetError(ctx, account.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
continue
}

View File

@@ -152,6 +152,14 @@ func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int
return s.accountRepo.GetByID(fallbackCtx, accountID)
}
// GetGroupByID 获取分组信息(供调度器使用)
func (s *SchedulerSnapshotService) GetGroupByID(ctx context.Context, groupID int64) (*Group, error) {
if s.groupRepo == nil {
return nil, nil
}
return s.groupRepo.GetByID(ctx, groupID)
}
// UpdateAccountInCache 立即更新 Redis 中单个账号的数据(用于模型限流后立即生效)
func (s *SchedulerSnapshotService) UpdateAccountInCache(ctx context.Context, account *Account) error {
if s.cache == nil || account == nil {