perf(认证): 引入 API Key 认证缓存与轻量删除查询

增加 L1/L2 缓存、负缓存与单飞回源
使用 key+owner 轻量查询替代全量加载并清理旧接口
补充缓存失效与余额更新测试,修复随机抖动 lint

测试: make test
This commit is contained in:
yangjianbo
2026-01-10 22:23:51 +08:00
parent e79dbad602
commit 9d0a4f3d68
22 changed files with 1360 additions and 99 deletions

View File

@@ -57,13 +57,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewAPIKeyRepository(client)
groupRepository := repository.NewGroupRepository(client, db)
apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client)
@@ -79,7 +80,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
accountRepository := repository.NewAccountRepository(client, db)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, apiKeyAuthCacheInvalidator)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient()

View File

@@ -44,11 +44,13 @@ require (
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgraph-io/ristretto v0.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.5.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/ebitengine/purego v0.8.4 // indirect
github.com/fatih/color v1.18.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect

View File

@@ -51,6 +51,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE=
github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
@@ -61,6 +63,8 @@ github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pM
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=

View File

@@ -49,6 +49,7 @@ type Config struct {
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
@@ -361,6 +362,16 @@ type RateLimitConfig struct {
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
}
// APIKeyAuthCacheConfig API Key 认证缓存配置
type APIKeyAuthCacheConfig struct {
L1Size int `mapstructure:"l1_size"`
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
L2TTLSeconds int `mapstructure:"l2_ttl_seconds"`
NegativeTTLSeconds int `mapstructure:"negative_ttl_seconds"`
JitterPercent int `mapstructure:"jitter_percent"`
Singleflight bool `mapstructure:"singleflight"`
}
func NormalizeRunMode(value string) string {
normalized := strings.ToLower(strings.TrimSpace(value))
switch normalized {
@@ -655,6 +666,14 @@ func setDefaults() {
// Timezone (default to Asia/Shanghai for Chinese users)
viper.SetDefault("timezone", "Asia/Shanghai")
// API Key auth cache
viper.SetDefault("api_key_auth_cache.l1_size", 65535)
viper.SetDefault("api_key_auth_cache.l1_ttl_seconds", 15)
viper.SetDefault("api_key_auth_cache.l2_ttl_seconds", 300)
viper.SetDefault("api_key_auth_cache.negative_ttl_seconds", 30)
viper.SetDefault("api_key_auth_cache.jitter_percent", 10)
viper.SetDefault("api_key_auth_cache.singleflight", true)
// Gateway
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", false)

View File

@@ -2,6 +2,7 @@ package repository
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
@@ -13,6 +14,7 @@ import (
const (
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
apiKeyRateLimitDuration = 24 * time.Hour
apiKeyAuthCachePrefix = "apikey:auth:"
)
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
@@ -20,6 +22,10 @@ func apiKeyRateLimitKey(userID int64) string {
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
}
func apiKeyAuthCacheKey(key string) string {
return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key)
}
type apiKeyCache struct {
rdb *redis.Client
}
@@ -58,3 +64,30 @@ func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) er
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return c.rdb.Expire(ctx, apiKey, ttl).Err()
}
func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes()
if err != nil {
return nil, err
}
var entry service.APIKeyAuthCacheEntry
if err := json.Unmarshal(val, &entry); err != nil {
return nil, err
}
return &entry, nil
}
func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
if entry == nil {
return nil
}
payload, err := json.Marshal(entry)
if err != nil {
return err
}
return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err()
}
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
}

View File

@@ -6,7 +6,9 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -64,23 +66,23 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK
return apiKeyEntityToService(m), nil
}
// GetOwnerID 根据 API Key ID 获取其所有者(用户)ID。
// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者用户ID。
// 相比 GetByID此方法性能更优因为
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
// - 使用 Select() 只查询必要字段,减少数据传输量
// - 不加载完整的 API Key 实体及其关联数据User、Group 等)
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
// - 适用于删除等只需 key 与用户 ID 的场景
func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
Select(apikey.FieldUserID).
Select(apikey.FieldKey, apikey.FieldUserID).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, service.ErrAPIKeyNotFound
return "", 0, service.ErrAPIKeyNotFound
}
return 0, err
return "", 0, err
}
return m.UserID, nil
return m.Key, m.UserID, nil
}
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
@@ -98,6 +100,54 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.KeyEQ(key)).
Select(
apikey.FieldID,
apikey.FieldUserID,
apikey.FieldGroupID,
apikey.FieldStatus,
apikey.FieldIPWhitelist,
apikey.FieldIPBlacklist,
).
WithUser(func(q *dbent.UserQuery) {
q.Select(
user.FieldID,
user.FieldStatus,
user.FieldRole,
user.FieldBalance,
user.FieldConcurrency,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
q.Select(
group.FieldID,
group.FieldName,
group.FieldPlatform,
group.FieldStatus,
group.FieldSubscriptionType,
group.FieldRateMultiplier,
group.FieldDailyLimitUsd,
group.FieldWeeklyLimitUsd,
group.FieldMonthlyLimitUsd,
group.FieldImagePrice1k,
group.FieldImagePrice2k,
group.FieldImagePrice4k,
group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID,
)
}).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
// 之前的实现先检查 Exist 再 UpdateOneID若在两步之间发生软删除
@@ -283,6 +333,28 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
return int64(count), err
}
func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
keys, err := r.activeQuery().
Where(apikey.UserIDEQ(userID)).
Select(apikey.FieldKey).
Strings(ctx)
if err != nil {
return nil, err
}
return keys, nil
}
func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
keys, err := r.activeQuery().
Where(apikey.GroupIDEQ(groupID)).
Select(apikey.FieldKey).
Strings(ctx)
if err != nil {
return nil, err
}
return keys, nil
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
if m == nil {
return nil

View File

@@ -389,7 +389,7 @@ func newContractDeps(t *testing.T) *contractDeps {
RunMode: config.RunModeStandard,
}
userService := service.NewUserService(userRepo)
userService := service.NewUserService(userRepo, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
@@ -565,6 +565,18 @@ func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, t
return nil
}
func (stubApiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
return nil, nil
}
func (stubApiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
return nil
}
func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return nil
}
type stubGroupRepo struct{}
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
@@ -737,12 +749,12 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey
return &clone, nil
}
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
key, ok := r.byID[id]
if !ok {
return 0, service.ErrAPIKeyNotFound
return "", 0, service.ErrAPIKeyNotFound
}
return key.UserID, nil
return key.Key, key.UserID, nil
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
@@ -754,6 +766,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API
return &clone, nil
}
func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
return r.GetByKey(ctx, key)
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
if key == nil {
return errors.New("nil key")
@@ -868,6 +884,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
type stubUsageLogRepo struct {
userLogs map[int64][]service.UsageLog
}

View File

@@ -27,8 +27,8 @@ func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
func (f fakeAPIKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
return "", 0, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
if f.getByKey == nil {
@@ -36,6 +36,9 @@ func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIK
}
return f.getByKey(ctx, key)
}
func (f fakeAPIKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
return f.GetByKey(ctx, key)
}
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
@@ -66,6 +69,12 @@ func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
type googleErrorResponse struct {
Error struct {

View File

@@ -256,8 +256,8 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
return "", 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
@@ -267,6 +267,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
return r.GetByKey(ctx, key)
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
@@ -307,6 +311,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
type stubUserSubscriptionRepo struct {
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
updateStatus func(ctx context.Context, subscriptionID int64, status string) error

View File

@@ -252,6 +252,7 @@ type adminServiceImpl struct {
redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewAdminService creates a new AdminService
@@ -264,6 +265,7 @@ func NewAdminService(
redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
@@ -274,6 +276,7 @@ func NewAdminService(
redeemCodeRepo: redeemCodeRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -323,6 +326,8 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
}
oldConcurrency := user.Concurrency
oldStatus := user.Status
oldRole := user.Role
if input.Email != "" {
user.Email = input.Email
@@ -355,6 +360,11 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
if s.authCacheInvalidator != nil {
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
}
}
concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 {
@@ -393,6 +403,9 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
log.Printf("delete user failed: user_id=%d err=%v", id, err)
return err
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id)
}
return nil
}
@@ -420,6 +433,10 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
balanceDiff := user.Balance - oldBalance
if s.authCacheInvalidator != nil && balanceDiff != 0 {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService != nil {
go func() {
@@ -431,7 +448,6 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
}()
}
balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 {
code, err := GenerateRedeemCode()
if err != nil {
@@ -675,10 +691,21 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
return group, nil
}
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
var groupKeys []string
if s.authCacheInvalidator != nil {
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id)
if err == nil {
groupKeys = keys
}
}
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
if err != nil {
return err
@@ -697,6 +724,11 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
}
}()
}
if s.authCacheInvalidator != nil {
for _, key := range groupKeys {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key)
}
}
return nil
}

View File

@@ -0,0 +1,97 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
type balanceUserRepoStub struct {
*userRepoStub
updateErr error
updated []*User
}
func (s *balanceUserRepoStub) Update(ctx context.Context, user *User) error {
if s.updateErr != nil {
return s.updateErr
}
if user == nil {
return nil
}
clone := *user
s.updated = append(s.updated, &clone)
if s.userRepoStub != nil {
s.userRepoStub.user = &clone
}
return nil
}
type balanceRedeemRepoStub struct {
*redeemRepoStub
created []*RedeemCode
}
func (s *balanceRedeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
if code == nil {
return nil
}
clone := *code
s.created = append(s.created, &clone)
return nil
}
type authCacheInvalidatorStub struct {
userIDs []int64
groupIDs []int64
keys []string
}
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByKey(ctx context.Context, key string) {
s.keys = append(s.keys, key)
}
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
s.userIDs = append(s.userIDs, userID)
}
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
s.groupIDs = append(s.groupIDs, groupID)
}
func TestAdminService_UpdateUserBalance_InvalidatesAuthCache(t *testing.T) {
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: redeemRepo,
authCacheInvalidator: invalidator,
}
_, err := svc.UpdateUserBalance(context.Background(), 7, 5, "add", "")
require.NoError(t, err)
require.Equal(t, []int64{7}, invalidator.userIDs)
require.Len(t, redeemRepo.created, 1)
}
func TestAdminService_UpdateUserBalance_NoChangeNoInvalidate(t *testing.T) {
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: redeemRepo,
authCacheInvalidator: invalidator,
}
_, err := svc.UpdateUserBalance(context.Background(), 7, 10, "set", "")
require.NoError(t, err)
require.Empty(t, invalidator.userIDs)
require.Empty(t, redeemRepo.created)
}

View File

@@ -0,0 +1,46 @@
package service
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type APIKeyAuthSnapshot struct {
APIKeyID int64 `json:"api_key_id"`
UserID int64 `json:"user_id"`
GroupID *int64 `json:"group_id,omitempty"`
Status string `json:"status"`
IPWhitelist []string `json:"ip_whitelist,omitempty"`
IPBlacklist []string `json:"ip_blacklist,omitempty"`
User APIKeyAuthUserSnapshot `json:"user"`
Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"`
}
// APIKeyAuthUserSnapshot 用户快照
type APIKeyAuthUserSnapshot struct {
ID int64 `json:"id"`
Status string `json:"status"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
}
// APIKeyAuthGroupSnapshot 分组快照
type APIKeyAuthGroupSnapshot struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
Status string `json:"status"`
SubscriptionType string `json:"subscription_type"`
RateMultiplier float64 `json:"rate_multiplier"`
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
type APIKeyAuthCacheEntry struct {
NotFound bool `json:"not_found"`
Snapshot *APIKeyAuthSnapshot `json:"snapshot,omitempty"`
}

View File

@@ -0,0 +1,269 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"math/rand"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/dgraph-io/ristretto"
)
type apiKeyAuthCacheConfig struct {
l1Size int
l1TTL time.Duration
l2TTL time.Duration
negativeTTL time.Duration
jitterPercent int
singleflight bool
}
var (
jitterRandMu sync.Mutex
// 认证缓存抖动使用独立随机源,避免全局 Seed
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
)
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
if cfg == nil {
return apiKeyAuthCacheConfig{}
}
auth := cfg.APIKeyAuth
return apiKeyAuthCacheConfig{
l1Size: auth.L1Size,
l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second,
l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second,
negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second,
jitterPercent: auth.JitterPercent,
singleflight: auth.Singleflight,
}
}
func (c apiKeyAuthCacheConfig) l1Enabled() bool {
return c.l1Size > 0 && c.l1TTL > 0
}
func (c apiKeyAuthCacheConfig) l2Enabled() bool {
return c.l2TTL > 0
}
func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
return c.negativeTTL > 0
}
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
if ttl <= 0 {
return ttl
}
if c.jitterPercent <= 0 {
return ttl
}
percent := c.jitterPercent
if percent > 100 {
percent = 100
}
delta := float64(percent) / 100
jitterRandMu.Lock()
randVal := jitterRand.Float64()
jitterRandMu.Unlock()
factor := 1 - delta + randVal*(2*delta)
if factor <= 0 {
return ttl
}
return time.Duration(float64(ttl) * factor)
}
func (s *APIKeyService) initAuthCache(cfg *config.Config) {
s.authCfg = newAPIKeyAuthCacheConfig(cfg)
if !s.authCfg.l1Enabled() {
return
}
cache, err := ristretto.NewCache(&ristretto.Config{
NumCounters: int64(s.authCfg.l1Size) * 10,
MaxCost: int64(s.authCfg.l1Size),
BufferItems: 64,
})
if err != nil {
return
}
s.authCacheL1 = cache
}
func (s *APIKeyService) authCacheKey(key string) string {
sum := sha256.Sum256([]byte(key))
return hex.EncodeToString(sum[:])
}
func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) {
if s.authCacheL1 != nil {
if val, ok := s.authCacheL1.Get(cacheKey); ok {
if entry, ok := val.(*APIKeyAuthCacheEntry); ok {
return entry, true
}
}
}
if s.cache == nil || !s.authCfg.l2Enabled() {
return nil, false
}
entry, err := s.cache.GetAuthCache(ctx, cacheKey)
if err != nil {
return nil, false
}
s.setAuthCacheL1(cacheKey, entry)
return entry, true
}
func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) {
if s.authCacheL1 == nil || entry == nil {
return
}
ttl := s.authCfg.l1TTL
if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl {
ttl = s.authCfg.negativeTTL
}
ttl = s.authCfg.jitterTTL(ttl)
_ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl)
}
func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) {
if entry == nil {
return
}
s.setAuthCacheL1(cacheKey, entry)
if s.cache == nil || !s.authCfg.l2Enabled() {
return
}
_ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl))
}
func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
if s.authCacheL1 != nil {
s.authCacheL1.Del(cacheKey)
}
if s.cache == nil {
return
}
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
}
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
if err != nil {
if errors.Is(err, ErrAPIKeyNotFound) {
entry := &APIKeyAuthCacheEntry{NotFound: true}
if s.authCfg.negativeEnabled() {
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL)
}
return entry, nil
}
return nil, fmt.Errorf("get api key: %w", err)
}
apiKey.Key = key
snapshot := s.snapshotFromAPIKey(apiKey)
if snapshot == nil {
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
}
entry := &APIKeyAuthCacheEntry{Snapshot: snapshot}
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL)
return entry, nil
}
func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) {
if entry == nil {
return nil, false, nil
}
if entry.NotFound {
return nil, true, ErrAPIKeyNotFound
}
if entry.Snapshot == nil {
return nil, false, nil
}
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
}
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
if apiKey == nil || apiKey.User == nil {
return nil
}
snapshot := &APIKeyAuthSnapshot{
APIKeyID: apiKey.ID,
UserID: apiKey.UserID,
GroupID: apiKey.GroupID,
Status: apiKey.Status,
IPWhitelist: apiKey.IPWhitelist,
IPBlacklist: apiKey.IPBlacklist,
User: APIKeyAuthUserSnapshot{
ID: apiKey.User.ID,
Status: apiKey.User.Status,
Role: apiKey.User.Role,
Balance: apiKey.User.Balance,
Concurrency: apiKey.User.Concurrency,
},
}
if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID,
Name: apiKey.Group.Name,
Platform: apiKey.Group.Platform,
Status: apiKey.Group.Status,
SubscriptionType: apiKey.Group.SubscriptionType,
RateMultiplier: apiKey.Group.RateMultiplier,
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID,
}
}
return snapshot
}
func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey {
if snapshot == nil {
return nil
}
apiKey := &APIKey{
ID: snapshot.APIKeyID,
UserID: snapshot.UserID,
GroupID: snapshot.GroupID,
Key: key,
Status: snapshot.Status,
IPWhitelist: snapshot.IPWhitelist,
IPBlacklist: snapshot.IPBlacklist,
User: &User{
ID: snapshot.User.ID,
Status: snapshot.User.Status,
Role: snapshot.User.Role,
Balance: snapshot.User.Balance,
Concurrency: snapshot.User.Concurrency,
},
}
if snapshot.Group != nil {
apiKey.Group = &Group{
ID: snapshot.Group.ID,
Name: snapshot.Group.Name,
Platform: snapshot.Group.Platform,
Status: snapshot.Group.Status,
Hydrated: true,
SubscriptionType: snapshot.Group.SubscriptionType,
RateMultiplier: snapshot.Group.RateMultiplier,
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID,
}
}
return apiKey
}

View File

@@ -0,0 +1,48 @@
package service
import "context"
// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存
func (s *APIKeyService) InvalidateAuthCacheByKey(ctx context.Context, key string) {
if key == "" {
return
}
cacheKey := s.authCacheKey(key)
s.deleteAuthCache(ctx, cacheKey)
}
// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存
func (s *APIKeyService) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
if userID <= 0 {
return
}
keys, err := s.apiKeyRepo.ListKeysByUserID(ctx, userID)
if err != nil {
return
}
s.deleteAuthCacheByKeys(ctx, keys)
}
// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存
func (s *APIKeyService) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
if groupID <= 0 {
return
}
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, groupID)
if err != nil {
return
}
s.deleteAuthCacheByKeys(ctx, keys)
}
func (s *APIKeyService) deleteAuthCacheByKeys(ctx context.Context, keys []string) {
if len(keys) == 0 {
return
}
for _, key := range keys {
if key == "" {
continue
}
s.deleteAuthCache(ctx, s.authCacheKey(key))
}
}

View File

@@ -12,6 +12,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/dgraph-io/ristretto"
"golang.org/x/sync/singleflight"
)
var (
@@ -31,9 +33,11 @@ const (
type APIKeyRepository interface {
Create(ctx context.Context, key *APIKey) error
GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetOwnerID 仅获取 API Key 的所有者 ID用于删除前的轻量级权限验证
GetOwnerID(ctx context.Context, id int64) (int64, error)
// GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID用于删除等轻量场景
GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error)
GetByKey(ctx context.Context, key string) (*APIKey, error)
// GetByKeyForAuth 认证专用查询,返回最小字段集
GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error)
Update(ctx context.Context, key *APIKey) error
Delete(ctx context.Context, id int64) error
@@ -45,6 +49,8 @@ type APIKeyRepository interface {
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
ListKeysByUserID(ctx context.Context, userID int64) ([]string, error)
ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error)
}
// APIKeyCache defines cache operations for API key service
@@ -55,6 +61,17 @@ type APIKeyCache interface {
IncrementDailyUsage(ctx context.Context, apiKey string) error
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
DeleteAuthCache(ctx context.Context, key string) error
}
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
type APIKeyAuthCacheInvalidator interface {
InvalidateAuthCacheByKey(ctx context.Context, key string)
InvalidateAuthCacheByUserID(ctx context.Context, userID int64)
InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64)
}
// CreateAPIKeyRequest 创建API Key请求
@@ -83,6 +100,9 @@ type APIKeyService struct {
userSubRepo UserSubscriptionRepository
cache APIKeyCache
cfg *config.Config
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
}
// NewAPIKeyService 创建API Key服务实例
@@ -94,7 +114,7 @@ func NewAPIKeyService(
cache APIKeyCache,
cfg *config.Config,
) *APIKeyService {
return &APIKeyService{
svc := &APIKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
@@ -102,6 +122,8 @@ func NewAPIKeyService(
cache: cache,
cfg: cfg,
}
svc.initAuthCache(cfg)
return svc
}
// GenerateKey 生成随机API Key
@@ -269,6 +291,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
return nil, fmt.Errorf("create api key: %w", err)
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
return apiKey, nil
}
@@ -304,21 +328,49 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetByKey 根据Key字符串获取API Key用于认证
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
cacheKey := s.authCacheKey(key)
// 这里可以添加Redis缓存逻辑暂时直接查询数据库
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok {
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
// 缓存到Redis可选TTL设置为5分钟
if s.cache != nil {
// 这里可以序列化并缓存API Key
_ = cacheKey // 使用变量避免未使用错误
return apiKey, nil
}
}
if s.authCfg.singleflight {
value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) {
return s.loadAuthCacheEntry(ctx, key, cacheKey)
})
if err != nil {
return nil, err
}
entry, _ := value.(*APIKeyAuthCacheEntry)
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
}
} else {
entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey)
if err != nil {
return nil, err
}
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
}
}
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
apiKey.Key = key
return apiKey, nil
}
@@ -388,15 +440,14 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
return nil, fmt.Errorf("update api key: %w", err)
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
return apiKey, nil
}
// Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 APIKey 对象及其关联数据User、Group提升删除操作的性能
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id)
if err != nil {
return fmt.Errorf("get api key: %w", err)
}
@@ -406,10 +457,11 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro
return ErrInsufficientPerms
}
// 清除Redis缓存使用 ownerID 而非 apiKey.UserID
// 清除Redis缓存使用 userID 而非 apiKey.UserID
if s.cache != nil {
_ = s.cache.DeleteCreateAttemptCount(ctx, ownerID)
_ = s.cache.DeleteCreateAttemptCount(ctx, userID)
}
s.InvalidateAuthCacheByKey(ctx, key)
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete api key: %w", err)

View File

@@ -0,0 +1,417 @@
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
type authRepoStub struct {
getByKeyForAuth func(ctx context.Context, key string) (*APIKey, error)
listKeysByUserID func(ctx context.Context, userID int64) ([]string, error)
listKeysByGroupID func(ctx context.Context, groupID int64) ([]string, error)
}
func (s *authRepoStub) Create(ctx context.Context, key *APIKey) error {
panic("unexpected Create call")
}
func (s *authRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
panic("unexpected GetByID call")
}
func (s *authRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
panic("unexpected GetKeyAndOwnerID call")
}
func (s *authRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *authRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
if s.getByKeyForAuth == nil {
panic("unexpected GetByKeyForAuth call")
}
return s.getByKeyForAuth(ctx, key)
}
func (s *authRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call")
}
func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
func (s *authRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
panic("unexpected VerifyOwnership call")
}
func (s *authRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) {
panic("unexpected CountByUserID call")
}
func (s *authRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) {
panic("unexpected ExistsByKey call")
}
func (s *authRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
}
func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected ClearGroupIDByGroupID call")
}
func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected CountByGroupID call")
}
func (s *authRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
if s.listKeysByUserID == nil {
panic("unexpected ListKeysByUserID call")
}
return s.listKeysByUserID(ctx, userID)
}
func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
if s.listKeysByGroupID == nil {
panic("unexpected ListKeysByGroupID call")
}
return s.listKeysByGroupID(ctx, groupID)
}
type authCacheStub struct {
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
setAuthKeys []string
deleteAuthKeys []string
}
func (s *authCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (s *authCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (s *authCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (s *authCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return nil
}
func (s *authCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return nil
}
func (s *authCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
if s.getAuthCache == nil {
return nil, redis.Nil
}
return s.getAuthCache(ctx, key)
}
func (s *authCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
s.setAuthKeys = append(s.setAuthKeys, key)
return nil
}
func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
return nil
}
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return nil, errors.New("unexpected repo call")
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{
Snapshot: &APIKeyAuthSnapshot{
APIKeyID: 1,
UserID: 2,
GroupID: &groupID,
Status: StatusActive,
User: APIKeyAuthUserSnapshot{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &APIKeyAuthGroupSnapshot{
ID: groupID,
Name: "g",
Platform: PlatformAnthropic,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
},
},
}
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return cacheEntry, nil
}
apiKey, err := svc.GetByKey(context.Background(), "k1")
require.NoError(t, err)
require.Equal(t, int64(1), apiKey.ID)
require.Equal(t, int64(2), apiKey.User.ID)
require.Equal(t, groupID, apiKey.Group.ID)
}
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return nil, errors.New("unexpected repo call")
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return &APIKeyAuthCacheEntry{NotFound: true}, nil
}
_, err := svc.GetByKey(context.Background(), "missing")
require.ErrorIs(t, err, ErrAPIKeyNotFound)
}
func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return &APIKey{
ID: 5,
UserID: 7,
Status: StatusActive,
User: &User{
ID: 7,
Status: StatusActive,
Role: RoleUser,
Balance: 12,
Concurrency: 2,
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
apiKey, err := svc.GetByKey(context.Background(), "k2")
require.NoError(t, err)
require.Equal(t, int64(5), apiKey.ID)
require.Len(t, cache.setAuthKeys, 1)
}
func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
var calls int32
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
atomic.AddInt32(&calls, 1)
return &APIKey{
ID: 21,
UserID: 3,
Status: StatusActive,
User: &User{
ID: 3,
Status: StatusActive,
Role: RoleUser,
Balance: 5,
Concurrency: 2,
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L1Size: 1000,
L1TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
require.NotNil(t, svc.authCacheL1)
_, err := svc.GetByKey(context.Background(), "k-l1")
require.NoError(t, err)
svc.authCacheL1.Wait()
cacheKey := svc.authCacheKey("k-l1")
_, ok := svc.authCacheL1.Get(cacheKey)
require.True(t, ok)
_, err = svc.GetByKey(context.Background(), "k-l1")
require.NoError(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
}
func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
return []string{"k1", "k2"}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
require.Len(t, cache.deleteAuthKeys, 2)
}
func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
listKeysByGroupID: func(ctx context.Context, groupID int64) ([]string, error) {
return []string{"k1", "k2"}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
require.Len(t, cache.deleteAuthKeys, 2)
}
func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
return nil, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
require.Len(t, cache.deleteAuthKeys, 1)
}
func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return nil, ErrAPIKeyNotFound
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
_, err := svc.GetByKey(context.Background(), "missing")
require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Len(t, cache.setAuthKeys, 1)
}
func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
var calls int32
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
atomic.AddInt32(&calls, 1)
time.Sleep(50 * time.Millisecond)
return &APIKey{
ID: 11,
UserID: 2,
Status: StatusActive,
User: &User{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 1,
Concurrency: 1,
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
Singleflight: true,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
start := make(chan struct{})
wg := sync.WaitGroup{}
errs := make([]error, 5)
for i := 0; i < 5; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
<-start
_, err := svc.GetByKey(context.Background(), "k1")
errs[idx] = err
}(i)
}
close(start)
wg.Wait()
for _, err := range errs {
require.NoError(t, err)
}
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
}

View File

@@ -20,13 +20,12 @@ import (
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
//
// 设计说明:
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound
// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误
// - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID用于断言验证
type apiKeyRepoStub struct {
ownerID int64 // GetOwnerID 的返回值
ownerErr error // GetOwnerID 的错误返回值
apiKey *APIKey // GetKeyAndOwnerID 的返回值
getByIDErr error // GetKeyAndOwnerID 的错误返回值
deleteErr error // Delete 的错误返回值
deletedIDs []int64 // 记录已删除的 API Key ID 列表
}
@@ -38,19 +37,34 @@ func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
}
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
if s.getByIDErr != nil {
return nil, s.getByIDErr
}
if s.apiKey != nil {
clone := *s.apiKey
return &clone, nil
}
panic("unexpected GetByID call")
}
// GetOwnerID 返回预设的所有者 ID 或错误。
// 这是 Delete 方法调用的第一个仓储方法,用于验证调用者是否为 API Key 的所有者。
func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return s.ownerID, s.ownerErr
func (s *apiKeyRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
if s.getByIDErr != nil {
return "", 0, s.getByIDErr
}
if s.apiKey != nil {
return s.apiKey.Key, s.apiKey.UserID, nil
}
return "", 0, ErrAPIKeyNotFound
}
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *apiKeyRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKeyForAuth call")
}
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call")
}
@@ -96,6 +110,14 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
panic("unexpected CountByGroupID call")
}
func (s *apiKeyRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
panic("unexpected ListKeysByUserID call")
}
func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
panic("unexpected ListKeysByGroupID call")
}
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//
@@ -103,6 +125,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
// - invalidated: 记录被清除缓存的用户 ID 列表
type apiKeyCacheStub struct {
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
deleteAuthKeys []string // 记录调用 DeleteAuthCache 时传入的缓存 key
}
// GetCreateAttemptCount 返回 0表示用户未超过创建次数限制
@@ -132,15 +155,30 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
return nil
}
func (s *apiKeyCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, nil
}
func (s *apiKeyCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
return nil
}
func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
return nil
}
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 1
// - GetKeyAndOwnerID 返回所有者 ID 为 1
// - 调用者 userID 为 2不匹配
// - 返回 ErrInsufficientPerms 错误
// - Delete 方法不被调用
// - 缓存不被清除
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 1}
repo := &apiKeyRepoStub{
apiKey: &APIKey{ID: 10, UserID: 1, Key: "k"},
}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
@@ -148,17 +186,20 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
require.ErrorIs(t, err, ErrInsufficientPerms)
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
require.Empty(t, cache.invalidated) // 验证缓存未被清除
require.Empty(t, cache.deleteAuthKeys)
}
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 7
// - GetKeyAndOwnerID 返回所有者 ID 为 7
// - 调用者 userID 为 7匹配
// - Delete 成功执行
// - 缓存被正确清除(使用 ownerID
// - 返回 nil 错误
func TestApiKeyService_Delete_Success(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 7}
repo := &apiKeyRepoStub{
apiKey: &APIKey{ID: 42, UserID: 7, Key: "k"},
}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
@@ -166,16 +207,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
require.NoError(t, err)
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
}
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// 预期行为:
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
// - GetKeyAndOwnerID 返回 ErrAPIKeyNotFound 错误
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
// - Delete 方法不被调用
// - 缓存不被清除
func TestApiKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
repo := &apiKeyRepoStub{getByIDErr: ErrAPIKeyNotFound}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
@@ -183,18 +225,19 @@ func TestApiKeyService_Delete_NotFound(t *testing.T) {
require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Empty(t, repo.deletedIDs)
require.Empty(t, cache.invalidated)
require.Empty(t, cache.deleteAuthKeys)
}
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
// 预期行为:
// - GetOwnerID 返回正确的所有者 ID
// - GetKeyAndOwnerID 返回正确的所有者 ID
// - 所有权验证通过
// - 缓存被清除(在删除之前)
// - Delete 被调用但返回错误
// - 返回包含 "delete api key" 的错误信息
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
repo := &apiKeyRepoStub{
ownerID: 3,
apiKey: &APIKey{ID: 42, UserID: 3, Key: "k"},
deleteErr: errors.New("delete failed"),
}
cache := &apiKeyCacheStub{}
@@ -205,4 +248,5 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
require.ErrorContains(t, err, "delete api key")
require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用
require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败)
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
}

View File

@@ -51,12 +51,14 @@ type UpdateGroupRequest struct {
// GroupService 分组管理服务
type GroupService struct {
groupRepo GroupRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewGroupService 创建分组服务实例
func NewGroupService(groupRepo GroupRepository) *GroupService {
func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *GroupService {
return &GroupService{
groupRepo: groupRepo,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -155,6 +157,9 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, fmt.Errorf("update group: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
return group, nil
}
@@ -170,6 +175,9 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
if err := s.groupRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete group: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
return nil
}

View File

@@ -56,12 +56,14 @@ type ChangePasswordRequest struct {
// UserService 用户服务
type UserService struct {
userRepo UserRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewUserService 创建用户服务实例
func NewUserService(userRepo UserRepository) *UserService {
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService {
return &UserService{
userRepo: userRepo,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -89,6 +91,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
oldConcurrency := user.Concurrency
// 更新字段
if req.Email != nil {
@@ -114,6 +117,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err)
}
if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return user, nil
}
@@ -169,6 +175,9 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil {
return fmt.Errorf("update balance: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return nil
}
@@ -177,6 +186,9 @@ func (s *UserService) UpdateConcurrency(ctx context.Context, userID int64, concu
if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil {
return fmt.Errorf("update concurrency: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return nil
}
@@ -192,6 +204,9 @@ func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status str
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return nil
}
@@ -201,5 +216,8 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error {
if err := s.userRepo.Delete(ctx, userID); err != nil {
return fmt.Errorf("delete user: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return nil
}

View File

@@ -77,12 +77,18 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi
return svc
}
// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力
func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator {
return apiKeyService
}
// ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet(
// Core services
NewAuthService,
NewUserService,
NewAPIKeyService,
ProvideAPIKeyAuthCacheInvalidator,
NewGroupService,
NewAccountService,
NewProxyService,

View File

@@ -170,6 +170,30 @@ gateway:
# 允许在特定 400 错误时进行故障转移(默认:关闭)
failover_on_400: false
# =============================================================================
# API Key Auth Cache Configuration
# API Key 认证缓存配置
# =============================================================================
api_key_auth_cache:
# L1 cache size (entries), in-process LRU/TTL cache
# L1 缓存容量(条目数),进程内 LRU/TTL 缓存
l1_size: 65535
# L1 cache TTL (seconds)
# L1 缓存 TTL
l1_ttl_seconds: 15
# L2 cache TTL (seconds), stored in Redis
# L2 缓存 TTLRedis 中存储
l2_ttl_seconds: 300
# Negative cache TTL (seconds)
# 负缓存 TTL
negative_ttl_seconds: 30
# TTL jitter percent (0-100)
# TTL 抖动百分比0-100
jitter_percent: 10
# Enable singleflight for cache misses
# 缓存未命中时启用 singleflight 合并回源
singleflight: true
# =============================================================================
# Concurrency Wait Configuration
# 并发等待配置

View File

@@ -170,6 +170,30 @@ gateway:
# 允许在特定 400 错误时进行故障转移(默认:关闭)
failover_on_400: false
# =============================================================================
# API Key Auth Cache Configuration
# API Key 认证缓存配置
# =============================================================================
api_key_auth_cache:
# L1 cache size (entries), in-process LRU/TTL cache
# L1 缓存容量(条目数),进程内 LRU/TTL 缓存
l1_size: 65535
# L1 cache TTL (seconds)
# L1 缓存 TTL
l1_ttl_seconds: 15
# L2 cache TTL (seconds), stored in Redis
# L2 缓存 TTLRedis 中存储
l2_ttl_seconds: 300
# Negative cache TTL (seconds)
# 负缓存 TTL
negative_ttl_seconds: 30
# TTL jitter percent (0-100)
# TTL 抖动百分比0-100
jitter_percent: 10
# Enable singleflight for cache misses
# 缓存未命中时启用 singleflight 合并回源
singleflight: true
# =============================================================================
# Concurrency Wait Configuration
# 并发等待配置