package service import ( "context" "crypto/sha256" "encoding/hex" "errors" "fmt" "math/rand" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/dgraph-io/ristretto" ) type apiKeyAuthCacheConfig struct { l1Size int l1TTL time.Duration l2TTL time.Duration negativeTTL time.Duration jitterPercent int singleflight bool } var ( jitterRandMu sync.Mutex // 认证缓存抖动使用独立随机源,避免全局 Seed jitterRand = rand.New(rand.NewSource(time.Now().UnixNano())) ) func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig { if cfg == nil { return apiKeyAuthCacheConfig{} } auth := cfg.APIKeyAuth return apiKeyAuthCacheConfig{ l1Size: auth.L1Size, l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second, l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second, negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second, jitterPercent: auth.JitterPercent, singleflight: auth.Singleflight, } } func (c apiKeyAuthCacheConfig) l1Enabled() bool { return c.l1Size > 0 && c.l1TTL > 0 } func (c apiKeyAuthCacheConfig) l2Enabled() bool { return c.l2TTL > 0 } func (c apiKeyAuthCacheConfig) negativeEnabled() bool { return c.negativeTTL > 0 } func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration { if ttl <= 0 { return ttl } if c.jitterPercent <= 0 { return ttl } percent := c.jitterPercent if percent > 100 { percent = 100 } delta := float64(percent) / 100 jitterRandMu.Lock() randVal := jitterRand.Float64() jitterRandMu.Unlock() factor := 1 - delta + randVal*(2*delta) if factor <= 0 { return ttl } return time.Duration(float64(ttl) * factor) } func (s *APIKeyService) initAuthCache(cfg *config.Config) { s.authCfg = newAPIKeyAuthCacheConfig(cfg) if !s.authCfg.l1Enabled() { return } cache, err := ristretto.NewCache(&ristretto.Config{ NumCounters: int64(s.authCfg.l1Size) * 10, MaxCost: int64(s.authCfg.l1Size), BufferItems: 64, }) if err != nil { return } s.authCacheL1 = cache } // StartAuthCacheInvalidationSubscriber starts the Pub/Sub subscriber for L1 cache invalidation. // This should be called after the service is fully initialized. func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context) { if s.cache == nil || s.authCacheL1 == nil { return } if err := s.cache.SubscribeAuthCacheInvalidation(ctx, func(cacheKey string) { s.authCacheL1.Del(cacheKey) }); err != nil { // Log but don't fail - L1 cache will still work, just without cross-instance invalidation println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error()) } } func (s *APIKeyService) authCacheKey(key string) string { sum := sha256.Sum256([]byte(key)) return hex.EncodeToString(sum[:]) } func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) { if s.authCacheL1 != nil { if val, ok := s.authCacheL1.Get(cacheKey); ok { if entry, ok := val.(*APIKeyAuthCacheEntry); ok { return entry, true } } } if s.cache == nil || !s.authCfg.l2Enabled() { return nil, false } entry, err := s.cache.GetAuthCache(ctx, cacheKey) if err != nil { return nil, false } s.setAuthCacheL1(cacheKey, entry) return entry, true } func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) { if s.authCacheL1 == nil || entry == nil { return } ttl := s.authCfg.l1TTL if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl { ttl = s.authCfg.negativeTTL } ttl = s.authCfg.jitterTTL(ttl) _ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl) } func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) { if entry == nil { return } s.setAuthCacheL1(cacheKey, entry) if s.cache == nil || !s.authCfg.l2Enabled() { return } _ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl)) } func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) { if s.authCacheL1 != nil { s.authCacheL1.Del(cacheKey) } if s.cache == nil { return } _ = s.cache.DeleteAuthCache(ctx, cacheKey) // Publish invalidation message to other instances _ = s.cache.PublishAuthCacheInvalidation(ctx, cacheKey) } func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) { apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key) if err != nil { if errors.Is(err, ErrAPIKeyNotFound) { entry := &APIKeyAuthCacheEntry{NotFound: true} if s.authCfg.negativeEnabled() { s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL) } return entry, nil } return nil, fmt.Errorf("get api key: %w", err) } apiKey.Key = key snapshot := s.snapshotFromAPIKey(apiKey) if snapshot == nil { return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound) } entry := &APIKeyAuthCacheEntry{Snapshot: snapshot} s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL) return entry, nil } func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) { if entry == nil { return nil, false, nil } if entry.NotFound { return nil, true, ErrAPIKeyNotFound } if entry.Snapshot == nil { return nil, false, nil } return s.snapshotToAPIKey(key, entry.Snapshot), true, nil } func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { if apiKey == nil || apiKey.User == nil { return nil } snapshot := &APIKeyAuthSnapshot{ APIKeyID: apiKey.ID, UserID: apiKey.UserID, GroupID: apiKey.GroupID, Status: apiKey.Status, IPWhitelist: apiKey.IPWhitelist, IPBlacklist: apiKey.IPBlacklist, Quota: apiKey.Quota, QuotaUsed: apiKey.QuotaUsed, ExpiresAt: apiKey.ExpiresAt, User: APIKeyAuthUserSnapshot{ ID: apiKey.User.ID, Status: apiKey.User.Status, Role: apiKey.User.Role, Balance: apiKey.User.Balance, Concurrency: apiKey.User.Concurrency, }, } if apiKey.Group != nil { snapshot.Group = &APIKeyAuthGroupSnapshot{ ID: apiKey.Group.ID, Name: apiKey.Group.Name, Platform: apiKey.Group.Platform, Status: apiKey.Group.Status, SubscriptionType: apiKey.Group.SubscriptionType, RateMultiplier: apiKey.Group.RateMultiplier, DailyLimitUSD: apiKey.Group.DailyLimitUSD, WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice4K: apiKey.Group.ImagePrice4K, ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, FallbackGroupID: apiKey.Group.FallbackGroupID, FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest, ModelRouting: apiKey.Group.ModelRouting, ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, MCPXMLInject: apiKey.Group.MCPXMLInject, SupportedModelScopes: apiKey.Group.SupportedModelScopes, } } return snapshot } func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey { if snapshot == nil { return nil } apiKey := &APIKey{ ID: snapshot.APIKeyID, UserID: snapshot.UserID, GroupID: snapshot.GroupID, Key: key, Status: snapshot.Status, IPWhitelist: snapshot.IPWhitelist, IPBlacklist: snapshot.IPBlacklist, Quota: snapshot.Quota, QuotaUsed: snapshot.QuotaUsed, ExpiresAt: snapshot.ExpiresAt, User: &User{ ID: snapshot.User.ID, Status: snapshot.User.Status, Role: snapshot.User.Role, Balance: snapshot.User.Balance, Concurrency: snapshot.User.Concurrency, }, } if snapshot.Group != nil { apiKey.Group = &Group{ ID: snapshot.Group.ID, Name: snapshot.Group.Name, Platform: snapshot.Group.Platform, Status: snapshot.Group.Status, Hydrated: true, SubscriptionType: snapshot.Group.SubscriptionType, RateMultiplier: snapshot.Group.RateMultiplier, DailyLimitUSD: snapshot.Group.DailyLimitUSD, WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice4K: snapshot.Group.ImagePrice4K, ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, FallbackGroupID: snapshot.Group.FallbackGroupID, FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest, ModelRouting: snapshot.Group.ModelRouting, ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, MCPXMLInject: snapshot.Group.MCPXMLInject, SupportedModelScopes: snapshot.Group.SupportedModelScopes, } } return apiKey }