package service import ( "context" "errors" "fmt" "log" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) // AdminService interface defines admin management operations type AdminService interface { // User management ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) GetUser(ctx context.Context, id int64) (*User, error) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) DeleteUser(ctx context.Context, id int64) error UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) // Group management ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) GetAllGroups(ctx context.Context) ([]Group, error) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) GetGroup(ctx context.Context, id int64) (*Group, error) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) DeleteGroup(ctx context.Context, id int64) error GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) // Account management ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) GetAccount(ctx context.Context, id int64) (*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) DeleteAccount(ctx context.Context, id int64) error RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) ClearAccountError(ctx context.Context, id int64) (*Account, error) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) // Proxy management ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) GetAllProxies(ctx context.Context) ([]Proxy, error) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) GetProxy(ctx context.Context, id int64) (*Proxy, error) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) DeleteProxy(ctx context.Context, id int64) error BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error) GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) // Redeem code management ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) DeleteRedeemCode(ctx context.Context, id int64) error BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) } // CreateUserInput represents input for creating a new user via admin operations. type CreateUserInput struct { Email string Password string Username string Notes string Balance float64 Concurrency int AllowedGroups []int64 } type UpdateUserInput struct { Email string Password string Username *string Notes *string Balance *float64 // 使用指针区分"未提供"和"设置为0" Concurrency *int // 使用指针区分"未提供"和"设置为0" Status string AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" } type CreateGroupInput struct { Name string Description string Platform string RateMultiplier float64 IsExclusive bool SubscriptionType string // standard/subscription DailyLimitUSD *float64 // 日限额 (USD) WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) ImagePrice1K *float64 ImagePrice2K *float64 ImagePrice4K *float64 ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 FallbackGroupID *int64 // 降级分组 ID } type UpdateGroupInput struct { Name string Description string Platform string RateMultiplier *float64 // 使用指针以支持设置为0 IsExclusive *bool Status string SubscriptionType string // standard/subscription DailyLimitUSD *float64 // 日限额 (USD) WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) ImagePrice1K *float64 ImagePrice2K *float64 ImagePrice4K *float64 ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 FallbackGroupID *int64 // 降级分组 ID } type CreateAccountInput struct { Name string Notes *string Platform string Type string Credentials map[string]any Extra map[string]any ProxyID *int64 Concurrency int Priority int RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) GroupIDs []int64 ExpiresAt *int64 AutoPauseOnExpired *bool // SkipMixedChannelCheck skips the mixed channel risk check when binding groups. // This should only be set when the caller has explicitly confirmed the risk. SkipMixedChannelCheck bool } type UpdateAccountInput struct { Name string Notes *string Type string // Account type: oauth, setup-token, apikey Credentials map[string]any Extra map[string]any ProxyID *int64 Concurrency *int // 使用指针区分"未提供"和"设置为0" Priority *int // 使用指针区分"未提供"和"设置为0" RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) Status string GroupIDs *[]int64 ExpiresAt *int64 AutoPauseOnExpired *bool SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险) } // BulkUpdateAccountsInput describes the payload for bulk updating accounts. type BulkUpdateAccountsInput struct { AccountIDs []int64 Name string ProxyID *int64 Concurrency *int Priority *int RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) Status string Schedulable *bool GroupIDs *[]int64 Credentials map[string]any Extra map[string]any // SkipMixedChannelCheck skips the mixed channel risk check when binding groups. // This should only be set when the caller has explicitly confirmed the risk. SkipMixedChannelCheck bool } // BulkUpdateAccountResult captures the result for a single account update. type BulkUpdateAccountResult struct { AccountID int64 `json:"account_id"` Success bool `json:"success"` Error string `json:"error,omitempty"` } // BulkUpdateAccountsResult is the aggregated response for bulk updates. type BulkUpdateAccountsResult struct { Success int `json:"success"` Failed int `json:"failed"` SuccessIDs []int64 `json:"success_ids"` FailedIDs []int64 `json:"failed_ids"` Results []BulkUpdateAccountResult `json:"results"` } type CreateProxyInput struct { Name string Protocol string Host string Port int Username string Password string } type UpdateProxyInput struct { Name string Protocol string Host string Port int Username string Password string Status string } type GenerateRedeemCodesInput struct { Count int Type string Value float64 GroupID *int64 // 订阅类型专用:关联的分组ID ValidityDays int // 订阅类型专用:有效天数 } type ProxyBatchDeleteResult struct { DeletedIDs []int64 `json:"deleted_ids"` Skipped []ProxyBatchDeleteSkipped `json:"skipped"` } type ProxyBatchDeleteSkipped struct { ID int64 `json:"id"` Reason string `json:"reason"` } // ProxyTestResult represents the result of testing a proxy type ProxyTestResult struct { Success bool `json:"success"` Message string `json:"message"` LatencyMs int64 `json:"latency_ms,omitempty"` IPAddress string `json:"ip_address,omitempty"` City string `json:"city,omitempty"` Region string `json:"region,omitempty"` Country string `json:"country,omitempty"` CountryCode string `json:"country_code,omitempty"` } // ProxyExitInfo represents proxy exit information from ip-api.com type ProxyExitInfo struct { IP string City string Region string Country string CountryCode string } // ProxyExitInfoProber tests proxy connectivity and retrieves exit information type ProxyExitInfoProber interface { ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error) } // adminServiceImpl implements AdminService type adminServiceImpl struct { userRepo UserRepository groupRepo GroupRepository accountRepo AccountRepository proxyRepo ProxyRepository apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository billingCacheService *BillingCacheService proxyProber ProxyExitInfoProber proxyLatencyCache ProxyLatencyCache authCacheInvalidator APIKeyAuthCacheInvalidator } // NewAdminService creates a new AdminService func NewAdminService( userRepo UserRepository, groupRepo GroupRepository, accountRepo AccountRepository, proxyRepo ProxyRepository, apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, billingCacheService *BillingCacheService, proxyProber ProxyExitInfoProber, proxyLatencyCache ProxyLatencyCache, authCacheInvalidator APIKeyAuthCacheInvalidator, ) AdminService { return &adminServiceImpl{ userRepo: userRepo, groupRepo: groupRepo, accountRepo: accountRepo, proxyRepo: proxyRepo, apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, billingCacheService: billingCacheService, proxyProber: proxyProber, proxyLatencyCache: proxyLatencyCache, authCacheInvalidator: authCacheInvalidator, } } // User management implementations func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} users, result, err := s.userRepo.ListWithFilters(ctx, params, filters) if err != nil { return nil, 0, err } return users, result.Total, nil } func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { return s.userRepo.GetByID(ctx, id) } func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { user := &User{ Email: input.Email, Username: input.Username, Notes: input.Notes, Role: RoleUser, // Always create as regular user, never admin Balance: input.Balance, Concurrency: input.Concurrency, Status: StatusActive, AllowedGroups: input.AllowedGroups, } if err := user.SetPassword(input.Password); err != nil { return nil, err } if err := s.userRepo.Create(ctx, user); err != nil { return nil, err } return user, nil } func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { user, err := s.userRepo.GetByID(ctx, id) if err != nil { return nil, err } // Protect admin users: cannot disable admin accounts if user.Role == "admin" && input.Status == "disabled" { return nil, errors.New("cannot disable admin user") } oldConcurrency := user.Concurrency oldStatus := user.Status oldRole := user.Role if input.Email != "" { user.Email = input.Email } if input.Password != "" { if err := user.SetPassword(input.Password); err != nil { return nil, err } } if input.Username != nil { user.Username = *input.Username } if input.Notes != nil { user.Notes = *input.Notes } if input.Status != "" { user.Status = input.Status } if input.Concurrency != nil { user.Concurrency = *input.Concurrency } if input.AllowedGroups != nil { user.AllowedGroups = *input.AllowedGroups } if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } if s.authCacheInvalidator != nil { if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) } } concurrencyDiff := user.Concurrency - oldConcurrency if concurrencyDiff != 0 { code, err := GenerateRedeemCode() if err != nil { log.Printf("failed to generate adjustment redeem code: %v", err) return user, nil } adjustmentRecord := &RedeemCode{ Code: code, Type: AdjustmentTypeAdminConcurrency, Value: float64(concurrencyDiff), Status: StatusUsed, UsedBy: &user.ID, } now := time.Now() adjustmentRecord.UsedAt = &now if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { log.Printf("failed to create concurrency adjustment redeem code: %v", err) } } return user, nil } func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { // Protect admin users: cannot delete admin accounts user, err := s.userRepo.GetByID(ctx, id) if err != nil { return err } if user.Role == "admin" { return errors.New("cannot delete admin user") } if err := s.userRepo.Delete(ctx, id); err != nil { log.Printf("delete user failed: user_id=%d err=%v", id, err) return err } if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id) } return nil } func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return nil, err } oldBalance := user.Balance switch operation { case "set": user.Balance = balance case "add": user.Balance += balance case "subtract": user.Balance -= balance } if user.Balance < 0 { return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance) } if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } balanceDiff := user.Balance - oldBalance if s.authCacheInvalidator != nil && balanceDiff != 0 { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } if s.billingCacheService != nil { go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil { log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err) } }() } if balanceDiff != 0 { code, err := GenerateRedeemCode() if err != nil { log.Printf("failed to generate adjustment redeem code: %v", err) return user, nil } adjustmentRecord := &RedeemCode{ Code: code, Type: AdjustmentTypeAdminBalance, Value: balanceDiff, Status: StatusUsed, UsedBy: &user.ID, Notes: notes, } now := time.Now() adjustmentRecord.UsedAt = &now if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { log.Printf("failed to create balance adjustment redeem code: %v", err) } } return user, nil } func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) if err != nil { return nil, 0, err } return keys, result.Total, nil } func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) { // Return mock data for now return map[string]any{ "period": period, "total_requests": 0, "total_cost": 0.0, "total_tokens": 0, "avg_duration_ms": 0, }, nil } // Group management implementations func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive) if err != nil { return nil, 0, err } return groups, result.Total, nil } func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]Group, error) { return s.groupRepo.ListActive(ctx) } func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) { return s.groupRepo.ListActiveByPlatform(ctx, platform) } func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, error) { return s.groupRepo.GetByID(ctx, id) } func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) { platform := input.Platform if platform == "" { platform = PlatformAnthropic } subscriptionType := input.SubscriptionType if subscriptionType == "" { subscriptionType = SubscriptionTypeStandard } // 限额字段:0 和 nil 都表示"无限制" dailyLimit := normalizeLimit(input.DailyLimitUSD) weeklyLimit := normalizeLimit(input.WeeklyLimitUSD) monthlyLimit := normalizeLimit(input.MonthlyLimitUSD) // 图片价格:负数表示清除(使用默认价格),0 保留(表示免费) imagePrice1K := normalizePrice(input.ImagePrice1K) imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice4K := normalizePrice(input.ImagePrice4K) // 校验降级分组 if input.FallbackGroupID != nil { if err := s.validateFallbackGroup(ctx, 0, *input.FallbackGroupID); err != nil { return nil, err } } group := &Group{ Name: input.Name, Description: input.Description, Platform: platform, RateMultiplier: input.RateMultiplier, IsExclusive: input.IsExclusive, Status: StatusActive, SubscriptionType: subscriptionType, DailyLimitUSD: dailyLimit, WeeklyLimitUSD: weeklyLimit, MonthlyLimitUSD: monthlyLimit, ImagePrice1K: imagePrice1K, ImagePrice2K: imagePrice2K, ImagePrice4K: imagePrice4K, ClaudeCodeOnly: input.ClaudeCodeOnly, FallbackGroupID: input.FallbackGroupID, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err } return group, nil } // normalizeLimit 将 0 或负数转换为 nil(表示无限制) func normalizeLimit(limit *float64) *float64 { if limit == nil || *limit <= 0 { return nil } return limit } // normalizePrice 将负数转换为 nil(表示使用默认价格),0 保留(表示免费) func normalizePrice(price *float64) *float64 { if price == nil || *price < 0 { return nil } return price } // validateFallbackGroup 校验降级分组的有效性 // currentGroupID: 当前分组 ID(新建时为 0) // fallbackGroupID: 降级分组 ID func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGroupID, fallbackGroupID int64) error { // 不能将自己设置为降级分组 if currentGroupID > 0 && currentGroupID == fallbackGroupID { return fmt.Errorf("cannot set self as fallback group") } visited := map[int64]struct{}{} nextID := fallbackGroupID for { if _, seen := visited[nextID]; seen { return fmt.Errorf("fallback group cycle detected") } visited[nextID] = struct{}{} if currentGroupID > 0 && nextID == currentGroupID { return fmt.Errorf("fallback group cycle detected") } // 检查降级分组是否存在 fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, nextID) if err != nil { return fmt.Errorf("fallback group not found: %w", err) } // 降级分组不能启用 claude_code_only,否则会造成死循环 if nextID == fallbackGroupID && fallbackGroup.ClaudeCodeOnly { return fmt.Errorf("fallback group cannot have claude_code_only enabled") } if fallbackGroup.FallbackGroupID == nil { return nil } nextID = *fallbackGroup.FallbackGroupID } } func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { group, err := s.groupRepo.GetByID(ctx, id) if err != nil { return nil, err } if input.Name != "" { group.Name = input.Name } if input.Description != "" { group.Description = input.Description } if input.Platform != "" { group.Platform = input.Platform } if input.RateMultiplier != nil { group.RateMultiplier = *input.RateMultiplier } if input.IsExclusive != nil { group.IsExclusive = *input.IsExclusive } if input.Status != "" { group.Status = input.Status } // 订阅相关字段 if input.SubscriptionType != "" { group.SubscriptionType = input.SubscriptionType } // 限额字段:0 和 nil 都表示"无限制",正数表示具体限额 if input.DailyLimitUSD != nil { group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD) } if input.WeeklyLimitUSD != nil { group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD) } if input.MonthlyLimitUSD != nil { group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) } // 图片生成计费配置:负数表示清除(使用默认价格) if input.ImagePrice1K != nil { group.ImagePrice1K = normalizePrice(input.ImagePrice1K) } if input.ImagePrice2K != nil { group.ImagePrice2K = normalizePrice(input.ImagePrice2K) } if input.ImagePrice4K != nil { group.ImagePrice4K = normalizePrice(input.ImagePrice4K) } // Claude Code 客户端限制 if input.ClaudeCodeOnly != nil { group.ClaudeCodeOnly = *input.ClaudeCodeOnly } if input.FallbackGroupID != nil { // 校验降级分组 if *input.FallbackGroupID > 0 { if err := s.validateFallbackGroup(ctx, id, *input.FallbackGroupID); err != nil { return nil, err } group.FallbackGroupID = input.FallbackGroupID } else { // 传入 0 或负数表示清除降级分组 group.FallbackGroupID = nil } } if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err } if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) } return group, nil } func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { var groupKeys []string if s.authCacheInvalidator != nil { keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id) if err == nil { groupKeys = keys } } affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id) if err != nil { return err } // 事务成功后,异步失效受影响用户的订阅缓存 if len(affectedUserIDs) > 0 && s.billingCacheService != nil { groupID := id go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() for _, userID := range affectedUserIDs { if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil { log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err) } } }() } if s.authCacheInvalidator != nil { for _, key := range groupKeys { s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key) } } return nil } func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params) if err != nil { return nil, 0, err } return keys, result.Total, nil } // Account management implementations func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search) if err != nil { return nil, 0, err } return accounts, result.Total, nil } func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account, error) { return s.accountRepo.GetByID(ctx, id) } func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) { if len(ids) == 0 { return []*Account{}, nil } accounts, err := s.accountRepo.GetByIDs(ctx, ids) if err != nil { return nil, fmt.Errorf("failed to get accounts by IDs: %w", err) } return accounts, nil } func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) { // 绑定分组 groupIDs := input.GroupIDs // 如果没有指定分组,自动绑定对应平台的默认分组 if len(groupIDs) == 0 { defaultGroupName := input.Platform + "-default" groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform) if err == nil { for _, g := range groups { if g.Name == defaultGroupName { groupIDs = []int64{g.ID} break } } } } // 检查混合渠道风险(除非用户已确认) if len(groupIDs) > 0 && !input.SkipMixedChannelCheck { if err := s.checkMixedChannelRisk(ctx, 0, input.Platform, groupIDs); err != nil { return nil, err } } account := &Account{ Name: input.Name, Notes: normalizeAccountNotes(input.Notes), Platform: input.Platform, Type: input.Type, Credentials: input.Credentials, Extra: input.Extra, ProxyID: input.ProxyID, Concurrency: input.Concurrency, Priority: input.Priority, Status: StatusActive, Schedulable: true, } if input.ExpiresAt != nil && *input.ExpiresAt > 0 { expiresAt := time.Unix(*input.ExpiresAt, 0) account.ExpiresAt = &expiresAt } if input.AutoPauseOnExpired != nil { account.AutoPauseOnExpired = *input.AutoPauseOnExpired } else { account.AutoPauseOnExpired = true } if input.RateMultiplier != nil { if *input.RateMultiplier < 0 { return nil, errors.New("rate_multiplier must be >= 0") } account.RateMultiplier = input.RateMultiplier } if err := s.accountRepo.Create(ctx, account); err != nil { return nil, err } // 绑定分组 if len(groupIDs) > 0 { if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil { return nil, err } } return account, nil } func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { return nil, err } if input.Name != "" { account.Name = input.Name } if input.Type != "" { account.Type = input.Type } if input.Notes != nil { account.Notes = normalizeAccountNotes(input.Notes) } if len(input.Credentials) > 0 { account.Credentials = input.Credentials } if len(input.Extra) > 0 { account.Extra = input.Extra } if input.ProxyID != nil { // 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图) if *input.ProxyID == 0 { account.ProxyID = nil } else { account.ProxyID = input.ProxyID } account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID } // 只在指针非 nil 时更新 Concurrency(支持设置为 0) if input.Concurrency != nil { account.Concurrency = *input.Concurrency } // 只在指针非 nil 时更新 Priority(支持设置为 0) if input.Priority != nil { account.Priority = *input.Priority } if input.RateMultiplier != nil { if *input.RateMultiplier < 0 { return nil, errors.New("rate_multiplier must be >= 0") } account.RateMultiplier = input.RateMultiplier } if input.Status != "" { account.Status = input.Status } if input.ExpiresAt != nil { if *input.ExpiresAt <= 0 { account.ExpiresAt = nil } else { expiresAt := time.Unix(*input.ExpiresAt, 0) account.ExpiresAt = &expiresAt } } if input.AutoPauseOnExpired != nil { account.AutoPauseOnExpired = *input.AutoPauseOnExpired } // 先验证分组是否存在(在任何写操作之前) if input.GroupIDs != nil { for _, groupID := range *input.GroupIDs { if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil { return nil, fmt.Errorf("get group: %w", err) } } // 检查混合渠道风险(除非用户已确认) if !input.SkipMixedChannelCheck { if err := s.checkMixedChannelRisk(ctx, account.ID, account.Platform, *input.GroupIDs); err != nil { return nil, err } } } if err := s.accountRepo.Update(ctx, account); err != nil { return nil, err } // 绑定分组 if input.GroupIDs != nil { if err := s.accountRepo.BindGroups(ctx, account.ID, *input.GroupIDs); err != nil { return nil, err } } // 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象) return s.accountRepo.GetByID(ctx, id) } // BulkUpdateAccounts updates multiple accounts in one request. // It merges credentials/extra keys instead of overwriting the whole object. func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) { result := &BulkUpdateAccountsResult{ SuccessIDs: make([]int64, 0, len(input.AccountIDs)), FailedIDs: make([]int64, 0, len(input.AccountIDs)), Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)), } if len(input.AccountIDs) == 0 { return result, nil } // Preload account platforms for mixed channel risk checks if group bindings are requested. platformByID := map[int64]string{} if input.GroupIDs != nil && !input.SkipMixedChannelCheck { accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) if err != nil { return nil, err } for _, account := range accounts { if account != nil { platformByID[account.ID] = account.Platform } } } if input.RateMultiplier != nil { if *input.RateMultiplier < 0 { return nil, errors.New("rate_multiplier must be >= 0") } } // Prepare bulk updates for columns and JSONB fields. repoUpdates := AccountBulkUpdate{ Credentials: input.Credentials, Extra: input.Extra, } if input.Name != "" { repoUpdates.Name = &input.Name } if input.ProxyID != nil { repoUpdates.ProxyID = input.ProxyID } if input.Concurrency != nil { repoUpdates.Concurrency = input.Concurrency } if input.Priority != nil { repoUpdates.Priority = input.Priority } if input.RateMultiplier != nil { repoUpdates.RateMultiplier = input.RateMultiplier } if input.Status != "" { repoUpdates.Status = &input.Status } if input.Schedulable != nil { repoUpdates.Schedulable = input.Schedulable } // Run bulk update for column/jsonb fields first. if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil { return nil, err } // Handle group bindings per account (requires individual operations). for _, accountID := range input.AccountIDs { entry := BulkUpdateAccountResult{AccountID: accountID} if input.GroupIDs != nil { // 检查混合渠道风险(除非用户已确认) if !input.SkipMixedChannelCheck { platform := platformByID[accountID] if platform == "" { account, err := s.accountRepo.GetByID(ctx, accountID) if err != nil { entry.Success = false entry.Error = err.Error() result.Failed++ result.FailedIDs = append(result.FailedIDs, accountID) result.Results = append(result.Results, entry) continue } platform = account.Platform } if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil { entry.Success = false entry.Error = err.Error() result.Failed++ result.FailedIDs = append(result.FailedIDs, accountID) result.Results = append(result.Results, entry) continue } } if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil { entry.Success = false entry.Error = err.Error() result.Failed++ result.FailedIDs = append(result.FailedIDs, accountID) result.Results = append(result.Results, entry) continue } } entry.Success = true result.Success++ result.SuccessIDs = append(result.SuccessIDs, accountID) result.Results = append(result.Results, entry) } return result, nil } func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { return s.accountRepo.Delete(ctx, id) } func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { return nil, err } // TODO: Implement refresh logic return account, nil } func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { return nil, err } account.Status = StatusActive account.ErrorMessage = "" if err := s.accountRepo.Update(ctx, account); err != nil { return nil, err } return account, nil } func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) { if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { return nil, err } return s.accountRepo.GetByID(ctx, id) } // Proxy management implementations func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search) if err != nil { return nil, 0, err } return proxies, result.Total, nil } func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search) if err != nil { return nil, 0, err } s.attachProxyLatency(ctx, proxies) return proxies, result.Total, nil } func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) { return s.proxyRepo.ListActive(ctx) } func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) { proxies, err := s.proxyRepo.ListActiveWithAccountCount(ctx) if err != nil { return nil, err } s.attachProxyLatency(ctx, proxies) return proxies, nil } func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) { return s.proxyRepo.GetByID(ctx, id) } func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) { proxy := &Proxy{ Name: input.Name, Protocol: input.Protocol, Host: input.Host, Port: input.Port, Username: input.Username, Password: input.Password, Status: StatusActive, } if err := s.proxyRepo.Create(ctx, proxy); err != nil { return nil, err } // Probe latency asynchronously so creation isn't blocked by network timeout. go s.probeProxyLatency(context.Background(), proxy) return proxy, nil } func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) { proxy, err := s.proxyRepo.GetByID(ctx, id) if err != nil { return nil, err } if input.Name != "" { proxy.Name = input.Name } if input.Protocol != "" { proxy.Protocol = input.Protocol } if input.Host != "" { proxy.Host = input.Host } if input.Port != 0 { proxy.Port = input.Port } if input.Username != "" { proxy.Username = input.Username } if input.Password != "" { proxy.Password = input.Password } if input.Status != "" { proxy.Status = input.Status } if err := s.proxyRepo.Update(ctx, proxy); err != nil { return nil, err } return proxy, nil } func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error { count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id) if err != nil { return err } if count > 0 { return ErrProxyInUse } return s.proxyRepo.Delete(ctx, id) } func (s *adminServiceImpl) BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error) { result := &ProxyBatchDeleteResult{} if len(ids) == 0 { return result, nil } for _, id := range ids { count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id) if err != nil { result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{ ID: id, Reason: err.Error(), }) continue } if count > 0 { result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{ ID: id, Reason: ErrProxyInUse.Error(), }) continue } if err := s.proxyRepo.Delete(ctx, id); err != nil { result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{ ID: id, Reason: err.Error(), }) continue } result.DeletedIDs = append(result.DeletedIDs, id) } return result, nil } func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) { return s.proxyRepo.ListAccountSummariesByProxyID(ctx, proxyID) } func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) { return s.proxyRepo.ExistsByHostPortAuth(ctx, host, port, username, password) } // Redeem code management implementations func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) if err != nil { return nil, 0, err } return codes, result.Total, nil } func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) { return s.redeemCodeRepo.GetByID(ctx, id) } func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) { // 如果是订阅类型,验证必须有 GroupID if input.Type == RedeemTypeSubscription { if input.GroupID == nil { return nil, errors.New("group_id is required for subscription type") } // 验证分组存在且为订阅类型 group, err := s.groupRepo.GetByID(ctx, *input.GroupID) if err != nil { return nil, fmt.Errorf("group not found: %w", err) } if !group.IsSubscriptionType() { return nil, errors.New("group must be subscription type") } } codes := make([]RedeemCode, 0, input.Count) for i := 0; i < input.Count; i++ { codeValue, err := GenerateRedeemCode() if err != nil { return nil, err } code := RedeemCode{ Code: codeValue, Type: input.Type, Value: input.Value, Status: StatusUnused, } // 订阅类型专用字段 if input.Type == RedeemTypeSubscription { code.GroupID = input.GroupID code.ValidityDays = input.ValidityDays if code.ValidityDays <= 0 { code.ValidityDays = 30 // 默认30天 } } if err := s.redeemCodeRepo.Create(ctx, &code); err != nil { return nil, err } codes = append(codes, code) } return codes, nil } func (s *adminServiceImpl) DeleteRedeemCode(ctx context.Context, id int64) error { return s.redeemCodeRepo.Delete(ctx, id) } func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) { var deleted int64 for _, id := range ids { if err := s.redeemCodeRepo.Delete(ctx, id); err == nil { deleted++ } } return deleted, nil } func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) { code, err := s.redeemCodeRepo.GetByID(ctx, id) if err != nil { return nil, err } code.Status = StatusExpired if err := s.redeemCodeRepo.Update(ctx, code); err != nil { return nil, err } return code, nil } func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) { proxy, err := s.proxyRepo.GetByID(ctx, id) if err != nil { return nil, err } proxyURL := proxy.URL() exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) if err != nil { s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{ Success: false, Message: err.Error(), UpdatedAt: time.Now(), }) return &ProxyTestResult{ Success: false, Message: err.Error(), }, nil } latency := latencyMs s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{ Success: true, LatencyMs: &latency, Message: "Proxy is accessible", IPAddress: exitInfo.IP, Country: exitInfo.Country, CountryCode: exitInfo.CountryCode, Region: exitInfo.Region, City: exitInfo.City, UpdatedAt: time.Now(), }) return &ProxyTestResult{ Success: true, Message: "Proxy is accessible", LatencyMs: latencyMs, IPAddress: exitInfo.IP, City: exitInfo.City, Region: exitInfo.Region, Country: exitInfo.Country, CountryCode: exitInfo.CountryCode, }, nil } func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) { if s.proxyProber == nil || proxy == nil { return } exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxy.URL()) if err != nil { s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{ Success: false, Message: err.Error(), UpdatedAt: time.Now(), }) return } latency := latencyMs s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{ Success: true, LatencyMs: &latency, Message: "Proxy is accessible", IPAddress: exitInfo.IP, Country: exitInfo.Country, CountryCode: exitInfo.CountryCode, Region: exitInfo.Region, City: exitInfo.City, UpdatedAt: time.Now(), }) } // checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic) // 如果存在混合,返回错误提示用户确认 func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { // 判断当前账号的渠道类型(基于 platform 字段,而不是 type 字段) currentPlatform := getAccountPlatform(currentAccountPlatform) if currentPlatform == "" { // 不是 Antigravity 或 Anthropic,无需检查 return nil } // 检查每个分组中的其他账号 for _, groupID := range groupIDs { accounts, err := s.accountRepo.ListByGroup(ctx, groupID) if err != nil { return fmt.Errorf("get accounts in group %d: %w", groupID, err) } // 检查是否存在不同渠道的账号 for _, account := range accounts { if currentAccountID > 0 && account.ID == currentAccountID { continue // 跳过当前账号 } otherPlatform := getAccountPlatform(account.Platform) if otherPlatform == "" { continue // 不是 Antigravity 或 Anthropic,跳过 } // 检测混合渠道 if currentPlatform != otherPlatform { group, _ := s.groupRepo.GetByID(ctx, groupID) groupName := fmt.Sprintf("Group %d", groupID) if group != nil { groupName = group.Name } return &MixedChannelError{ GroupID: groupID, GroupName: groupName, CurrentPlatform: currentPlatform, OtherPlatform: otherPlatform, } } } } return nil } func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) { if s.proxyLatencyCache == nil || len(proxies) == 0 { return } ids := make([]int64, 0, len(proxies)) for i := range proxies { ids = append(ids, proxies[i].ID) } latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids) if err != nil { log.Printf("Warning: load proxy latency cache failed: %v", err) return } for i := range proxies { info := latencies[proxies[i].ID] if info == nil { continue } if info.Success { proxies[i].LatencyStatus = "success" proxies[i].LatencyMs = info.LatencyMs } else { proxies[i].LatencyStatus = "failed" } proxies[i].LatencyMessage = info.Message proxies[i].IPAddress = info.IPAddress proxies[i].Country = info.Country proxies[i].CountryCode = info.CountryCode proxies[i].Region = info.Region proxies[i].City = info.City } } func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) { if s.proxyLatencyCache == nil || info == nil { return } if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil { log.Printf("Warning: store proxy latency cache failed: %v", err) } } // getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识 func getAccountPlatform(accountPlatform string) string { switch strings.ToLower(strings.TrimSpace(accountPlatform)) { case PlatformAntigravity: return "Antigravity" case PlatformAnthropic, "claude": return "Anthropic" default: return "" } } // MixedChannelError 混合渠道错误 type MixedChannelError struct { GroupID int64 GroupName string CurrentPlatform string OtherPlatform string } func (e *MixedChannelError) Error() string { return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.", e.GroupName, e.CurrentPlatform, e.OtherPlatform) }