fix: 修复keys速率限制未自动重置额度的bug

This commit is contained in:
shaw
2026-03-07 09:59:40 +08:00
parent 2d8d3b7857
commit 7a353028e7
7 changed files with 326 additions and 17 deletions

View File

@@ -89,9 +89,9 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
RateLimit5h: k.RateLimit5h, RateLimit5h: k.RateLimit5h,
RateLimit1d: k.RateLimit1d, RateLimit1d: k.RateLimit1d,
RateLimit7d: k.RateLimit7d, RateLimit7d: k.RateLimit7d,
Usage5h: k.Usage5h, Usage5h: k.EffectiveUsage5h(),
Usage1d: k.Usage1d, Usage1d: k.EffectiveUsage1d(),
Usage7d: k.Usage7d, Usage7d: k.EffectiveUsage7d(),
Window5hStart: k.Window5hStart, Window5hStart: k.Window5hStart,
Window1dStart: k.Window1dStart, Window1dStart: k.Window1dStart,
Window7dStart: k.Window7dStart, Window7dStart: k.Window7dStart,

View File

@@ -971,7 +971,7 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
if err == nil && rateLimitData != nil { if err == nil && rateLimitData != nil {
var rateLimits []gin.H var rateLimits []gin.H
if apiKey.RateLimit5h > 0 { if apiKey.RateLimit5h > 0 {
used := rateLimitData.Usage5h used := rateLimitData.EffectiveUsage5h()
rateLimits = append(rateLimits, gin.H{ rateLimits = append(rateLimits, gin.H{
"window": "5h", "window": "5h",
"limit": apiKey.RateLimit5h, "limit": apiKey.RateLimit5h,
@@ -981,7 +981,7 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
}) })
} }
if apiKey.RateLimit1d > 0 { if apiKey.RateLimit1d > 0 {
used := rateLimitData.Usage1d used := rateLimitData.EffectiveUsage1d()
rateLimits = append(rateLimits, gin.H{ rateLimits = append(rateLimits, gin.H{
"window": "1d", "window": "1d",
"limit": apiKey.RateLimit1d, "limit": apiKey.RateLimit1d,
@@ -991,7 +991,7 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
}) })
} }
if apiKey.RateLimit7d > 0 { if apiKey.RateLimit7d > 0 {
used := rateLimitData.Usage7d used := rateLimitData.EffectiveUsage7d()
rateLimits = append(rateLimits, gin.H{ rateLimits = append(rateLimits, gin.H{
"window": "7d", "window": "7d",
"limit": apiKey.RateLimit7d, "limit": apiKey.RateLimit7d,

View File

@@ -470,12 +470,12 @@ func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt
func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
_, err := r.sql.ExecContext(ctx, ` _, err := r.sql.ExecContext(ctx, `
UPDATE api_keys SET UPDATE api_keys SET
usage_5h = usage_5h + $1, usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
usage_1d = usage_1d + $1, usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
usage_7d = usage_7d + $1, usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
window_5h_start = COALESCE(window_5h_start, NOW()), window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
window_1d_start = COALESCE(window_1d_start, NOW()), window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
window_7d_start = COALESCE(window_7d_start, NOW()), window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
updated_at = NOW() updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL`, WHERE id = $2 AND deleted_at IS NULL`,
cost, id) cost, id)

View File

@@ -14,6 +14,18 @@ const (
StatusAPIKeyExpired = "expired" StatusAPIKeyExpired = "expired"
) )
// Rate limit window durations
const (
RateLimitWindow5h = 5 * time.Hour
RateLimitWindow1d = 24 * time.Hour
RateLimitWindow7d = 7 * 24 * time.Hour
)
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool {
return windowStart != nil && time.Since(*windowStart) >= duration
}
type APIKey struct { type APIKey struct {
ID int64 ID int64
UserID int64 UserID int64
@@ -98,6 +110,30 @@ func (k *APIKey) GetDaysUntilExpiry() int {
return int(duration.Hours() / 24) return int(duration.Hours() / 24)
} }
// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired.
func (k *APIKey) EffectiveUsage5h() float64 {
if IsWindowExpired(k.Window5hStart, RateLimitWindow5h) {
return 0
}
return k.Usage5h
}
// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired.
func (k *APIKey) EffectiveUsage1d() float64 {
if IsWindowExpired(k.Window1dStart, RateLimitWindow1d) {
return 0
}
return k.Usage1d
}
// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired.
func (k *APIKey) EffectiveUsage7d() float64 {
if IsWindowExpired(k.Window7dStart, RateLimitWindow7d) {
return 0
}
return k.Usage7d
}
// APIKeyListFilters holds optional filtering parameters for listing API keys. // APIKeyListFilters holds optional filtering parameters for listing API keys.
type APIKeyListFilters struct { type APIKeyListFilters struct {
Search string Search string

View File

@@ -0,0 +1,245 @@
package service
import (
"testing"
"time"
)
func TestIsWindowExpired(t *testing.T) {
now := time.Now()
tests := []struct {
name string
start *time.Time
duration time.Duration
want bool
}{
{
name: "nil window start",
start: nil,
duration: RateLimitWindow5h,
want: false,
},
{
name: "active window (started 1h ago, 5h window)",
start: rateLimitTimePtr(now.Add(-1 * time.Hour)),
duration: RateLimitWindow5h,
want: false,
},
{
name: "expired window (started 6h ago, 5h window)",
start: rateLimitTimePtr(now.Add(-6 * time.Hour)),
duration: RateLimitWindow5h,
want: true,
},
{
name: "exactly at boundary (started 5h ago, 5h window)",
start: rateLimitTimePtr(now.Add(-5 * time.Hour)),
duration: RateLimitWindow5h,
want: true,
},
{
name: "active 1d window (started 12h ago)",
start: rateLimitTimePtr(now.Add(-12 * time.Hour)),
duration: RateLimitWindow1d,
want: false,
},
{
name: "expired 1d window (started 25h ago)",
start: rateLimitTimePtr(now.Add(-25 * time.Hour)),
duration: RateLimitWindow1d,
want: true,
},
{
name: "active 7d window (started 3d ago)",
start: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)),
duration: RateLimitWindow7d,
want: false,
},
{
name: "expired 7d window (started 8d ago)",
start: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)),
duration: RateLimitWindow7d,
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsWindowExpired(tt.start, tt.duration)
if got != tt.want {
t.Errorf("IsWindowExpired() = %v, want %v", got, tt.want)
}
})
}
}
func TestAPIKey_EffectiveUsage(t *testing.T) {
now := time.Now()
tests := []struct {
name string
key APIKey
want5h float64
want1d float64
want7d float64
}{
{
name: "all windows active",
key: APIKey{
Usage5h: 5.0,
Usage1d: 10.0,
Usage7d: 50.0,
Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)),
Window7dStart: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)),
},
want5h: 5.0,
want1d: 10.0,
want7d: 50.0,
},
{
name: "all windows expired",
key: APIKey{
Usage5h: 5.0,
Usage1d: 10.0,
Usage7d: 50.0,
Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-25 * time.Hour)),
Window7dStart: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)),
},
want5h: 0,
want1d: 0,
want7d: 0,
},
{
name: "nil window starts return raw usage",
key: APIKey{
Usage5h: 5.0,
Usage1d: 10.0,
Usage7d: 50.0,
Window5hStart: nil,
Window1dStart: nil,
Window7dStart: nil,
},
want5h: 5.0,
want1d: 10.0,
want7d: 50.0,
},
{
name: "mixed: 5h expired, 1d active, 7d nil",
key: APIKey{
Usage5h: 5.0,
Usage1d: 10.0,
Usage7d: 50.0,
Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)),
Window7dStart: nil,
},
want5h: 0,
want1d: 10.0,
want7d: 50.0,
},
{
name: "zero usage with active windows",
key: APIKey{
Usage5h: 0,
Usage1d: 0,
Usage7d: 0,
Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
Window7dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
},
want5h: 0,
want1d: 0,
want7d: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.key.EffectiveUsage5h(); got != tt.want5h {
t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h)
}
if got := tt.key.EffectiveUsage1d(); got != tt.want1d {
t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d)
}
if got := tt.key.EffectiveUsage7d(); got != tt.want7d {
t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d)
}
})
}
}
func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
now := time.Now()
tests := []struct {
name string
data APIKeyRateLimitData
want5h float64
want1d float64
want7d float64
}{
{
name: "all windows active",
data: APIKeyRateLimitData{
Usage5h: 3.0,
Usage1d: 8.0,
Usage7d: 40.0,
Window5hStart: rateLimitTimePtr(now.Add(-2 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-10 * time.Hour)),
Window7dStart: rateLimitTimePtr(now.Add(-2 * 24 * time.Hour)),
},
want5h: 3.0,
want1d: 8.0,
want7d: 40.0,
},
{
name: "all windows expired",
data: APIKeyRateLimitData{
Usage5h: 3.0,
Usage1d: 8.0,
Usage7d: 40.0,
Window5hStart: rateLimitTimePtr(now.Add(-10 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-48 * time.Hour)),
Window7dStart: rateLimitTimePtr(now.Add(-10 * 24 * time.Hour)),
},
want5h: 0,
want1d: 0,
want7d: 0,
},
{
name: "nil window starts return raw usage",
data: APIKeyRateLimitData{
Usage5h: 3.0,
Usage1d: 8.0,
Usage7d: 40.0,
Window5hStart: nil,
Window1dStart: nil,
Window7dStart: nil,
},
want5h: 3.0,
want1d: 8.0,
want7d: 40.0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.data.EffectiveUsage5h(); got != tt.want5h {
t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h)
}
if got := tt.data.EffectiveUsage1d(); got != tt.want1d {
t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d)
}
if got := tt.data.EffectiveUsage7d(); got != tt.want7d {
t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d)
}
})
}
}
func rateLimitTimePtr(t time.Time) *time.Time {
return &t
}

View File

@@ -86,6 +86,30 @@ type APIKeyRateLimitData struct {
Window7dStart *time.Time Window7dStart *time.Time
} }
// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired.
func (d *APIKeyRateLimitData) EffectiveUsage5h() float64 {
if IsWindowExpired(d.Window5hStart, RateLimitWindow5h) {
return 0
}
return d.Usage5h
}
// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired.
func (d *APIKeyRateLimitData) EffectiveUsage1d() float64 {
if IsWindowExpired(d.Window1dStart, RateLimitWindow1d) {
return 0
}
return d.Usage1d
}
// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired.
func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 {
if IsWindowExpired(d.Window7dStart, RateLimitWindow7d) {
return 0
}
return d.Usage7d
}
// APIKeyCache defines cache operations for API key service // APIKeyCache defines cache operations for API key service
type APIKeyCache interface { type APIKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)

View File

@@ -565,15 +565,15 @@ func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *AP
needsReset := false needsReset := false
// Reset expired windows in-memory for check purposes // Reset expired windows in-memory for check purposes
if w5h != nil && time.Since(*w5h) >= 5*time.Hour { if IsWindowExpired(w5h, RateLimitWindow5h) {
usage5h = 0 usage5h = 0
needsReset = true needsReset = true
} }
if w1d != nil && time.Since(*w1d) >= 24*time.Hour { if IsWindowExpired(w1d, RateLimitWindow1d) {
usage1d = 0 usage1d = 0
needsReset = true needsReset = true
} }
if w7d != nil && time.Since(*w7d) >= 7*24*time.Hour { if IsWindowExpired(w7d, RateLimitWindow7d) {
usage7d = 0 usage7d = 0
needsReset = true needsReset = true
} }
@@ -589,12 +589,16 @@ func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *AP
if loader, ok := s.apiKeyRateLimitLoader.(interface { if loader, ok := s.apiKeyRateLimitLoader.(interface {
ResetRateLimitWindows(ctx context.Context, id int64) error ResetRateLimitWindows(ctx context.Context, id int64) error
}); ok { }); ok {
_ = loader.ResetRateLimitWindows(resetCtx, keyID) if err := loader.ResetRateLimitWindows(resetCtx, keyID); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: reset rate limit windows failed for api key %d: %v", keyID, err)
}
} }
} }
// Invalidate cache so next request loads fresh data // Invalidate cache so next request loads fresh data
if s.cache != nil { if s.cache != nil {
_ = s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID) if err := s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate rate limit cache failed for api key %d: %v", keyID, err)
}
} }
}() }()
} }