feat: 品牌重命名 Sub2API -> TianShuAPI
Some checks failed
CI / test (push) Has been cancelled
CI / golangci-lint (push) Has been cancelled

- 前端: 所有界面显示、i18n 文本、组件中的品牌名称
- 后端: 服务层、设置默认值、邮件模板、安装向导
- 数据库: 迁移脚本注释
- 保持功能完全一致,仅更改品牌名称

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
huangzhenpc
2026-01-04 17:50:29 +08:00
parent e27c1acf79
commit d274c8cb14
417 changed files with 112280 additions and 112280 deletions

View File

@@ -1,478 +1,478 @@
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
)
var (
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
)
const (
apiKeyMaxErrorsPerHour = 20
)
type ApiKeyRepository interface {
Create(ctx context.Context, key *ApiKey) error
GetByID(ctx context.Context, id int64) (*ApiKey, error)
// GetOwnerID 仅获取 API Key 的所有者 ID用于删除前的轻量级权限验证
GetOwnerID(ctx context.Context, id int64) (int64, error)
GetByKey(ctx context.Context, key string) (*ApiKey, error)
Update(ctx context.Context, key *ApiKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
}
// ApiKeyCache defines cache operations for API key service
type ApiKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
IncrementDailyUsage(ctx context.Context, apiKey string) error
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
}
// CreateApiKeyRequest 创建API Key请求
type CreateApiKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
}
// UpdateApiKeyRequest 更新API Key请求
type UpdateApiKeyRequest struct {
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
}
// ApiKeyService API Key服务
type ApiKeyService struct {
apiKeyRepo ApiKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
cache ApiKeyCache
cfg *config.Config
}
// NewApiKeyService 创建API Key服务实例
func NewApiKeyService(
apiKeyRepo ApiKeyRepository,
userRepo UserRepository,
groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository,
cache ApiKeyCache,
cfg *config.Config,
) *ApiKeyService {
return &ApiKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
}
}
// GenerateKey 生成随机API Key
func (s *ApiKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
// 转换为十六进制字符串并添加前缀
prefix := s.cfg.Default.ApiKeyPrefix
if prefix == "" {
prefix = "sk-"
}
key := prefix + hex.EncodeToString(bytes)
return key, nil
}
// ValidateCustomKey 验证自定义API Key格式
func (s *ApiKeyService) ValidateCustomKey(key string) error {
// 检查长度
if len(key) < 16 {
return ErrApiKeyTooShort
}
// 检查字符:只允许字母、数字、下划线、连字符
for _, c := range key {
if (c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') ||
c == '_' || c == '-' {
continue
}
return ErrApiKeyInvalidChars
}
return nil
}
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil {
return nil
}
count, err := s.cache.GetCreateAttemptCount(ctx, userID)
if err != nil {
// Redis 出错时不阻止用户操作
return nil
}
if count >= apiKeyMaxErrorsPerHour {
return ErrApiKeyRateLimited
}
return nil
}
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
if s.cache == nil {
return
}
_ = s.cache.IncrementCreateAttemptCount(ctx, userID)
}
// canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
return err == nil // 有有效订阅则允许
}
// 标准类型分组:使用原有逻辑
return user.CanBindGroup(group.ID, group.IsExclusive)
}
// Create 创建API Key
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
// 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
// 验证分组权限(如果指定了分组)
if req.GroupID != nil {
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
// 检查用户是否可以绑定该分组
if !s.canUserBindGroup(ctx, user, group) {
return nil, ErrGroupNotAllowed
}
}
var key string
// 判断是否使用自定义Key
if req.CustomKey != nil && *req.CustomKey != "" {
// 检查限流仅对自定义key进行限流
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
return nil, err
}
// 验证自定义Key格式
if err := s.ValidateCustomKey(*req.CustomKey); err != nil {
return nil, err
}
// 检查Key是否已存在
exists, err := s.apiKeyRepo.ExistsByKey(ctx, *req.CustomKey)
if err != nil {
return nil, fmt.Errorf("check key exists: %w", err)
}
if exists {
// Key已存在增加错误计数
s.incrementApiKeyErrorCount(ctx, userID)
return nil, ErrApiKeyExists
}
key = *req.CustomKey
} else {
// 生成随机API Key
var err error
key, err = s.GenerateKey()
if err != nil {
return nil, fmt.Errorf("generate key: %w", err)
}
}
// 创建API Key记录
apiKey := &ApiKey{
UserID: userID,
Key: key,
Name: req.Name,
GroupID: req.GroupID,
Status: StatusActive,
}
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
return nil, fmt.Errorf("create api key: %w", err)
}
return apiKey, nil
}
// List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
}
return keys, pagination, nil
}
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 {
return []int64{}, nil
}
validIDs, err := s.apiKeyRepo.VerifyOwnership(ctx, userID, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("verify api key ownership: %w", err)
}
return validIDs, nil
}
// GetByID 根据ID获取API Key
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
}
// GetByKey 根据Key字符串获取API Key用于认证
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
// 这里可以添加Redis缓存逻辑暂时直接查询数据库
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
// 缓存到Redis可选TTL设置为5分钟
if s.cache != nil {
// 这里可以序列化并缓存API Key
_ = cacheKey // 使用变量避免未使用错误
}
return apiKey, nil
}
// Update 更新API Key
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
// 验证所有权
if apiKey.UserID != userID {
return nil, ErrInsufficientPerms
}
// 更新字段
if req.Name != nil {
apiKey.Name = *req.Name
}
if req.GroupID != nil {
// 验证分组权限
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
if !s.canUserBindGroup(ctx, user, group) {
return nil, ErrGroupNotAllowed
}
apiKey.GroupID = req.GroupID
}
if req.Status != nil {
apiKey.Status = *req.Status
// 如果状态改变清除Redis缓存
if s.cache != nil {
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
}
}
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}
return apiKey, nil
}
// Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 ApiKey 对象及其关联数据User、Group提升删除操作的性能
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
if err != nil {
return fmt.Errorf("get api key: %w", err)
}
// 验证当前用户是否为该 API Key 的所有者
if ownerID != userID {
return ErrInsufficientPerms
}
// 清除Redis缓存使用 ownerID 而非 apiKey.UserID
if s.cache != nil {
_ = s.cache.DeleteCreateAttemptCount(ctx, ownerID)
}
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete api key: %w", err)
}
return nil
}
// ValidateKey 验证API Key是否有效用于认证中间件
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
// 获取API Key
apiKey, err := s.GetByKey(ctx, key)
if err != nil {
return nil, nil, err
}
// 检查API Key状态
if !apiKey.IsActive() {
return nil, nil, infraerrors.Unauthorized("API_KEY_INACTIVE", "api key is not active")
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
if err != nil {
return nil, nil, fmt.Errorf("get user: %w", err)
}
// 检查用户状态
if !user.IsActive() {
return nil, nil, ErrUserNotActive
}
return apiKey, user, nil
}
// IncrementUsage 增加API Key使用次数可选用于统计
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器
if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil {
return fmt.Errorf("increment usage: %w", err)
}
// 设置24小时过期
_ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour)
}
return nil
}
// GetAvailableGroups 获取用户有权限绑定的分组列表
// 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
// 获取所有活跃分组
allGroups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active groups: %w", err)
}
// 获取用户的所有有效订阅
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("list active subscriptions: %w", err)
}
// 构建订阅分组 ID 集合
subscribedGroupIDs := make(map[int64]bool)
for _, sub := range activeSubscriptions {
subscribedGroupIDs[sub.GroupID] = true
}
// 过滤出用户有权限的分组
availableGroups := make([]Group, 0)
for _, group := range allGroups {
if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) {
availableGroups = append(availableGroups, group)
}
}
return availableGroups, nil
}
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID]
}
// 标准类型分组:使用原有逻辑
return user.CanBindGroup(group.ID, group.IsExclusive)
}
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
if err != nil {
return nil, fmt.Errorf("search api keys: %w", err)
}
return keys, nil
}
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
)
var (
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
)
const (
apiKeyMaxErrorsPerHour = 20
)
type ApiKeyRepository interface {
Create(ctx context.Context, key *ApiKey) error
GetByID(ctx context.Context, id int64) (*ApiKey, error)
// GetOwnerID 仅获取 API Key 的所有者 ID用于删除前的轻量级权限验证
GetOwnerID(ctx context.Context, id int64) (int64, error)
GetByKey(ctx context.Context, key string) (*ApiKey, error)
Update(ctx context.Context, key *ApiKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
}
// ApiKeyCache defines cache operations for API key service
type ApiKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
IncrementDailyUsage(ctx context.Context, apiKey string) error
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
}
// CreateApiKeyRequest 创建API Key请求
type CreateApiKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
}
// UpdateApiKeyRequest 更新API Key请求
type UpdateApiKeyRequest struct {
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
}
// ApiKeyService API Key服务
type ApiKeyService struct {
apiKeyRepo ApiKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
cache ApiKeyCache
cfg *config.Config
}
// NewApiKeyService 创建API Key服务实例
func NewApiKeyService(
apiKeyRepo ApiKeyRepository,
userRepo UserRepository,
groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository,
cache ApiKeyCache,
cfg *config.Config,
) *ApiKeyService {
return &ApiKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
}
}
// GenerateKey 生成随机API Key
func (s *ApiKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
// 转换为十六进制字符串并添加前缀
prefix := s.cfg.Default.ApiKeyPrefix
if prefix == "" {
prefix = "sk-"
}
key := prefix + hex.EncodeToString(bytes)
return key, nil
}
// ValidateCustomKey 验证自定义API Key格式
func (s *ApiKeyService) ValidateCustomKey(key string) error {
// 检查长度
if len(key) < 16 {
return ErrApiKeyTooShort
}
// 检查字符:只允许字母、数字、下划线、连字符
for _, c := range key {
if (c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') ||
c == '_' || c == '-' {
continue
}
return ErrApiKeyInvalidChars
}
return nil
}
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil {
return nil
}
count, err := s.cache.GetCreateAttemptCount(ctx, userID)
if err != nil {
// Redis 出错时不阻止用户操作
return nil
}
if count >= apiKeyMaxErrorsPerHour {
return ErrApiKeyRateLimited
}
return nil
}
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
if s.cache == nil {
return
}
_ = s.cache.IncrementCreateAttemptCount(ctx, userID)
}
// canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
return err == nil // 有有效订阅则允许
}
// 标准类型分组:使用原有逻辑
return user.CanBindGroup(group.ID, group.IsExclusive)
}
// Create 创建API Key
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
// 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
// 验证分组权限(如果指定了分组)
if req.GroupID != nil {
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
// 检查用户是否可以绑定该分组
if !s.canUserBindGroup(ctx, user, group) {
return nil, ErrGroupNotAllowed
}
}
var key string
// 判断是否使用自定义Key
if req.CustomKey != nil && *req.CustomKey != "" {
// 检查限流仅对自定义key进行限流
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
return nil, err
}
// 验证自定义Key格式
if err := s.ValidateCustomKey(*req.CustomKey); err != nil {
return nil, err
}
// 检查Key是否已存在
exists, err := s.apiKeyRepo.ExistsByKey(ctx, *req.CustomKey)
if err != nil {
return nil, fmt.Errorf("check key exists: %w", err)
}
if exists {
// Key已存在增加错误计数
s.incrementApiKeyErrorCount(ctx, userID)
return nil, ErrApiKeyExists
}
key = *req.CustomKey
} else {
// 生成随机API Key
var err error
key, err = s.GenerateKey()
if err != nil {
return nil, fmt.Errorf("generate key: %w", err)
}
}
// 创建API Key记录
apiKey := &ApiKey{
UserID: userID,
Key: key,
Name: req.Name,
GroupID: req.GroupID,
Status: StatusActive,
}
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
return nil, fmt.Errorf("create api key: %w", err)
}
return apiKey, nil
}
// List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
}
return keys, pagination, nil
}
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 {
return []int64{}, nil
}
validIDs, err := s.apiKeyRepo.VerifyOwnership(ctx, userID, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("verify api key ownership: %w", err)
}
return validIDs, nil
}
// GetByID 根据ID获取API Key
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
}
// GetByKey 根据Key字符串获取API Key用于认证
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
// 这里可以添加Redis缓存逻辑暂时直接查询数据库
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
// 缓存到Redis可选TTL设置为5分钟
if s.cache != nil {
// 这里可以序列化并缓存API Key
_ = cacheKey // 使用变量避免未使用错误
}
return apiKey, nil
}
// Update 更新API Key
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
// 验证所有权
if apiKey.UserID != userID {
return nil, ErrInsufficientPerms
}
// 更新字段
if req.Name != nil {
apiKey.Name = *req.Name
}
if req.GroupID != nil {
// 验证分组权限
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
if !s.canUserBindGroup(ctx, user, group) {
return nil, ErrGroupNotAllowed
}
apiKey.GroupID = req.GroupID
}
if req.Status != nil {
apiKey.Status = *req.Status
// 如果状态改变清除Redis缓存
if s.cache != nil {
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
}
}
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}
return apiKey, nil
}
// Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 ApiKey 对象及其关联数据User、Group提升删除操作的性能
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
if err != nil {
return fmt.Errorf("get api key: %w", err)
}
// 验证当前用户是否为该 API Key 的所有者
if ownerID != userID {
return ErrInsufficientPerms
}
// 清除Redis缓存使用 ownerID 而非 apiKey.UserID
if s.cache != nil {
_ = s.cache.DeleteCreateAttemptCount(ctx, ownerID)
}
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete api key: %w", err)
}
return nil
}
// ValidateKey 验证API Key是否有效用于认证中间件
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
// 获取API Key
apiKey, err := s.GetByKey(ctx, key)
if err != nil {
return nil, nil, err
}
// 检查API Key状态
if !apiKey.IsActive() {
return nil, nil, infraerrors.Unauthorized("API_KEY_INACTIVE", "api key is not active")
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
if err != nil {
return nil, nil, fmt.Errorf("get user: %w", err)
}
// 检查用户状态
if !user.IsActive() {
return nil, nil, ErrUserNotActive
}
return apiKey, user, nil
}
// IncrementUsage 增加API Key使用次数可选用于统计
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器
if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil {
return fmt.Errorf("increment usage: %w", err)
}
// 设置24小时过期
_ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour)
}
return nil
}
// GetAvailableGroups 获取用户有权限绑定的分组列表
// 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
// 获取所有活跃分组
allGroups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active groups: %w", err)
}
// 获取用户的所有有效订阅
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("list active subscriptions: %w", err)
}
// 构建订阅分组 ID 集合
subscribedGroupIDs := make(map[int64]bool)
for _, sub := range activeSubscriptions {
subscribedGroupIDs[sub.GroupID] = true
}
// 过滤出用户有权限的分组
availableGroups := make([]Group, 0)
for _, group := range allGroups {
if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) {
availableGroups = append(availableGroups, group)
}
}
return availableGroups, nil
}
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID]
}
// 标准类型分组:使用原有逻辑
return user.CanBindGroup(group.ID, group.IsExclusive)
}
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
if err != nil {
return nil, fmt.Errorf("search api keys: %w", err)
}
return keys, nil
}