Files
yinghuoapi/backend/internal/repository/api_key_repo.go
yangjianbo 9d0a4f3d68 perf(认证): 引入 API Key 认证缓存与轻量删除查询
增加 L1/L2 缓存、负缓存与单飞回源
使用 key+owner 轻量查询替代全量加载并清理旧接口
补充缓存失效与余额更新测试,修复随机抖动 lint

测试: make test
2026-01-10 22:23:51 +08:00

436 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
type apiKeyRepository struct {
client *dbent.Client
}
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
return &apiKeyRepository{client: client}
}
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
// 默认过滤已软删除记录,避免删除后仍被查询到。
return r.client.APIKey.Query().Where(apikey.DeletedAtIsNil())
}
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
builder := r.client.APIKey.Create().
SetUserID(key.UserID).
SetKey(key.Key).
SetName(key.Name).
SetStatus(key.Status).
SetNillableGroupID(key.GroupID)
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
}
if len(key.IPBlacklist) > 0 {
builder.SetIPBlacklist(key.IPBlacklist)
}
created, err := builder.Save(ctx)
if err == nil {
key.ID = created.ID
key.CreatedAt = created.CreatedAt
key.UpdatedAt = created.UpdatedAt
}
return translatePersistenceError(err, nil, service.ErrAPIKeyExists)
}
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
WithUser().
WithGroup().
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return apiKeyEntityToService(m), nil
}
// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者用户ID。
// 相比 GetByID此方法性能更优因为
// - 使用 Select() 只查询必要字段,减少数据传输量
// - 不加载完整的 API Key 实体及其关联数据User、Group 等)
// - 适用于删除等只需 key 与用户 ID 的场景
func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
Select(apikey.FieldKey, apikey.FieldUserID).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return "", 0, service.ErrAPIKeyNotFound
}
return "", 0, err
}
return m.Key, m.UserID, nil
}
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.KeyEQ(key)).
WithUser().
WithGroup().
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.KeyEQ(key)).
Select(
apikey.FieldID,
apikey.FieldUserID,
apikey.FieldGroupID,
apikey.FieldStatus,
apikey.FieldIPWhitelist,
apikey.FieldIPBlacklist,
).
WithUser(func(q *dbent.UserQuery) {
q.Select(
user.FieldID,
user.FieldStatus,
user.FieldRole,
user.FieldBalance,
user.FieldConcurrency,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
q.Select(
group.FieldID,
group.FieldName,
group.FieldPlatform,
group.FieldStatus,
group.FieldSubscriptionType,
group.FieldRateMultiplier,
group.FieldDailyLimitUsd,
group.FieldWeeklyLimitUsd,
group.FieldMonthlyLimitUsd,
group.FieldImagePrice1k,
group.FieldImagePrice2k,
group.FieldImagePrice4k,
group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID,
)
}).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
// 之前的实现先检查 Exist 再 UpdateOneID若在两步之间发生软删除
// 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at避免二次查询带来的并发可见性问题。
now := time.Now()
builder := r.client.APIKey.Update().
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name).
SetStatus(key.Status).
SetUpdatedAt(now)
if key.GroupID != nil {
builder.SetGroupID(*key.GroupID)
} else {
builder.ClearGroupID()
}
// IP 限制字段
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
} else {
builder.ClearIPWhitelist()
}
if len(key.IPBlacklist) > 0 {
builder.SetIPBlacklist(key.IPBlacklist)
} else {
builder.ClearIPBlacklist()
}
affected, err := builder.Save(ctx)
if err != nil {
return err
}
if affected == 0 {
// 更新影响行数为 0说明记录不存在或已被软删除。
return service.ErrAPIKeyNotFound
}
// 使用同一时间戳回填,避免并发删除导致二次查询失败。
key.UpdatedAt = now
return nil
}
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
affected, err := r.client.APIKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetDeletedAt(time.Now()).
Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrAPIKeyNotFound
}
return err
}
if affected == 0 {
exists, err := r.client.APIKey.Query().
Where(apikey.IDEQ(id)).
Exist(mixins.SkipSoftDelete(ctx))
if err != nil {
return err
}
if exists {
return nil
}
return service.ErrAPIKeyNotFound
}
return nil
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
keys, err := q.
WithGroup().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(apikey.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
return outKeys, paginationResultFromTotal(int64(total), params), nil
}
func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 {
return []int64{}, nil
}
ids, err := r.client.APIKey.Query().
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
IDs(ctx)
if err != nil {
return nil, err
}
return ids, nil
}
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
count, err := r.activeQuery().Where(apikey.UserIDEQ(userID)).Count(ctx)
return int64(count), err
}
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
count, err := r.activeQuery().Where(apikey.KeyEQ(key)).Count(ctx)
return count > 0, err
}
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
keys, err := q.
WithUser().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(apikey.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
return outKeys, paginationResultFromTotal(int64(total), params), nil
}
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
q := r.activeQuery()
if userID > 0 {
q = q.Where(apikey.UserIDEQ(userID))
}
if keyword != "" {
q = q.Where(apikey.NameContainsFold(keyword))
}
keys, err := q.Limit(limit).Order(dbent.Desc(apikey.FieldID)).All(ctx)
if err != nil {
return nil, err
}
outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
return outKeys, nil
}
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
n, err := r.client.APIKey.Update().
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
ClearGroupID().
Save(ctx)
return int64(n), err
}
// CountByGroupID 获取分组的 API Key 数量
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
count, err := r.activeQuery().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
return int64(count), err
}
func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
keys, err := r.activeQuery().
Where(apikey.UserIDEQ(userID)).
Select(apikey.FieldKey).
Strings(ctx)
if err != nil {
return nil, err
}
return keys, nil
}
func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
keys, err := r.activeQuery().
Where(apikey.GroupIDEQ(groupID)).
Select(apikey.FieldKey).
Strings(ctx)
if err != nil {
return nil, err
}
return keys, nil
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
if m == nil {
return nil
}
out := &service.APIKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
Status: m.Status,
IPWhitelist: m.IPWhitelist,
IPBlacklist: m.IPBlacklist,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
}
if m.Edges.Group != nil {
out.Group = groupEntityToService(m.Edges.Group)
}
return out
}
func userEntityToService(u *dbent.User) *service.User {
if u == nil {
return nil
}
return &service.User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
func groupEntityToService(g *dbent.Group) *service.Group {
if g == nil {
return nil
}
return &service.Group{
ID: g.ID,
Name: g.Name,
Description: derefString(g.Description),
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
Hydrated: true,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
}
func derefString(s *string) string {
if s == nil {
return ""
}
return *s
}