Files
sub2api/backend/internal/service/admin_service.go

1187 lines
36 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"errors"
"fmt"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
// AdminService interface defines admin management operations
type AdminService interface {
// User management
ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*Group, error)
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
DeleteAccount(ctx context.Context, id int64) error
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
ClearAccountError(ctx context.Context, id int64) (*Account, error)
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
// Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*Proxy, error)
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
DeleteProxy(ctx context.Context, id int64) error
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error)
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
// Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error)
DeleteRedeemCode(ctx context.Context, id int64) error
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
}
// CreateUserInput represents input for creating a new user via admin operations.
type CreateUserInput struct {
Email string
Password string
Username string
Notes string
Balance float64
Concurrency int
AllowedGroups []int64
}
type UpdateUserInput struct {
Email string
Password string
Username *string
Notes *string
Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
}
type CreateGroupInput struct {
Name string
Description string
Platform string
RateMultiplier float64
IsExclusive bool
SubscriptionType string // standard/subscription
DailyLimitUSD *float64 // 日限额 (USD)
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
}
type UpdateGroupInput struct {
Name string
Description string
Platform string
RateMultiplier *float64 // 使用指针以支持设置为0
IsExclusive *bool
Status string
SubscriptionType string // standard/subscription
DailyLimitUSD *float64 // 日限额 (USD)
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
}
type CreateAccountInput struct {
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
GroupIDs []int64
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
}
type UpdateAccountInput struct {
Name string
Notes *string
Type string // Account type: oauth, setup-token, apikey
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
Status string
GroupIDs *[]int64
SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险)
}
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct {
AccountIDs []int64
Name string
ProxyID *int64
Concurrency *int
Priority *int
Status string
GroupIDs *[]int64
Credentials map[string]any
Extra map[string]any
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
}
// BulkUpdateAccountResult captures the result for a single account update.
type BulkUpdateAccountResult struct {
AccountID int64 `json:"account_id"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
Failed int `json:"failed"`
Results []BulkUpdateAccountResult `json:"results"`
}
type CreateProxyInput struct {
Name string
Protocol string
Host string
Port int
Username string
Password string
}
type UpdateProxyInput struct {
Name string
Protocol string
Host string
Port int
Username string
Password string
Status string
}
type GenerateRedeemCodesInput struct {
Count int
Type string
Value float64
GroupID *int64 // 订阅类型专用关联的分组ID
ValidityDays int // 订阅类型专用:有效天数
}
// ProxyTestResult represents the result of testing a proxy
type ProxyTestResult struct {
Success bool `json:"success"`
Message string `json:"message"`
LatencyMs int64 `json:"latency_ms,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
City string `json:"city,omitempty"`
Region string `json:"region,omitempty"`
Country string `json:"country,omitempty"`
}
// ProxyExitInfo represents proxy exit information from ipinfo.io
type ProxyExitInfo struct {
IP string
City string
Region string
Country string
}
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
type ProxyExitInfoProber interface {
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
}
// adminServiceImpl implements AdminService
type adminServiceImpl struct {
userRepo UserRepository
groupRepo GroupRepository
accountRepo AccountRepository
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
}
// NewAdminService creates a new AdminService
func NewAdminService(
userRepo UserRepository,
groupRepo GroupRepository,
accountRepo AccountRepository,
proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
groupRepo: groupRepo,
accountRepo: accountRepo,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
}
}
// User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, filters)
if err != nil {
return nil, 0, err
}
return users, result.Total, nil
}
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
return s.userRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
user := &User{
Email: input.Email,
Username: input.Username,
Notes: input.Notes,
Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
}
if err := user.SetPassword(input.Password); err != nil {
return nil, err
}
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
return user, nil
}
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// Protect admin users: cannot disable admin accounts
if user.Role == "admin" && input.Status == "disabled" {
return nil, errors.New("cannot disable admin user")
}
oldConcurrency := user.Concurrency
if input.Email != "" {
user.Email = input.Email
}
if input.Password != "" {
if err := user.SetPassword(input.Password); err != nil {
return nil, err
}
}
if input.Username != nil {
user.Username = *input.Username
}
if input.Notes != nil {
user.Notes = *input.Notes
}
if input.Status != "" {
user.Status = input.Status
}
if input.Concurrency != nil {
user.Concurrency = *input.Concurrency
}
if input.AllowedGroups != nil {
user.AllowedGroups = *input.AllowedGroups
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 {
code, err := GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &RedeemCode{
Code: code,
Type: AdjustmentTypeAdminConcurrency,
Value: float64(concurrencyDiff),
Status: StatusUsed,
UsedBy: &user.ID,
}
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
log.Printf("failed to create concurrency adjustment redeem code: %v", err)
}
}
return user, nil
}
func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
// Protect admin users: cannot delete admin accounts
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return err
}
if user.Role == "admin" {
return errors.New("cannot delete admin user")
}
if err := s.userRepo.Delete(ctx, id); err != nil {
log.Printf("delete user failed: user_id=%d err=%v", id, err)
return err
}
return nil
}
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
oldBalance := user.Balance
switch operation {
case "set":
user.Balance = balance
case "add":
user.Balance += balance
case "subtract":
user.Balance -= balance
}
if user.Balance < 0 {
return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance)
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
}
}()
}
balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 {
code, err := GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &RedeemCode{
Code: code,
Type: AdjustmentTypeAdminBalance,
Value: balanceDiff,
Status: StatusUsed,
UsedBy: &user.ID,
Notes: notes,
}
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
log.Printf("failed to create balance adjustment redeem code: %v", err)
}
}
return user, nil
}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, 0, err
}
return keys, result.Total, nil
}
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
// Return mock data for now
return map[string]any{
"period": period,
"total_requests": 0,
"total_cost": 0.0,
"total_tokens": 0,
"avg_duration_ms": 0,
}, nil
}
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
if err != nil {
return nil, 0, err
}
return groups, result.Total, nil
}
func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]Group, error) {
return s.groupRepo.ListActive(ctx)
}
func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) {
return s.groupRepo.ListActiveByPlatform(ctx, platform)
}
func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, error) {
return s.groupRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) {
platform := input.Platform
if platform == "" {
platform = PlatformAnthropic
}
subscriptionType := input.SubscriptionType
if subscriptionType == "" {
subscriptionType = SubscriptionTypeStandard
}
// 限额字段0 和 nil 都表示"无限制"
dailyLimit := normalizeLimit(input.DailyLimitUSD)
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
// 图片价格负数表示清除使用默认价格0 保留(表示免费)
imagePrice1K := normalizePrice(input.ImagePrice1K)
imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K)
group := &Group{
Name: input.Name,
Description: input.Description,
Platform: platform,
RateMultiplier: input.RateMultiplier,
IsExclusive: input.IsExclusive,
Status: StatusActive,
SubscriptionType: subscriptionType,
DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: monthlyLimit,
ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
}
return group, nil
}
// normalizeLimit 将 0 或负数转换为 nil表示无限制
func normalizeLimit(limit *float64) *float64 {
if limit == nil || *limit <= 0 {
return nil
}
return limit
}
// normalizePrice 将负数转换为 nil表示使用默认价格0 保留(表示免费)
func normalizePrice(price *float64) *float64 {
if price == nil || *price < 0 {
return nil
}
return price
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != "" {
group.Name = input.Name
}
if input.Description != "" {
group.Description = input.Description
}
if input.Platform != "" {
group.Platform = input.Platform
}
if input.RateMultiplier != nil {
group.RateMultiplier = *input.RateMultiplier
}
if input.IsExclusive != nil {
group.IsExclusive = *input.IsExclusive
}
if input.Status != "" {
group.Status = input.Status
}
// 订阅相关字段
if input.SubscriptionType != "" {
group.SubscriptionType = input.SubscriptionType
}
// 限额字段0 和 nil 都表示"无限制",正数表示具体限额
if input.DailyLimitUSD != nil {
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
}
if input.WeeklyLimitUSD != nil {
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
}
if input.MonthlyLimitUSD != nil {
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
}
// 图片生成计费配置:负数表示清除(使用默认价格)
if input.ImagePrice1K != nil {
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
}
if input.ImagePrice2K != nil {
group.ImagePrice2K = normalizePrice(input.ImagePrice2K)
}
if input.ImagePrice4K != nil {
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
}
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
return group, nil
}
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
if err != nil {
return err
}
// 事务成功后,异步失效受影响用户的订阅缓存
if len(affectedUserIDs) > 0 && s.billingCacheService != nil {
groupID := id
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
for _, userID := range affectedUserIDs {
if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil {
log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err)
}
}
}()
}
return nil
}
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil {
return nil, 0, err
}
return keys, result.Total, nil
}
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
if err != nil {
return nil, 0, err
}
return accounts, result.Total, nil
}
func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account, error) {
return s.accountRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
if len(ids) == 0 {
return []*Account{}, nil
}
accounts, err := s.accountRepo.GetByIDs(ctx, ids)
if err != nil {
return nil, fmt.Errorf("failed to get accounts by IDs: %w", err)
}
return accounts, nil
}
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
// 绑定分组
groupIDs := input.GroupIDs
// 如果没有指定分组,自动绑定对应平台的默认分组
if len(groupIDs) == 0 {
defaultGroupName := input.Platform + "-default"
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
if err == nil {
for _, g := range groups {
if g.Name == defaultGroupName {
groupIDs = []int64{g.ID}
break
}
}
}
}
// 检查混合渠道风险(除非用户已确认)
if len(groupIDs) > 0 && !input.SkipMixedChannelCheck {
if err := s.checkMixedChannelRisk(ctx, 0, input.Platform, groupIDs); err != nil {
return nil, err
}
}
account := &Account{
Name: input.Name,
Notes: normalizeAccountNotes(input.Notes),
Platform: input.Platform,
Type: input.Type,
Credentials: input.Credentials,
Extra: input.Extra,
ProxyID: input.ProxyID,
Concurrency: input.Concurrency,
Priority: input.Priority,
Status: StatusActive,
Schedulable: true,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
}
// 绑定分组
if len(groupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
return nil, err
}
}
return account, nil
}
func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != "" {
account.Name = input.Name
}
if input.Type != "" {
account.Type = input.Type
}
if input.Notes != nil {
account.Notes = normalizeAccountNotes(input.Notes)
}
if len(input.Credentials) > 0 {
account.Credentials = input.Credentials
}
if len(input.Extra) > 0 {
account.Extra = input.Extra
}
if input.ProxyID != nil {
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
if *input.ProxyID == 0 {
account.ProxyID = nil
} else {
account.ProxyID = input.ProxyID
}
account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID
}
// 只在指针非 nil 时更新 Concurrency支持设置为 0
if input.Concurrency != nil {
account.Concurrency = *input.Concurrency
}
// 只在指针非 nil 时更新 Priority支持设置为 0
if input.Priority != nil {
account.Priority = *input.Priority
}
if input.Status != "" {
account.Status = input.Status
}
// 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil {
for _, groupID := range *input.GroupIDs {
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
}
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
if err := s.checkMixedChannelRisk(ctx, account.ID, account.Platform, *input.GroupIDs); err != nil {
return nil, err
}
}
}
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, err
}
// 绑定分组
if input.GroupIDs != nil {
if err := s.accountRepo.BindGroups(ctx, account.ID, *input.GroupIDs); err != nil {
return nil, err
}
}
// 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象)
return s.accountRepo.GetByID(ctx, id)
}
// BulkUpdateAccounts updates multiple accounts in one request.
// It merges credentials/extra keys instead of overwriting the whole object.
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
result := &BulkUpdateAccountsResult{
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
}
if len(input.AccountIDs) == 0 {
return result, nil
}
// Preload account platforms for mixed channel risk checks if group bindings are requested.
platformByID := map[int64]string{}
if input.GroupIDs != nil && !input.SkipMixedChannelCheck {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
return nil, err
}
for _, account := range accounts {
if account != nil {
platformByID[account.ID] = account.Platform
}
}
}
// Prepare bulk updates for columns and JSONB fields.
repoUpdates := AccountBulkUpdate{
Credentials: input.Credentials,
Extra: input.Extra,
}
if input.Name != "" {
repoUpdates.Name = &input.Name
}
if input.ProxyID != nil {
repoUpdates.ProxyID = input.ProxyID
}
if input.Concurrency != nil {
repoUpdates.Concurrency = input.Concurrency
}
if input.Priority != nil {
repoUpdates.Priority = input.Priority
}
if input.Status != "" {
repoUpdates.Status = &input.Status
}
// Run bulk update for column/jsonb fields first.
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
return nil, err
}
// Handle group bindings per account (requires individual operations).
for _, accountID := range input.AccountIDs {
entry := BulkUpdateAccountResult{AccountID: accountID}
if input.GroupIDs != nil {
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
platform := platformByID[accountID]
if platform == "" {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
platform = account.Platform
}
if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
}
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
}
entry.Success = true
result.Success++
result.Results = append(result.Results, entry)
}
return result, nil
}
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
return s.accountRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// TODO: Implement refresh logic
return account, nil
}
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
account.Status = StatusActive
account.ErrorMessage = ""
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, err
}
return account, nil
}
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err
}
return s.accountRepo.GetByID(ctx, id)
}
// Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil {
return nil, 0, err
}
return proxies, result.Total, nil
}
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
return s.proxyRepo.ListActive(ctx)
}
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
return s.proxyRepo.ListActiveWithAccountCount(ctx)
}
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
return s.proxyRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) {
proxy := &Proxy{
Name: input.Name,
Protocol: input.Protocol,
Host: input.Host,
Port: input.Port,
Username: input.Username,
Password: input.Password,
Status: StatusActive,
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err
}
return proxy, nil
}
func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != "" {
proxy.Name = input.Name
}
if input.Protocol != "" {
proxy.Protocol = input.Protocol
}
if input.Host != "" {
proxy.Host = input.Host
}
if input.Port != 0 {
proxy.Port = input.Port
}
if input.Username != "" {
proxy.Username = input.Username
}
if input.Password != "" {
proxy.Password = input.Password
}
if input.Status != "" {
proxy.Status = input.Status
}
if err := s.proxyRepo.Update(ctx, proxy); err != nil {
return nil, err
}
return proxy, nil
}
func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
return s.proxyRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) {
// Return mock data for now - would need a dedicated repository method
return []Account{}, 0, nil
}
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
return s.proxyRepo.ExistsByHostPortAuth(ctx, host, port, username, password)
}
// Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil {
return nil, 0, err
}
return codes, result.Total, nil
}
func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) {
return s.redeemCodeRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) {
// 如果是订阅类型,验证必须有 GroupID
if input.Type == RedeemTypeSubscription {
if input.GroupID == nil {
return nil, errors.New("group_id is required for subscription type")
}
// 验证分组存在且为订阅类型
group, err := s.groupRepo.GetByID(ctx, *input.GroupID)
if err != nil {
return nil, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, errors.New("group must be subscription type")
}
}
codes := make([]RedeemCode, 0, input.Count)
for i := 0; i < input.Count; i++ {
codeValue, err := GenerateRedeemCode()
if err != nil {
return nil, err
}
code := RedeemCode{
Code: codeValue,
Type: input.Type,
Value: input.Value,
Status: StatusUnused,
}
// 订阅类型专用字段
if input.Type == RedeemTypeSubscription {
code.GroupID = input.GroupID
code.ValidityDays = input.ValidityDays
if code.ValidityDays <= 0 {
code.ValidityDays = 30 // 默认30天
}
}
if err := s.redeemCodeRepo.Create(ctx, &code); err != nil {
return nil, err
}
codes = append(codes, code)
}
return codes, nil
}
func (s *adminServiceImpl) DeleteRedeemCode(ctx context.Context, id int64) error {
return s.redeemCodeRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) {
var deleted int64
for _, id := range ids {
if err := s.redeemCodeRepo.Delete(ctx, id); err == nil {
deleted++
}
}
return deleted, nil
}
func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) {
code, err := s.redeemCodeRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
code.Status = StatusExpired
if err := s.redeemCodeRepo.Update(ctx, code); err != nil {
return nil, err
}
return code, nil
}
func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
proxyURL := proxy.URL()
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
if err != nil {
return &ProxyTestResult{
Success: false,
Message: err.Error(),
}, nil
}
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible",
LatencyMs: latencyMs,
IPAddress: exitInfo.IP,
City: exitInfo.City,
Region: exitInfo.Region,
Country: exitInfo.Country,
}, nil
}
// checkMixedChannelRisk 检查分组中是否存在混合渠道Antigravity + Anthropic
// 如果存在混合,返回错误提示用户确认
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
// 判断当前账号的渠道类型(基于 platform 字段,而不是 type 字段)
currentPlatform := getAccountPlatform(currentAccountPlatform)
if currentPlatform == "" {
// 不是 Antigravity 或 Anthropic无需检查
return nil
}
// 检查每个分组中的其他账号
for _, groupID := range groupIDs {
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return fmt.Errorf("get accounts in group %d: %w", groupID, err)
}
// 检查是否存在不同渠道的账号
for _, account := range accounts {
if currentAccountID > 0 && account.ID == currentAccountID {
continue // 跳过当前账号
}
otherPlatform := getAccountPlatform(account.Platform)
if otherPlatform == "" {
continue // 不是 Antigravity 或 Anthropic跳过
}
// 检测混合渠道
if currentPlatform != otherPlatform {
group, _ := s.groupRepo.GetByID(ctx, groupID)
groupName := fmt.Sprintf("Group %d", groupID)
if group != nil {
groupName = group.Name
}
return &MixedChannelError{
GroupID: groupID,
GroupName: groupName,
CurrentPlatform: currentPlatform,
OtherPlatform: otherPlatform,
}
}
}
}
return nil
}
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
func getAccountPlatform(accountPlatform string) string {
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
case PlatformAntigravity:
return "Antigravity"
case PlatformAnthropic, "claude":
return "Anthropic"
default:
return ""
}
}
// MixedChannelError 混合渠道错误
type MixedChannelError struct {
GroupID int64
GroupName string
CurrentPlatform string
OtherPlatform string
}
func (e *MixedChannelError) Error() string {
return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.",
e.GroupName, e.CurrentPlatform, e.OtherPlatform)
}