Files
yinghuoapi/backend/internal/service/admin_service.go
2026-02-06 11:33:45 +08:00

1773 lines
56 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"errors"
"fmt"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
// AdminService interface defines admin management operations
type AdminService interface {
// User management
ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types.
// Also returns totalRecharged (sum of all positive balance top-ups).
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*Group, error)
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
DeleteAccount(ctx context.Context, id int64) error
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
ClearAccountError(ctx context.Context, id int64) (*Account, error)
SetAccountError(ctx context.Context, id int64, errorMsg string) error
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
// Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error)
GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*Proxy, error)
GetProxiesByIDs(ctx context.Context, ids []int64) ([]Proxy, error)
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
DeleteProxy(ctx context.Context, id int64) error
BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error)
GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
// Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error)
DeleteRedeemCode(ctx context.Context, id int64) error
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
}
// CreateUserInput represents input for creating a new user via admin operations.
type CreateUserInput struct {
Email string
Password string
Username string
Notes string
Balance float64
Concurrency int
AllowedGroups []int64
}
type UpdateUserInput struct {
Email string
Password string
Username *string
Notes *string
Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
// map[groupID]*ratenil 表示删除该分组的专属倍率
GroupRates map[int64]*float64
}
type CreateGroupInput struct {
Name string
Description string
Platform string
RateMultiplier float64
IsExclusive bool
SubscriptionType string // standard/subscription
DailyLimitUSD *float64 // 日限额 (USD)
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
type UpdateGroupInput struct {
Name string
Description string
Platform string
RateMultiplier *float64 // 使用指针以支持设置为0
IsExclusive *bool
Status string
SubscriptionType string // standard/subscription
DailyLimitUSD *float64 // 日限额 (USD)
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
type CreateAccountInput struct {
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
RateMultiplier *float64 // 账号计费倍率(>=0允许 0
GroupIDs []int64
ExpiresAt *int64
AutoPauseOnExpired *bool
// SkipDefaultGroupBind prevents auto-binding to platform default group when GroupIDs is empty.
SkipDefaultGroupBind bool
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
}
type UpdateAccountInput struct {
Name string
Notes *string
Type string // Account type: oauth, setup-token, apikey
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
RateMultiplier *float64 // 账号计费倍率(>=0允许 0
Status string
GroupIDs *[]int64
ExpiresAt *int64
AutoPauseOnExpired *bool
SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险)
}
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct {
AccountIDs []int64
Name string
ProxyID *int64
Concurrency *int
Priority *int
RateMultiplier *float64 // 账号计费倍率(>=0允许 0
Status string
Schedulable *bool
GroupIDs *[]int64
Credentials map[string]any
Extra map[string]any
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
}
// BulkUpdateAccountResult captures the result for a single account update.
type BulkUpdateAccountResult struct {
AccountID int64 `json:"account_id"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
Failed int `json:"failed"`
SuccessIDs []int64 `json:"success_ids"`
FailedIDs []int64 `json:"failed_ids"`
Results []BulkUpdateAccountResult `json:"results"`
}
type CreateProxyInput struct {
Name string
Protocol string
Host string
Port int
Username string
Password string
}
type UpdateProxyInput struct {
Name string
Protocol string
Host string
Port int
Username string
Password string
Status string
}
type GenerateRedeemCodesInput struct {
Count int
Type string
Value float64
GroupID *int64 // 订阅类型专用关联的分组ID
ValidityDays int // 订阅类型专用:有效天数
}
type ProxyBatchDeleteResult struct {
DeletedIDs []int64 `json:"deleted_ids"`
Skipped []ProxyBatchDeleteSkipped `json:"skipped"`
}
type ProxyBatchDeleteSkipped struct {
ID int64 `json:"id"`
Reason string `json:"reason"`
}
// ProxyTestResult represents the result of testing a proxy
type ProxyTestResult struct {
Success bool `json:"success"`
Message string `json:"message"`
LatencyMs int64 `json:"latency_ms,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
City string `json:"city,omitempty"`
Region string `json:"region,omitempty"`
Country string `json:"country,omitempty"`
CountryCode string `json:"country_code,omitempty"`
}
// ProxyExitInfo represents proxy exit information from ip-api.com
type ProxyExitInfo struct {
IP string
City string
Region string
Country string
CountryCode string
}
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
type ProxyExitInfoProber interface {
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
}
// adminServiceImpl implements AdminService
type adminServiceImpl struct {
userRepo UserRepository
groupRepo GroupRepository
accountRepo AccountRepository
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
userGroupRateRepo UserGroupRateRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewAdminService creates a new AdminService
func NewAdminService(
userRepo UserRepository,
groupRepo GroupRepository,
accountRepo AccountRepository,
proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
userGroupRateRepo UserGroupRateRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
groupRepo: groupRepo,
accountRepo: accountRepo,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
userGroupRateRepo: userGroupRateRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache,
authCacheInvalidator: authCacheInvalidator,
}
}
// User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, filters)
if err != nil {
return nil, 0, err
}
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
continue
}
users[i].GroupRates = rates
}
}
return users, result.Total, nil
}
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// 加载用户专属分组倍率
if s.userGroupRateRepo != nil {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", id, err)
} else {
user.GroupRates = rates
}
}
return user, nil
}
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
user := &User{
Email: input.Email,
Username: input.Username,
Notes: input.Notes,
Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
}
if err := user.SetPassword(input.Password); err != nil {
return nil, err
}
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
return user, nil
}
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// Protect admin users: cannot disable admin accounts
if user.Role == "admin" && input.Status == "disabled" {
return nil, errors.New("cannot disable admin user")
}
oldConcurrency := user.Concurrency
oldStatus := user.Status
oldRole := user.Role
if input.Email != "" {
user.Email = input.Email
}
if input.Password != "" {
if err := user.SetPassword(input.Password); err != nil {
return nil, err
}
}
if input.Username != nil {
user.Username = *input.Username
}
if input.Notes != nil {
user.Notes = *input.Notes
}
if input.Status != "" {
user.Status = input.Status
}
if input.Concurrency != nil {
user.Concurrency = *input.Concurrency
}
if input.AllowedGroups != nil {
user.AllowedGroups = *input.AllowedGroups
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
// 同步用户专属分组倍率
if input.GroupRates != nil && s.userGroupRateRepo != nil {
if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil {
log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err)
}
}
if s.authCacheInvalidator != nil {
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
}
}
concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 {
code, err := GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &RedeemCode{
Code: code,
Type: AdjustmentTypeAdminConcurrency,
Value: float64(concurrencyDiff),
Status: StatusUsed,
UsedBy: &user.ID,
}
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
log.Printf("failed to create concurrency adjustment redeem code: %v", err)
}
}
return user, nil
}
func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
// Protect admin users: cannot delete admin accounts
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return err
}
if user.Role == "admin" {
return errors.New("cannot delete admin user")
}
if err := s.userRepo.Delete(ctx, id); err != nil {
log.Printf("delete user failed: user_id=%d err=%v", id, err)
return err
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id)
}
return nil
}
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
oldBalance := user.Balance
switch operation {
case "set":
user.Balance = balance
case "add":
user.Balance += balance
case "subtract":
user.Balance -= balance
}
if user.Balance < 0 {
return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance)
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
balanceDiff := user.Balance - oldBalance
if s.authCacheInvalidator != nil && balanceDiff != 0 {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
}
}()
}
if balanceDiff != 0 {
code, err := GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &RedeemCode{
Code: code,
Type: AdjustmentTypeAdminBalance,
Value: balanceDiff,
Status: StatusUsed,
UsedBy: &user.ID,
Notes: notes,
}
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
log.Printf("failed to create balance adjustment redeem code: %v", err)
}
}
return user, nil
}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, 0, err
}
return keys, result.Total, nil
}
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
// Return mock data for now
return map[string]any{
"period": period,
"total_requests": 0,
"total_cost": 0.0,
"total_tokens": 0,
"avg_duration_ms": 0,
}, nil
}
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType)
if err != nil {
return nil, 0, 0, err
}
// Aggregate total recharged amount (only once, regardless of type filter)
totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID)
if err != nil {
return nil, 0, 0, err
}
return codes, result.Total, totalRecharged, nil
}
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive)
if err != nil {
return nil, 0, err
}
return groups, result.Total, nil
}
func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]Group, error) {
return s.groupRepo.ListActive(ctx)
}
func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) {
return s.groupRepo.ListActiveByPlatform(ctx, platform)
}
func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, error) {
return s.groupRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) {
platform := input.Platform
if platform == "" {
platform = PlatformAnthropic
}
subscriptionType := input.SubscriptionType
if subscriptionType == "" {
subscriptionType = SubscriptionTypeStandard
}
// 限额字段0 和 nil 都表示"无限制"
dailyLimit := normalizeLimit(input.DailyLimitUSD)
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
// 图片价格负数表示清除使用默认价格0 保留(表示免费)
imagePrice1K := normalizePrice(input.ImagePrice1K)
imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K)
// 校验降级分组
if input.FallbackGroupID != nil {
if err := s.validateFallbackGroup(ctx, 0, *input.FallbackGroupID); err != nil {
return nil, err
}
}
fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest
if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 {
fallbackOnInvalidRequest = nil
}
// 校验无效请求兜底分组
if fallbackOnInvalidRequest != nil {
if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil {
return nil, err
}
}
// MCPXMLInject默认为 true仅当显式传入 false 时关闭
mcpXMLInject := true
if input.MCPXMLInject != nil {
mcpXMLInject = *input.MCPXMLInject
}
// 如果指定了复制账号的源分组,先获取账号 ID 列表
var accountIDsToCopy []int64
if len(input.CopyAccountsFromGroupIDs) > 0 {
// 去重源分组 IDs
seen := make(map[int64]struct{})
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
if _, exists := seen[srcGroupID]; !exists {
seen[srcGroupID] = struct{}{}
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
}
}
// 校验源分组的平台是否与新分组一致
for _, srcGroupID := range uniqueSourceGroupIDs {
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
if err != nil {
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
}
if srcGroup.Platform != platform {
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, platform, srcGroup.Platform)
}
}
// 获取所有源分组的账号(去重)
var err error
accountIDsToCopy, err = s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
if err != nil {
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
}
}
group := &Group{
Name: input.Name,
Description: input.Description,
Platform: platform,
RateMultiplier: input.RateMultiplier,
IsExclusive: input.IsExclusive,
Status: StatusActive,
SubscriptionType: subscriptionType,
DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: monthlyLimit,
ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K,
ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
ModelRouting: input.ModelRouting,
MCPXMLInject: mcpXMLInject,
SupportedModelScopes: input.SupportedModelScopes,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
}
// 如果有需要复制的账号,绑定到新分组
if len(accountIDsToCopy) > 0 {
if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil {
return nil, fmt.Errorf("failed to bind accounts to new group: %w", err)
}
group.AccountCount = int64(len(accountIDsToCopy))
}
return group, nil
}
// normalizeLimit 将 0 或负数转换为 nil表示无限制
func normalizeLimit(limit *float64) *float64 {
if limit == nil || *limit <= 0 {
return nil
}
return limit
}
// normalizePrice 将负数转换为 nil表示使用默认价格0 保留(表示免费)
func normalizePrice(price *float64) *float64 {
if price == nil || *price < 0 {
return nil
}
return price
}
// validateFallbackGroup 校验降级分组的有效性
// currentGroupID: 当前分组 ID新建时为 0
// fallbackGroupID: 降级分组 ID
func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGroupID, fallbackGroupID int64) error {
// 不能将自己设置为降级分组
if currentGroupID > 0 && currentGroupID == fallbackGroupID {
return fmt.Errorf("cannot set self as fallback group")
}
visited := map[int64]struct{}{}
nextID := fallbackGroupID
for {
if _, seen := visited[nextID]; seen {
return fmt.Errorf("fallback group cycle detected")
}
visited[nextID] = struct{}{}
if currentGroupID > 0 && nextID == currentGroupID {
return fmt.Errorf("fallback group cycle detected")
}
// 检查降级分组是否存在
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, nextID)
if err != nil {
return fmt.Errorf("fallback group not found: %w", err)
}
// 降级分组不能启用 claude_code_only否则会造成死循环
if nextID == fallbackGroupID && fallbackGroup.ClaudeCodeOnly {
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
}
if fallbackGroup.FallbackGroupID == nil {
return nil
}
nextID = *fallbackGroup.FallbackGroupID
}
}
// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性
// currentGroupID: 当前分组 ID新建时为 0
// platform/subscriptionType: 当前分组的有效平台/订阅类型
// fallbackGroupID: 兜底分组 ID
func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error {
if platform != PlatformAnthropic && platform != PlatformAntigravity {
return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups")
}
if subscriptionType == SubscriptionTypeSubscription {
return fmt.Errorf("subscription groups cannot set invalid request fallback")
}
if currentGroupID > 0 && currentGroupID == fallbackGroupID {
return fmt.Errorf("cannot set self as invalid request fallback group")
}
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID)
if err != nil {
return fmt.Errorf("fallback group not found: %w", err)
}
if fallbackGroup.Platform != PlatformAnthropic {
return fmt.Errorf("fallback group must be anthropic platform")
}
if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription {
return fmt.Errorf("fallback group cannot be subscription type")
}
if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
return fmt.Errorf("fallback group cannot have invalid request fallback configured")
}
return nil
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != "" {
group.Name = input.Name
}
if input.Description != "" {
group.Description = input.Description
}
if input.Platform != "" {
group.Platform = input.Platform
}
if input.RateMultiplier != nil {
group.RateMultiplier = *input.RateMultiplier
}
if input.IsExclusive != nil {
group.IsExclusive = *input.IsExclusive
}
if input.Status != "" {
group.Status = input.Status
}
// 订阅相关字段
if input.SubscriptionType != "" {
group.SubscriptionType = input.SubscriptionType
}
// 限额字段0 和 nil 都表示"无限制",正数表示具体限额
if input.DailyLimitUSD != nil {
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
}
if input.WeeklyLimitUSD != nil {
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
}
if input.MonthlyLimitUSD != nil {
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
}
// 图片生成计费配置:负数表示清除(使用默认价格)
if input.ImagePrice1K != nil {
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
}
if input.ImagePrice2K != nil {
group.ImagePrice2K = normalizePrice(input.ImagePrice2K)
}
if input.ImagePrice4K != nil {
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
}
// Claude Code 客户端限制
if input.ClaudeCodeOnly != nil {
group.ClaudeCodeOnly = *input.ClaudeCodeOnly
}
if input.FallbackGroupID != nil {
// 校验降级分组
if *input.FallbackGroupID > 0 {
if err := s.validateFallbackGroup(ctx, id, *input.FallbackGroupID); err != nil {
return nil, err
}
group.FallbackGroupID = input.FallbackGroupID
} else {
// 传入 0 或负数表示清除降级分组
group.FallbackGroupID = nil
}
}
fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest
if input.FallbackGroupIDOnInvalidRequest != nil {
if *input.FallbackGroupIDOnInvalidRequest > 0 {
fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest
} else {
fallbackOnInvalidRequest = nil
}
}
if fallbackOnInvalidRequest != nil {
if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil {
return nil, err
}
}
group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest
// 模型路由配置
if input.ModelRouting != nil {
group.ModelRouting = input.ModelRouting
}
if input.ModelRoutingEnabled != nil {
group.ModelRoutingEnabled = *input.ModelRoutingEnabled
}
if input.MCPXMLInject != nil {
group.MCPXMLInject = *input.MCPXMLInject
}
// 支持的模型系列(仅 antigravity 平台使用)
if input.SupportedModelScopes != nil {
group.SupportedModelScopes = *input.SupportedModelScopes
}
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
if len(input.CopyAccountsFromGroupIDs) > 0 {
// 去重源分组 IDs
seen := make(map[int64]struct{})
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
// 校验:源分组不能是自身
if srcGroupID == id {
return nil, fmt.Errorf("cannot copy accounts from self")
}
// 去重
if _, exists := seen[srcGroupID]; !exists {
seen[srcGroupID] = struct{}{}
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
}
}
// 校验源分组的平台是否与当前分组一致
for _, srcGroupID := range uniqueSourceGroupIDs {
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
if err != nil {
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
}
if srcGroup.Platform != group.Platform {
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, group.Platform, srcGroup.Platform)
}
}
// 获取所有源分组的账号(去重)
accountIDsToCopy, err := s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
if err != nil {
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
}
// 先清空当前分组的所有账号绑定
if _, err := s.groupRepo.DeleteAccountGroupsByGroupID(ctx, id); err != nil {
return nil, fmt.Errorf("failed to clear existing account bindings: %w", err)
}
// 再绑定源分组的账号
if len(accountIDsToCopy) > 0 {
if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil {
return nil, fmt.Errorf("failed to bind accounts to group: %w", err)
}
}
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
return group, nil
}
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
var groupKeys []string
if s.authCacheInvalidator != nil {
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id)
if err == nil {
groupKeys = keys
}
}
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
if err != nil {
return err
}
// 注意user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理
// 事务成功后,异步失效受影响用户的订阅缓存
if len(affectedUserIDs) > 0 && s.billingCacheService != nil {
groupID := id
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
for _, userID := range affectedUserIDs {
if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil {
log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err)
}
}
}()
}
if s.authCacheInvalidator != nil {
for _, key := range groupKeys {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key)
}
}
return nil
}
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil {
return nil, 0, err
}
return keys, result.Total, nil
}
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
if err != nil {
return nil, 0, err
}
return accounts, result.Total, nil
}
func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account, error) {
return s.accountRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
if len(ids) == 0 {
return []*Account{}, nil
}
accounts, err := s.accountRepo.GetByIDs(ctx, ids)
if err != nil {
return nil, fmt.Errorf("failed to get accounts by IDs: %w", err)
}
return accounts, nil
}
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
// 绑定分组
groupIDs := input.GroupIDs
// 如果没有指定分组,自动绑定对应平台的默认分组
if len(groupIDs) == 0 && !input.SkipDefaultGroupBind {
defaultGroupName := input.Platform + "-default"
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
if err == nil {
for _, g := range groups {
if g.Name == defaultGroupName {
groupIDs = []int64{g.ID}
break
}
}
}
}
// 检查混合渠道风险(除非用户已确认)
if len(groupIDs) > 0 && !input.SkipMixedChannelCheck {
if err := s.checkMixedChannelRisk(ctx, 0, input.Platform, groupIDs); err != nil {
return nil, err
}
}
account := &Account{
Name: input.Name,
Notes: normalizeAccountNotes(input.Notes),
Platform: input.Platform,
Type: input.Type,
Credentials: input.Credentials,
Extra: input.Extra,
ProxyID: input.ProxyID,
Concurrency: input.Concurrency,
Priority: input.Priority,
Status: StatusActive,
Schedulable: true,
}
if input.ExpiresAt != nil && *input.ExpiresAt > 0 {
expiresAt := time.Unix(*input.ExpiresAt, 0)
account.ExpiresAt = &expiresAt
}
if input.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
} else {
account.AutoPauseOnExpired = true
}
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
account.RateMultiplier = input.RateMultiplier
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
}
// 绑定分组
if len(groupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
return nil, err
}
}
return account, nil
}
func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != "" {
account.Name = input.Name
}
if input.Type != "" {
account.Type = input.Type
}
if input.Notes != nil {
account.Notes = normalizeAccountNotes(input.Notes)
}
if len(input.Credentials) > 0 {
account.Credentials = input.Credentials
}
if len(input.Extra) > 0 {
account.Extra = input.Extra
}
if input.ProxyID != nil {
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
if *input.ProxyID == 0 {
account.ProxyID = nil
} else {
account.ProxyID = input.ProxyID
}
account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID
}
// 只在指针非 nil 时更新 Concurrency支持设置为 0
if input.Concurrency != nil {
account.Concurrency = *input.Concurrency
}
// 只在指针非 nil 时更新 Priority支持设置为 0
if input.Priority != nil {
account.Priority = *input.Priority
}
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
account.RateMultiplier = input.RateMultiplier
}
if input.Status != "" {
account.Status = input.Status
}
if input.ExpiresAt != nil {
if *input.ExpiresAt <= 0 {
account.ExpiresAt = nil
} else {
expiresAt := time.Unix(*input.ExpiresAt, 0)
account.ExpiresAt = &expiresAt
}
}
if input.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
}
// 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil {
for _, groupID := range *input.GroupIDs {
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
}
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
if err := s.checkMixedChannelRisk(ctx, account.ID, account.Platform, *input.GroupIDs); err != nil {
return nil, err
}
}
}
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, err
}
// 绑定分组
if input.GroupIDs != nil {
if err := s.accountRepo.BindGroups(ctx, account.ID, *input.GroupIDs); err != nil {
return nil, err
}
}
// 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象)
return s.accountRepo.GetByID(ctx, id)
}
// BulkUpdateAccounts updates multiple accounts in one request.
// It merges credentials/extra keys instead of overwriting the whole object.
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
result := &BulkUpdateAccountsResult{
SuccessIDs: make([]int64, 0, len(input.AccountIDs)),
FailedIDs: make([]int64, 0, len(input.AccountIDs)),
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
}
if len(input.AccountIDs) == 0 {
return result, nil
}
// Preload account platforms for mixed channel risk checks if group bindings are requested.
platformByID := map[int64]string{}
if input.GroupIDs != nil && !input.SkipMixedChannelCheck {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
return nil, err
}
for _, account := range accounts {
if account != nil {
platformByID[account.ID] = account.Platform
}
}
}
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
}
// Prepare bulk updates for columns and JSONB fields.
repoUpdates := AccountBulkUpdate{
Credentials: input.Credentials,
Extra: input.Extra,
}
if input.Name != "" {
repoUpdates.Name = &input.Name
}
if input.ProxyID != nil {
repoUpdates.ProxyID = input.ProxyID
}
if input.Concurrency != nil {
repoUpdates.Concurrency = input.Concurrency
}
if input.Priority != nil {
repoUpdates.Priority = input.Priority
}
if input.RateMultiplier != nil {
repoUpdates.RateMultiplier = input.RateMultiplier
}
if input.Status != "" {
repoUpdates.Status = &input.Status
}
if input.Schedulable != nil {
repoUpdates.Schedulable = input.Schedulable
}
// Run bulk update for column/jsonb fields first.
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
return nil, err
}
// Handle group bindings per account (requires individual operations).
for _, accountID := range input.AccountIDs {
entry := BulkUpdateAccountResult{AccountID: accountID}
if input.GroupIDs != nil {
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
platform := platformByID[accountID]
if platform == "" {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.FailedIDs = append(result.FailedIDs, accountID)
result.Results = append(result.Results, entry)
continue
}
platform = account.Platform
}
if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.FailedIDs = append(result.FailedIDs, accountID)
result.Results = append(result.Results, entry)
continue
}
}
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.FailedIDs = append(result.FailedIDs, accountID)
result.Results = append(result.Results, entry)
continue
}
}
entry.Success = true
result.Success++
result.SuccessIDs = append(result.SuccessIDs, accountID)
result.Results = append(result.Results, entry)
}
return result, nil
}
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
return s.accountRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// TODO: Implement refresh logic
return account, nil
}
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
account.Status = StatusActive
account.ErrorMessage = ""
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, err
}
return account, nil
}
func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
return s.accountRepo.SetError(ctx, id, errorMsg)
}
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err
}
return s.accountRepo.GetByID(ctx, id)
}
// Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil {
return nil, 0, err
}
return proxies, result.Total, nil
}
func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search)
if err != nil {
return nil, 0, err
}
s.attachProxyLatency(ctx, proxies)
return proxies, result.Total, nil
}
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
return s.proxyRepo.ListActive(ctx)
}
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
proxies, err := s.proxyRepo.ListActiveWithAccountCount(ctx)
if err != nil {
return nil, err
}
s.attachProxyLatency(ctx, proxies)
return proxies, nil
}
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
return s.proxyRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) GetProxiesByIDs(ctx context.Context, ids []int64) ([]Proxy, error) {
return s.proxyRepo.ListByIDs(ctx, ids)
}
func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) {
proxy := &Proxy{
Name: input.Name,
Protocol: input.Protocol,
Host: input.Host,
Port: input.Port,
Username: input.Username,
Password: input.Password,
Status: StatusActive,
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err
}
// Probe latency asynchronously so creation isn't blocked by network timeout.
go s.probeProxyLatency(context.Background(), proxy)
return proxy, nil
}
func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != "" {
proxy.Name = input.Name
}
if input.Protocol != "" {
proxy.Protocol = input.Protocol
}
if input.Host != "" {
proxy.Host = input.Host
}
if input.Port != 0 {
proxy.Port = input.Port
}
if input.Username != "" {
proxy.Username = input.Username
}
if input.Password != "" {
proxy.Password = input.Password
}
if input.Status != "" {
proxy.Status = input.Status
}
if err := s.proxyRepo.Update(ctx, proxy); err != nil {
return nil, err
}
return proxy, nil
}
func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
if err != nil {
return err
}
if count > 0 {
return ErrProxyInUse
}
return s.proxyRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error) {
result := &ProxyBatchDeleteResult{}
if len(ids) == 0 {
return result, nil
}
for _, id := range ids {
count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
if err != nil {
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
ID: id,
Reason: err.Error(),
})
continue
}
if count > 0 {
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
ID: id,
Reason: ErrProxyInUse.Error(),
})
continue
}
if err := s.proxyRepo.Delete(ctx, id); err != nil {
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
ID: id,
Reason: err.Error(),
})
continue
}
result.DeletedIDs = append(result.DeletedIDs, id)
}
return result, nil
}
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
return s.proxyRepo.ListAccountSummariesByProxyID(ctx, proxyID)
}
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
return s.proxyRepo.ExistsByHostPortAuth(ctx, host, port, username, password)
}
// Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil {
return nil, 0, err
}
return codes, result.Total, nil
}
func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) {
return s.redeemCodeRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) {
// 如果是订阅类型,验证必须有 GroupID
if input.Type == RedeemTypeSubscription {
if input.GroupID == nil {
return nil, errors.New("group_id is required for subscription type")
}
// 验证分组存在且为订阅类型
group, err := s.groupRepo.GetByID(ctx, *input.GroupID)
if err != nil {
return nil, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, errors.New("group must be subscription type")
}
}
codes := make([]RedeemCode, 0, input.Count)
for i := 0; i < input.Count; i++ {
codeValue, err := GenerateRedeemCode()
if err != nil {
return nil, err
}
code := RedeemCode{
Code: codeValue,
Type: input.Type,
Value: input.Value,
Status: StatusUnused,
}
// 订阅类型专用字段
if input.Type == RedeemTypeSubscription {
code.GroupID = input.GroupID
code.ValidityDays = input.ValidityDays
if code.ValidityDays <= 0 {
code.ValidityDays = 30 // 默认30天
}
}
if err := s.redeemCodeRepo.Create(ctx, &code); err != nil {
return nil, err
}
codes = append(codes, code)
}
return codes, nil
}
func (s *adminServiceImpl) DeleteRedeemCode(ctx context.Context, id int64) error {
return s.redeemCodeRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) {
var deleted int64
for _, id := range ids {
if err := s.redeemCodeRepo.Delete(ctx, id); err == nil {
deleted++
}
}
return deleted, nil
}
func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) {
code, err := s.redeemCodeRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
code.Status = StatusExpired
if err := s.redeemCodeRepo.Update(ctx, code); err != nil {
return nil, err
}
return code, nil
}
func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
proxyURL := proxy.URL()
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
if err != nil {
s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
Success: false,
Message: err.Error(),
UpdatedAt: time.Now(),
})
return &ProxyTestResult{
Success: false,
Message: err.Error(),
}, nil
}
latency := latencyMs
s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
Success: true,
LatencyMs: &latency,
Message: "Proxy is accessible",
IPAddress: exitInfo.IP,
Country: exitInfo.Country,
CountryCode: exitInfo.CountryCode,
Region: exitInfo.Region,
City: exitInfo.City,
UpdatedAt: time.Now(),
})
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible",
LatencyMs: latencyMs,
IPAddress: exitInfo.IP,
City: exitInfo.City,
Region: exitInfo.Region,
Country: exitInfo.Country,
CountryCode: exitInfo.CountryCode,
}, nil
}
func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) {
if s.proxyProber == nil || proxy == nil {
return
}
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxy.URL())
if err != nil {
s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
Success: false,
Message: err.Error(),
UpdatedAt: time.Now(),
})
return
}
latency := latencyMs
s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
Success: true,
LatencyMs: &latency,
Message: "Proxy is accessible",
IPAddress: exitInfo.IP,
Country: exitInfo.Country,
CountryCode: exitInfo.CountryCode,
Region: exitInfo.Region,
City: exitInfo.City,
UpdatedAt: time.Now(),
})
}
// checkMixedChannelRisk 检查分组中是否存在混合渠道Antigravity + Anthropic
// 如果存在混合,返回错误提示用户确认
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
// 判断当前账号的渠道类型(基于 platform 字段,而不是 type 字段)
currentPlatform := getAccountPlatform(currentAccountPlatform)
if currentPlatform == "" {
// 不是 Antigravity 或 Anthropic无需检查
return nil
}
// 检查每个分组中的其他账号
for _, groupID := range groupIDs {
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return fmt.Errorf("get accounts in group %d: %w", groupID, err)
}
// 检查是否存在不同渠道的账号
for _, account := range accounts {
if currentAccountID > 0 && account.ID == currentAccountID {
continue // 跳过当前账号
}
otherPlatform := getAccountPlatform(account.Platform)
if otherPlatform == "" {
continue // 不是 Antigravity 或 Anthropic跳过
}
// 检测混合渠道
if currentPlatform != otherPlatform {
group, _ := s.groupRepo.GetByID(ctx, groupID)
groupName := fmt.Sprintf("Group %d", groupID)
if group != nil {
groupName = group.Name
}
return &MixedChannelError{
GroupID: groupID,
GroupName: groupName,
CurrentPlatform: currentPlatform,
OtherPlatform: otherPlatform,
}
}
}
}
return nil
}
func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
if s.proxyLatencyCache == nil || len(proxies) == 0 {
return
}
ids := make([]int64, 0, len(proxies))
for i := range proxies {
ids = append(ids, proxies[i].ID)
}
latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids)
if err != nil {
log.Printf("Warning: load proxy latency cache failed: %v", err)
return
}
for i := range proxies {
info := latencies[proxies[i].ID]
if info == nil {
continue
}
if info.Success {
proxies[i].LatencyStatus = "success"
proxies[i].LatencyMs = info.LatencyMs
} else {
proxies[i].LatencyStatus = "failed"
}
proxies[i].LatencyMessage = info.Message
proxies[i].IPAddress = info.IPAddress
proxies[i].Country = info.Country
proxies[i].CountryCode = info.CountryCode
proxies[i].Region = info.Region
proxies[i].City = info.City
}
}
func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) {
if s.proxyLatencyCache == nil || info == nil {
return
}
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil {
log.Printf("Warning: store proxy latency cache failed: %v", err)
}
}
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
func getAccountPlatform(accountPlatform string) string {
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
case PlatformAntigravity:
return "Antigravity"
case PlatformAnthropic, "claude":
return "Anthropic"
default:
return ""
}
}
// MixedChannelError 混合渠道错误
type MixedChannelError struct {
GroupID int64
GroupName string
CurrentPlatform string
OtherPlatform string
}
func (e *MixedChannelError) Error() string {
return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.",
e.GroupName, e.CurrentPlatform, e.OtherPlatform)
}