Files
xinghuoapi/backend/internal/service/api_key_service.go
2025-12-18 13:50:39 +08:00

465 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"time"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
var (
ErrApiKeyNotFound = errors.New("api key not found")
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
ErrApiKeyExists = errors.New("api key already exists")
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
)
const (
apiKeyRateLimitKeyPrefix = "apikey:create_rate_limit:"
apiKeyMaxErrorsPerHour = 20
apiKeyRateLimitDuration = time.Hour
)
// CreateApiKeyRequest 创建API Key请求
type CreateApiKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
}
// UpdateApiKeyRequest 更新API Key请求
type UpdateApiKeyRequest struct {
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
}
// ApiKeyService API Key服务
type ApiKeyService struct {
apiKeyRepo *repository.ApiKeyRepository
userRepo *repository.UserRepository
groupRepo *repository.GroupRepository
userSubRepo *repository.UserSubscriptionRepository
rdb *redis.Client
cfg *config.Config
}
// NewApiKeyService 创建API Key服务实例
func NewApiKeyService(
apiKeyRepo *repository.ApiKeyRepository,
userRepo *repository.UserRepository,
groupRepo *repository.GroupRepository,
userSubRepo *repository.UserSubscriptionRepository,
rdb *redis.Client,
cfg *config.Config,
) *ApiKeyService {
return &ApiKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
rdb: rdb,
cfg: cfg,
}
}
// GenerateKey 生成随机API Key
func (s *ApiKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
// 转换为十六进制字符串并添加前缀
prefix := s.cfg.Default.ApiKeyPrefix
if prefix == "" {
prefix = "sk-"
}
key := prefix + hex.EncodeToString(bytes)
return key, nil
}
// ValidateCustomKey 验证自定义API Key格式
func (s *ApiKeyService) ValidateCustomKey(key string) error {
// 检查长度
if len(key) < 16 {
return ErrApiKeyTooShort
}
// 检查字符:只允许字母、数字、下划线、连字符
for _, c := range key {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') || c == '_' || c == '-') {
return ErrApiKeyInvalidChars
}
}
return nil
}
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
if s.rdb == nil {
return nil
}
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
count, err := s.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) {
// Redis 出错时不阻止用户操作
return nil
}
if count >= apiKeyMaxErrorsPerHour {
return ErrApiKeyRateLimited
}
return nil
}
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
if s.rdb == nil {
return
}
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
pipe := s.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
_, _ = pipe.Exec(ctx)
}
// canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User, group *model.Group) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
return err == nil // 有有效订阅则允许
}
// 标准类型分组:使用原有逻辑
return user.CanBindGroup(group.ID, group.IsExclusive)
}
// Create 创建API Key
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*model.ApiKey, error) {
// 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
// 验证分组权限(如果指定了分组)
if req.GroupID != nil {
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("group not found")
}
return nil, fmt.Errorf("get group: %w", err)
}
// 检查用户是否可以绑定该分组
if !s.canUserBindGroup(ctx, user, group) {
return nil, ErrGroupNotAllowed
}
}
var key string
// 判断是否使用自定义Key
if req.CustomKey != nil && *req.CustomKey != "" {
// 检查限流仅对自定义key进行限流
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
return nil, err
}
// 验证自定义Key格式
if err := s.ValidateCustomKey(*req.CustomKey); err != nil {
return nil, err
}
// 检查Key是否已存在
exists, err := s.apiKeyRepo.ExistsByKey(ctx, *req.CustomKey)
if err != nil {
return nil, fmt.Errorf("check key exists: %w", err)
}
if exists {
// Key已存在增加错误计数
s.incrementApiKeyErrorCount(ctx, userID)
return nil, ErrApiKeyExists
}
key = *req.CustomKey
} else {
// 生成随机API Key
var err error
key, err = s.GenerateKey()
if err != nil {
return nil, fmt.Errorf("generate key: %w", err)
}
}
// 创建API Key记录
apiKey := &model.ApiKey{
UserID: userID,
Key: key,
Name: req.Name,
GroupID: req.GroupID,
Status: model.StatusActive,
}
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
return nil, fmt.Errorf("create api key: %w", err)
}
return apiKey, nil
}
// List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.ApiKey, *repository.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
}
return keys, pagination, nil
}
// GetByID 根据ID获取API Key
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
}
// GetByKey 根据Key字符串获取API Key用于认证
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
// 这里可以添加Redis缓存逻辑暂时直接查询数据库
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err)
}
// 缓存到Redis可选TTL设置为5分钟
if s.rdb != nil {
// 这里可以序列化并缓存API Key
_ = cacheKey // 使用变量避免未使用错误
}
return apiKey, nil
}
// Update 更新API Key
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err)
}
// 验证所有权
if apiKey.UserID != userID {
return nil, ErrInsufficientPerms
}
// 更新字段
if req.Name != nil {
apiKey.Name = *req.Name
}
if req.GroupID != nil {
// 验证分组权限
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("group not found")
}
return nil, fmt.Errorf("get group: %w", err)
}
if !s.canUserBindGroup(ctx, user, group) {
return nil, ErrGroupNotAllowed
}
apiKey.GroupID = req.GroupID
}
if req.Status != nil {
apiKey.Status = *req.Status
// 如果状态改变清除Redis缓存
if s.rdb != nil {
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
_ = s.rdb.Del(ctx, cacheKey)
}
}
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}
return apiKey, nil
}
// Delete 删除API Key
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrApiKeyNotFound
}
return fmt.Errorf("get api key: %w", err)
}
// 验证所有权
if apiKey.UserID != userID {
return ErrInsufficientPerms
}
// 清除Redis缓存
if s.rdb != nil {
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
_ = s.rdb.Del(ctx, cacheKey)
}
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete api key: %w", err)
}
return nil
}
// ValidateKey 验证API Key是否有效用于认证中间件
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.ApiKey, *model.User, error) {
// 获取API Key
apiKey, err := s.GetByKey(ctx, key)
if err != nil {
return nil, nil, err
}
// 检查API Key状态
if !apiKey.IsActive() {
return nil, nil, errors.New("api key is not active")
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, ErrUserNotFound
}
return nil, nil, fmt.Errorf("get user: %w", err)
}
// 检查用户状态
if !user.IsActive() {
return nil, nil, ErrUserNotActive
}
return apiKey, user, nil
}
// IncrementUsage 增加API Key使用次数可选用于统计
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器
if s.rdb != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
if err := s.rdb.Incr(ctx, cacheKey).Err(); err != nil {
return fmt.Errorf("increment usage: %w", err)
}
// 设置24小时过期
_ = s.rdb.Expire(ctx, cacheKey, 24*time.Hour)
}
return nil
}
// GetAvailableGroups 获取用户有权限绑定的分组列表
// 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]model.Group, error) {
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
// 获取所有活跃分组
allGroups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active groups: %w", err)
}
// 获取用户的所有有效订阅
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("list active subscriptions: %w", err)
}
// 构建订阅分组 ID 集合
subscribedGroupIDs := make(map[int64]bool)
for _, sub := range activeSubscriptions {
subscribedGroupIDs[sub.GroupID] = true
}
// 过滤出用户有权限的分组
availableGroups := make([]model.Group, 0)
for _, group := range allGroups {
if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) {
availableGroups = append(availableGroups, group)
}
}
return availableGroups, nil
}
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID]
}
// 标准类型分组:使用原有逻辑
return user.CanBindGroup(group.ID, group.IsExclusive)
}