perf(认证): 引入 API Key 认证缓存与轻量删除查询
增加 L1/L2 缓存、负缓存与单飞回源 使用 key+owner 轻量查询替代全量加载并清理旧接口 补充缓存失效与余额更新测试,修复随机抖动 lint 测试: make test
This commit is contained in:
@@ -57,13 +57,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
|
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
|
||||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client)
|
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client)
|
||||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
||||||
userService := service.NewUserService(userRepository)
|
|
||||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
|
|
||||||
userHandler := handler.NewUserHandler(userService)
|
|
||||||
apiKeyRepository := repository.NewAPIKeyRepository(client)
|
apiKeyRepository := repository.NewAPIKeyRepository(client)
|
||||||
groupRepository := repository.NewGroupRepository(client, db)
|
groupRepository := repository.NewGroupRepository(client, db)
|
||||||
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
||||||
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
||||||
|
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||||
|
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
|
||||||
|
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
|
||||||
|
userHandler := handler.NewUserHandler(userService)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client)
|
usageService := service.NewUsageService(usageLogRepository, userRepository, client)
|
||||||
@@ -79,7 +80,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
accountRepository := repository.NewAccountRepository(client, db)
|
accountRepository := repository.NewAccountRepository(client, db)
|
||||||
proxyRepository := repository.NewProxyRepository(client, db)
|
proxyRepository := repository.NewProxyRepository(client, db)
|
||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, apiKeyAuthCacheInvalidator)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService)
|
adminUserHandler := admin.NewUserHandler(adminService)
|
||||||
groupHandler := admin.NewGroupHandler(adminService)
|
groupHandler := admin.NewGroupHandler(adminService)
|
||||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||||
|
|||||||
@@ -44,11 +44,13 @@ require (
|
|||||||
github.com/containerd/platforms v0.2.1 // indirect
|
github.com/containerd/platforms v0.2.1 // indirect
|
||||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||||
|
github.com/dgraph-io/ristretto v0.2.0 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/distribution/reference v0.6.0 // indirect
|
github.com/distribution/reference v0.6.0 // indirect
|
||||||
github.com/docker/docker v28.5.1+incompatible // indirect
|
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||||
github.com/docker/go-connections v0.6.0 // indirect
|
github.com/docker/go-connections v0.6.0 // indirect
|
||||||
github.com/docker/go-units v0.5.0 // indirect
|
github.com/docker/go-units v0.5.0 // indirect
|
||||||
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/ebitengine/purego v0.8.4 // indirect
|
github.com/ebitengine/purego v0.8.4 // indirect
|
||||||
github.com/fatih/color v1.18.0 // indirect
|
github.com/fatih/color v1.18.0 // indirect
|
||||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
|
|||||||
@@ -51,6 +51,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
|
|||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE=
|
||||||
|
github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
@@ -61,6 +63,8 @@ github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pM
|
|||||||
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
||||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||||
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
||||||
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||||
|
|||||||
@@ -36,25 +36,26 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig `mapstructure:"server"`
|
Server ServerConfig `mapstructure:"server"`
|
||||||
CORS CORSConfig `mapstructure:"cors"`
|
CORS CORSConfig `mapstructure:"cors"`
|
||||||
Security SecurityConfig `mapstructure:"security"`
|
Security SecurityConfig `mapstructure:"security"`
|
||||||
Billing BillingConfig `mapstructure:"billing"`
|
Billing BillingConfig `mapstructure:"billing"`
|
||||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||||
Database DatabaseConfig `mapstructure:"database"`
|
Database DatabaseConfig `mapstructure:"database"`
|
||||||
Redis RedisConfig `mapstructure:"redis"`
|
Redis RedisConfig `mapstructure:"redis"`
|
||||||
JWT JWTConfig `mapstructure:"jwt"`
|
JWT JWTConfig `mapstructure:"jwt"`
|
||||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||||
Default DefaultConfig `mapstructure:"default"`
|
Default DefaultConfig `mapstructure:"default"`
|
||||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||||
Pricing PricingConfig `mapstructure:"pricing"`
|
Pricing PricingConfig `mapstructure:"pricing"`
|
||||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
|
||||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||||
Update UpdateConfig `mapstructure:"update"`
|
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||||
|
Update UpdateConfig `mapstructure:"update"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateConfig 在线更新相关配置
|
// UpdateConfig 在线更新相关配置
|
||||||
@@ -361,6 +362,16 @@ type RateLimitConfig struct {
|
|||||||
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
|
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// APIKeyAuthCacheConfig API Key 认证缓存配置
|
||||||
|
type APIKeyAuthCacheConfig struct {
|
||||||
|
L1Size int `mapstructure:"l1_size"`
|
||||||
|
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
|
||||||
|
L2TTLSeconds int `mapstructure:"l2_ttl_seconds"`
|
||||||
|
NegativeTTLSeconds int `mapstructure:"negative_ttl_seconds"`
|
||||||
|
JitterPercent int `mapstructure:"jitter_percent"`
|
||||||
|
Singleflight bool `mapstructure:"singleflight"`
|
||||||
|
}
|
||||||
|
|
||||||
func NormalizeRunMode(value string) string {
|
func NormalizeRunMode(value string) string {
|
||||||
normalized := strings.ToLower(strings.TrimSpace(value))
|
normalized := strings.ToLower(strings.TrimSpace(value))
|
||||||
switch normalized {
|
switch normalized {
|
||||||
@@ -655,6 +666,14 @@ func setDefaults() {
|
|||||||
// Timezone (default to Asia/Shanghai for Chinese users)
|
// Timezone (default to Asia/Shanghai for Chinese users)
|
||||||
viper.SetDefault("timezone", "Asia/Shanghai")
|
viper.SetDefault("timezone", "Asia/Shanghai")
|
||||||
|
|
||||||
|
// API Key auth cache
|
||||||
|
viper.SetDefault("api_key_auth_cache.l1_size", 65535)
|
||||||
|
viper.SetDefault("api_key_auth_cache.l1_ttl_seconds", 15)
|
||||||
|
viper.SetDefault("api_key_auth_cache.l2_ttl_seconds", 300)
|
||||||
|
viper.SetDefault("api_key_auth_cache.negative_ttl_seconds", 30)
|
||||||
|
viper.SetDefault("api_key_auth_cache.jitter_percent", 10)
|
||||||
|
viper.SetDefault("api_key_auth_cache.singleflight", true)
|
||||||
|
|
||||||
// Gateway
|
// Gateway
|
||||||
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久
|
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||||
viper.SetDefault("gateway.log_upstream_error_body", false)
|
viper.SetDefault("gateway.log_upstream_error_body", false)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
@@ -13,6 +14,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||||
apiKeyRateLimitDuration = 24 * time.Hour
|
apiKeyRateLimitDuration = 24 * time.Hour
|
||||||
|
apiKeyAuthCachePrefix = "apikey:auth:"
|
||||||
)
|
)
|
||||||
|
|
||||||
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
||||||
@@ -20,6 +22,10 @@ func apiKeyRateLimitKey(userID int64) string {
|
|||||||
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func apiKeyAuthCacheKey(key string) string {
|
||||||
|
return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key)
|
||||||
|
}
|
||||||
|
|
||||||
type apiKeyCache struct {
|
type apiKeyCache struct {
|
||||||
rdb *redis.Client
|
rdb *redis.Client
|
||||||
}
|
}
|
||||||
@@ -58,3 +64,30 @@ func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) er
|
|||||||
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||||
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
|
||||||
|
val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var entry service.APIKeyAuthCacheEntry
|
||||||
|
if err := json.Unmarshal(val, &entry); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &entry, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||||
|
if entry == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload, err := json.Marshal(entry)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
||||||
|
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import (
|
|||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
@@ -64,23 +66,23 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK
|
|||||||
return apiKeyEntityToService(m), nil
|
return apiKeyEntityToService(m), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
|
// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者(用户)ID。
|
||||||
// 相比 GetByID,此方法性能更优,因为:
|
// 相比 GetByID,此方法性能更优,因为:
|
||||||
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
|
// - 使用 Select() 只查询必要字段,减少数据传输量
|
||||||
// - 不加载完整的 API Key 实体及其关联数据(User、Group 等)
|
// - 不加载完整的 API Key 实体及其关联数据(User、Group 等)
|
||||||
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
|
// - 适用于删除等只需 key 与用户 ID 的场景
|
||||||
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
m, err := r.activeQuery().
|
m, err := r.activeQuery().
|
||||||
Where(apikey.IDEQ(id)).
|
Where(apikey.IDEQ(id)).
|
||||||
Select(apikey.FieldUserID).
|
Select(apikey.FieldKey, apikey.FieldUserID).
|
||||||
Only(ctx)
|
Only(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if dbent.IsNotFound(err) {
|
if dbent.IsNotFound(err) {
|
||||||
return 0, service.ErrAPIKeyNotFound
|
return "", 0, service.ErrAPIKeyNotFound
|
||||||
}
|
}
|
||||||
return 0, err
|
return "", 0, err
|
||||||
}
|
}
|
||||||
return m.UserID, nil
|
return m.Key, m.UserID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
@@ -98,6 +100,54 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
|
|||||||
return apiKeyEntityToService(m), nil
|
return apiKeyEntityToService(m), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
|
m, err := r.activeQuery().
|
||||||
|
Where(apikey.KeyEQ(key)).
|
||||||
|
Select(
|
||||||
|
apikey.FieldID,
|
||||||
|
apikey.FieldUserID,
|
||||||
|
apikey.FieldGroupID,
|
||||||
|
apikey.FieldStatus,
|
||||||
|
apikey.FieldIPWhitelist,
|
||||||
|
apikey.FieldIPBlacklist,
|
||||||
|
).
|
||||||
|
WithUser(func(q *dbent.UserQuery) {
|
||||||
|
q.Select(
|
||||||
|
user.FieldID,
|
||||||
|
user.FieldStatus,
|
||||||
|
user.FieldRole,
|
||||||
|
user.FieldBalance,
|
||||||
|
user.FieldConcurrency,
|
||||||
|
)
|
||||||
|
}).
|
||||||
|
WithGroup(func(q *dbent.GroupQuery) {
|
||||||
|
q.Select(
|
||||||
|
group.FieldID,
|
||||||
|
group.FieldName,
|
||||||
|
group.FieldPlatform,
|
||||||
|
group.FieldStatus,
|
||||||
|
group.FieldSubscriptionType,
|
||||||
|
group.FieldRateMultiplier,
|
||||||
|
group.FieldDailyLimitUsd,
|
||||||
|
group.FieldWeeklyLimitUsd,
|
||||||
|
group.FieldMonthlyLimitUsd,
|
||||||
|
group.FieldImagePrice1k,
|
||||||
|
group.FieldImagePrice2k,
|
||||||
|
group.FieldImagePrice4k,
|
||||||
|
group.FieldClaudeCodeOnly,
|
||||||
|
group.FieldFallbackGroupID,
|
||||||
|
)
|
||||||
|
}).
|
||||||
|
Only(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if dbent.IsNotFound(err) {
|
||||||
|
return nil, service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return apiKeyEntityToService(m), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
|
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
|
||||||
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
|
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
|
||||||
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
|
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
|
||||||
@@ -283,6 +333,28 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
|
|||||||
return int64(count), err
|
return int64(count), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
keys, err := r.activeQuery().
|
||||||
|
Where(apikey.UserIDEQ(userID)).
|
||||||
|
Select(apikey.FieldKey).
|
||||||
|
Strings(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
keys, err := r.activeQuery().
|
||||||
|
Where(apikey.GroupIDEQ(groupID)).
|
||||||
|
Select(apikey.FieldKey).
|
||||||
|
Strings(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -389,7 +389,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
RunMode: config.RunModeStandard,
|
RunMode: config.RunModeStandard,
|
||||||
}
|
}
|
||||||
|
|
||||||
userService := service.NewUserService(userRepo)
|
userService := service.NewUserService(userRepo, nil)
|
||||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
|
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
|
||||||
|
|
||||||
usageRepo := newStubUsageLogRepo()
|
usageRepo := newStubUsageLogRepo()
|
||||||
@@ -565,6 +565,18 @@ func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, t
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (stubApiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubApiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type stubGroupRepo struct{}
|
type stubGroupRepo struct{}
|
||||||
|
|
||||||
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
|
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
|
||||||
@@ -737,12 +749,12 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey
|
|||||||
return &clone, nil
|
return &clone, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
key, ok := r.byID[id]
|
key, ok := r.byID[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, service.ErrAPIKeyNotFound
|
return "", 0, service.ErrAPIKeyNotFound
|
||||||
}
|
}
|
||||||
return key.UserID, nil
|
return key.Key, key.UserID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
@@ -754,6 +766,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API
|
|||||||
return &clone, nil
|
return &clone, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
|
return r.GetByKey(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||||
if key == nil {
|
if key == nil {
|
||||||
return errors.New("nil key")
|
return errors.New("nil key")
|
||||||
@@ -868,6 +884,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int
|
|||||||
return 0, errors.New("not implemented")
|
return 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubUsageLogRepo struct {
|
type stubUsageLogRepo struct {
|
||||||
userLogs map[int64][]service.UsageLog
|
userLogs map[int64][]service.UsageLog
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
|||||||
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
func (f fakeAPIKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
return 0, errors.New("not implemented")
|
return "", 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
if f.getByKey == nil {
|
if f.getByKey == nil {
|
||||||
@@ -36,6 +36,9 @@ func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIK
|
|||||||
}
|
}
|
||||||
return f.getByKey(ctx, key)
|
return f.getByKey(ctx, key)
|
||||||
}
|
}
|
||||||
|
func (f fakeAPIKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
|
return f.GetByKey(ctx, key)
|
||||||
|
}
|
||||||
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
@@ -66,6 +69,12 @@ func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64
|
|||||||
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
return 0, errors.New("not implemented")
|
return 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type googleErrorResponse struct {
|
type googleErrorResponse struct {
|
||||||
Error struct {
|
Error struct {
|
||||||
|
|||||||
@@ -256,8 +256,8 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
return 0, errors.New("not implemented")
|
return "", 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
@@ -267,6 +267,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
|
return r.GetByKey(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
@@ -307,6 +311,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int
|
|||||||
return 0, errors.New("not implemented")
|
return 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubUserSubscriptionRepo struct {
|
type stubUserSubscriptionRepo struct {
|
||||||
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
||||||
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
||||||
|
|||||||
@@ -244,14 +244,15 @@ type ProxyExitInfoProber interface {
|
|||||||
|
|
||||||
// adminServiceImpl implements AdminService
|
// adminServiceImpl implements AdminService
|
||||||
type adminServiceImpl struct {
|
type adminServiceImpl struct {
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
groupRepo GroupRepository
|
groupRepo GroupRepository
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
proxyRepo ProxyRepository
|
proxyRepo ProxyRepository
|
||||||
apiKeyRepo APIKeyRepository
|
apiKeyRepo APIKeyRepository
|
||||||
redeemCodeRepo RedeemCodeRepository
|
redeemCodeRepo RedeemCodeRepository
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
proxyProber ProxyExitInfoProber
|
proxyProber ProxyExitInfoProber
|
||||||
|
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAdminService creates a new AdminService
|
// NewAdminService creates a new AdminService
|
||||||
@@ -264,16 +265,18 @@ func NewAdminService(
|
|||||||
redeemCodeRepo RedeemCodeRepository,
|
redeemCodeRepo RedeemCodeRepository,
|
||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
proxyProber ProxyExitInfoProber,
|
proxyProber ProxyExitInfoProber,
|
||||||
|
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||||
) AdminService {
|
) AdminService {
|
||||||
return &adminServiceImpl{
|
return &adminServiceImpl{
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
proxyRepo: proxyRepo,
|
proxyRepo: proxyRepo,
|
||||||
apiKeyRepo: apiKeyRepo,
|
apiKeyRepo: apiKeyRepo,
|
||||||
redeemCodeRepo: redeemCodeRepo,
|
redeemCodeRepo: redeemCodeRepo,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
proxyProber: proxyProber,
|
proxyProber: proxyProber,
|
||||||
|
authCacheInvalidator: authCacheInvalidator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -323,6 +326,8 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
}
|
}
|
||||||
|
|
||||||
oldConcurrency := user.Concurrency
|
oldConcurrency := user.Concurrency
|
||||||
|
oldStatus := user.Status
|
||||||
|
oldRole := user.Role
|
||||||
|
|
||||||
if input.Email != "" {
|
if input.Email != "" {
|
||||||
user.Email = input.Email
|
user.Email = input.Email
|
||||||
@@ -355,6 +360,11 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
concurrencyDiff := user.Concurrency - oldConcurrency
|
concurrencyDiff := user.Concurrency - oldConcurrency
|
||||||
if concurrencyDiff != 0 {
|
if concurrencyDiff != 0 {
|
||||||
@@ -393,6 +403,9 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
|
|||||||
log.Printf("delete user failed: user_id=%d err=%v", id, err)
|
log.Printf("delete user failed: user_id=%d err=%v", id, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -420,6 +433,10 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
|||||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
balanceDiff := user.Balance - oldBalance
|
||||||
|
if s.authCacheInvalidator != nil && balanceDiff != 0 {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
if s.billingCacheService != nil {
|
if s.billingCacheService != nil {
|
||||||
go func() {
|
go func() {
|
||||||
@@ -431,7 +448,6 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
balanceDiff := user.Balance - oldBalance
|
|
||||||
if balanceDiff != 0 {
|
if balanceDiff != 0 {
|
||||||
code, err := GenerateRedeemCode()
|
code, err := GenerateRedeemCode()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -675,10 +691,21 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||||
|
}
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||||
|
var groupKeys []string
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id)
|
||||||
|
if err == nil {
|
||||||
|
groupKeys = keys
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
|
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -697,6 +724,11 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
for _, key := range groupKeys {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,97 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type balanceUserRepoStub struct {
|
||||||
|
*userRepoStub
|
||||||
|
updateErr error
|
||||||
|
updated []*User
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *balanceUserRepoStub) Update(ctx context.Context, user *User) error {
|
||||||
|
if s.updateErr != nil {
|
||||||
|
return s.updateErr
|
||||||
|
}
|
||||||
|
if user == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
clone := *user
|
||||||
|
s.updated = append(s.updated, &clone)
|
||||||
|
if s.userRepoStub != nil {
|
||||||
|
s.userRepoStub.user = &clone
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type balanceRedeemRepoStub struct {
|
||||||
|
*redeemRepoStub
|
||||||
|
created []*RedeemCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *balanceRedeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
|
||||||
|
if code == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
clone := *code
|
||||||
|
s.created = append(s.created, &clone)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type authCacheInvalidatorStub struct {
|
||||||
|
userIDs []int64
|
||||||
|
groupIDs []int64
|
||||||
|
keys []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByKey(ctx context.Context, key string) {
|
||||||
|
s.keys = append(s.keys, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
|
||||||
|
s.userIDs = append(s.userIDs, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
|
||||||
|
s.groupIDs = append(s.groupIDs, groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateUserBalance_InvalidatesAuthCache(t *testing.T) {
|
||||||
|
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
|
||||||
|
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
|
||||||
|
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
|
||||||
|
invalidator := &authCacheInvalidatorStub{}
|
||||||
|
svc := &adminServiceImpl{
|
||||||
|
userRepo: repo,
|
||||||
|
redeemCodeRepo: redeemRepo,
|
||||||
|
authCacheInvalidator: invalidator,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.UpdateUserBalance(context.Background(), 7, 5, "add", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []int64{7}, invalidator.userIDs)
|
||||||
|
require.Len(t, redeemRepo.created, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateUserBalance_NoChangeNoInvalidate(t *testing.T) {
|
||||||
|
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
|
||||||
|
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
|
||||||
|
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
|
||||||
|
invalidator := &authCacheInvalidatorStub{}
|
||||||
|
svc := &adminServiceImpl{
|
||||||
|
userRepo: repo,
|
||||||
|
redeemCodeRepo: redeemRepo,
|
||||||
|
authCacheInvalidator: invalidator,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.UpdateUserBalance(context.Background(), 7, 10, "set", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, invalidator.userIDs)
|
||||||
|
require.Empty(t, redeemRepo.created)
|
||||||
|
}
|
||||||
46
backend/internal/service/api_key_auth_cache.go
Normal file
46
backend/internal/service/api_key_auth_cache.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
|
||||||
|
type APIKeyAuthSnapshot struct {
|
||||||
|
APIKeyID int64 `json:"api_key_id"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
GroupID *int64 `json:"group_id,omitempty"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
IPWhitelist []string `json:"ip_whitelist,omitempty"`
|
||||||
|
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
||||||
|
User APIKeyAuthUserSnapshot `json:"user"`
|
||||||
|
Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIKeyAuthUserSnapshot 用户快照
|
||||||
|
type APIKeyAuthUserSnapshot struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Balance float64 `json:"balance"`
|
||||||
|
Concurrency int `json:"concurrency"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIKeyAuthGroupSnapshot 分组快照
|
||||||
|
type APIKeyAuthGroupSnapshot struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Platform string `json:"platform"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
SubscriptionType string `json:"subscription_type"`
|
||||||
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
|
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||||
|
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||||
|
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||||
|
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||||
|
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||||
|
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||||
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
|
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||||
|
type APIKeyAuthCacheEntry struct {
|
||||||
|
NotFound bool `json:"not_found"`
|
||||||
|
Snapshot *APIKeyAuthSnapshot `json:"snapshot,omitempty"`
|
||||||
|
}
|
||||||
269
backend/internal/service/api_key_auth_cache_impl.go
Normal file
269
backend/internal/service/api_key_auth_cache_impl.go
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/dgraph-io/ristretto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type apiKeyAuthCacheConfig struct {
|
||||||
|
l1Size int
|
||||||
|
l1TTL time.Duration
|
||||||
|
l2TTL time.Duration
|
||||||
|
negativeTTL time.Duration
|
||||||
|
jitterPercent int
|
||||||
|
singleflight bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
jitterRandMu sync.Mutex
|
||||||
|
// 认证缓存抖动使用独立随机源,避免全局 Seed
|
||||||
|
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
|
||||||
|
if cfg == nil {
|
||||||
|
return apiKeyAuthCacheConfig{}
|
||||||
|
}
|
||||||
|
auth := cfg.APIKeyAuth
|
||||||
|
return apiKeyAuthCacheConfig{
|
||||||
|
l1Size: auth.L1Size,
|
||||||
|
l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second,
|
||||||
|
l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second,
|
||||||
|
negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second,
|
||||||
|
jitterPercent: auth.JitterPercent,
|
||||||
|
singleflight: auth.Singleflight,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c apiKeyAuthCacheConfig) l1Enabled() bool {
|
||||||
|
return c.l1Size > 0 && c.l1TTL > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c apiKeyAuthCacheConfig) l2Enabled() bool {
|
||||||
|
return c.l2TTL > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
|
||||||
|
return c.negativeTTL > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
|
||||||
|
if ttl <= 0 {
|
||||||
|
return ttl
|
||||||
|
}
|
||||||
|
if c.jitterPercent <= 0 {
|
||||||
|
return ttl
|
||||||
|
}
|
||||||
|
percent := c.jitterPercent
|
||||||
|
if percent > 100 {
|
||||||
|
percent = 100
|
||||||
|
}
|
||||||
|
delta := float64(percent) / 100
|
||||||
|
jitterRandMu.Lock()
|
||||||
|
randVal := jitterRand.Float64()
|
||||||
|
jitterRandMu.Unlock()
|
||||||
|
factor := 1 - delta + randVal*(2*delta)
|
||||||
|
if factor <= 0 {
|
||||||
|
return ttl
|
||||||
|
}
|
||||||
|
return time.Duration(float64(ttl) * factor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) initAuthCache(cfg *config.Config) {
|
||||||
|
s.authCfg = newAPIKeyAuthCacheConfig(cfg)
|
||||||
|
if !s.authCfg.l1Enabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cache, err := ristretto.NewCache(&ristretto.Config{
|
||||||
|
NumCounters: int64(s.authCfg.l1Size) * 10,
|
||||||
|
MaxCost: int64(s.authCfg.l1Size),
|
||||||
|
BufferItems: 64,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.authCacheL1 = cache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) authCacheKey(key string) string {
|
||||||
|
sum := sha256.Sum256([]byte(key))
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) {
|
||||||
|
if s.authCacheL1 != nil {
|
||||||
|
if val, ok := s.authCacheL1.Get(cacheKey); ok {
|
||||||
|
if entry, ok := val.(*APIKeyAuthCacheEntry); ok {
|
||||||
|
return entry, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.cache == nil || !s.authCfg.l2Enabled() {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
entry, err := s.cache.GetAuthCache(ctx, cacheKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
s.setAuthCacheL1(cacheKey, entry)
|
||||||
|
return entry, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) {
|
||||||
|
if s.authCacheL1 == nil || entry == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ttl := s.authCfg.l1TTL
|
||||||
|
if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl {
|
||||||
|
ttl = s.authCfg.negativeTTL
|
||||||
|
}
|
||||||
|
ttl = s.authCfg.jitterTTL(ttl)
|
||||||
|
_ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) {
|
||||||
|
if entry == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.setAuthCacheL1(cacheKey, entry)
|
||||||
|
if s.cache == nil || !s.authCfg.l2Enabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
|
||||||
|
if s.authCacheL1 != nil {
|
||||||
|
s.authCacheL1.Del(cacheKey)
|
||||||
|
}
|
||||||
|
if s.cache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrAPIKeyNotFound) {
|
||||||
|
entry := &APIKeyAuthCacheEntry{NotFound: true}
|
||||||
|
if s.authCfg.negativeEnabled() {
|
||||||
|
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL)
|
||||||
|
}
|
||||||
|
return entry, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("get api key: %w", err)
|
||||||
|
}
|
||||||
|
apiKey.Key = key
|
||||||
|
snapshot := s.snapshotFromAPIKey(apiKey)
|
||||||
|
if snapshot == nil {
|
||||||
|
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
|
||||||
|
}
|
||||||
|
entry := &APIKeyAuthCacheEntry{Snapshot: snapshot}
|
||||||
|
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL)
|
||||||
|
return entry, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) {
|
||||||
|
if entry == nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
if entry.NotFound {
|
||||||
|
return nil, true, ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
if entry.Snapshot == nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||||
|
if apiKey == nil || apiKey.User == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
snapshot := &APIKeyAuthSnapshot{
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: apiKey.UserID,
|
||||||
|
GroupID: apiKey.GroupID,
|
||||||
|
Status: apiKey.Status,
|
||||||
|
IPWhitelist: apiKey.IPWhitelist,
|
||||||
|
IPBlacklist: apiKey.IPBlacklist,
|
||||||
|
User: APIKeyAuthUserSnapshot{
|
||||||
|
ID: apiKey.User.ID,
|
||||||
|
Status: apiKey.User.Status,
|
||||||
|
Role: apiKey.User.Role,
|
||||||
|
Balance: apiKey.User.Balance,
|
||||||
|
Concurrency: apiKey.User.Concurrency,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if apiKey.Group != nil {
|
||||||
|
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||||
|
ID: apiKey.Group.ID,
|
||||||
|
Name: apiKey.Group.Name,
|
||||||
|
Platform: apiKey.Group.Platform,
|
||||||
|
Status: apiKey.Group.Status,
|
||||||
|
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||||
|
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||||
|
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||||
|
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||||
|
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||||
|
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||||
|
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||||
|
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||||
|
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||||
|
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey {
|
||||||
|
if snapshot == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
apiKey := &APIKey{
|
||||||
|
ID: snapshot.APIKeyID,
|
||||||
|
UserID: snapshot.UserID,
|
||||||
|
GroupID: snapshot.GroupID,
|
||||||
|
Key: key,
|
||||||
|
Status: snapshot.Status,
|
||||||
|
IPWhitelist: snapshot.IPWhitelist,
|
||||||
|
IPBlacklist: snapshot.IPBlacklist,
|
||||||
|
User: &User{
|
||||||
|
ID: snapshot.User.ID,
|
||||||
|
Status: snapshot.User.Status,
|
||||||
|
Role: snapshot.User.Role,
|
||||||
|
Balance: snapshot.User.Balance,
|
||||||
|
Concurrency: snapshot.User.Concurrency,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if snapshot.Group != nil {
|
||||||
|
apiKey.Group = &Group{
|
||||||
|
ID: snapshot.Group.ID,
|
||||||
|
Name: snapshot.Group.Name,
|
||||||
|
Platform: snapshot.Group.Platform,
|
||||||
|
Status: snapshot.Group.Status,
|
||||||
|
Hydrated: true,
|
||||||
|
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||||
|
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||||
|
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||||
|
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||||
|
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||||
|
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||||
|
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||||
|
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||||
|
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||||
|
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return apiKey
|
||||||
|
}
|
||||||
48
backend/internal/service/api_key_auth_cache_invalidate.go
Normal file
48
backend/internal/service/api_key_auth_cache_invalidate.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存
|
||||||
|
func (s *APIKeyService) InvalidateAuthCacheByKey(ctx context.Context, key string) {
|
||||||
|
if key == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cacheKey := s.authCacheKey(key)
|
||||||
|
s.deleteAuthCache(ctx, cacheKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存
|
||||||
|
func (s *APIKeyService) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
|
||||||
|
if userID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
keys, err := s.apiKeyRepo.ListKeysByUserID(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.deleteAuthCacheByKeys(ctx, keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存
|
||||||
|
func (s *APIKeyService) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
|
||||||
|
if groupID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, groupID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.deleteAuthCacheByKeys(ctx, keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) deleteAuthCacheByKeys(ctx context.Context, keys []string) {
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, key := range keys {
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.deleteAuthCache(ctx, s.authCacheKey(key))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,6 +12,8 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
|
"github.com/dgraph-io/ristretto"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -31,9 +33,11 @@ const (
|
|||||||
type APIKeyRepository interface {
|
type APIKeyRepository interface {
|
||||||
Create(ctx context.Context, key *APIKey) error
|
Create(ctx context.Context, key *APIKey) error
|
||||||
GetByID(ctx context.Context, id int64) (*APIKey, error)
|
GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||||
// GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证
|
// GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID,用于删除等轻量场景
|
||||||
GetOwnerID(ctx context.Context, id int64) (int64, error)
|
GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error)
|
||||||
GetByKey(ctx context.Context, key string) (*APIKey, error)
|
GetByKey(ctx context.Context, key string) (*APIKey, error)
|
||||||
|
// GetByKeyForAuth 认证专用查询,返回最小字段集
|
||||||
|
GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error)
|
||||||
Update(ctx context.Context, key *APIKey) error
|
Update(ctx context.Context, key *APIKey) error
|
||||||
Delete(ctx context.Context, id int64) error
|
Delete(ctx context.Context, id int64) error
|
||||||
|
|
||||||
@@ -45,6 +49,8 @@ type APIKeyRepository interface {
|
|||||||
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
|
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
|
||||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||||
|
ListKeysByUserID(ctx context.Context, userID int64) ([]string, error)
|
||||||
|
ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyCache defines cache operations for API key service
|
// APIKeyCache defines cache operations for API key service
|
||||||
@@ -55,6 +61,17 @@ type APIKeyCache interface {
|
|||||||
|
|
||||||
IncrementDailyUsage(ctx context.Context, apiKey string) error
|
IncrementDailyUsage(ctx context.Context, apiKey string) error
|
||||||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||||
|
|
||||||
|
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||||
|
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
|
||||||
|
DeleteAuthCache(ctx context.Context, key string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
|
||||||
|
type APIKeyAuthCacheInvalidator interface {
|
||||||
|
InvalidateAuthCacheByKey(ctx context.Context, key string)
|
||||||
|
InvalidateAuthCacheByUserID(ctx context.Context, userID int64)
|
||||||
|
InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAPIKeyRequest 创建API Key请求
|
// CreateAPIKeyRequest 创建API Key请求
|
||||||
@@ -83,6 +100,9 @@ type APIKeyService struct {
|
|||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
cache APIKeyCache
|
cache APIKeyCache
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
|
authCacheL1 *ristretto.Cache
|
||||||
|
authCfg apiKeyAuthCacheConfig
|
||||||
|
authGroup singleflight.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAPIKeyService 创建API Key服务实例
|
// NewAPIKeyService 创建API Key服务实例
|
||||||
@@ -94,7 +114,7 @@ func NewAPIKeyService(
|
|||||||
cache APIKeyCache,
|
cache APIKeyCache,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
) *APIKeyService {
|
) *APIKeyService {
|
||||||
return &APIKeyService{
|
svc := &APIKeyService{
|
||||||
apiKeyRepo: apiKeyRepo,
|
apiKeyRepo: apiKeyRepo,
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
@@ -102,6 +122,8 @@ func NewAPIKeyService(
|
|||||||
cache: cache,
|
cache: cache,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
}
|
}
|
||||||
|
svc.initAuthCache(cfg)
|
||||||
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateKey 生成随机API Key
|
// GenerateKey 生成随机API Key
|
||||||
@@ -269,6 +291,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
|||||||
return nil, fmt.Errorf("create api key: %w", err)
|
return nil, fmt.Errorf("create api key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||||
|
|
||||||
return apiKey, nil
|
return apiKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -304,21 +328,49 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
|
|||||||
|
|
||||||
// GetByKey 根据Key字符串获取API Key(用于认证)
|
// GetByKey 根据Key字符串获取API Key(用于认证)
|
||||||
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||||
// 尝试从Redis缓存获取
|
cacheKey := s.authCacheKey(key)
|
||||||
cacheKey := fmt.Sprintf("apikey:%s", key)
|
|
||||||
|
|
||||||
// 这里可以添加Redis缓存逻辑,暂时直接查询数据库
|
if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok {
|
||||||
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
|
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get api key: %w", err)
|
||||||
|
}
|
||||||
|
return apiKey, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.authCfg.singleflight {
|
||||||
|
value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) {
|
||||||
|
return s.loadAuthCacheEntry(ctx, key, cacheKey)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
entry, _ := value.(*APIKeyAuthCacheEntry)
|
||||||
|
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get api key: %w", err)
|
||||||
|
}
|
||||||
|
return apiKey, nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get api key: %w", err)
|
||||||
|
}
|
||||||
|
return apiKey, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get api key: %w", err)
|
return nil, fmt.Errorf("get api key: %w", err)
|
||||||
}
|
}
|
||||||
|
apiKey.Key = key
|
||||||
// 缓存到Redis(可选,TTL设置为5分钟)
|
|
||||||
if s.cache != nil {
|
|
||||||
// 这里可以序列化并缓存API Key
|
|
||||||
_ = cacheKey // 使用变量避免未使用错误
|
|
||||||
}
|
|
||||||
|
|
||||||
return apiKey, nil
|
return apiKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -388,15 +440,14 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
|||||||
return nil, fmt.Errorf("update api key: %w", err)
|
return nil, fmt.Errorf("update api key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||||
|
|
||||||
return apiKey, nil
|
return apiKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete 删除API Key
|
// Delete 删除API Key
|
||||||
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
|
|
||||||
// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能
|
|
||||||
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||||
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
|
key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id)
|
||||||
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get api key: %w", err)
|
return fmt.Errorf("get api key: %w", err)
|
||||||
}
|
}
|
||||||
@@ -406,10 +457,11 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
|||||||
return ErrInsufficientPerms
|
return ErrInsufficientPerms
|
||||||
}
|
}
|
||||||
|
|
||||||
// 清除Redis缓存(使用 ownerID 而非 apiKey.UserID)
|
// 清除Redis缓存(使用 userID 而非 apiKey.UserID)
|
||||||
if s.cache != nil {
|
if s.cache != nil {
|
||||||
_ = s.cache.DeleteCreateAttemptCount(ctx, ownerID)
|
_ = s.cache.DeleteCreateAttemptCount(ctx, userID)
|
||||||
}
|
}
|
||||||
|
s.InvalidateAuthCacheByKey(ctx, key)
|
||||||
|
|
||||||
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
||||||
return fmt.Errorf("delete api key: %w", err)
|
return fmt.Errorf("delete api key: %w", err)
|
||||||
|
|||||||
417
backend/internal/service/api_key_service_cache_test.go
Normal file
417
backend/internal/service/api_key_service_cache_test.go
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type authRepoStub struct {
|
||||||
|
getByKeyForAuth func(ctx context.Context, key string) (*APIKey, error)
|
||||||
|
listKeysByUserID func(ctx context.Context, userID int64) ([]string, error)
|
||||||
|
listKeysByGroupID func(ctx context.Context, groupID int64) ([]string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) Create(ctx context.Context, key *APIKey) error {
|
||||||
|
panic("unexpected Create call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||||
|
panic("unexpected GetByID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
|
panic("unexpected GetKeyAndOwnerID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
panic("unexpected GetByKey call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
if s.getByKeyForAuth == nil {
|
||||||
|
panic("unexpected GetByKeyForAuth call")
|
||||||
|
}
|
||||||
|
return s.getByKeyForAuth(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||||
|
panic("unexpected Update call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListByUserID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||||
|
panic("unexpected VerifyOwnership call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||||
|
panic("unexpected CountByUserID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||||
|
panic("unexpected ExistsByKey call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListByGroupID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||||
|
panic("unexpected SearchAPIKeys call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
|
panic("unexpected ClearGroupIDByGroupID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
|
panic("unexpected CountByGroupID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
if s.listKeysByUserID == nil {
|
||||||
|
panic("unexpected ListKeysByUserID call")
|
||||||
|
}
|
||||||
|
return s.listKeysByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
if s.listKeysByGroupID == nil {
|
||||||
|
panic("unexpected ListKeysByGroupID call")
|
||||||
|
}
|
||||||
|
return s.listKeysByGroupID(ctx, groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
type authCacheStub struct {
|
||||||
|
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||||
|
setAuthKeys []string
|
||||||
|
deleteAuthKeys []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
if s.getAuthCache == nil {
|
||||||
|
return nil, redis.Nil
|
||||||
|
}
|
||||||
|
return s.getAuthCache(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||||
|
s.setAuthKeys = append(s.setAuthKeys, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||||
|
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
return nil, errors.New("unexpected repo call")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
NegativeTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
|
||||||
|
groupID := int64(9)
|
||||||
|
cacheEntry := &APIKeyAuthCacheEntry{
|
||||||
|
Snapshot: &APIKeyAuthSnapshot{
|
||||||
|
APIKeyID: 1,
|
||||||
|
UserID: 2,
|
||||||
|
GroupID: &groupID,
|
||||||
|
Status: StatusActive,
|
||||||
|
User: APIKeyAuthUserSnapshot{
|
||||||
|
ID: 2,
|
||||||
|
Status: StatusActive,
|
||||||
|
Role: RoleUser,
|
||||||
|
Balance: 10,
|
||||||
|
Concurrency: 3,
|
||||||
|
},
|
||||||
|
Group: &APIKeyAuthGroupSnapshot{
|
||||||
|
ID: groupID,
|
||||||
|
Name: "g",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Status: StatusActive,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
RateMultiplier: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return cacheEntry, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey, err := svc.GetByKey(context.Background(), "k1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(1), apiKey.ID)
|
||||||
|
require.Equal(t, int64(2), apiKey.User.ID)
|
||||||
|
require.Equal(t, groupID, apiKey.Group.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
return nil, errors.New("unexpected repo call")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
NegativeTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return &APIKeyAuthCacheEntry{NotFound: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.GetByKey(context.Background(), "missing")
|
||||||
|
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
return &APIKey{
|
||||||
|
ID: 5,
|
||||||
|
UserID: 7,
|
||||||
|
Status: StatusActive,
|
||||||
|
User: &User{
|
||||||
|
ID: 7,
|
||||||
|
Status: StatusActive,
|
||||||
|
Role: RoleUser,
|
||||||
|
Balance: 12,
|
||||||
|
Concurrency: 2,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
NegativeTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return nil, redis.Nil
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey, err := svc.GetByKey(context.Background(), "k2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(5), apiKey.ID)
|
||||||
|
require.Len(t, cache.setAuthKeys, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
|
||||||
|
var calls int32
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
atomic.AddInt32(&calls, 1)
|
||||||
|
return &APIKey{
|
||||||
|
ID: 21,
|
||||||
|
UserID: 3,
|
||||||
|
Status: StatusActive,
|
||||||
|
User: &User{
|
||||||
|
ID: 3,
|
||||||
|
Status: StatusActive,
|
||||||
|
Role: RoleUser,
|
||||||
|
Balance: 5,
|
||||||
|
Concurrency: 2,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L1Size: 1000,
|
||||||
|
L1TTLSeconds: 60,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
require.NotNil(t, svc.authCacheL1)
|
||||||
|
|
||||||
|
_, err := svc.GetByKey(context.Background(), "k-l1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
svc.authCacheL1.Wait()
|
||||||
|
cacheKey := svc.authCacheKey("k-l1")
|
||||||
|
_, ok := svc.authCacheL1.Get(cacheKey)
|
||||||
|
require.True(t, ok)
|
||||||
|
_, err = svc.GetByKey(context.Background(), "k-l1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
return []string{"k1", "k2"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
NegativeTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
|
||||||
|
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
|
||||||
|
require.Len(t, cache.deleteAuthKeys, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
listKeysByGroupID: func(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
return []string{"k1", "k2"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
|
||||||
|
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
|
||||||
|
require.Len(t, cache.deleteAuthKeys, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
|
||||||
|
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
|
||||||
|
require.Len(t, cache.deleteAuthKeys, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
return nil, ErrAPIKeyNotFound
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
NegativeTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return nil, redis.Nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.GetByKey(context.Background(), "missing")
|
||||||
|
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||||
|
require.Len(t, cache.setAuthKeys, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
|
||||||
|
var calls int32
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
atomic.AddInt32(&calls, 1)
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
return &APIKey{
|
||||||
|
ID: 11,
|
||||||
|
UserID: 2,
|
||||||
|
Status: StatusActive,
|
||||||
|
User: &User{
|
||||||
|
ID: 2,
|
||||||
|
Status: StatusActive,
|
||||||
|
Role: RoleUser,
|
||||||
|
Balance: 1,
|
||||||
|
Concurrency: 1,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
Singleflight: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
|
||||||
|
start := make(chan struct{})
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
errs := make([]error, 5)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
<-start
|
||||||
|
_, err := svc.GetByKey(context.Background(), "k1")
|
||||||
|
errs[idx] = err
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
close(start)
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
for _, err := range errs {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
|
||||||
|
}
|
||||||
@@ -20,13 +20,12 @@ import (
|
|||||||
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
|
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
|
||||||
//
|
//
|
||||||
// 设计说明:
|
// 设计说明:
|
||||||
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
|
// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误
|
||||||
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound)
|
|
||||||
// - deleteErr: 模拟 Delete 返回的错误
|
// - deleteErr: 模拟 Delete 返回的错误
|
||||||
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
|
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
|
||||||
type apiKeyRepoStub struct {
|
type apiKeyRepoStub struct {
|
||||||
ownerID int64 // GetOwnerID 的返回值
|
apiKey *APIKey // GetKeyAndOwnerID 的返回值
|
||||||
ownerErr error // GetOwnerID 的错误返回值
|
getByIDErr error // GetKeyAndOwnerID 的错误返回值
|
||||||
deleteErr error // Delete 的错误返回值
|
deleteErr error // Delete 的错误返回值
|
||||||
deletedIDs []int64 // 记录已删除的 API Key ID 列表
|
deletedIDs []int64 // 记录已删除的 API Key ID 列表
|
||||||
}
|
}
|
||||||
@@ -38,19 +37,34 @@ func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||||
|
if s.getByIDErr != nil {
|
||||||
|
return nil, s.getByIDErr
|
||||||
|
}
|
||||||
|
if s.apiKey != nil {
|
||||||
|
clone := *s.apiKey
|
||||||
|
return &clone, nil
|
||||||
|
}
|
||||||
panic("unexpected GetByID call")
|
panic("unexpected GetByID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOwnerID 返回预设的所有者 ID 或错误。
|
func (s *apiKeyRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
// 这是 Delete 方法调用的第一个仓储方法,用于验证调用者是否为 API Key 的所有者。
|
if s.getByIDErr != nil {
|
||||||
func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
return "", 0, s.getByIDErr
|
||||||
return s.ownerID, s.ownerErr
|
}
|
||||||
|
if s.apiKey != nil {
|
||||||
|
return s.apiKey.Key, s.apiKey.UserID, nil
|
||||||
|
}
|
||||||
|
return "", 0, ErrAPIKeyNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||||
panic("unexpected GetByKey call")
|
panic("unexpected GetByKey call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
panic("unexpected GetByKeyForAuth call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
|
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||||
panic("unexpected Update call")
|
panic("unexpected Update call")
|
||||||
}
|
}
|
||||||
@@ -96,13 +110,22 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
|
|||||||
panic("unexpected CountByGroupID call")
|
panic("unexpected CountByGroupID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
panic("unexpected ListKeysByUserID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
panic("unexpected ListKeysByGroupID call")
|
||||||
|
}
|
||||||
|
|
||||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||||
//
|
//
|
||||||
// 设计说明:
|
// 设计说明:
|
||||||
// - invalidated: 记录被清除缓存的用户 ID 列表
|
// - invalidated: 记录被清除缓存的用户 ID 列表
|
||||||
type apiKeyCacheStub struct {
|
type apiKeyCacheStub struct {
|
||||||
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
|
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
|
||||||
|
deleteAuthKeys []string // 记录调用 DeleteAuthCache 时传入的缓存 key
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCreateAttemptCount 返回 0,表示用户未超过创建次数限制
|
// GetCreateAttemptCount 返回 0,表示用户未超过创建次数限制
|
||||||
@@ -132,15 +155,30 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||||
|
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
||||||
// 预期行为:
|
// 预期行为:
|
||||||
// - GetOwnerID 返回所有者 ID 为 1
|
// - GetKeyAndOwnerID 返回所有者 ID 为 1
|
||||||
// - 调用者 userID 为 2(不匹配)
|
// - 调用者 userID 为 2(不匹配)
|
||||||
// - 返回 ErrInsufficientPerms 错误
|
// - 返回 ErrInsufficientPerms 错误
|
||||||
// - Delete 方法不被调用
|
// - Delete 方法不被调用
|
||||||
// - 缓存不被清除
|
// - 缓存不被清除
|
||||||
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||||
repo := &apiKeyRepoStub{ownerID: 1}
|
repo := &apiKeyRepoStub{
|
||||||
|
apiKey: &APIKey{ID: 10, UserID: 1, Key: "k"},
|
||||||
|
}
|
||||||
cache := &apiKeyCacheStub{}
|
cache := &apiKeyCacheStub{}
|
||||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||||
|
|
||||||
@@ -148,17 +186,20 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
|||||||
require.ErrorIs(t, err, ErrInsufficientPerms)
|
require.ErrorIs(t, err, ErrInsufficientPerms)
|
||||||
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
|
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
|
||||||
require.Empty(t, cache.invalidated) // 验证缓存未被清除
|
require.Empty(t, cache.invalidated) // 验证缓存未被清除
|
||||||
|
require.Empty(t, cache.deleteAuthKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
|
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
|
||||||
// 预期行为:
|
// 预期行为:
|
||||||
// - GetOwnerID 返回所有者 ID 为 7
|
// - GetKeyAndOwnerID 返回所有者 ID 为 7
|
||||||
// - 调用者 userID 为 7(匹配)
|
// - 调用者 userID 为 7(匹配)
|
||||||
// - Delete 成功执行
|
// - Delete 成功执行
|
||||||
// - 缓存被正确清除(使用 ownerID)
|
// - 缓存被正确清除(使用 ownerID)
|
||||||
// - 返回 nil 错误
|
// - 返回 nil 错误
|
||||||
func TestApiKeyService_Delete_Success(t *testing.T) {
|
func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||||
repo := &apiKeyRepoStub{ownerID: 7}
|
repo := &apiKeyRepoStub{
|
||||||
|
apiKey: &APIKey{ID: 42, UserID: 7, Key: "k"},
|
||||||
|
}
|
||||||
cache := &apiKeyCacheStub{}
|
cache := &apiKeyCacheStub{}
|
||||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||||
|
|
||||||
@@ -166,16 +207,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
|
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
|
||||||
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
|
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
|
||||||
|
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
||||||
// 预期行为:
|
// 预期行为:
|
||||||
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
|
// - GetKeyAndOwnerID 返回 ErrAPIKeyNotFound 错误
|
||||||
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
|
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||||
// - Delete 方法不被调用
|
// - Delete 方法不被调用
|
||||||
// - 缓存不被清除
|
// - 缓存不被清除
|
||||||
func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
||||||
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
|
repo := &apiKeyRepoStub{getByIDErr: ErrAPIKeyNotFound}
|
||||||
cache := &apiKeyCacheStub{}
|
cache := &apiKeyCacheStub{}
|
||||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||||
|
|
||||||
@@ -183,18 +225,19 @@ func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
|||||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||||
require.Empty(t, repo.deletedIDs)
|
require.Empty(t, repo.deletedIDs)
|
||||||
require.Empty(t, cache.invalidated)
|
require.Empty(t, cache.invalidated)
|
||||||
|
require.Empty(t, cache.deleteAuthKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
|
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
|
||||||
// 预期行为:
|
// 预期行为:
|
||||||
// - GetOwnerID 返回正确的所有者 ID
|
// - GetKeyAndOwnerID 返回正确的所有者 ID
|
||||||
// - 所有权验证通过
|
// - 所有权验证通过
|
||||||
// - 缓存被清除(在删除之前)
|
// - 缓存被清除(在删除之前)
|
||||||
// - Delete 被调用但返回错误
|
// - Delete 被调用但返回错误
|
||||||
// - 返回包含 "delete api key" 的错误信息
|
// - 返回包含 "delete api key" 的错误信息
|
||||||
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
||||||
repo := &apiKeyRepoStub{
|
repo := &apiKeyRepoStub{
|
||||||
ownerID: 3,
|
apiKey: &APIKey{ID: 42, UserID: 3, Key: "k"},
|
||||||
deleteErr: errors.New("delete failed"),
|
deleteErr: errors.New("delete failed"),
|
||||||
}
|
}
|
||||||
cache := &apiKeyCacheStub{}
|
cache := &apiKeyCacheStub{}
|
||||||
@@ -205,4 +248,5 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
|||||||
require.ErrorContains(t, err, "delete api key")
|
require.ErrorContains(t, err, "delete api key")
|
||||||
require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用
|
require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用
|
||||||
require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败)
|
require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败)
|
||||||
|
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,13 +50,15 @@ type UpdateGroupRequest struct {
|
|||||||
|
|
||||||
// GroupService 分组管理服务
|
// GroupService 分组管理服务
|
||||||
type GroupService struct {
|
type GroupService struct {
|
||||||
groupRepo GroupRepository
|
groupRepo GroupRepository
|
||||||
|
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGroupService 创建分组服务实例
|
// NewGroupService 创建分组服务实例
|
||||||
func NewGroupService(groupRepo GroupRepository) *GroupService {
|
func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *GroupService {
|
||||||
return &GroupService{
|
return &GroupService{
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
|
authCacheInvalidator: authCacheInvalidator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,6 +157,9 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ
|
|||||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||||
return nil, fmt.Errorf("update group: %w", err)
|
return nil, fmt.Errorf("update group: %w", err)
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
@@ -170,6 +175,9 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
|
|||||||
if err := s.groupRepo.Delete(ctx, id); err != nil {
|
if err := s.groupRepo.Delete(ctx, id); err != nil {
|
||||||
return fmt.Errorf("delete group: %w", err)
|
return fmt.Errorf("delete group: %w", err)
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,13 +55,15 @@ type ChangePasswordRequest struct {
|
|||||||
|
|
||||||
// UserService 用户服务
|
// UserService 用户服务
|
||||||
type UserService struct {
|
type UserService struct {
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
|
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUserService 创建用户服务实例
|
// NewUserService 创建用户服务实例
|
||||||
func NewUserService(userRepo UserRepository) *UserService {
|
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService {
|
||||||
return &UserService{
|
return &UserService{
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
|
authCacheInvalidator: authCacheInvalidator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,6 +91,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get user: %w", err)
|
return nil, fmt.Errorf("get user: %w", err)
|
||||||
}
|
}
|
||||||
|
oldConcurrency := user.Concurrency
|
||||||
|
|
||||||
// 更新字段
|
// 更新字段
|
||||||
if req.Email != nil {
|
if req.Email != nil {
|
||||||
@@ -114,6 +117,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
|
|||||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
return nil, fmt.Errorf("update user: %w", err)
|
return nil, fmt.Errorf("update user: %w", err)
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
@@ -169,6 +175,9 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
|
|||||||
if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil {
|
if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil {
|
||||||
return fmt.Errorf("update balance: %w", err)
|
return fmt.Errorf("update balance: %w", err)
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -177,6 +186,9 @@ func (s *UserService) UpdateConcurrency(ctx context.Context, userID int64, concu
|
|||||||
if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil {
|
if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil {
|
||||||
return fmt.Errorf("update concurrency: %w", err)
|
return fmt.Errorf("update concurrency: %w", err)
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,6 +204,9 @@ func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status str
|
|||||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
return fmt.Errorf("update user: %w", err)
|
return fmt.Errorf("update user: %w", err)
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -201,5 +216,8 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error {
|
|||||||
if err := s.userRepo.Delete(ctx, userID); err != nil {
|
if err := s.userRepo.Delete(ctx, userID); err != nil {
|
||||||
return fmt.Errorf("delete user: %w", err)
|
return fmt.Errorf("delete user: %w", err)
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -77,12 +77,18 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi
|
|||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力
|
||||||
|
func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator {
|
||||||
|
return apiKeyService
|
||||||
|
}
|
||||||
|
|
||||||
// ProviderSet is the Wire provider set for all services
|
// ProviderSet is the Wire provider set for all services
|
||||||
var ProviderSet = wire.NewSet(
|
var ProviderSet = wire.NewSet(
|
||||||
// Core services
|
// Core services
|
||||||
NewAuthService,
|
NewAuthService,
|
||||||
NewUserService,
|
NewUserService,
|
||||||
NewAPIKeyService,
|
NewAPIKeyService,
|
||||||
|
ProvideAPIKeyAuthCacheInvalidator,
|
||||||
NewGroupService,
|
NewGroupService,
|
||||||
NewAccountService,
|
NewAccountService,
|
||||||
NewProxyService,
|
NewProxyService,
|
||||||
|
|||||||
24
config.yaml
24
config.yaml
@@ -170,6 +170,30 @@ gateway:
|
|||||||
# 允许在特定 400 错误时进行故障转移(默认:关闭)
|
# 允许在特定 400 错误时进行故障转移(默认:关闭)
|
||||||
failover_on_400: false
|
failover_on_400: false
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# API Key Auth Cache Configuration
|
||||||
|
# API Key 认证缓存配置
|
||||||
|
# =============================================================================
|
||||||
|
api_key_auth_cache:
|
||||||
|
# L1 cache size (entries), in-process LRU/TTL cache
|
||||||
|
# L1 缓存容量(条目数),进程内 LRU/TTL 缓存
|
||||||
|
l1_size: 65535
|
||||||
|
# L1 cache TTL (seconds)
|
||||||
|
# L1 缓存 TTL(秒)
|
||||||
|
l1_ttl_seconds: 15
|
||||||
|
# L2 cache TTL (seconds), stored in Redis
|
||||||
|
# L2 缓存 TTL(秒),Redis 中存储
|
||||||
|
l2_ttl_seconds: 300
|
||||||
|
# Negative cache TTL (seconds)
|
||||||
|
# 负缓存 TTL(秒)
|
||||||
|
negative_ttl_seconds: 30
|
||||||
|
# TTL jitter percent (0-100)
|
||||||
|
# TTL 抖动百分比(0-100)
|
||||||
|
jitter_percent: 10
|
||||||
|
# Enable singleflight for cache misses
|
||||||
|
# 缓存未命中时启用 singleflight 合并回源
|
||||||
|
singleflight: true
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Concurrency Wait Configuration
|
# Concurrency Wait Configuration
|
||||||
# 并发等待配置
|
# 并发等待配置
|
||||||
|
|||||||
@@ -170,6 +170,30 @@ gateway:
|
|||||||
# 允许在特定 400 错误时进行故障转移(默认:关闭)
|
# 允许在特定 400 错误时进行故障转移(默认:关闭)
|
||||||
failover_on_400: false
|
failover_on_400: false
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# API Key Auth Cache Configuration
|
||||||
|
# API Key 认证缓存配置
|
||||||
|
# =============================================================================
|
||||||
|
api_key_auth_cache:
|
||||||
|
# L1 cache size (entries), in-process LRU/TTL cache
|
||||||
|
# L1 缓存容量(条目数),进程内 LRU/TTL 缓存
|
||||||
|
l1_size: 65535
|
||||||
|
# L1 cache TTL (seconds)
|
||||||
|
# L1 缓存 TTL(秒)
|
||||||
|
l1_ttl_seconds: 15
|
||||||
|
# L2 cache TTL (seconds), stored in Redis
|
||||||
|
# L2 缓存 TTL(秒),Redis 中存储
|
||||||
|
l2_ttl_seconds: 300
|
||||||
|
# Negative cache TTL (seconds)
|
||||||
|
# 负缓存 TTL(秒)
|
||||||
|
negative_ttl_seconds: 30
|
||||||
|
# TTL jitter percent (0-100)
|
||||||
|
# TTL 抖动百分比(0-100)
|
||||||
|
jitter_percent: 10
|
||||||
|
# Enable singleflight for cache misses
|
||||||
|
# 缓存未命中时启用 singleflight 合并回源
|
||||||
|
singleflight: true
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Concurrency Wait Configuration
|
# Concurrency Wait Configuration
|
||||||
# 并发等待配置
|
# 并发等待配置
|
||||||
|
|||||||
Reference in New Issue
Block a user