feat(sync): full code sync from release

This commit is contained in:
yangjianbo
2026-02-28 15:01:20 +08:00
parent bfc7b339f7
commit bb664d9bbf
338 changed files with 54513 additions and 2011 deletions

View File

@@ -83,13 +83,14 @@ type AdminService interface {
// CreateUserInput represents input for creating a new user via admin operations.
type CreateUserInput struct {
Email string
Password string
Username string
Notes string
Balance float64
Concurrency int
AllowedGroups []int64
Email string
Password string
Username string
Notes string
Balance float64
Concurrency int
AllowedGroups []int64
SoraStorageQuotaBytes int64
}
type UpdateUserInput struct {
@@ -103,7 +104,8 @@ type UpdateUserInput struct {
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
// map[groupID]*ratenil 表示删除该分组的专属倍率
GroupRates map[int64]*float64
GroupRates map[int64]*float64
SoraStorageQuotaBytes *int64
}
type CreateGroupInput struct {
@@ -135,6 +137,8 @@ type CreateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
// Sora 存储配额
SoraStorageQuotaBytes int64
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
@@ -169,6 +173,8 @@ type UpdateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
// Sora 存储配额
SoraStorageQuotaBytes *int64
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
@@ -402,6 +408,14 @@ type adminServiceImpl struct {
authCacheInvalidator APIKeyAuthCacheInvalidator
}
type userGroupRateBatchReader interface {
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
}
type groupExistenceBatchReader interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
// NewAdminService creates a new AdminService
func NewAdminService(
userRepo UserRepository,
@@ -442,18 +456,43 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
}
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
continue
if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok {
userIDs := make([]int64, 0, len(users))
for i := range users {
userIDs = append(userIDs, users[i].ID)
}
users[i].GroupRates = rates
ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs)
if err != nil {
logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err)
s.loadUserGroupRatesOneByOne(ctx, users)
} else {
for i := range users {
if rates, ok := ratesByUser[users[i].ID]; ok {
users[i].GroupRates = rates
}
}
}
} else {
s.loadUserGroupRatesOneByOne(ctx, users)
}
}
return users, result.Total, nil
}
func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users []User) {
if s.userGroupRateRepo == nil {
return
}
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
continue
}
users[i].GroupRates = rates
}
}
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
@@ -473,14 +512,15 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
user := &User{
Email: input.Email,
Username: input.Username,
Notes: input.Notes,
Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
Email: input.Email,
Username: input.Username,
Notes: input.Notes,
Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
}
if err := user.SetPassword(input.Password); err != nil {
return nil, err
@@ -534,6 +574,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.AllowedGroups = *input.AllowedGroups
}
if input.SoraStorageQuotaBytes != nil {
user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
@@ -820,6 +864,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ModelRouting: input.ModelRouting,
MCPXMLInject: mcpXMLInject,
SupportedModelScopes: input.SupportedModelScopes,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
@@ -982,6 +1027,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.SoraVideoPricePerRequestHD != nil {
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
}
if input.SoraStorageQuotaBytes != nil {
group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
}
// Claude Code 客户端限制
if input.ClaudeCodeOnly != nil {
@@ -1188,6 +1236,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
}
// Sora apikey 账号的 base_url 必填校验
if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey {
baseURL, _ := input.Credentials["base_url"].(string)
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, errors.New("sora apikey 账号必须设置 base_url")
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
account := &Account{
Name: input.Name,
Notes: normalizeAccountNotes(input.Notes),
@@ -1301,12 +1361,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
}
// Sora apikey 账号的 base_url 必填校验
if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey {
baseURL, _ := account.Credentials["base_url"].(string)
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, errors.New("sora apikey 账号必须设置 base_url")
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
// 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil {
for _, groupID := range *input.GroupIDs {
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
return nil, err
}
// 检查混合渠道风险(除非用户已确认)
@@ -1348,11 +1418,18 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if len(input.AccountIDs) == 0 {
return result, nil
}
if input.GroupIDs != nil {
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
return nil, err
}
}
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
// 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
platformByID := map[int64]string{}
groupAccountsByID := map[int64][]Account{}
groupNameByID := map[int64]string{}
if needMixedChannelCheck {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
@@ -1366,6 +1443,13 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
}
}
loadedAccounts, loadedNames, err := s.preloadMixedChannelRiskData(ctx, *input.GroupIDs)
if err != nil {
return nil, err
}
groupAccountsByID = loadedAccounts
groupNameByID = loadedNames
}
if input.RateMultiplier != nil {
@@ -1409,11 +1493,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
// Handle group bindings per account (requires individual operations).
for _, accountID := range input.AccountIDs {
entry := BulkUpdateAccountResult{AccountID: accountID}
platform := ""
if input.GroupIDs != nil {
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
platform := platformByID[accountID]
platform = platformByID[accountID]
if platform == "" {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
@@ -1426,7 +1511,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
platform = account.Platform
}
if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
if err := s.checkMixedChannelRiskWithPreloaded(accountID, platform, *input.GroupIDs, groupAccountsByID, groupNameByID); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
@@ -1444,6 +1529,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
result.Results = append(result.Results, entry)
continue
}
if !input.SkipMixedChannelCheck && platform != "" {
updateMixedChannelPreloadedAccounts(groupAccountsByID, *input.GroupIDs, accountID, platform)
}
}
entry.Success = true
@@ -2115,6 +2203,135 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return nil
}
func (s *adminServiceImpl) preloadMixedChannelRiskData(ctx context.Context, groupIDs []int64) (map[int64][]Account, map[int64]string, error) {
accountsByGroup := make(map[int64][]Account)
groupNameByID := make(map[int64]string)
if len(groupIDs) == 0 {
return accountsByGroup, groupNameByID, nil
}
seen := make(map[int64]struct{}, len(groupIDs))
for _, groupID := range groupIDs {
if groupID <= 0 {
continue
}
if _, ok := seen[groupID]; ok {
continue
}
seen[groupID] = struct{}{}
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return nil, nil, fmt.Errorf("get accounts in group %d: %w", groupID, err)
}
accountsByGroup[groupID] = accounts
group, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
continue
}
if group != nil {
groupNameByID[groupID] = group.Name
}
}
return accountsByGroup, groupNameByID, nil
}
func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
}
if s.groupRepo == nil {
return errors.New("group repository not configured")
}
if batchReader, ok := s.groupRepo.(groupExistenceBatchReader); ok {
existsByID, err := batchReader.ExistsByIDs(ctx, groupIDs)
if err != nil {
return fmt.Errorf("check groups exists: %w", err)
}
for _, groupID := range groupIDs {
if groupID <= 0 || !existsByID[groupID] {
return fmt.Errorf("get group: %w", ErrGroupNotFound)
}
}
return nil
}
for _, groupID := range groupIDs {
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
return fmt.Errorf("get group: %w", err)
}
}
return nil
}
func (s *adminServiceImpl) checkMixedChannelRiskWithPreloaded(currentAccountID int64, currentAccountPlatform string, groupIDs []int64, accountsByGroup map[int64][]Account, groupNameByID map[int64]string) error {
currentPlatform := getAccountPlatform(currentAccountPlatform)
if currentPlatform == "" {
return nil
}
for _, groupID := range groupIDs {
accounts := accountsByGroup[groupID]
for _, account := range accounts {
if currentAccountID > 0 && account.ID == currentAccountID {
continue
}
otherPlatform := getAccountPlatform(account.Platform)
if otherPlatform == "" {
continue
}
if currentPlatform != otherPlatform {
groupName := fmt.Sprintf("Group %d", groupID)
if name := strings.TrimSpace(groupNameByID[groupID]); name != "" {
groupName = name
}
return &MixedChannelError{
GroupID: groupID,
GroupName: groupName,
CurrentPlatform: currentPlatform,
OtherPlatform: otherPlatform,
}
}
}
}
return nil
}
func updateMixedChannelPreloadedAccounts(accountsByGroup map[int64][]Account, groupIDs []int64, accountID int64, platform string) {
if len(groupIDs) == 0 || accountID <= 0 || platform == "" {
return
}
for _, groupID := range groupIDs {
if groupID <= 0 {
continue
}
accounts := accountsByGroup[groupID]
found := false
for i := range accounts {
if accounts[i].ID != accountID {
continue
}
accounts[i].Platform = platform
found = true
break
}
if !found {
accounts = append(accounts, Account{
ID: accountID,
Platform: platform,
})
}
accountsByGroup[groupID] = accounts
}
}
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)