Merge pull request #1406 from DaydreamCoding/feat/group-account-filter
feat(group-filter): 分组账号过滤控制 — require_oauth_only + require_privacy_set
This commit is contained in:
@@ -141,6 +141,21 @@ func (a *Account) IsOAuth() bool {
|
||||
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
|
||||
}
|
||||
|
||||
// IsPrivacySet 检查账号的 privacy 是否已成功设置。
|
||||
// OpenAI: privacy_mode == "training_off"
|
||||
// Antigravity: privacy_mode == "privacy_set"
|
||||
// 其他平台: 无 privacy 概念,始终返回 true
|
||||
func (a *Account) IsPrivacySet() bool {
|
||||
switch a.Platform {
|
||||
case PlatformOpenAI:
|
||||
return a.getExtraString("privacy_mode") == PrivacyModeTrainingOff
|
||||
case PlatformAntigravity:
|
||||
return a.getExtraString("privacy_mode") == AntigravityPrivacySet
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) IsGemini() bool {
|
||||
return a.Platform == PlatformGemini
|
||||
}
|
||||
|
||||
@@ -174,6 +174,19 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
|
||||
return nil, fmt.Errorf("create account: %w", err)
|
||||
}
|
||||
|
||||
// require_oauth_only 检查:apikey 类型账号不可加入限制分组
|
||||
if account.Type == AccountTypeAPIKey && len(req.GroupIDs) > 0 {
|
||||
for _, gid := range req.GroupIDs {
|
||||
g, err := s.groupRepo.GetByID(ctx, gid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) {
|
||||
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 绑定分组
|
||||
if len(req.GroupIDs) > 0 {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, req.GroupIDs); err != nil {
|
||||
@@ -277,6 +290,19 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
|
||||
return nil, fmt.Errorf("update account: %w", err)
|
||||
}
|
||||
|
||||
// require_oauth_only 检查
|
||||
if account.Type == AccountTypeAPIKey && req.GroupIDs != nil {
|
||||
for _, gid := range *req.GroupIDs {
|
||||
g, err := s.groupRepo.GetByID(ctx, gid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) {
|
||||
return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 绑定分组
|
||||
if req.GroupIDs != nil {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil {
|
||||
|
||||
@@ -163,6 +163,8 @@ type CreateGroupInput struct {
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch bool
|
||||
DefaultMappedModel string
|
||||
RequireOAuthOnly bool
|
||||
RequirePrivacySet bool
|
||||
// 从指定分组复制账号(创建分组后在同一事务内绑定)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -202,6 +204,8 @@ type UpdateGroupInput struct {
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch *bool
|
||||
DefaultMappedModel *string
|
||||
RequireOAuthOnly *bool
|
||||
RequirePrivacySet *bool
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -942,12 +946,35 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
SupportedModelScopes: input.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
||||
AllowMessagesDispatch: input.AllowMessagesDispatch,
|
||||
RequireOAuthOnly: input.RequireOAuthOnly,
|
||||
RequirePrivacySet: input.RequirePrivacySet,
|
||||
DefaultMappedModel: input.DefaultMappedModel,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// require_oauth_only: 过滤掉 apikey 类型账号
|
||||
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 {
|
||||
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
|
||||
}
|
||||
oauthIDs := make(map[int64]struct{}, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc.Type != AccountTypeAPIKey {
|
||||
oauthIDs[acc.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
var filtered []int64
|
||||
for _, aid := range accountIDsToCopy {
|
||||
if _, ok := oauthIDs[aid]; ok {
|
||||
filtered = append(filtered, aid)
|
||||
}
|
||||
}
|
||||
accountIDsToCopy = filtered
|
||||
}
|
||||
|
||||
// 如果有需要复制的账号,绑定到新分组
|
||||
if len(accountIDsToCopy) > 0 {
|
||||
if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil {
|
||||
@@ -1155,6 +1182,12 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.AllowMessagesDispatch != nil {
|
||||
group.AllowMessagesDispatch = *input.AllowMessagesDispatch
|
||||
}
|
||||
if input.RequireOAuthOnly != nil {
|
||||
group.RequireOAuthOnly = *input.RequireOAuthOnly
|
||||
}
|
||||
if input.RequirePrivacySet != nil {
|
||||
group.RequirePrivacySet = *input.RequirePrivacySet
|
||||
}
|
||||
if input.DefaultMappedModel != nil {
|
||||
group.DefaultMappedModel = *input.DefaultMappedModel
|
||||
}
|
||||
@@ -1202,6 +1235,27 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
return nil, fmt.Errorf("failed to clear existing account bindings: %w", err)
|
||||
}
|
||||
|
||||
// require_oauth_only: 过滤掉 apikey 类型账号
|
||||
if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 {
|
||||
accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err)
|
||||
}
|
||||
oauthIDs := make(map[int64]struct{}, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc.Type != AccountTypeAPIKey {
|
||||
oauthIDs[acc.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
var filtered []int64
|
||||
for _, aid := range accountIDsToCopy {
|
||||
if _, ok := oauthIDs[aid]; ok {
|
||||
filtered = append(filtered, aid)
|
||||
}
|
||||
}
|
||||
accountIDsToCopy = filtered
|
||||
}
|
||||
|
||||
// 再绑定源分组的账号
|
||||
if len(accountIDsToCopy) > 0 {
|
||||
if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil {
|
||||
|
||||
@@ -3139,7 +3139,7 @@ func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 1, groupRepo.getByIDCalls) // +1 for require_privacy_set check
|
||||
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
@@ -3182,7 +3182,7 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T)
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 1, groupRepo.getByIDCalls) // +1 for require_privacy_set check
|
||||
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
@@ -3252,7 +3252,7 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 1, groupRepo.getByIDCalls) // +1 for require_privacy_set check
|
||||
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
|
||||
@@ -2744,6 +2744,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
preferOAuth := platform == PlatformGemini
|
||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
||||
|
||||
// require_privacy_set: 获取分组信息
|
||||
var schedGroup *Group
|
||||
if groupID != nil && s.groupRepo != nil {
|
||||
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
|
||||
}
|
||||
|
||||
var accounts []Account
|
||||
accountsLoaded := false
|
||||
|
||||
@@ -2815,6 +2821,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -2917,6 +2929,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -2980,6 +2998,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
preferOAuth := nativePlatform == PlatformGemini
|
||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
|
||||
|
||||
// require_privacy_set: 获取分组信息
|
||||
var schedGroup *Group
|
||||
if groupID != nil && s.groupRepo != nil {
|
||||
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
|
||||
}
|
||||
|
||||
var accounts []Account
|
||||
accountsLoaded := false
|
||||
|
||||
@@ -3047,6 +3071,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
@@ -3151,6 +3181,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
|
||||
@@ -59,6 +59,8 @@ type Group struct {
|
||||
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch bool
|
||||
RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini)
|
||||
RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini)
|
||||
DefaultMappedModel string
|
||||
|
||||
CreatedAt time.Time
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"math"
|
||||
"sort"
|
||||
@@ -575,6 +576,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
|
||||
}
|
||||
|
||||
// require_privacy_set: 获取分组信息
|
||||
var schedGroup *Group
|
||||
if req.GroupID != nil && s.service.schedulerSnapshot != nil {
|
||||
schedGroup, _ = s.service.schedulerSnapshot.GetGroupByID(ctx, *req.GroupID)
|
||||
}
|
||||
|
||||
filtered := make([]*Account, 0, len(accounts))
|
||||
loadReq := make([]AccountWithConcurrency, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
@@ -587,6 +594,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
if !account.IsSchedulable() || !account.IsOpenAI() {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !account.IsPrivacySet() {
|
||||
_ = s.service.accountRepo.SetError(ctx, account.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -152,6 +152,14 @@ func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int
|
||||
return s.accountRepo.GetByID(fallbackCtx, accountID)
|
||||
}
|
||||
|
||||
// GetGroupByID 获取分组信息(供调度器使用)
|
||||
func (s *SchedulerSnapshotService) GetGroupByID(ctx context.Context, groupID int64) (*Group, error) {
|
||||
if s.groupRepo == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.groupRepo.GetByID(ctx, groupID)
|
||||
}
|
||||
|
||||
// UpdateAccountInCache 立即更新 Redis 中单个账号的数据(用于模型限流后立即生效)
|
||||
func (s *SchedulerSnapshotService) UpdateAccountInCache(ctx context.Context, account *Account) error {
|
||||
if s.cache == nil || account == nil {
|
||||
|
||||
Reference in New Issue
Block a user