feat: apikey支持5h/1d/7d速率控制
This commit is contained in:
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -16,10 +17,11 @@ import (
|
||||
|
||||
type apiKeyRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
|
||||
return &apiKeyRepository{client: client}
|
||||
func NewAPIKeyRepository(client *dbent.Client, sqlDB *sql.DB) service.APIKeyRepository {
|
||||
return &apiKeyRepository{client: client, sql: sqlDB}
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
|
||||
@@ -37,7 +39,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
|
||||
SetNillableLastUsedAt(key.LastUsedAt).
|
||||
SetQuota(key.Quota).
|
||||
SetQuotaUsed(key.QuotaUsed).
|
||||
SetNillableExpiresAt(key.ExpiresAt)
|
||||
SetNillableExpiresAt(key.ExpiresAt).
|
||||
SetRateLimit5h(key.RateLimit5h).
|
||||
SetRateLimit1d(key.RateLimit1d).
|
||||
SetRateLimit7d(key.RateLimit7d)
|
||||
|
||||
if len(key.IPWhitelist) > 0 {
|
||||
builder.SetIPWhitelist(key.IPWhitelist)
|
||||
@@ -118,6 +123,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
apikey.FieldQuota,
|
||||
apikey.FieldQuotaUsed,
|
||||
apikey.FieldExpiresAt,
|
||||
apikey.FieldRateLimit5h,
|
||||
apikey.FieldRateLimit1d,
|
||||
apikey.FieldRateLimit7d,
|
||||
).
|
||||
WithUser(func(q *dbent.UserQuery) {
|
||||
q.Select(
|
||||
@@ -179,6 +187,12 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
|
||||
SetStatus(key.Status).
|
||||
SetQuota(key.Quota).
|
||||
SetQuotaUsed(key.QuotaUsed).
|
||||
SetRateLimit5h(key.RateLimit5h).
|
||||
SetRateLimit1d(key.RateLimit1d).
|
||||
SetRateLimit7d(key.RateLimit7d).
|
||||
SetUsage5h(key.Usage5h).
|
||||
SetUsage1d(key.Usage1d).
|
||||
SetUsage7d(key.Usage7d).
|
||||
SetUpdatedAt(now)
|
||||
if key.GroupID != nil {
|
||||
builder.SetGroupID(*key.GroupID)
|
||||
@@ -193,6 +207,23 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
|
||||
builder.ClearExpiresAt()
|
||||
}
|
||||
|
||||
// Rate limit window start times
|
||||
if key.Window5hStart != nil {
|
||||
builder.SetWindow5hStart(*key.Window5hStart)
|
||||
} else {
|
||||
builder.ClearWindow5hStart()
|
||||
}
|
||||
if key.Window1dStart != nil {
|
||||
builder.SetWindow1dStart(*key.Window1dStart)
|
||||
} else {
|
||||
builder.ClearWindow1dStart()
|
||||
}
|
||||
if key.Window7dStart != nil {
|
||||
builder.SetWindow7dStart(*key.Window7dStart)
|
||||
} else {
|
||||
builder.ClearWindow7dStart()
|
||||
}
|
||||
|
||||
// IP 限制字段
|
||||
if len(key.IPWhitelist) > 0 {
|
||||
builder.SetIPWhitelist(key.IPWhitelist)
|
||||
@@ -412,25 +443,88 @@ func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
// IncrementRateLimitUsage atomically increments all rate limit usage counters and initializes
|
||||
// window start times via COALESCE if not already set.
|
||||
func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE api_keys SET
|
||||
usage_5h = usage_5h + $1,
|
||||
usage_1d = usage_1d + $1,
|
||||
usage_7d = usage_7d + $1,
|
||||
window_5h_start = COALESCE(window_5h_start, NOW()),
|
||||
window_1d_start = COALESCE(window_1d_start, NOW()),
|
||||
window_7d_start = COALESCE(window_7d_start, NOW()),
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL`,
|
||||
cost, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// ResetRateLimitWindows resets expired rate limit windows atomically.
|
||||
func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE api_keys SET
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetRateLimitData returns the current rate limit usage and window start times for an API key.
|
||||
func (r *apiKeyRepository) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT usage_5h, usage_1d, usage_7d, window_5h_start, window_1d_start, window_7d_start
|
||||
FROM api_keys
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
data := &service.APIKeyRateLimitData{}
|
||||
if err := rows.Scan(&data.Usage5h, &data.Usage1d, &data.Usage7d, &data.Window5hStart, &data.Window1dStart, &data.Window7dStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, rows.Err()
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.APIKey{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
Status: m.Status,
|
||||
IPWhitelist: m.IPWhitelist,
|
||||
IPBlacklist: m.IPBlacklist,
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
Quota: m.Quota,
|
||||
QuotaUsed: m.QuotaUsed,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
Status: m.Status,
|
||||
IPWhitelist: m.IPWhitelist,
|
||||
IPBlacklist: m.IPBlacklist,
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
Quota: m.Quota,
|
||||
QuotaUsed: m.QuotaUsed,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
RateLimit5h: m.RateLimit5h,
|
||||
RateLimit1d: m.RateLimit1d,
|
||||
RateLimit7d: m.RateLimit7d,
|
||||
Usage5h: m.Usage5h,
|
||||
Usage1d: m.Usage1d,
|
||||
Usage7d: m.Usage7d,
|
||||
Window5hStart: m.Window5hStart,
|
||||
Window1dStart: m.Window1dStart,
|
||||
Window7dStart: m.Window7dStart,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
|
||||
@@ -14,10 +14,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
billingCacheJitter = 30 * time.Second
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingRateLimitKeyPrefix = "apikey:rate:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
billingCacheJitter = 30 * time.Second
|
||||
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
|
||||
)
|
||||
|
||||
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
|
||||
@@ -49,6 +51,20 @@ const (
|
||||
subFieldVersion = "version"
|
||||
)
|
||||
|
||||
// billingRateLimitKey generates the Redis key for API key rate limit cache.
|
||||
func billingRateLimitKey(keyID int64) string {
|
||||
return fmt.Sprintf("%s%d", billingRateLimitKeyPrefix, keyID)
|
||||
}
|
||||
|
||||
const (
|
||||
rateLimitFieldUsage5h = "usage_5h"
|
||||
rateLimitFieldUsage1d = "usage_1d"
|
||||
rateLimitFieldUsage7d = "usage_7d"
|
||||
rateLimitFieldWindow5h = "window_5h"
|
||||
rateLimitFieldWindow1d = "window_1d"
|
||||
rateLimitFieldWindow7d = "window_7d"
|
||||
)
|
||||
|
||||
var (
|
||||
deductBalanceScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
@@ -73,6 +89,21 @@ var (
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
|
||||
// updateRateLimitUsageScript atomically increments all three rate limit usage counters.
|
||||
// Returns 0 if the key doesn't exist (cache miss), 1 on success.
|
||||
updateRateLimitUsageScript = redis.NewScript(`
|
||||
local exists = redis.call('EXISTS', KEYS[1])
|
||||
if exists == 0 then
|
||||
return 0
|
||||
end
|
||||
local cost = tonumber(ARGV[1])
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
type billingCache struct {
|
||||
@@ -195,3 +226,69 @@ func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID,
|
||||
key := billingSubKey(userID, groupID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*service.APIKeyRateLimitCacheData, error) {
|
||||
key := billingRateLimitKey(keyID)
|
||||
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
data := &service.APIKeyRateLimitCacheData{}
|
||||
if v, ok := result[rateLimitFieldUsage5h]; ok {
|
||||
data.Usage5h, _ = strconv.ParseFloat(v, 64)
|
||||
}
|
||||
if v, ok := result[rateLimitFieldUsage1d]; ok {
|
||||
data.Usage1d, _ = strconv.ParseFloat(v, 64)
|
||||
}
|
||||
if v, ok := result[rateLimitFieldUsage7d]; ok {
|
||||
data.Usage7d, _ = strconv.ParseFloat(v, 64)
|
||||
}
|
||||
if v, ok := result[rateLimitFieldWindow5h]; ok {
|
||||
data.Window5h, _ = strconv.ParseInt(v, 10, 64)
|
||||
}
|
||||
if v, ok := result[rateLimitFieldWindow1d]; ok {
|
||||
data.Window1d, _ = strconv.ParseInt(v, 10, 64)
|
||||
}
|
||||
if v, ok := result[rateLimitFieldWindow7d]; ok {
|
||||
data.Window7d, _ = strconv.ParseInt(v, 10, 64)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *service.APIKeyRateLimitCacheData) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
key := billingRateLimitKey(keyID)
|
||||
fields := map[string]any{
|
||||
rateLimitFieldUsage5h: data.Usage5h,
|
||||
rateLimitFieldUsage1d: data.Usage1d,
|
||||
rateLimitFieldUsage7d: data.Usage7d,
|
||||
rateLimitFieldWindow5h: data.Window5h,
|
||||
rateLimitFieldWindow1d: data.Window1d,
|
||||
rateLimitFieldWindow7d: data.Window7d,
|
||||
}
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, rateLimitCacheTTL)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||||
key := billingRateLimitKey(keyID)
|
||||
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||||
key := billingRateLimitKey(keyID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user