feat(sync): full code sync from release
This commit is contained in:
@@ -83,13 +83,14 @@ type AdminService interface {
|
||||
|
||||
// CreateUserInput represents input for creating a new user via admin operations.
|
||||
type CreateUserInput struct {
|
||||
Email string
|
||||
Password string
|
||||
Username string
|
||||
Notes string
|
||||
Balance float64
|
||||
Concurrency int
|
||||
AllowedGroups []int64
|
||||
Email string
|
||||
Password string
|
||||
Username string
|
||||
Notes string
|
||||
Balance float64
|
||||
Concurrency int
|
||||
AllowedGroups []int64
|
||||
SoraStorageQuotaBytes int64
|
||||
}
|
||||
|
||||
type UpdateUserInput struct {
|
||||
@@ -103,7 +104,8 @@ type UpdateUserInput struct {
|
||||
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||
GroupRates map[int64]*float64
|
||||
GroupRates map[int64]*float64
|
||||
SoraStorageQuotaBytes *int64
|
||||
}
|
||||
|
||||
type CreateGroupInput struct {
|
||||
@@ -135,6 +137,8 @@ type CreateGroupInput struct {
|
||||
MCPXMLInject *bool
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64
|
||||
// 从指定分组复制账号(创建分组后在同一事务内绑定)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -169,6 +173,8 @@ type UpdateGroupInput struct {
|
||||
MCPXMLInject *bool
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes *[]string
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes *int64
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -402,6 +408,14 @@ type adminServiceImpl struct {
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
}
|
||||
|
||||
type userGroupRateBatchReader interface {
|
||||
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
|
||||
}
|
||||
|
||||
type groupExistenceBatchReader interface {
|
||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||
}
|
||||
|
||||
// NewAdminService creates a new AdminService
|
||||
func NewAdminService(
|
||||
userRepo UserRepository,
|
||||
@@ -442,18 +456,43 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
|
||||
}
|
||||
// 批量加载用户专属分组倍率
|
||||
if s.userGroupRateRepo != nil && len(users) > 0 {
|
||||
for i := range users {
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
|
||||
continue
|
||||
if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok {
|
||||
userIDs := make([]int64, 0, len(users))
|
||||
for i := range users {
|
||||
userIDs = append(userIDs, users[i].ID)
|
||||
}
|
||||
users[i].GroupRates = rates
|
||||
ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err)
|
||||
s.loadUserGroupRatesOneByOne(ctx, users)
|
||||
} else {
|
||||
for i := range users {
|
||||
if rates, ok := ratesByUser[users[i].ID]; ok {
|
||||
users[i].GroupRates = rates
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.loadUserGroupRatesOneByOne(ctx, users)
|
||||
}
|
||||
}
|
||||
return users, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users []User) {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return
|
||||
}
|
||||
for i := range users {
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
|
||||
continue
|
||||
}
|
||||
users[i].GroupRates = rates
|
||||
}
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -473,14 +512,15 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
|
||||
|
||||
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
|
||||
user := &User{
|
||||
Email: input.Email,
|
||||
Username: input.Username,
|
||||
Notes: input.Notes,
|
||||
Role: RoleUser, // Always create as regular user, never admin
|
||||
Balance: input.Balance,
|
||||
Concurrency: input.Concurrency,
|
||||
Status: StatusActive,
|
||||
AllowedGroups: input.AllowedGroups,
|
||||
Email: input.Email,
|
||||
Username: input.Username,
|
||||
Notes: input.Notes,
|
||||
Role: RoleUser, // Always create as regular user, never admin
|
||||
Balance: input.Balance,
|
||||
Concurrency: input.Concurrency,
|
||||
Status: StatusActive,
|
||||
AllowedGroups: input.AllowedGroups,
|
||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
||||
}
|
||||
if err := user.SetPassword(input.Password); err != nil {
|
||||
return nil, err
|
||||
@@ -534,6 +574,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
user.AllowedGroups = *input.AllowedGroups
|
||||
}
|
||||
|
||||
if input.SoraStorageQuotaBytes != nil {
|
||||
user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
|
||||
}
|
||||
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -820,6 +864,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
ModelRouting: input.ModelRouting,
|
||||
MCPXMLInject: mcpXMLInject,
|
||||
SupportedModelScopes: input.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -982,6 +1027,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.SoraVideoPricePerRequestHD != nil {
|
||||
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
|
||||
}
|
||||
if input.SoraStorageQuotaBytes != nil {
|
||||
group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
|
||||
}
|
||||
|
||||
// Claude Code 客户端限制
|
||||
if input.ClaudeCodeOnly != nil {
|
||||
@@ -1188,6 +1236,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
}
|
||||
}
|
||||
|
||||
// Sora apikey 账号的 base_url 必填校验
|
||||
if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey {
|
||||
baseURL, _ := input.Credentials["base_url"].(string)
|
||||
baseURL = strings.TrimSpace(baseURL)
|
||||
if baseURL == "" {
|
||||
return nil, errors.New("sora apikey 账号必须设置 base_url")
|
||||
}
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
|
||||
}
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
Name: input.Name,
|
||||
Notes: normalizeAccountNotes(input.Notes),
|
||||
@@ -1301,12 +1361,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
|
||||
}
|
||||
|
||||
// Sora apikey 账号的 base_url 必填校验
|
||||
if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey {
|
||||
baseURL, _ := account.Credentials["base_url"].(string)
|
||||
baseURL = strings.TrimSpace(baseURL)
|
||||
if baseURL == "" {
|
||||
return nil, errors.New("sora apikey 账号必须设置 base_url")
|
||||
}
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
|
||||
}
|
||||
}
|
||||
|
||||
// 先验证分组是否存在(在任何写操作之前)
|
||||
if input.GroupIDs != nil {
|
||||
for _, groupID := range *input.GroupIDs {
|
||||
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查混合渠道风险(除非用户已确认)
|
||||
@@ -1348,11 +1418,18 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
if len(input.AccountIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
if input.GroupIDs != nil {
|
||||
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
|
||||
|
||||
// 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
|
||||
platformByID := map[int64]string{}
|
||||
groupAccountsByID := map[int64][]Account{}
|
||||
groupNameByID := map[int64]string{}
|
||||
if needMixedChannelCheck {
|
||||
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
|
||||
if err != nil {
|
||||
@@ -1366,6 +1443,13 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loadedAccounts, loadedNames, err := s.preloadMixedChannelRiskData(ctx, *input.GroupIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
groupAccountsByID = loadedAccounts
|
||||
groupNameByID = loadedNames
|
||||
}
|
||||
|
||||
if input.RateMultiplier != nil {
|
||||
@@ -1409,11 +1493,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
// Handle group bindings per account (requires individual operations).
|
||||
for _, accountID := range input.AccountIDs {
|
||||
entry := BulkUpdateAccountResult{AccountID: accountID}
|
||||
platform := ""
|
||||
|
||||
if input.GroupIDs != nil {
|
||||
// 检查混合渠道风险(除非用户已确认)
|
||||
if !input.SkipMixedChannelCheck {
|
||||
platform := platformByID[accountID]
|
||||
platform = platformByID[accountID]
|
||||
if platform == "" {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -1426,7 +1511,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
}
|
||||
platform = account.Platform
|
||||
}
|
||||
if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
|
||||
if err := s.checkMixedChannelRiskWithPreloaded(accountID, platform, *input.GroupIDs, groupAccountsByID, groupNameByID); err != nil {
|
||||
entry.Success = false
|
||||
entry.Error = err.Error()
|
||||
result.Failed++
|
||||
@@ -1444,6 +1529,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
result.Results = append(result.Results, entry)
|
||||
continue
|
||||
}
|
||||
if !input.SkipMixedChannelCheck && platform != "" {
|
||||
updateMixedChannelPreloadedAccounts(groupAccountsByID, *input.GroupIDs, accountID, platform)
|
||||
}
|
||||
}
|
||||
|
||||
entry.Success = true
|
||||
@@ -2115,6 +2203,135 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) preloadMixedChannelRiskData(ctx context.Context, groupIDs []int64) (map[int64][]Account, map[int64]string, error) {
|
||||
accountsByGroup := make(map[int64][]Account)
|
||||
groupNameByID := make(map[int64]string)
|
||||
if len(groupIDs) == 0 {
|
||||
return accountsByGroup, groupNameByID, nil
|
||||
}
|
||||
|
||||
seen := make(map[int64]struct{}, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
if groupID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[groupID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[groupID] = struct{}{}
|
||||
|
||||
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get accounts in group %d: %w", groupID, err)
|
||||
}
|
||||
accountsByGroup[groupID] = accounts
|
||||
|
||||
group, err := s.groupRepo.GetByID(ctx, groupID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if group != nil {
|
||||
groupNameByID[groupID] = group.Name
|
||||
}
|
||||
}
|
||||
|
||||
return accountsByGroup, groupNameByID, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
if s.groupRepo == nil {
|
||||
return errors.New("group repository not configured")
|
||||
}
|
||||
|
||||
if batchReader, ok := s.groupRepo.(groupExistenceBatchReader); ok {
|
||||
existsByID, err := batchReader.ExistsByIDs(ctx, groupIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check groups exists: %w", err)
|
||||
}
|
||||
for _, groupID := range groupIDs {
|
||||
if groupID <= 0 || !existsByID[groupID] {
|
||||
return fmt.Errorf("get group: %w", ErrGroupNotFound)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
|
||||
return fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) checkMixedChannelRiskWithPreloaded(currentAccountID int64, currentAccountPlatform string, groupIDs []int64, accountsByGroup map[int64][]Account, groupNameByID map[int64]string) error {
|
||||
currentPlatform := getAccountPlatform(currentAccountPlatform)
|
||||
if currentPlatform == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
accounts := accountsByGroup[groupID]
|
||||
for _, account := range accounts {
|
||||
if currentAccountID > 0 && account.ID == currentAccountID {
|
||||
continue
|
||||
}
|
||||
|
||||
otherPlatform := getAccountPlatform(account.Platform)
|
||||
if otherPlatform == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if currentPlatform != otherPlatform {
|
||||
groupName := fmt.Sprintf("Group %d", groupID)
|
||||
if name := strings.TrimSpace(groupNameByID[groupID]); name != "" {
|
||||
groupName = name
|
||||
}
|
||||
|
||||
return &MixedChannelError{
|
||||
GroupID: groupID,
|
||||
GroupName: groupName,
|
||||
CurrentPlatform: currentPlatform,
|
||||
OtherPlatform: otherPlatform,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateMixedChannelPreloadedAccounts(accountsByGroup map[int64][]Account, groupIDs []int64, accountID int64, platform string) {
|
||||
if len(groupIDs) == 0 || accountID <= 0 || platform == "" {
|
||||
return
|
||||
}
|
||||
for _, groupID := range groupIDs {
|
||||
if groupID <= 0 {
|
||||
continue
|
||||
}
|
||||
accounts := accountsByGroup[groupID]
|
||||
found := false
|
||||
for i := range accounts {
|
||||
if accounts[i].ID != accountID {
|
||||
continue
|
||||
}
|
||||
accounts[i].Platform = platform
|
||||
found = true
|
||||
break
|
||||
}
|
||||
if !found {
|
||||
accounts = append(accounts, Account{
|
||||
ID: accountID,
|
||||
Platform: platform,
|
||||
})
|
||||
}
|
||||
accountsByGroup[groupID] = accounts
|
||||
}
|
||||
}
|
||||
|
||||
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
|
||||
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
||||
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)
|
||||
|
||||
Reference in New Issue
Block a user