Files
sub2api/backend/internal/service/api_key_auth_cache_impl.go
james-6-23 dc5d42addc feat(rpm): RPM 限流模块优化
P0:
- rpm_override 嵌入 Auth Cache Snapshot,消除每请求 DB 查询 (snapshot v6→v7)
- 429 RPM 响应返回 Retry-After 头(当前分钟剩余秒数)

P1:
- ClearAll 按钮直连 DELETE API,带 loading 防重复
- 新增 GET /admin/users/:id/rpm-status 管理员 RPM 用量查询端点

优化:
- checkRPM 从级联互斥改为并行取最严,user.rpm_limit 作为全局硬上限始终生效
- Override/Group 变更后自动失效 auth cache
- fail-open 语义不变,Redis 故障不阻塞业务
2026-04-23 16:34:37 +08:00

343 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"log/slog"
"math/rand/v2"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/dgraph-io/ristretto"
)
const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot
type apiKeyAuthCacheConfig struct {
l1Size int
l1TTL time.Duration
l2TTL time.Duration
negativeTTL time.Duration
jitterPercent int
singleflight bool
}
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
if cfg == nil {
return apiKeyAuthCacheConfig{}
}
auth := cfg.APIKeyAuth
return apiKeyAuthCacheConfig{
l1Size: auth.L1Size,
l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second,
l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second,
negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second,
jitterPercent: auth.JitterPercent,
singleflight: auth.Singleflight,
}
}
func (c apiKeyAuthCacheConfig) l1Enabled() bool {
return c.l1Size > 0 && c.l1TTL > 0
}
func (c apiKeyAuthCacheConfig) l2Enabled() bool {
return c.l2TTL > 0
}
func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
return c.negativeTTL > 0
}
// jitterTTL 为缓存 TTL 添加抖动,避免多个请求在同一时刻同时过期触发集中回源。
// 这里直接使用 rand/v2 的顶层函数:并发安全,无需全局互斥锁。
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
if ttl <= 0 {
return ttl
}
if c.jitterPercent <= 0 {
return ttl
}
percent := c.jitterPercent
if percent > 100 {
percent = 100
}
delta := float64(percent) / 100
randVal := rand.Float64()
factor := 1 - delta + randVal*(2*delta)
if factor <= 0 {
return ttl
}
return time.Duration(float64(ttl) * factor)
}
func (s *APIKeyService) initAuthCache(cfg *config.Config) {
s.authCfg = newAPIKeyAuthCacheConfig(cfg)
if !s.authCfg.l1Enabled() {
return
}
cache, err := ristretto.NewCache(&ristretto.Config{
NumCounters: int64(s.authCfg.l1Size) * 10,
MaxCost: int64(s.authCfg.l1Size),
BufferItems: 64,
})
if err != nil {
return
}
s.authCacheL1 = cache
}
// StartAuthCacheInvalidationSubscriber starts the Pub/Sub subscriber for L1 cache invalidation.
// This should be called after the service is fully initialized.
func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context) {
if s.cache == nil || s.authCacheL1 == nil {
return
}
if err := s.cache.SubscribeAuthCacheInvalidation(ctx, func(cacheKey string) {
s.authCacheL1.Del(cacheKey)
}); err != nil {
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
slog.Warn("failed to start auth cache invalidation subscriber", "error", err)
}
}
func (s *APIKeyService) authCacheKey(key string) string {
sum := sha256.Sum256([]byte(key))
return hex.EncodeToString(sum[:])
}
func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) {
if s.authCacheL1 != nil {
if val, ok := s.authCacheL1.Get(cacheKey); ok {
if entry, ok := val.(*APIKeyAuthCacheEntry); ok {
return entry, true
}
}
}
if s.cache == nil || !s.authCfg.l2Enabled() {
return nil, false
}
entry, err := s.cache.GetAuthCache(ctx, cacheKey)
if err != nil {
return nil, false
}
s.setAuthCacheL1(cacheKey, entry)
return entry, true
}
func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) {
if s.authCacheL1 == nil || entry == nil {
return
}
ttl := s.authCfg.l1TTL
if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl {
ttl = s.authCfg.negativeTTL
}
ttl = s.authCfg.jitterTTL(ttl)
_ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl)
}
func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) {
if entry == nil {
return
}
s.setAuthCacheL1(cacheKey, entry)
if s.cache == nil || !s.authCfg.l2Enabled() {
return
}
_ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl))
}
func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
if s.authCacheL1 != nil {
s.authCacheL1.Del(cacheKey)
}
if s.cache == nil {
return
}
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
// Publish invalidation message to other instances
_ = s.cache.PublishAuthCacheInvalidation(ctx, cacheKey)
}
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
if err != nil {
if errors.Is(err, ErrAPIKeyNotFound) {
entry := &APIKeyAuthCacheEntry{NotFound: true}
if s.authCfg.negativeEnabled() {
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL)
}
return entry, nil
}
return nil, fmt.Errorf("get api key: %w", err)
}
apiKey.Key = key
snapshot := s.snapshotFromAPIKey(ctx, apiKey)
if snapshot == nil {
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
}
entry := &APIKeyAuthCacheEntry{Snapshot: snapshot}
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL)
return entry, nil
}
func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) {
if entry == nil {
return nil, false, nil
}
if entry.NotFound {
return nil, true, ErrAPIKeyNotFound
}
if entry.Snapshot == nil {
return nil, false, nil
}
if entry.Snapshot.Version != apiKeyAuthSnapshotVersion {
return nil, false, nil
}
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
}
func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) *APIKeyAuthSnapshot {
if apiKey == nil || apiKey.User == nil {
return nil
}
snapshot := &APIKeyAuthSnapshot{
Version: apiKeyAuthSnapshotVersion,
APIKeyID: apiKey.ID,
UserID: apiKey.UserID,
GroupID: apiKey.GroupID,
Status: apiKey.Status,
IPWhitelist: apiKey.IPWhitelist,
IPBlacklist: apiKey.IPBlacklist,
Quota: apiKey.Quota,
QuotaUsed: apiKey.QuotaUsed,
ExpiresAt: apiKey.ExpiresAt,
RateLimit5h: apiKey.RateLimit5h,
RateLimit1d: apiKey.RateLimit1d,
RateLimit7d: apiKey.RateLimit7d,
User: APIKeyAuthUserSnapshot{
ID: apiKey.User.ID,
Status: apiKey.User.Status,
Role: apiKey.User.Role,
Balance: apiKey.User.Balance,
Concurrency: apiKey.User.Concurrency,
Email: apiKey.User.Email,
Username: apiKey.User.Username,
BalanceNotifyEnabled: apiKey.User.BalanceNotifyEnabled,
BalanceNotifyThresholdType: apiKey.User.BalanceNotifyThresholdType,
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
TotalRecharged: apiKey.User.TotalRecharged,
RPMLimit: apiKey.User.RPMLimit,
},
}
// 填充 (user, group) RPM override —— snapshot 构建时查一次 DB后续请求零 DB 往返。
if apiKey.GroupID != nil && *apiKey.GroupID > 0 && s.userGroupRateRepo != nil {
override, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, apiKey.UserID, *apiKey.GroupID)
if err == nil && override != nil {
snapshot.User.UserGroupRPMOverride = override
}
// 查询失败或无 override 时留 nilcheckRPM 会回退到 DB 查询
}
if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID,
Name: apiKey.Group.Name,
Platform: apiKey.Group.Platform,
Status: apiKey.Group.Status,
SubscriptionType: apiKey.Group.SubscriptionType,
RateMultiplier: apiKey.Group.RateMultiplier,
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
ModelRouting: apiKey.Group.ModelRouting,
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
MCPXMLInject: apiKey.Group.MCPXMLInject,
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig,
RPMLimit: apiKey.Group.RPMLimit,
}
}
return snapshot
}
func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey {
if snapshot == nil {
return nil
}
apiKey := &APIKey{
ID: snapshot.APIKeyID,
UserID: snapshot.UserID,
GroupID: snapshot.GroupID,
Key: key,
Status: snapshot.Status,
IPWhitelist: snapshot.IPWhitelist,
IPBlacklist: snapshot.IPBlacklist,
Quota: snapshot.Quota,
QuotaUsed: snapshot.QuotaUsed,
ExpiresAt: snapshot.ExpiresAt,
RateLimit5h: snapshot.RateLimit5h,
RateLimit1d: snapshot.RateLimit1d,
RateLimit7d: snapshot.RateLimit7d,
User: &User{
ID: snapshot.User.ID,
Status: snapshot.User.Status,
Role: snapshot.User.Role,
Balance: snapshot.User.Balance,
Concurrency: snapshot.User.Concurrency,
Email: snapshot.User.Email,
Username: snapshot.User.Username,
BalanceNotifyEnabled: snapshot.User.BalanceNotifyEnabled,
BalanceNotifyThresholdType: snapshot.User.BalanceNotifyThresholdType,
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
TotalRecharged: snapshot.User.TotalRecharged,
RPMLimit: snapshot.User.RPMLimit,
UserGroupRPMOverride: snapshot.User.UserGroupRPMOverride,
},
}
if snapshot.Group != nil {
apiKey.Group = &Group{
ID: snapshot.Group.ID,
Name: snapshot.Group.Name,
Platform: snapshot.Group.Platform,
Status: snapshot.Group.Status,
Hydrated: true,
SubscriptionType: snapshot.Group.SubscriptionType,
RateMultiplier: snapshot.Group.RateMultiplier,
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
ModelRouting: snapshot.Group.ModelRouting,
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
MCPXMLInject: snapshot.Group.MCPXMLInject,
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig,
RPMLimit: snapshot.Group.RPMLimit,
}
}
s.compileAPIKeyIPRules(apiKey)
return apiKey
}