feat(api-key): 增加 API Key 上次使用时间并补齐测试
This commit is contained in:
@@ -19,6 +19,7 @@ type APIKey struct {
|
||||
Status string
|
||||
IPWhitelist []string
|
||||
IPBlacklist []string
|
||||
LastUsedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
User *User
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -32,6 +34,7 @@ var (
|
||||
|
||||
const (
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
apiKeyLastUsedMinTouch = 30 * time.Second
|
||||
)
|
||||
|
||||
type APIKeyRepository interface {
|
||||
@@ -58,6 +61,7 @@ type APIKeyRepository interface {
|
||||
|
||||
// Quota methods
|
||||
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error)
|
||||
UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error
|
||||
}
|
||||
|
||||
// APIKeyCache defines cache operations for API key service
|
||||
@@ -125,6 +129,8 @@ type APIKeyService struct {
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
lastUsedTouchL1 sync.Map // keyID -> time.Time
|
||||
lastUsedTouchSF singleflight.Group
|
||||
}
|
||||
|
||||
// NewAPIKeyService 创建API Key服务实例
|
||||
@@ -527,6 +533,7 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
||||
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete api key: %w", err)
|
||||
}
|
||||
s.lastUsedTouchL1.Delete(id)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -558,6 +565,37 @@ func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *
|
||||
return apiKey, user, nil
|
||||
}
|
||||
|
||||
// TouchLastUsed 通过防抖更新 api_keys.last_used_at,减少高频写放大。
|
||||
// 该操作为尽力而为,不应阻塞主请求链路。
|
||||
func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error {
|
||||
if keyID <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if v, ok := s.lastUsedTouchL1.Load(keyID); ok {
|
||||
if last, ok := v.(time.Time); ok && now.Sub(last) < apiKeyLastUsedMinTouch {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
_, err, _ := s.lastUsedTouchSF.Do(strconv.FormatInt(keyID, 10), func() (any, error) {
|
||||
latest := time.Now()
|
||||
if v, ok := s.lastUsedTouchL1.Load(keyID); ok {
|
||||
if last, ok := v.(time.Time); ok && latest.Sub(last) < apiKeyLastUsedMinTouch {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.UpdateLastUsed(ctx, keyID, latest); err != nil {
|
||||
return nil, fmt.Errorf("touch api key last used: %w", err)
|
||||
}
|
||||
s.lastUsedTouchL1.Store(keyID, latest)
|
||||
return nil, nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
||||
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 使用Redis计数器
|
||||
|
||||
@@ -103,6 +103,10 @@ func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount
|
||||
panic("unexpected IncrementQuotaUsed call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
panic("unexpected UpdateLastUsed call")
|
||||
}
|
||||
|
||||
type authCacheStub struct {
|
||||
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||
setAuthKeys []string
|
||||
|
||||
@@ -24,10 +24,13 @@ import (
|
||||
// - deleteErr: 模拟 Delete 返回的错误
|
||||
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
|
||||
type apiKeyRepoStub struct {
|
||||
apiKey *APIKey // GetKeyAndOwnerID 的返回值
|
||||
getByIDErr error // GetKeyAndOwnerID 的错误返回值
|
||||
deleteErr error // Delete 的错误返回值
|
||||
deletedIDs []int64 // 记录已删除的 API Key ID 列表
|
||||
apiKey *APIKey // GetKeyAndOwnerID 的返回值
|
||||
getByIDErr error // GetKeyAndOwnerID 的错误返回值
|
||||
deleteErr error // Delete 的错误返回值
|
||||
deletedIDs []int64 // 记录已删除的 API Key ID 列表
|
||||
updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error
|
||||
touchedIDs []int64
|
||||
touchedUsedAts []time.Time
|
||||
}
|
||||
|
||||
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
|
||||
@@ -122,6 +125,15 @@ func (s *apiKeyRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amoun
|
||||
panic("unexpected IncrementQuotaUsed call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
s.touchedIDs = append(s.touchedIDs, id)
|
||||
s.touchedUsedAts = append(s.touchedUsedAts, usedAt)
|
||||
if s.updateLastUsed != nil {
|
||||
return s.updateLastUsed(ctx, id, usedAt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||
//
|
||||
@@ -214,12 +226,15 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc.lastUsedTouchL1.Store(int64(42), time.Now())
|
||||
|
||||
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
|
||||
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
|
||||
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
|
||||
_, exists := svc.lastUsedTouchL1.Load(int64(42))
|
||||
require.False(t, exists, "delete should clear touch debounce cache")
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
||||
|
||||
141
backend/internal/service/api_key_service_touch_last_used_test.go
Normal file
141
backend/internal/service/api_key_service_touch_last_used_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAPIKeyService_TouchLastUsed_InvalidKeyID(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
return errors.New("should not be called")
|
||||
},
|
||||
}
|
||||
svc := &APIKeyService{apiKeyRepo: repo}
|
||||
|
||||
require.NoError(t, svc.TouchLastUsed(context.Background(), 0))
|
||||
require.NoError(t, svc.TouchLastUsed(context.Background(), -1))
|
||||
require.Empty(t, repo.touchedIDs)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_TouchLastUsed_FirstTouchSucceeds(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo}
|
||||
|
||||
err := svc.TouchLastUsed(context.Background(), 123)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{123}, repo.touchedIDs)
|
||||
require.Len(t, repo.touchedUsedAts, 1)
|
||||
require.False(t, repo.touchedUsedAts[0].IsZero())
|
||||
|
||||
cached, ok := svc.lastUsedTouchL1.Load(int64(123))
|
||||
require.True(t, ok, "successful touch should update debounce cache")
|
||||
_, isTime := cached.(time.Time)
|
||||
require.True(t, isTime)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_TouchLastUsed_DebouncedWithinWindow(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo}
|
||||
|
||||
require.NoError(t, svc.TouchLastUsed(context.Background(), 123))
|
||||
require.NoError(t, svc.TouchLastUsed(context.Background(), 123))
|
||||
|
||||
require.Equal(t, []int64{123}, repo.touchedIDs, "second touch within debounce window should not hit repository")
|
||||
}
|
||||
|
||||
func TestAPIKeyService_TouchLastUsed_ExpiredDebounceTouchesAgain(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo}
|
||||
|
||||
require.NoError(t, svc.TouchLastUsed(context.Background(), 123))
|
||||
|
||||
// 强制将 debounce 时间回拨到窗口之外,触发第二次写库。
|
||||
svc.lastUsedTouchL1.Store(int64(123), time.Now().Add(-apiKeyLastUsedMinTouch-time.Second))
|
||||
|
||||
require.NoError(t, svc.TouchLastUsed(context.Background(), 123))
|
||||
require.Len(t, repo.touchedIDs, 2)
|
||||
require.Equal(t, int64(123), repo.touchedIDs[0])
|
||||
require.Equal(t, int64(123), repo.touchedIDs[1])
|
||||
}
|
||||
|
||||
func TestAPIKeyService_TouchLastUsed_RepoError(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
return errors.New("db write failed")
|
||||
},
|
||||
}
|
||||
svc := &APIKeyService{apiKeyRepo: repo}
|
||||
|
||||
err := svc.TouchLastUsed(context.Background(), 123)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "touch api key last used")
|
||||
require.Equal(t, []int64{123}, repo.touchedIDs)
|
||||
|
||||
_, ok := svc.lastUsedTouchL1.Load(int64(123))
|
||||
require.False(t, ok, "failed touch should not update debounce cache")
|
||||
}
|
||||
|
||||
type touchSingleflightRepo struct {
|
||||
*apiKeyRepoStub
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
blockCh chan struct{}
|
||||
}
|
||||
|
||||
func (r *touchSingleflightRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
r.calls++
|
||||
r.mu.Unlock()
|
||||
<-r.blockCh
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAPIKeyService_TouchLastUsed_ConcurrentFirstTouchDeduplicated(t *testing.T) {
|
||||
repo := &touchSingleflightRepo{
|
||||
apiKeyRepoStub: &apiKeyRepoStub{},
|
||||
blockCh: make(chan struct{}),
|
||||
}
|
||||
svc := &APIKeyService{apiKeyRepo: repo}
|
||||
|
||||
const workers = 20
|
||||
startCh := make(chan struct{})
|
||||
errCh := make(chan error, workers)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < workers; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-startCh
|
||||
errCh <- svc.TouchLastUsed(context.Background(), 321)
|
||||
}()
|
||||
}
|
||||
|
||||
close(startCh)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
return repo.calls >= 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
close(repo.blockCh)
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Equal(t, 1, repo.calls, "并发首次 touch 只应写库一次")
|
||||
}
|
||||
Reference in New Issue
Block a user