First commit
This commit is contained in:
464
backend/internal/service/api_key_service.go
Normal file
464
backend/internal/service/api_key_service.go
Normal file
@@ -0,0 +1,464 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrApiKeyNotFound = errors.New("api key not found")
|
||||
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
|
||||
ErrApiKeyExists = errors.New("api key already exists")
|
||||
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
|
||||
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:create_rate_limit:"
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
apiKeyRateLimitDuration = time.Hour
|
||||
)
|
||||
|
||||
// CreateApiKeyRequest 创建API Key请求
|
||||
type CreateApiKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
}
|
||||
|
||||
// UpdateApiKeyRequest 更新API Key请求
|
||||
type UpdateApiKeyRequest struct {
|
||||
Name *string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status *string `json:"status"`
|
||||
}
|
||||
|
||||
// ApiKeyService API Key服务
|
||||
type ApiKeyService struct {
|
||||
apiKeyRepo *repository.ApiKeyRepository
|
||||
userRepo *repository.UserRepository
|
||||
groupRepo *repository.GroupRepository
|
||||
userSubRepo *repository.UserSubscriptionRepository
|
||||
rdb *redis.Client
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewApiKeyService 创建API Key服务实例
|
||||
func NewApiKeyService(
|
||||
apiKeyRepo *repository.ApiKeyRepository,
|
||||
userRepo *repository.UserRepository,
|
||||
groupRepo *repository.GroupRepository,
|
||||
userSubRepo *repository.UserSubscriptionRepository,
|
||||
rdb *redis.Client,
|
||||
cfg *config.Config,
|
||||
) *ApiKeyService {
|
||||
return &ApiKeyService{
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
rdb: rdb,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateKey 生成随机API Key
|
||||
func (s *ApiKeyService) GenerateKey() (string, error) {
|
||||
// 生成32字节随机数据
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// 转换为十六进制字符串并添加前缀
|
||||
prefix := s.cfg.Default.ApiKeyPrefix
|
||||
if prefix == "" {
|
||||
prefix = "sk-"
|
||||
}
|
||||
|
||||
key := prefix + hex.EncodeToString(bytes)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// ValidateCustomKey 验证自定义API Key格式
|
||||
func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
// 检查长度
|
||||
if len(key) < 16 {
|
||||
return ErrApiKeyTooShort
|
||||
}
|
||||
|
||||
// 检查字符:只允许字母、数字、下划线、连字符
|
||||
for _, c := range key {
|
||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
|
||||
(c >= '0' && c <= '9') || c == '_' || c == '-') {
|
||||
return ErrApiKeyInvalidChars
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.rdb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
|
||||
count, err := s.rdb.Get(ctx, key).Int()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
// Redis 出错时不阻止用户操作
|
||||
return nil
|
||||
}
|
||||
|
||||
if count >= apiKeyMaxErrorsPerHour {
|
||||
return ErrApiKeyRateLimited
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
|
||||
if s.rdb == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
|
||||
pipe := s.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
||||
_, _ = pipe.Exec(ctx)
|
||||
}
|
||||
|
||||
// canUserBindGroup 检查用户是否可以绑定指定分组
|
||||
// 对于订阅类型分组:检查用户是否有有效订阅
|
||||
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
|
||||
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User, group *model.Group) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
|
||||
return err == nil // 有有效订阅则允许
|
||||
}
|
||||
// 标准类型分组:使用原有逻辑
|
||||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||||
}
|
||||
|
||||
// Create 创建API Key
|
||||
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*model.ApiKey, error) {
|
||||
// 验证用户存在
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 验证分组权限(如果指定了分组)
|
||||
if req.GroupID != nil {
|
||||
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
// 检查用户是否可以绑定该分组
|
||||
if !s.canUserBindGroup(ctx, user, group) {
|
||||
return nil, ErrGroupNotAllowed
|
||||
}
|
||||
}
|
||||
|
||||
var key string
|
||||
|
||||
// 判断是否使用自定义Key
|
||||
if req.CustomKey != nil && *req.CustomKey != "" {
|
||||
// 检查限流(仅对自定义key进行限流)
|
||||
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证自定义Key格式
|
||||
if err := s.ValidateCustomKey(*req.CustomKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查Key是否已存在
|
||||
exists, err := s.apiKeyRepo.ExistsByKey(ctx, *req.CustomKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check key exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
// Key已存在,增加错误计数
|
||||
s.incrementApiKeyErrorCount(ctx, userID)
|
||||
return nil, ErrApiKeyExists
|
||||
}
|
||||
|
||||
key = *req.CustomKey
|
||||
} else {
|
||||
// 生成随机API Key
|
||||
var err error
|
||||
key, err = s.GenerateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 创建API Key记录
|
||||
apiKey := &model.ApiKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
Status: model.StatusActive,
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("create api key: %w", err)
|
||||
}
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// List 获取用户的API Key列表
|
||||
func (s *ApiKeyService) List(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.ApiKey, *repository.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||
}
|
||||
return keys, pagination, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取API Key
|
||||
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrApiKeyNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// GetByKey 根据Key字符串获取API Key(用于认证)
|
||||
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
|
||||
// 尝试从Redis缓存获取
|
||||
cacheKey := fmt.Sprintf("apikey:%s", key)
|
||||
|
||||
// 这里可以添加Redis缓存逻辑,暂时直接查询数据库
|
||||
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrApiKeyNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
|
||||
// 缓存到Redis(可选,TTL设置为5分钟)
|
||||
if s.rdb != nil {
|
||||
// 这里可以序列化并缓存API Key
|
||||
_ = cacheKey // 使用变量避免未使用错误
|
||||
}
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// Update 更新API Key
|
||||
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrApiKeyNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
|
||||
// 验证所有权
|
||||
if apiKey.UserID != userID {
|
||||
return nil, ErrInsufficientPerms
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != nil {
|
||||
apiKey.Name = *req.Name
|
||||
}
|
||||
|
||||
if req.GroupID != nil {
|
||||
// 验证分组权限
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
if !s.canUserBindGroup(ctx, user, group) {
|
||||
return nil, ErrGroupNotAllowed
|
||||
}
|
||||
|
||||
apiKey.GroupID = req.GroupID
|
||||
}
|
||||
|
||||
if req.Status != nil {
|
||||
apiKey.Status = *req.Status
|
||||
// 如果状态改变,清除Redis缓存
|
||||
if s.rdb != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
|
||||
_ = s.rdb.Del(ctx, cacheKey)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("update api key: %w", err)
|
||||
}
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// Delete 删除API Key
|
||||
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrApiKeyNotFound
|
||||
}
|
||||
return fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
|
||||
// 验证所有权
|
||||
if apiKey.UserID != userID {
|
||||
return ErrInsufficientPerms
|
||||
}
|
||||
|
||||
// 清除Redis缓存
|
||||
if s.rdb != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
|
||||
_ = s.rdb.Del(ctx, cacheKey)
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete api key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateKey 验证API Key是否有效(用于认证中间件)
|
||||
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.ApiKey, *model.User, error) {
|
||||
// 获取API Key
|
||||
apiKey, err := s.GetByKey(ctx, key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 检查API Key状态
|
||||
if !apiKey.IsActive() {
|
||||
return nil, nil, errors.New("api key is not active")
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, ErrUserNotFound
|
||||
}
|
||||
return nil, nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
return nil, nil, ErrUserNotActive
|
||||
}
|
||||
|
||||
return apiKey, user, nil
|
||||
}
|
||||
|
||||
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
||||
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 使用Redis计数器
|
||||
if s.rdb != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
||||
if err := s.rdb.Incr(ctx, cacheKey).Err(); err != nil {
|
||||
return fmt.Errorf("increment usage: %w", err)
|
||||
}
|
||||
// 设置24小时过期
|
||||
_ = s.rdb.Expire(ctx, cacheKey, 24*time.Hour)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAvailableGroups 获取用户有权限绑定的分组列表
|
||||
// 返回用户可以选择的分组:
|
||||
// - 标准类型分组:公开的(非专属)或用户被明确允许的
|
||||
// - 订阅类型分组:用户有有效订阅的
|
||||
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]model.Group, error) {
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 获取所有活跃分组
|
||||
allGroups, err := s.groupRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active groups: %w", err)
|
||||
}
|
||||
|
||||
// 获取用户的所有有效订阅
|
||||
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("list active subscriptions: %w", err)
|
||||
}
|
||||
|
||||
// 构建订阅分组 ID 集合
|
||||
subscribedGroupIDs := make(map[int64]bool)
|
||||
for _, sub := range activeSubscriptions {
|
||||
subscribedGroupIDs[sub.GroupID] = true
|
||||
}
|
||||
|
||||
// 过滤出用户有权限的分组
|
||||
availableGroups := make([]model.Group, 0)
|
||||
for _, group := range allGroups {
|
||||
if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) {
|
||||
availableGroups = append(availableGroups, group)
|
||||
}
|
||||
}
|
||||
|
||||
return availableGroups, nil
|
||||
}
|
||||
|
||||
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
|
||||
func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.Group, subscribedGroupIDs map[int64]bool) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
return subscribedGroupIDs[group.ID]
|
||||
}
|
||||
// 标准类型分组:使用原有逻辑
|
||||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||||
}
|
||||
Reference in New Issue
Block a user