From 9d0a4f3d68ea9aea079cf2292e968550561dd3a9 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 10 Jan 2026 22:23:51 +0800 Subject: [PATCH] =?UTF-8?q?perf(=E8=AE=A4=E8=AF=81):=20=E5=BC=95=E5=85=A5?= =?UTF-8?q?=20API=20Key=20=E8=AE=A4=E8=AF=81=E7=BC=93=E5=AD=98=E4=B8=8E?= =?UTF-8?q?=E8=BD=BB=E9=87=8F=E5=88=A0=E9=99=A4=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加 L1/L2 缓存、负缓存与单飞回源 使用 key+owner 轻量查询替代全量加载并清理旧接口 补充缓存失效与余额更新测试,修复随机抖动 lint 测试: make test --- backend/cmd/server/wire_gen.go | 9 +- backend/go.mod | 2 + backend/go.sum | 4 + backend/internal/config/config.go | 57 ++- backend/internal/repository/api_key_cache.go | 33 ++ backend/internal/repository/api_key_repo.go | 88 +++- backend/internal/server/api_contract_test.go | 32 +- .../middleware/api_key_auth_google_test.go | 13 +- .../server/middleware/api_key_auth_test.go | 16 +- backend/internal/service/admin_service.go | 66 ++- .../admin_service_update_balance_test.go | 97 ++++ .../internal/service/api_key_auth_cache.go | 46 ++ .../service/api_key_auth_cache_impl.go | 269 +++++++++++ .../service/api_key_auth_cache_invalidate.go | 48 ++ backend/internal/service/api_key_service.go | 92 +++- .../service/api_key_service_cache_test.go | 417 ++++++++++++++++++ .../service/api_key_service_delete_test.go | 78 +++- backend/internal/service/group_service.go | 14 +- backend/internal/service/user_service.go | 24 +- backend/internal/service/wire.go | 6 + config.yaml | 24 + deploy/config.example.yaml | 24 + 22 files changed, 1360 insertions(+), 99 deletions(-) create mode 100644 backend/internal/service/admin_service_update_balance_test.go create mode 100644 backend/internal/service/api_key_auth_cache.go create mode 100644 backend/internal/service/api_key_auth_cache_impl.go create mode 100644 backend/internal/service/api_key_auth_cache_invalidate.go create mode 100644 backend/internal/service/api_key_service_cache_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 30ea0fdb..a372f673 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -57,13 +57,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) - userService := service.NewUserService(userRepository) - authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService) - userHandler := handler.NewUserHandler(userService) apiKeyRepository := repository.NewAPIKeyRepository(client) groupRepository := repository.NewGroupRepository(client, db) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) + apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) + userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) + authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService) + userHandler := handler.NewUserHandler(userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client) @@ -79,7 +80,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { accountRepository := repository.NewAccountRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, apiKeyAuthCacheInvalidator) adminUserHandler := admin.NewUserHandler(adminService) groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() diff --git a/backend/go.mod b/backend/go.mod index 9ac48305..82a8e88e 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -44,11 +44,13 @@ require ( github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dgraph-io/ristretto v0.2.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/docker v28.5.1+incompatible // indirect github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.5.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/ebitengine/purego v0.8.4 // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect diff --git a/backend/go.sum b/backend/go.sum index 38e2b53e..0fd47498 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -51,6 +51,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE= +github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= @@ -61,6 +63,8 @@ github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pM github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw= github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 2cc11967..29eaa42e 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -36,25 +36,26 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - JWT JWTConfig `mapstructure:"jwt"` - LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + JWT JWTConfig `mapstructure:"jwt"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` } // UpdateConfig 在线更新相关配置 @@ -361,6 +362,16 @@ type RateLimitConfig struct { OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) } +// APIKeyAuthCacheConfig API Key 认证缓存配置 +type APIKeyAuthCacheConfig struct { + L1Size int `mapstructure:"l1_size"` + L1TTLSeconds int `mapstructure:"l1_ttl_seconds"` + L2TTLSeconds int `mapstructure:"l2_ttl_seconds"` + NegativeTTLSeconds int `mapstructure:"negative_ttl_seconds"` + JitterPercent int `mapstructure:"jitter_percent"` + Singleflight bool `mapstructure:"singleflight"` +} + func NormalizeRunMode(value string) string { normalized := strings.ToLower(strings.TrimSpace(value)) switch normalized { @@ -655,6 +666,14 @@ func setDefaults() { // Timezone (default to Asia/Shanghai for Chinese users) viper.SetDefault("timezone", "Asia/Shanghai") + // API Key auth cache + viper.SetDefault("api_key_auth_cache.l1_size", 65535) + viper.SetDefault("api_key_auth_cache.l1_ttl_seconds", 15) + viper.SetDefault("api_key_auth_cache.l2_ttl_seconds", 300) + viper.SetDefault("api_key_auth_cache.negative_ttl_seconds", 30) + viper.SetDefault("api_key_auth_cache.jitter_percent", 10) + viper.SetDefault("api_key_auth_cache.singleflight", true) + // Gateway viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.log_upstream_error_body", false) diff --git a/backend/internal/repository/api_key_cache.go b/backend/internal/repository/api_key_cache.go index 73a929c5..6d834b40 100644 --- a/backend/internal/repository/api_key_cache.go +++ b/backend/internal/repository/api_key_cache.go @@ -2,6 +2,7 @@ package repository import ( "context" + "encoding/json" "errors" "fmt" "time" @@ -13,6 +14,7 @@ import ( const ( apiKeyRateLimitKeyPrefix = "apikey:ratelimit:" apiKeyRateLimitDuration = 24 * time.Hour + apiKeyAuthCachePrefix = "apikey:auth:" ) // apiKeyRateLimitKey generates the Redis key for API key creation rate limiting. @@ -20,6 +22,10 @@ func apiKeyRateLimitKey(userID int64) string { return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) } +func apiKeyAuthCacheKey(key string) string { + return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key) +} + type apiKeyCache struct { rdb *redis.Client } @@ -58,3 +64,30 @@ func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) er func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error { return c.rdb.Expire(ctx, apiKey, ttl).Err() } + +func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) { + val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes() + if err != nil { + return nil, err + } + var entry service.APIKeyAuthCacheEntry + if err := json.Unmarshal(val, &entry); err != nil { + return nil, err + } + return &entry, nil +} + +func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error { + if entry == nil { + return nil + } + payload, err := json.Marshal(entry) + if err != nil { + return err + } + return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err() +} + +func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error { + return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err() +} diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 6b8cd40d..77a3f233 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -6,7 +6,9 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -64,23 +66,23 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK return apiKeyEntityToService(m), nil } -// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。 +// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者(用户)ID。 // 相比 GetByID,此方法性能更优,因为: -// - 使用 Select() 只查询 user_id 字段,减少数据传输量 +// - 使用 Select() 只查询必要字段,减少数据传输量 // - 不加载完整的 API Key 实体及其关联数据(User、Group 等) -// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查) -func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) { +// - 适用于删除等只需 key 与用户 ID 的场景 +func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { m, err := r.activeQuery(). Where(apikey.IDEQ(id)). - Select(apikey.FieldUserID). + Select(apikey.FieldKey, apikey.FieldUserID). Only(ctx) if err != nil { if dbent.IsNotFound(err) { - return 0, service.ErrAPIKeyNotFound + return "", 0, service.ErrAPIKeyNotFound } - return 0, err + return "", 0, err } - return m.UserID, nil + return m.Key, m.UserID, nil } func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { @@ -98,6 +100,54 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A return apiKeyEntityToService(m), nil } +func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { + m, err := r.activeQuery(). + Where(apikey.KeyEQ(key)). + Select( + apikey.FieldID, + apikey.FieldUserID, + apikey.FieldGroupID, + apikey.FieldStatus, + apikey.FieldIPWhitelist, + apikey.FieldIPBlacklist, + ). + WithUser(func(q *dbent.UserQuery) { + q.Select( + user.FieldID, + user.FieldStatus, + user.FieldRole, + user.FieldBalance, + user.FieldConcurrency, + ) + }). + WithGroup(func(q *dbent.GroupQuery) { + q.Select( + group.FieldID, + group.FieldName, + group.FieldPlatform, + group.FieldStatus, + group.FieldSubscriptionType, + group.FieldRateMultiplier, + group.FieldDailyLimitUsd, + group.FieldWeeklyLimitUsd, + group.FieldMonthlyLimitUsd, + group.FieldImagePrice1k, + group.FieldImagePrice2k, + group.FieldImagePrice4k, + group.FieldClaudeCodeOnly, + group.FieldFallbackGroupID, + ) + }). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrAPIKeyNotFound + } + return nil, err + } + return apiKeyEntityToService(m), nil +} + func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error { // 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。 // 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除, @@ -283,6 +333,28 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i return int64(count), err } +func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + keys, err := r.activeQuery(). + Where(apikey.UserIDEQ(userID)). + Select(apikey.FieldKey). + Strings(ctx) + if err != nil { + return nil, err + } + return keys, nil +} + +func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + keys, err := r.activeQuery(). + Where(apikey.GroupIDEQ(groupID)). + Select(apikey.FieldKey). + Strings(ctx) + if err != nil { + return nil, err + } + return keys, nil +} + func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { if m == nil { return nil diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 41d8bfdb..bd02f47d 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -389,7 +389,7 @@ func newContractDeps(t *testing.T) *contractDeps { RunMode: config.RunModeStandard, } - userService := service.NewUserService(userRepo) + userService := service.NewUserService(userRepo, nil) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() @@ -565,6 +565,18 @@ func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, t return nil } +func (stubApiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) { + return nil, nil +} + +func (stubApiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error { + return nil +} + +func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error { + return nil +} + type stubGroupRepo struct{} func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error { @@ -737,12 +749,12 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey return &clone, nil } -func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { +func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { key, ok := r.byID[id] if !ok { - return 0, service.ErrAPIKeyNotFound + return "", 0, service.ErrAPIKeyNotFound } - return key.UserID, nil + return key.Key, key.UserID, nil } func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { @@ -754,6 +766,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API return &clone, nil } +func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { + return r.GetByKey(ctx, key) +} + func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error { if key == nil { return errors.New("nil key") @@ -868,6 +884,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + return nil, errors.New("not implemented") +} + type stubUsageLogRepo struct { userLogs map[int64][]service.UsageLog } diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 07b8e370..6f09469b 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -27,8 +27,8 @@ func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error { func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) { return nil, errors.New("not implemented") } -func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { - return 0, errors.New("not implemented") +func (f fakeAPIKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { + return "", 0, errors.New("not implemented") } func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { if f.getByKey == nil { @@ -36,6 +36,9 @@ func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIK } return f.getByKey(ctx, key) } +func (f fakeAPIKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { + return f.GetByKey(ctx, key) +} func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error { return errors.New("not implemented") } @@ -66,6 +69,12 @@ func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64 func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, errors.New("not implemented") } +func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + return nil, errors.New("not implemented") +} +func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + return nil, errors.New("not implemented") +} type googleErrorResponse struct { Error struct { diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 182ea5f8..84398093 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -256,8 +256,8 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey return nil, errors.New("not implemented") } -func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { - return 0, errors.New("not implemented") +func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { + return "", 0, errors.New("not implemented") } func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { @@ -267,6 +267,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API return nil, errors.New("not implemented") } +func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { + return r.GetByKey(ctx, key) +} + func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error { return errors.New("not implemented") } @@ -307,6 +311,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + return nil, errors.New("not implemented") +} + type stubUserSubscriptionRepo struct { getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) updateStatus func(ctx context.Context, subscriptionID int64, status string) error diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 14bb6daf..75b57852 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -244,14 +244,15 @@ type ProxyExitInfoProber interface { // adminServiceImpl implements AdminService type adminServiceImpl struct { - userRepo UserRepository - groupRepo GroupRepository - accountRepo AccountRepository - proxyRepo ProxyRepository - apiKeyRepo APIKeyRepository - redeemCodeRepo RedeemCodeRepository - billingCacheService *BillingCacheService - proxyProber ProxyExitInfoProber + userRepo UserRepository + groupRepo GroupRepository + accountRepo AccountRepository + proxyRepo ProxyRepository + apiKeyRepo APIKeyRepository + redeemCodeRepo RedeemCodeRepository + billingCacheService *BillingCacheService + proxyProber ProxyExitInfoProber + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewAdminService creates a new AdminService @@ -264,16 +265,18 @@ func NewAdminService( redeemCodeRepo RedeemCodeRepository, billingCacheService *BillingCacheService, proxyProber ProxyExitInfoProber, + authCacheInvalidator APIKeyAuthCacheInvalidator, ) AdminService { return &adminServiceImpl{ - userRepo: userRepo, - groupRepo: groupRepo, - accountRepo: accountRepo, - proxyRepo: proxyRepo, - apiKeyRepo: apiKeyRepo, - redeemCodeRepo: redeemCodeRepo, - billingCacheService: billingCacheService, - proxyProber: proxyProber, + userRepo: userRepo, + groupRepo: groupRepo, + accountRepo: accountRepo, + proxyRepo: proxyRepo, + apiKeyRepo: apiKeyRepo, + redeemCodeRepo: redeemCodeRepo, + billingCacheService: billingCacheService, + proxyProber: proxyProber, + authCacheInvalidator: authCacheInvalidator, } } @@ -323,6 +326,8 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda } oldConcurrency := user.Concurrency + oldStatus := user.Status + oldRole := user.Role if input.Email != "" { user.Email = input.Email @@ -355,6 +360,11 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } + if s.authCacheInvalidator != nil { + if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) + } + } concurrencyDiff := user.Concurrency - oldConcurrency if concurrencyDiff != 0 { @@ -393,6 +403,9 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { log.Printf("delete user failed: user_id=%d err=%v", id, err) return err } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id) + } return nil } @@ -420,6 +433,10 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } + balanceDiff := user.Balance - oldBalance + if s.authCacheInvalidator != nil && balanceDiff != 0 { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } if s.billingCacheService != nil { go func() { @@ -431,7 +448,6 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, }() } - balanceDiff := user.Balance - oldBalance if balanceDiff != 0 { code, err := GenerateRedeemCode() if err != nil { @@ -675,10 +691,21 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } return group, nil } func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { + var groupKeys []string + if s.authCacheInvalidator != nil { + keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id) + if err == nil { + groupKeys = keys + } + } + affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id) if err != nil { return err @@ -697,6 +724,11 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { } }() } + if s.authCacheInvalidator != nil { + for _, key := range groupKeys { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key) + } + } return nil } diff --git a/backend/internal/service/admin_service_update_balance_test.go b/backend/internal/service/admin_service_update_balance_test.go new file mode 100644 index 00000000..d3b3c700 --- /dev/null +++ b/backend/internal/service/admin_service_update_balance_test.go @@ -0,0 +1,97 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +type balanceUserRepoStub struct { + *userRepoStub + updateErr error + updated []*User +} + +func (s *balanceUserRepoStub) Update(ctx context.Context, user *User) error { + if s.updateErr != nil { + return s.updateErr + } + if user == nil { + return nil + } + clone := *user + s.updated = append(s.updated, &clone) + if s.userRepoStub != nil { + s.userRepoStub.user = &clone + } + return nil +} + +type balanceRedeemRepoStub struct { + *redeemRepoStub + created []*RedeemCode +} + +func (s *balanceRedeemRepoStub) Create(ctx context.Context, code *RedeemCode) error { + if code == nil { + return nil + } + clone := *code + s.created = append(s.created, &clone) + return nil +} + +type authCacheInvalidatorStub struct { + userIDs []int64 + groupIDs []int64 + keys []string +} + +func (s *authCacheInvalidatorStub) InvalidateAuthCacheByKey(ctx context.Context, key string) { + s.keys = append(s.keys, key) +} + +func (s *authCacheInvalidatorStub) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) { + s.userIDs = append(s.userIDs, userID) +} + +func (s *authCacheInvalidatorStub) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) { + s.groupIDs = append(s.groupIDs, groupID) +} + +func TestAdminService_UpdateUserBalance_InvalidatesAuthCache(t *testing.T) { + baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}} + repo := &balanceUserRepoStub{userRepoStub: baseRepo} + redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + userRepo: repo, + redeemCodeRepo: redeemRepo, + authCacheInvalidator: invalidator, + } + + _, err := svc.UpdateUserBalance(context.Background(), 7, 5, "add", "") + require.NoError(t, err) + require.Equal(t, []int64{7}, invalidator.userIDs) + require.Len(t, redeemRepo.created, 1) +} + +func TestAdminService_UpdateUserBalance_NoChangeNoInvalidate(t *testing.T) { + baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}} + repo := &balanceUserRepoStub{userRepoStub: baseRepo} + redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + userRepo: repo, + redeemCodeRepo: redeemRepo, + authCacheInvalidator: invalidator, + } + + _, err := svc.UpdateUserBalance(context.Background(), 7, 10, "set", "") + require.NoError(t, err) + require.Empty(t, invalidator.userIDs) + require.Empty(t, redeemRepo.created) +} diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go new file mode 100644 index 00000000..7ce9a8a2 --- /dev/null +++ b/backend/internal/service/api_key_auth_cache.go @@ -0,0 +1,46 @@ +package service + +// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段) +type APIKeyAuthSnapshot struct { + APIKeyID int64 `json:"api_key_id"` + UserID int64 `json:"user_id"` + GroupID *int64 `json:"group_id,omitempty"` + Status string `json:"status"` + IPWhitelist []string `json:"ip_whitelist,omitempty"` + IPBlacklist []string `json:"ip_blacklist,omitempty"` + User APIKeyAuthUserSnapshot `json:"user"` + Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"` +} + +// APIKeyAuthUserSnapshot 用户快照 +type APIKeyAuthUserSnapshot struct { + ID int64 `json:"id"` + Status string `json:"status"` + Role string `json:"role"` + Balance float64 `json:"balance"` + Concurrency int `json:"concurrency"` +} + +// APIKeyAuthGroupSnapshot 分组快照 +type APIKeyAuthGroupSnapshot struct { + ID int64 `json:"id"` + Name string `json:"name"` + Platform string `json:"platform"` + Status string `json:"status"` + SubscriptionType string `json:"subscription_type"` + RateMultiplier float64 `json:"rate_multiplier"` + DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` + WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` + MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` + ImagePrice1K *float64 `json:"image_price_1k,omitempty"` + ImagePrice2K *float64 `json:"image_price_2k,omitempty"` + ImagePrice4K *float64 `json:"image_price_4k,omitempty"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` +} + +// APIKeyAuthCacheEntry 缓存条目,支持负缓存 +type APIKeyAuthCacheEntry struct { + NotFound bool `json:"not_found"` + Snapshot *APIKeyAuthSnapshot `json:"snapshot,omitempty"` +} diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go new file mode 100644 index 00000000..dfc55eeb --- /dev/null +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -0,0 +1,269 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "math/rand" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/dgraph-io/ristretto" +) + +type apiKeyAuthCacheConfig struct { + l1Size int + l1TTL time.Duration + l2TTL time.Duration + negativeTTL time.Duration + jitterPercent int + singleflight bool +} + +var ( + jitterRandMu sync.Mutex + // 认证缓存抖动使用独立随机源,避免全局 Seed + jitterRand = rand.New(rand.NewSource(time.Now().UnixNano())) +) + +func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig { + if cfg == nil { + return apiKeyAuthCacheConfig{} + } + auth := cfg.APIKeyAuth + return apiKeyAuthCacheConfig{ + l1Size: auth.L1Size, + l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second, + l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second, + negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second, + jitterPercent: auth.JitterPercent, + singleflight: auth.Singleflight, + } +} + +func (c apiKeyAuthCacheConfig) l1Enabled() bool { + return c.l1Size > 0 && c.l1TTL > 0 +} + +func (c apiKeyAuthCacheConfig) l2Enabled() bool { + return c.l2TTL > 0 +} + +func (c apiKeyAuthCacheConfig) negativeEnabled() bool { + return c.negativeTTL > 0 +} + +func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration { + if ttl <= 0 { + return ttl + } + if c.jitterPercent <= 0 { + return ttl + } + percent := c.jitterPercent + if percent > 100 { + percent = 100 + } + delta := float64(percent) / 100 + jitterRandMu.Lock() + randVal := jitterRand.Float64() + jitterRandMu.Unlock() + factor := 1 - delta + randVal*(2*delta) + if factor <= 0 { + return ttl + } + return time.Duration(float64(ttl) * factor) +} + +func (s *APIKeyService) initAuthCache(cfg *config.Config) { + s.authCfg = newAPIKeyAuthCacheConfig(cfg) + if !s.authCfg.l1Enabled() { + return + } + cache, err := ristretto.NewCache(&ristretto.Config{ + NumCounters: int64(s.authCfg.l1Size) * 10, + MaxCost: int64(s.authCfg.l1Size), + BufferItems: 64, + }) + if err != nil { + return + } + s.authCacheL1 = cache +} + +func (s *APIKeyService) authCacheKey(key string) string { + sum := sha256.Sum256([]byte(key)) + return hex.EncodeToString(sum[:]) +} + +func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) { + if s.authCacheL1 != nil { + if val, ok := s.authCacheL1.Get(cacheKey); ok { + if entry, ok := val.(*APIKeyAuthCacheEntry); ok { + return entry, true + } + } + } + if s.cache == nil || !s.authCfg.l2Enabled() { + return nil, false + } + entry, err := s.cache.GetAuthCache(ctx, cacheKey) + if err != nil { + return nil, false + } + s.setAuthCacheL1(cacheKey, entry) + return entry, true +} + +func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) { + if s.authCacheL1 == nil || entry == nil { + return + } + ttl := s.authCfg.l1TTL + if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl { + ttl = s.authCfg.negativeTTL + } + ttl = s.authCfg.jitterTTL(ttl) + _ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl) +} + +func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) { + if entry == nil { + return + } + s.setAuthCacheL1(cacheKey, entry) + if s.cache == nil || !s.authCfg.l2Enabled() { + return + } + _ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl)) +} + +func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) { + if s.authCacheL1 != nil { + s.authCacheL1.Del(cacheKey) + } + if s.cache == nil { + return + } + _ = s.cache.DeleteAuthCache(ctx, cacheKey) +} + +func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) { + apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key) + if err != nil { + if errors.Is(err, ErrAPIKeyNotFound) { + entry := &APIKeyAuthCacheEntry{NotFound: true} + if s.authCfg.negativeEnabled() { + s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL) + } + return entry, nil + } + return nil, fmt.Errorf("get api key: %w", err) + } + apiKey.Key = key + snapshot := s.snapshotFromAPIKey(apiKey) + if snapshot == nil { + return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound) + } + entry := &APIKeyAuthCacheEntry{Snapshot: snapshot} + s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL) + return entry, nil +} + +func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) { + if entry == nil { + return nil, false, nil + } + if entry.NotFound { + return nil, true, ErrAPIKeyNotFound + } + if entry.Snapshot == nil { + return nil, false, nil + } + return s.snapshotToAPIKey(key, entry.Snapshot), true, nil +} + +func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { + if apiKey == nil || apiKey.User == nil { + return nil + } + snapshot := &APIKeyAuthSnapshot{ + APIKeyID: apiKey.ID, + UserID: apiKey.UserID, + GroupID: apiKey.GroupID, + Status: apiKey.Status, + IPWhitelist: apiKey.IPWhitelist, + IPBlacklist: apiKey.IPBlacklist, + User: APIKeyAuthUserSnapshot{ + ID: apiKey.User.ID, + Status: apiKey.User.Status, + Role: apiKey.User.Role, + Balance: apiKey.User.Balance, + Concurrency: apiKey.User.Concurrency, + }, + } + if apiKey.Group != nil { + snapshot.Group = &APIKeyAuthGroupSnapshot{ + ID: apiKey.Group.ID, + Name: apiKey.Group.Name, + Platform: apiKey.Group.Platform, + Status: apiKey.Group.Status, + SubscriptionType: apiKey.Group.SubscriptionType, + RateMultiplier: apiKey.Group.RateMultiplier, + DailyLimitUSD: apiKey.Group.DailyLimitUSD, + WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, + MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, + ImagePrice1K: apiKey.Group.ImagePrice1K, + ImagePrice2K: apiKey.Group.ImagePrice2K, + ImagePrice4K: apiKey.Group.ImagePrice4K, + ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, + FallbackGroupID: apiKey.Group.FallbackGroupID, + } + } + return snapshot +} + +func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey { + if snapshot == nil { + return nil + } + apiKey := &APIKey{ + ID: snapshot.APIKeyID, + UserID: snapshot.UserID, + GroupID: snapshot.GroupID, + Key: key, + Status: snapshot.Status, + IPWhitelist: snapshot.IPWhitelist, + IPBlacklist: snapshot.IPBlacklist, + User: &User{ + ID: snapshot.User.ID, + Status: snapshot.User.Status, + Role: snapshot.User.Role, + Balance: snapshot.User.Balance, + Concurrency: snapshot.User.Concurrency, + }, + } + if snapshot.Group != nil { + apiKey.Group = &Group{ + ID: snapshot.Group.ID, + Name: snapshot.Group.Name, + Platform: snapshot.Group.Platform, + Status: snapshot.Group.Status, + Hydrated: true, + SubscriptionType: snapshot.Group.SubscriptionType, + RateMultiplier: snapshot.Group.RateMultiplier, + DailyLimitUSD: snapshot.Group.DailyLimitUSD, + WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, + MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, + ImagePrice1K: snapshot.Group.ImagePrice1K, + ImagePrice2K: snapshot.Group.ImagePrice2K, + ImagePrice4K: snapshot.Group.ImagePrice4K, + ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, + FallbackGroupID: snapshot.Group.FallbackGroupID, + } + } + return apiKey +} diff --git a/backend/internal/service/api_key_auth_cache_invalidate.go b/backend/internal/service/api_key_auth_cache_invalidate.go new file mode 100644 index 00000000..aeb58bcc --- /dev/null +++ b/backend/internal/service/api_key_auth_cache_invalidate.go @@ -0,0 +1,48 @@ +package service + +import "context" + +// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存 +func (s *APIKeyService) InvalidateAuthCacheByKey(ctx context.Context, key string) { + if key == "" { + return + } + cacheKey := s.authCacheKey(key) + s.deleteAuthCache(ctx, cacheKey) +} + +// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存 +func (s *APIKeyService) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) { + if userID <= 0 { + return + } + keys, err := s.apiKeyRepo.ListKeysByUserID(ctx, userID) + if err != nil { + return + } + s.deleteAuthCacheByKeys(ctx, keys) +} + +// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存 +func (s *APIKeyService) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) { + if groupID <= 0 { + return + } + keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, groupID) + if err != nil { + return + } + s.deleteAuthCacheByKeys(ctx, keys) +} + +func (s *APIKeyService) deleteAuthCacheByKeys(ctx context.Context, keys []string) { + if len(keys) == 0 { + return + } + for _, key := range keys { + if key == "" { + continue + } + s.deleteAuthCache(ctx, s.authCacheKey(key)) + } +} diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index 578afc1a..ecc570c7 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -12,6 +12,8 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + "github.com/dgraph-io/ristretto" + "golang.org/x/sync/singleflight" ) var ( @@ -31,9 +33,11 @@ const ( type APIKeyRepository interface { Create(ctx context.Context, key *APIKey) error GetByID(ctx context.Context, id int64) (*APIKey, error) - // GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证 - GetOwnerID(ctx context.Context, id int64) (int64, error) + // GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID,用于删除等轻量场景 + GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) GetByKey(ctx context.Context, key string) (*APIKey, error) + // GetByKeyForAuth 认证专用查询,返回最小字段集 + GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) Update(ctx context.Context, key *APIKey) error Delete(ctx context.Context, id int64) error @@ -45,6 +49,8 @@ type APIKeyRepository interface { SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) CountByGroupID(ctx context.Context, groupID int64) (int64, error) + ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) + ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) } // APIKeyCache defines cache operations for API key service @@ -55,6 +61,17 @@ type APIKeyCache interface { IncrementDailyUsage(ctx context.Context, apiKey string) error SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error + + GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) + SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error + DeleteAuthCache(ctx context.Context, key string) error +} + +// APIKeyAuthCacheInvalidator 提供认证缓存失效能力 +type APIKeyAuthCacheInvalidator interface { + InvalidateAuthCacheByKey(ctx context.Context, key string) + InvalidateAuthCacheByUserID(ctx context.Context, userID int64) + InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) } // CreateAPIKeyRequest 创建API Key请求 @@ -83,6 +100,9 @@ type APIKeyService struct { userSubRepo UserSubscriptionRepository cache APIKeyCache cfg *config.Config + authCacheL1 *ristretto.Cache + authCfg apiKeyAuthCacheConfig + authGroup singleflight.Group } // NewAPIKeyService 创建API Key服务实例 @@ -94,7 +114,7 @@ func NewAPIKeyService( cache APIKeyCache, cfg *config.Config, ) *APIKeyService { - return &APIKeyService{ + svc := &APIKeyService{ apiKeyRepo: apiKeyRepo, userRepo: userRepo, groupRepo: groupRepo, @@ -102,6 +122,8 @@ func NewAPIKeyService( cache: cache, cfg: cfg, } + svc.initAuthCache(cfg) + return svc } // GenerateKey 生成随机API Key @@ -269,6 +291,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK return nil, fmt.Errorf("create api key: %w", err) } + s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + return apiKey, nil } @@ -304,21 +328,49 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) // GetByKey 根据Key字符串获取API Key(用于认证) func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) { - // 尝试从Redis缓存获取 - cacheKey := fmt.Sprintf("apikey:%s", key) + cacheKey := s.authCacheKey(key) - // 这里可以添加Redis缓存逻辑,暂时直接查询数据库 - apiKey, err := s.apiKeyRepo.GetByKey(ctx, key) + if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok { + if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + return apiKey, nil + } + } + + if s.authCfg.singleflight { + value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) { + return s.loadAuthCacheEntry(ctx, key, cacheKey) + }) + if err != nil { + return nil, err + } + entry, _ := value.(*APIKeyAuthCacheEntry) + if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + return apiKey, nil + } + } else { + entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey) + if err != nil { + return nil, err + } + if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + return apiKey, nil + } + } + + apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key) if err != nil { return nil, fmt.Errorf("get api key: %w", err) } - - // 缓存到Redis(可选,TTL设置为5分钟) - if s.cache != nil { - // 这里可以序列化并缓存API Key - _ = cacheKey // 使用变量避免未使用错误 - } - + apiKey.Key = key return apiKey, nil } @@ -388,15 +440,14 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req return nil, fmt.Errorf("update api key: %w", err) } + s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + return apiKey, nil } // Delete 删除API Key -// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证, -// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能 func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error { - // 仅获取所有者 ID 用于权限验证,而非加载完整对象 - ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id) + key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id) if err != nil { return fmt.Errorf("get api key: %w", err) } @@ -406,10 +457,11 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro return ErrInsufficientPerms } - // 清除Redis缓存(使用 ownerID 而非 apiKey.UserID) + // 清除Redis缓存(使用 userID 而非 apiKey.UserID) if s.cache != nil { - _ = s.cache.DeleteCreateAttemptCount(ctx, ownerID) + _ = s.cache.DeleteCreateAttemptCount(ctx, userID) } + s.InvalidateAuthCacheByKey(ctx, key) if err := s.apiKeyRepo.Delete(ctx, id); err != nil { return fmt.Errorf("delete api key: %w", err) diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go new file mode 100644 index 00000000..3314ca8d --- /dev/null +++ b/backend/internal/service/api_key_service_cache_test.go @@ -0,0 +1,417 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +type authRepoStub struct { + getByKeyForAuth func(ctx context.Context, key string) (*APIKey, error) + listKeysByUserID func(ctx context.Context, userID int64) ([]string, error) + listKeysByGroupID func(ctx context.Context, groupID int64) ([]string, error) +} + +func (s *authRepoStub) Create(ctx context.Context, key *APIKey) error { + panic("unexpected Create call") +} + +func (s *authRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) { + panic("unexpected GetByID call") +} + +func (s *authRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { + panic("unexpected GetKeyAndOwnerID call") +} + +func (s *authRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) { + panic("unexpected GetByKey call") +} + +func (s *authRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) { + if s.getByKeyForAuth == nil { + panic("unexpected GetByKeyForAuth call") + } + return s.getByKeyForAuth(ctx, key) +} + +func (s *authRepoStub) Update(ctx context.Context, key *APIKey) error { + panic("unexpected Update call") +} + +func (s *authRepoStub) Delete(ctx context.Context, id int64) error { + panic("unexpected Delete call") +} + +func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByUserID call") +} + +func (s *authRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { + panic("unexpected VerifyOwnership call") +} + +func (s *authRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) { + panic("unexpected CountByUserID call") +} + +func (s *authRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) { + panic("unexpected ExistsByKey call") +} + +func (s *authRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByGroupID call") +} + +func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) { + panic("unexpected SearchAPIKeys call") +} + +func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { + panic("unexpected ClearGroupIDByGroupID call") +} + +func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { + panic("unexpected CountByGroupID call") +} + +func (s *authRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + if s.listKeysByUserID == nil { + panic("unexpected ListKeysByUserID call") + } + return s.listKeysByUserID(ctx, userID) +} + +func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + if s.listKeysByGroupID == nil { + panic("unexpected ListKeysByGroupID call") + } + return s.listKeysByGroupID(ctx, groupID) +} + +type authCacheStub struct { + getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) + setAuthKeys []string + deleteAuthKeys []string +} + +func (s *authCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) { + return 0, nil +} + +func (s *authCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error { + return nil +} + +func (s *authCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error { + return nil +} + +func (s *authCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error { + return nil +} + +func (s *authCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error { + return nil +} + +func (s *authCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + if s.getAuthCache == nil { + return nil, redis.Nil + } + return s.getAuthCache(ctx, key) +} + +func (s *authCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error { + s.setAuthKeys = append(s.setAuthKeys, key) + return nil +} + +func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error { + s.deleteAuthKeys = append(s.deleteAuthKeys, key) + return nil +} + +func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + return nil, errors.New("unexpected repo call") + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + groupID := int64(9) + cacheEntry := &APIKeyAuthCacheEntry{ + Snapshot: &APIKeyAuthSnapshot{ + APIKeyID: 1, + UserID: 2, + GroupID: &groupID, + Status: StatusActive, + User: APIKeyAuthUserSnapshot{ + ID: 2, + Status: StatusActive, + Role: RoleUser, + Balance: 10, + Concurrency: 3, + }, + Group: &APIKeyAuthGroupSnapshot{ + ID: groupID, + Name: "g", + Platform: PlatformAnthropic, + Status: StatusActive, + SubscriptionType: SubscriptionTypeStandard, + RateMultiplier: 1, + }, + }, + } + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return cacheEntry, nil + } + + apiKey, err := svc.GetByKey(context.Background(), "k1") + require.NoError(t, err) + require.Equal(t, int64(1), apiKey.ID) + require.Equal(t, int64(2), apiKey.User.ID) + require.Equal(t, groupID, apiKey.Group.ID) +} + +func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + return nil, errors.New("unexpected repo call") + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return &APIKeyAuthCacheEntry{NotFound: true}, nil + } + + _, err := svc.GetByKey(context.Background(), "missing") + require.ErrorIs(t, err, ErrAPIKeyNotFound) +} + +func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + return &APIKey{ + ID: 5, + UserID: 7, + Status: StatusActive, + User: &User{ + ID: 7, + Status: StatusActive, + Role: RoleUser, + Balance: 12, + Concurrency: 2, + }, + }, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return nil, redis.Nil + } + + apiKey, err := svc.GetByKey(context.Background(), "k2") + require.NoError(t, err) + require.Equal(t, int64(5), apiKey.ID) + require.Len(t, cache.setAuthKeys, 1) +} + +func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) { + var calls int32 + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + atomic.AddInt32(&calls, 1) + return &APIKey{ + ID: 21, + UserID: 3, + Status: StatusActive, + User: &User{ + ID: 3, + Status: StatusActive, + Role: RoleUser, + Balance: 5, + Concurrency: 2, + }, + }, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L1Size: 1000, + L1TTLSeconds: 60, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + require.NotNil(t, svc.authCacheL1) + + _, err := svc.GetByKey(context.Background(), "k-l1") + require.NoError(t, err) + svc.authCacheL1.Wait() + cacheKey := svc.authCacheKey("k-l1") + _, ok := svc.authCacheL1.Get(cacheKey) + require.True(t, ok) + _, err = svc.GetByKey(context.Background(), "k-l1") + require.NoError(t, err) + require.Equal(t, int32(1), atomic.LoadInt32(&calls)) +} + +func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) { + return []string{"k1", "k2"}, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + svc.InvalidateAuthCacheByUserID(context.Background(), 7) + require.Len(t, cache.deleteAuthKeys, 2) +} + +func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + listKeysByGroupID: func(ctx context.Context, groupID int64) ([]string, error) { + return []string{"k1", "k2"}, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + svc.InvalidateAuthCacheByGroupID(context.Background(), 9) + require.Len(t, cache.deleteAuthKeys, 2) +} + +func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) { + return nil, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + svc.InvalidateAuthCacheByKey(context.Background(), "k1") + require.Len(t, cache.deleteAuthKeys, 1) +} + +func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + return nil, ErrAPIKeyNotFound + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return nil, redis.Nil + } + + _, err := svc.GetByKey(context.Background(), "missing") + require.ErrorIs(t, err, ErrAPIKeyNotFound) + require.Len(t, cache.setAuthKeys, 1) +} + +func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) { + var calls int32 + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + atomic.AddInt32(&calls, 1) + time.Sleep(50 * time.Millisecond) + return &APIKey{ + ID: 11, + UserID: 2, + Status: StatusActive, + User: &User{ + ID: 2, + Status: StatusActive, + Role: RoleUser, + Balance: 1, + Concurrency: 1, + }, + }, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + Singleflight: true, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + start := make(chan struct{}) + wg := sync.WaitGroup{} + errs := make([]error, 5) + for i := 0; i < 5; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + _, err := svc.GetByKey(context.Background(), "k1") + errs[idx] = err + }(i) + } + close(start) + wg.Wait() + + for _, err := range errs { + require.NoError(t, err) + } + require.Equal(t, int32(1), atomic.LoadInt32(&calls)) +} diff --git a/backend/internal/service/api_key_service_delete_test.go b/backend/internal/service/api_key_service_delete_test.go index 7d04c5ac..32ae884e 100644 --- a/backend/internal/service/api_key_service_delete_test.go +++ b/backend/internal/service/api_key_service_delete_test.go @@ -20,13 +20,12 @@ import ( // 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。 // // 设计说明: -// - ownerID: 模拟 GetOwnerID 返回的所有者 ID -// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound) +// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误 // - deleteErr: 模拟 Delete 返回的错误 // - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证 type apiKeyRepoStub struct { - ownerID int64 // GetOwnerID 的返回值 - ownerErr error // GetOwnerID 的错误返回值 + apiKey *APIKey // GetKeyAndOwnerID 的返回值 + getByIDErr error // GetKeyAndOwnerID 的错误返回值 deleteErr error // Delete 的错误返回值 deletedIDs []int64 // 记录已删除的 API Key ID 列表 } @@ -38,19 +37,34 @@ func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error { } func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) { + if s.getByIDErr != nil { + return nil, s.getByIDErr + } + if s.apiKey != nil { + clone := *s.apiKey + return &clone, nil + } panic("unexpected GetByID call") } -// GetOwnerID 返回预设的所有者 ID 或错误。 -// 这是 Delete 方法调用的第一个仓储方法,用于验证调用者是否为 API Key 的所有者。 -func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error) { - return s.ownerID, s.ownerErr +func (s *apiKeyRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { + if s.getByIDErr != nil { + return "", 0, s.getByIDErr + } + if s.apiKey != nil { + return s.apiKey.Key, s.apiKey.UserID, nil + } + return "", 0, ErrAPIKeyNotFound } func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) { panic("unexpected GetByKey call") } +func (s *apiKeyRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) { + panic("unexpected GetByKeyForAuth call") +} + func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error { panic("unexpected Update call") } @@ -96,13 +110,22 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int panic("unexpected CountByGroupID call") } +func (s *apiKeyRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + panic("unexpected ListKeysByUserID call") +} + +func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + panic("unexpected ListKeysByGroupID call") +} + // apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。 // // 设计说明: // - invalidated: 记录被清除缓存的用户 ID 列表 type apiKeyCacheStub struct { - invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID + invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID + deleteAuthKeys []string // 记录调用 DeleteAuthCache 时传入的缓存 key } // GetCreateAttemptCount 返回 0,表示用户未超过创建次数限制 @@ -132,15 +155,30 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string return nil } +func (s *apiKeyCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return nil, nil +} + +func (s *apiKeyCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error { + return nil +} + +func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error { + s.deleteAuthKeys = append(s.deleteAuthKeys, key) + return nil +} + // TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。 // 预期行为: -// - GetOwnerID 返回所有者 ID 为 1 +// - GetKeyAndOwnerID 返回所有者 ID 为 1 // - 调用者 userID 为 2(不匹配) // - 返回 ErrInsufficientPerms 错误 // - Delete 方法不被调用 // - 缓存不被清除 func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) { - repo := &apiKeyRepoStub{ownerID: 1} + repo := &apiKeyRepoStub{ + apiKey: &APIKey{ID: 10, UserID: 1, Key: "k"}, + } cache := &apiKeyCacheStub{} svc := &APIKeyService{apiKeyRepo: repo, cache: cache} @@ -148,17 +186,20 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) { require.ErrorIs(t, err, ErrInsufficientPerms) require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用 require.Empty(t, cache.invalidated) // 验证缓存未被清除 + require.Empty(t, cache.deleteAuthKeys) } // TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。 // 预期行为: -// - GetOwnerID 返回所有者 ID 为 7 +// - GetKeyAndOwnerID 返回所有者 ID 为 7 // - 调用者 userID 为 7(匹配) // - Delete 成功执行 // - 缓存被正确清除(使用 ownerID) // - 返回 nil 错误 func TestApiKeyService_Delete_Success(t *testing.T) { - repo := &apiKeyRepoStub{ownerID: 7} + repo := &apiKeyRepoStub{ + apiKey: &APIKey{ID: 42, UserID: 7, Key: "k"}, + } cache := &apiKeyCacheStub{} svc := &APIKeyService{apiKeyRepo: repo, cache: cache} @@ -166,16 +207,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) { require.NoError(t, err) require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除 require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除 + require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys) } // TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。 // 预期行为: -// - GetOwnerID 返回 ErrAPIKeyNotFound 错误 +// - GetKeyAndOwnerID 返回 ErrAPIKeyNotFound 错误 // - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装) // - Delete 方法不被调用 // - 缓存不被清除 func TestApiKeyService_Delete_NotFound(t *testing.T) { - repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound} + repo := &apiKeyRepoStub{getByIDErr: ErrAPIKeyNotFound} cache := &apiKeyCacheStub{} svc := &APIKeyService{apiKeyRepo: repo, cache: cache} @@ -183,18 +225,19 @@ func TestApiKeyService_Delete_NotFound(t *testing.T) { require.ErrorIs(t, err, ErrAPIKeyNotFound) require.Empty(t, repo.deletedIDs) require.Empty(t, cache.invalidated) + require.Empty(t, cache.deleteAuthKeys) } // TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。 // 预期行为: -// - GetOwnerID 返回正确的所有者 ID +// - GetKeyAndOwnerID 返回正确的所有者 ID // - 所有权验证通过 // - 缓存被清除(在删除之前) // - Delete 被调用但返回错误 // - 返回包含 "delete api key" 的错误信息 func TestApiKeyService_Delete_DeleteFails(t *testing.T) { repo := &apiKeyRepoStub{ - ownerID: 3, + apiKey: &APIKey{ID: 42, UserID: 3, Key: "k"}, deleteErr: errors.New("delete failed"), } cache := &apiKeyCacheStub{} @@ -205,4 +248,5 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) { require.ErrorContains(t, err, "delete api key") require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用 require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败) + require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys) } diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 2f0f4975..a9214c82 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -50,13 +50,15 @@ type UpdateGroupRequest struct { // GroupService 分组管理服务 type GroupService struct { - groupRepo GroupRepository + groupRepo GroupRepository + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewGroupService 创建分组服务实例 -func NewGroupService(groupRepo GroupRepository) *GroupService { +func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *GroupService { return &GroupService{ - groupRepo: groupRepo, + groupRepo: groupRepo, + authCacheInvalidator: authCacheInvalidator, } } @@ -155,6 +157,9 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ if err := s.groupRepo.Update(ctx, group); err != nil { return nil, fmt.Errorf("update group: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } return group, nil } @@ -170,6 +175,9 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error { if err := s.groupRepo.Delete(ctx, id); err != nil { return fmt.Errorf("delete group: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } return nil } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 08fa40b5..a7a36760 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -55,13 +55,15 @@ type ChangePasswordRequest struct { // UserService 用户服务 type UserService struct { - userRepo UserRepository + userRepo UserRepository + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewUserService 创建用户服务实例 -func NewUserService(userRepo UserRepository) *UserService { +func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService { return &UserService{ - userRepo: userRepo, + userRepo: userRepo, + authCacheInvalidator: authCacheInvalidator, } } @@ -89,6 +91,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat if err != nil { return nil, fmt.Errorf("get user: %w", err) } + oldConcurrency := user.Concurrency // 更新字段 if req.Email != nil { @@ -114,6 +117,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat if err := s.userRepo.Update(ctx, user); err != nil { return nil, fmt.Errorf("update user: %w", err) } + if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return user, nil } @@ -169,6 +175,9 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil { return fmt.Errorf("update balance: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return nil } @@ -177,6 +186,9 @@ func (s *UserService) UpdateConcurrency(ctx context.Context, userID int64, concu if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil { return fmt.Errorf("update concurrency: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return nil } @@ -192,6 +204,9 @@ func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status str if err := s.userRepo.Update(ctx, user); err != nil { return fmt.Errorf("update user: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return nil } @@ -201,5 +216,8 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error { if err := s.userRepo.Delete(ctx, userID); err != nil { return fmt.Errorf("delete user: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return nil } diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 512d2550..54c37b54 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -77,12 +77,18 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi return svc } +// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力 +func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator { + return apiKeyService +} + // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services NewAuthService, NewUserService, NewAPIKeyService, + ProvideAPIKeyAuthCacheInvalidator, NewGroupService, NewAccountService, NewProxyService, diff --git a/config.yaml b/config.yaml index 54b591f3..ecd7dfc2 100644 --- a/config.yaml +++ b/config.yaml @@ -170,6 +170,30 @@ gateway: # 允许在特定 400 错误时进行故障转移(默认:关闭) failover_on_400: false +# ============================================================================= +# API Key Auth Cache Configuration +# API Key 认证缓存配置 +# ============================================================================= +api_key_auth_cache: + # L1 cache size (entries), in-process LRU/TTL cache + # L1 缓存容量(条目数),进程内 LRU/TTL 缓存 + l1_size: 65535 + # L1 cache TTL (seconds) + # L1 缓存 TTL(秒) + l1_ttl_seconds: 15 + # L2 cache TTL (seconds), stored in Redis + # L2 缓存 TTL(秒),Redis 中存储 + l2_ttl_seconds: 300 + # Negative cache TTL (seconds) + # 负缓存 TTL(秒) + negative_ttl_seconds: 30 + # TTL jitter percent (0-100) + # TTL 抖动百分比(0-100) + jitter_percent: 10 + # Enable singleflight for cache misses + # 缓存未命中时启用 singleflight 合并回源 + singleflight: true + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置 diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 87ff3148..87abffa0 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -170,6 +170,30 @@ gateway: # 允许在特定 400 错误时进行故障转移(默认:关闭) failover_on_400: false +# ============================================================================= +# API Key Auth Cache Configuration +# API Key 认证缓存配置 +# ============================================================================= +api_key_auth_cache: + # L1 cache size (entries), in-process LRU/TTL cache + # L1 缓存容量(条目数),进程内 LRU/TTL 缓存 + l1_size: 65535 + # L1 cache TTL (seconds) + # L1 缓存 TTL(秒) + l1_ttl_seconds: 15 + # L2 cache TTL (seconds), stored in Redis + # L2 缓存 TTL(秒),Redis 中存储 + l2_ttl_seconds: 300 + # Negative cache TTL (seconds) + # 负缓存 TTL(秒) + negative_ttl_seconds: 30 + # TTL jitter percent (0-100) + # TTL 抖动百分比(0-100) + jitter_percent: 10 + # Enable singleflight for cache misses + # 缓存未命中时启用 singleflight 合并回源 + singleflight: true + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置