package service import ( "context" "crypto/rand" "encoding/hex" "fmt" "time" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/dgraph-io/ristretto" "golang.org/x/sync/singleflight" ) var ( ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found") ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group") ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists") ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters") ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens") ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later") ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern") // ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key has expired") ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期") // ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted") ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完") ) const ( apiKeyMaxErrorsPerHour = 20 ) type APIKeyRepository interface { Create(ctx context.Context, key *APIKey) error GetByID(ctx context.Context, id int64) (*APIKey, error) // GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID,用于删除等轻量场景 GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) GetByKey(ctx context.Context, key string) (*APIKey, error) // GetByKeyForAuth 认证专用查询,返回最小字段集 GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) Update(ctx context.Context, key *APIKey) error Delete(ctx context.Context, id int64) error ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) CountByUserID(ctx context.Context, userID int64) (int64, error) ExistsByKey(ctx context.Context, key string) (bool, error) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) CountByGroupID(ctx context.Context, groupID int64) (int64, error) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) // Quota methods IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) } // APIKeyCache defines cache operations for API key service type APIKeyCache interface { GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) IncrementCreateAttemptCount(ctx context.Context, userID int64) error DeleteCreateAttemptCount(ctx context.Context, userID int64) error IncrementDailyUsage(ctx context.Context, apiKey string) error SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error DeleteAuthCache(ctx context.Context, key string) error // Pub/Sub for L1 cache invalidation across instances PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error } // APIKeyAuthCacheInvalidator 提供认证缓存失效能力 type APIKeyAuthCacheInvalidator interface { InvalidateAuthCacheByKey(ctx context.Context, key string) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) } // CreateAPIKeyRequest 创建API Key请求 type CreateAPIKeyRequest struct { Name string `json:"name"` GroupID *int64 `json:"group_id"` CustomKey *string `json:"custom_key"` // 可选的自定义key IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 // Quota fields Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires) } // UpdateAPIKeyRequest 更新API Key请求 type UpdateAPIKeyRequest struct { Name *string `json:"name"` GroupID *int64 `json:"group_id"` Status *string `json:"status"` IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空) IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空) // Quota fields Quota *float64 `json:"quota"` // Quota limit in USD (nil = no change, 0 = unlimited) ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change) ClearExpiration bool `json:"-"` // Clear expiration (internal use) ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0 } // APIKeyService API Key服务 type APIKeyService struct { apiKeyRepo APIKeyRepository userRepo UserRepository groupRepo GroupRepository userSubRepo UserSubscriptionRepository userGroupRateRepo UserGroupRateRepository cache APIKeyCache cfg *config.Config authCacheL1 *ristretto.Cache authCfg apiKeyAuthCacheConfig authGroup singleflight.Group } // NewAPIKeyService 创建API Key服务实例 func NewAPIKeyService( apiKeyRepo APIKeyRepository, userRepo UserRepository, groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, userGroupRateRepo UserGroupRateRepository, cache APIKeyCache, cfg *config.Config, ) *APIKeyService { svc := &APIKeyService{ apiKeyRepo: apiKeyRepo, userRepo: userRepo, groupRepo: groupRepo, userSubRepo: userSubRepo, userGroupRateRepo: userGroupRateRepo, cache: cache, cfg: cfg, } svc.initAuthCache(cfg) return svc } // GenerateKey 生成随机API Key func (s *APIKeyService) GenerateKey() (string, error) { // 生成32字节随机数据 bytes := make([]byte, 32) if _, err := rand.Read(bytes); err != nil { return "", fmt.Errorf("generate random bytes: %w", err) } // 转换为十六进制字符串并添加前缀 prefix := s.cfg.Default.APIKeyPrefix if prefix == "" { prefix = "sk-" } key := prefix + hex.EncodeToString(bytes) return key, nil } // ValidateCustomKey 验证自定义API Key格式 func (s *APIKeyService) ValidateCustomKey(key string) error { // 检查长度 if len(key) < 16 { return ErrAPIKeyTooShort } // 检查字符:只允许字母、数字、下划线、连字符 for _, c := range key { if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-' { continue } return ErrAPIKeyInvalidChars } return nil } // checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限 func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error { if s.cache == nil { return nil } count, err := s.cache.GetCreateAttemptCount(ctx, userID) if err != nil { // Redis 出错时不阻止用户操作 return nil } if count >= apiKeyMaxErrorsPerHour { return ErrAPIKeyRateLimited } return nil } // incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数 func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) { if s.cache == nil { return } _ = s.cache.IncrementCreateAttemptCount(ctx, userID) } // canUserBindGroup 检查用户是否可以绑定指定分组 // 对于订阅类型分组:检查用户是否有有效订阅 // 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑 func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool { // 订阅类型分组:需要有效订阅 if group.IsSubscriptionType() { _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID) return err == nil // 有有效订阅则允许 } // 标准类型分组:使用原有逻辑 return user.CanBindGroup(group.ID, group.IsExclusive) } // Create 创建API Key func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) { // 验证用户存在 user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return nil, fmt.Errorf("get user: %w", err) } // 验证 IP 白名单格式 if len(req.IPWhitelist) > 0 { if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 { return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid) } } // 验证 IP 黑名单格式 if len(req.IPBlacklist) > 0 { if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 { return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid) } } // 验证分组权限(如果指定了分组) if req.GroupID != nil { group, err := s.groupRepo.GetByID(ctx, *req.GroupID) if err != nil { return nil, fmt.Errorf("get group: %w", err) } // 检查用户是否可以绑定该分组 if !s.canUserBindGroup(ctx, user, group) { return nil, ErrGroupNotAllowed } } var key string // 判断是否使用自定义Key if req.CustomKey != nil && *req.CustomKey != "" { // 检查限流(仅对自定义key进行限流) if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil { return nil, err } // 验证自定义Key格式 if err := s.ValidateCustomKey(*req.CustomKey); err != nil { return nil, err } // 检查Key是否已存在 exists, err := s.apiKeyRepo.ExistsByKey(ctx, *req.CustomKey) if err != nil { return nil, fmt.Errorf("check key exists: %w", err) } if exists { // Key已存在,增加错误计数 s.incrementAPIKeyErrorCount(ctx, userID) return nil, ErrAPIKeyExists } key = *req.CustomKey } else { // 生成随机API Key var err error key, err = s.GenerateKey() if err != nil { return nil, fmt.Errorf("generate key: %w", err) } } // 创建API Key记录 apiKey := &APIKey{ UserID: userID, Key: key, Name: req.Name, GroupID: req.GroupID, Status: StatusActive, IPWhitelist: req.IPWhitelist, IPBlacklist: req.IPBlacklist, Quota: req.Quota, QuotaUsed: 0, } // Set expiration time if specified if req.ExpiresInDays != nil && *req.ExpiresInDays > 0 { expiresAt := time.Now().AddDate(0, 0, *req.ExpiresInDays) apiKey.ExpiresAt = &expiresAt } if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil { return nil, fmt.Errorf("create api key: %w", err) } s.InvalidateAuthCacheByKey(ctx, apiKey.Key) return apiKey, nil } // List 获取用户的API Key列表 func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) if err != nil { return nil, nil, fmt.Errorf("list api keys: %w", err) } return keys, pagination, nil } func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { if len(apiKeyIDs) == 0 { return []int64{}, nil } validIDs, err := s.apiKeyRepo.VerifyOwnership(ctx, userID, apiKeyIDs) if err != nil { return nil, fmt.Errorf("verify api key ownership: %w", err) } return validIDs, nil } // GetByID 根据ID获取API Key func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) { apiKey, err := s.apiKeyRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get api key: %w", err) } return apiKey, nil } // GetByKey 根据Key字符串获取API Key(用于认证) func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) { cacheKey := s.authCacheKey(key) if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok { if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { if err != nil { return nil, fmt.Errorf("get api key: %w", err) } return apiKey, nil } } if s.authCfg.singleflight { value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) { return s.loadAuthCacheEntry(ctx, key, cacheKey) }) if err != nil { return nil, err } entry, _ := value.(*APIKeyAuthCacheEntry) if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { if err != nil { return nil, fmt.Errorf("get api key: %w", err) } return apiKey, nil } } else { entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey) if err != nil { return nil, err } if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { if err != nil { return nil, fmt.Errorf("get api key: %w", err) } return apiKey, nil } } apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key) if err != nil { return nil, fmt.Errorf("get api key: %w", err) } apiKey.Key = key return apiKey, nil } // Update 更新API Key func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) { apiKey, err := s.apiKeyRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get api key: %w", err) } // 验证所有权 if apiKey.UserID != userID { return nil, ErrInsufficientPerms } // 验证 IP 白名单格式 if len(req.IPWhitelist) > 0 { if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 { return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid) } } // 验证 IP 黑名单格式 if len(req.IPBlacklist) > 0 { if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 { return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid) } } // 更新字段 if req.Name != nil { apiKey.Name = *req.Name } if req.GroupID != nil { // 验证分组权限 user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return nil, fmt.Errorf("get user: %w", err) } group, err := s.groupRepo.GetByID(ctx, *req.GroupID) if err != nil { return nil, fmt.Errorf("get group: %w", err) } if !s.canUserBindGroup(ctx, user, group) { return nil, ErrGroupNotAllowed } apiKey.GroupID = req.GroupID } if req.Status != nil { apiKey.Status = *req.Status // 如果状态改变,清除Redis缓存 if s.cache != nil { _ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID) } } // Update quota fields if req.Quota != nil { apiKey.Quota = *req.Quota // If quota is increased and status was quota_exhausted, reactivate if apiKey.Status == StatusAPIKeyQuotaExhausted && *req.Quota > apiKey.QuotaUsed { apiKey.Status = StatusActive } } if req.ResetQuota != nil && *req.ResetQuota { apiKey.QuotaUsed = 0 // If resetting quota and status was quota_exhausted, reactivate if apiKey.Status == StatusAPIKeyQuotaExhausted { apiKey.Status = StatusActive } } if req.ClearExpiration { apiKey.ExpiresAt = nil // If clearing expiry and status was expired, reactivate if apiKey.Status == StatusAPIKeyExpired { apiKey.Status = StatusActive } } else if req.ExpiresAt != nil { apiKey.ExpiresAt = req.ExpiresAt // If extending expiry and status was expired, reactivate if apiKey.Status == StatusAPIKeyExpired && time.Now().Before(*req.ExpiresAt) { apiKey.Status = StatusActive } } // 更新 IP 限制(空数组会清空设置) apiKey.IPWhitelist = req.IPWhitelist apiKey.IPBlacklist = req.IPBlacklist if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { return nil, fmt.Errorf("update api key: %w", err) } s.InvalidateAuthCacheByKey(ctx, apiKey.Key) return apiKey, nil } // Delete 删除API Key func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error { key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id) if err != nil { return fmt.Errorf("get api key: %w", err) } // 验证当前用户是否为该 API Key 的所有者 if ownerID != userID { return ErrInsufficientPerms } // 清除Redis缓存(使用 userID 而非 apiKey.UserID) if s.cache != nil { _ = s.cache.DeleteCreateAttemptCount(ctx, userID) } s.InvalidateAuthCacheByKey(ctx, key) if err := s.apiKeyRepo.Delete(ctx, id); err != nil { return fmt.Errorf("delete api key: %w", err) } return nil } // ValidateKey 验证API Key是否有效(用于认证中间件) func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) { // 获取API Key apiKey, err := s.GetByKey(ctx, key) if err != nil { return nil, nil, err } // 检查API Key状态 if !apiKey.IsActive() { return nil, nil, infraerrors.Unauthorized("API_KEY_INACTIVE", "api key is not active") } // 获取用户信息 user, err := s.userRepo.GetByID(ctx, apiKey.UserID) if err != nil { return nil, nil, fmt.Errorf("get user: %w", err) } // 检查用户状态 if !user.IsActive() { return nil, nil, ErrUserNotActive } return apiKey, user, nil } // IncrementUsage 增加API Key使用次数(可选:用于统计) func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error { // 使用Redis计数器 if s.cache != nil { cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02")) if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil { return fmt.Errorf("increment usage: %w", err) } // 设置24小时过期 _ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour) } return nil } // GetAvailableGroups 获取用户有权限绑定的分组列表 // 返回用户可以选择的分组: // - 标准类型分组:公开的(非专属)或用户被明确允许的 // - 订阅类型分组:用户有有效订阅的 func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) { // 获取用户信息 user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return nil, fmt.Errorf("get user: %w", err) } // 获取所有活跃分组 allGroups, err := s.groupRepo.ListActive(ctx) if err != nil { return nil, fmt.Errorf("list active groups: %w", err) } // 获取用户的所有有效订阅 activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID) if err != nil { return nil, fmt.Errorf("list active subscriptions: %w", err) } // 构建订阅分组 ID 集合 subscribedGroupIDs := make(map[int64]bool) for _, sub := range activeSubscriptions { subscribedGroupIDs[sub.GroupID] = true } // 过滤出用户有权限的分组 availableGroups := make([]Group, 0) for _, group := range allGroups { if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) { availableGroups = append(availableGroups, group) } } return availableGroups, nil } // canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据) func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool { // 订阅类型分组:需要有效订阅 if group.IsSubscriptionType() { return subscribedGroupIDs[group.ID] } // 标准类型分组:使用原有逻辑 return user.CanBindGroup(group.ID, group.IsExclusive) } func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) { keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit) if err != nil { return nil, fmt.Errorf("search api keys: %w", err) } return keys, nil } // GetUserGroupRates 获取用户的专属分组倍率配置 // 返回 map[groupID]rateMultiplier func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) { if s.userGroupRateRepo == nil { return nil, nil } rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID) if err != nil { return nil, fmt.Errorf("get user group rates: %w", err) } return rates, nil } // CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted) // Returns nil if valid, error if invalid func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error { // Check expiration if apiKey.IsExpired() { return ErrAPIKeyExpired } // Check quota if apiKey.IsQuotaExhausted() { return ErrAPIKeyQuotaExhausted } return nil } // UpdateQuotaUsed updates the quota_used field after a request // Also checks if quota is exhausted and updates status accordingly func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error { if cost <= 0 { return nil } // Use repository to atomically increment quota_used newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost) if err != nil { return fmt.Errorf("increment quota used: %w", err) } // Check if quota is now exhausted and update status if needed apiKey, err := s.apiKeyRepo.GetByID(ctx, apiKeyID) if err != nil { return nil // Don't fail the request, just log } // If quota is set and now exhausted, update status if apiKey.Quota > 0 && newQuotaUsed >= apiKey.Quota { apiKey.Status = StatusAPIKeyQuotaExhausted if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { return nil // Don't fail the request } // Invalidate cache so next request sees the new status s.InvalidateAuthCacheByKey(ctx, apiKey.Key) } return nil }