From 7a353028e7ce373f3afdd5e8d77779cf3afcc779 Mon Sep 17 00:00:00 2001 From: shaw Date: Sat, 7 Mar 2026 09:59:40 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dkeys=E9=80=9F=E7=8E=87?= =?UTF-8?q?=E9=99=90=E5=88=B6=E6=9C=AA=E8=87=AA=E5=8A=A8=E9=87=8D=E7=BD=AE?= =?UTF-8?q?=E9=A2=9D=E5=BA=A6=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/handler/dto/mappers.go | 6 +- backend/internal/handler/gateway_handler.go | 6 +- backend/internal/repository/api_key_repo.go | 12 +- backend/internal/service/api_key.go | 36 +++ .../service/api_key_rate_limit_test.go | 245 ++++++++++++++++++ backend/internal/service/api_key_service.go | 24 ++ .../internal/service/billing_cache_service.go | 14 +- 7 files changed, 326 insertions(+), 17 deletions(-) create mode 100644 backend/internal/service/api_key_rate_limit_test.go diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 03b122f3..31a02cca 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -89,9 +89,9 @@ func APIKeyFromService(k *service.APIKey) *APIKey { RateLimit5h: k.RateLimit5h, RateLimit1d: k.RateLimit1d, RateLimit7d: k.RateLimit7d, - Usage5h: k.Usage5h, - Usage1d: k.Usage1d, - Usage7d: k.Usage7d, + Usage5h: k.EffectiveUsage5h(), + Usage1d: k.EffectiveUsage1d(), + Usage7d: k.EffectiveUsage7d(), Window5hStart: k.Window5hStart, Window1dStart: k.Window1dStart, Window7dStart: k.Window7dStart, diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 1c0ef8e6..148d83e9 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -971,7 +971,7 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, if err == nil && rateLimitData != nil { var rateLimits []gin.H if apiKey.RateLimit5h > 0 { - used := rateLimitData.Usage5h + used := rateLimitData.EffectiveUsage5h() rateLimits = append(rateLimits, gin.H{ "window": "5h", "limit": apiKey.RateLimit5h, @@ -981,7 +981,7 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, }) } if apiKey.RateLimit1d > 0 { - used := rateLimitData.Usage1d + used := rateLimitData.EffectiveUsage1d() rateLimits = append(rateLimits, gin.H{ "window": "1d", "limit": apiKey.RateLimit1d, @@ -991,7 +991,7 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, }) } if apiKey.RateLimit7d > 0 { - used := rateLimitData.Usage7d + used := rateLimitData.EffectiveUsage7d() rateLimits = append(rateLimits, gin.H{ "window": "7d", "limit": apiKey.RateLimit7d, diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index c761e8c9..d9732f68 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -470,12 +470,12 @@ func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { _, err := r.sql.ExecContext(ctx, ` UPDATE api_keys SET - usage_5h = usage_5h + $1, - usage_1d = usage_1d + $1, - usage_7d = usage_7d + $1, - window_5h_start = COALESCE(window_5h_start, NOW()), - window_1d_start = COALESCE(window_1d_start, NOW()), - window_7d_start = COALESCE(window_7d_start, NOW()), + usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END, + usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END, + usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END, + window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END, + window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END, + window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL`, cost, id) diff --git a/backend/internal/service/api_key.go b/backend/internal/service/api_key.go index 4c565495..eb9f2b15 100644 --- a/backend/internal/service/api_key.go +++ b/backend/internal/service/api_key.go @@ -14,6 +14,18 @@ const ( StatusAPIKeyExpired = "expired" ) +// Rate limit window durations +const ( + RateLimitWindow5h = 5 * time.Hour + RateLimitWindow1d = 24 * time.Hour + RateLimitWindow7d = 7 * 24 * time.Hour +) + +// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration. +func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool { + return windowStart != nil && time.Since(*windowStart) >= duration +} + type APIKey struct { ID int64 UserID int64 @@ -98,6 +110,30 @@ func (k *APIKey) GetDaysUntilExpiry() int { return int(duration.Hours() / 24) } +// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired. +func (k *APIKey) EffectiveUsage5h() float64 { + if IsWindowExpired(k.Window5hStart, RateLimitWindow5h) { + return 0 + } + return k.Usage5h +} + +// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired. +func (k *APIKey) EffectiveUsage1d() float64 { + if IsWindowExpired(k.Window1dStart, RateLimitWindow1d) { + return 0 + } + return k.Usage1d +} + +// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired. +func (k *APIKey) EffectiveUsage7d() float64 { + if IsWindowExpired(k.Window7dStart, RateLimitWindow7d) { + return 0 + } + return k.Usage7d +} + // APIKeyListFilters holds optional filtering parameters for listing API keys. type APIKeyListFilters struct { Search string diff --git a/backend/internal/service/api_key_rate_limit_test.go b/backend/internal/service/api_key_rate_limit_test.go new file mode 100644 index 00000000..cf7e7983 --- /dev/null +++ b/backend/internal/service/api_key_rate_limit_test.go @@ -0,0 +1,245 @@ +package service + +import ( + "testing" + "time" +) + +func TestIsWindowExpired(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + start *time.Time + duration time.Duration + want bool + }{ + { + name: "nil window start", + start: nil, + duration: RateLimitWindow5h, + want: false, + }, + { + name: "active window (started 1h ago, 5h window)", + start: rateLimitTimePtr(now.Add(-1 * time.Hour)), + duration: RateLimitWindow5h, + want: false, + }, + { + name: "expired window (started 6h ago, 5h window)", + start: rateLimitTimePtr(now.Add(-6 * time.Hour)), + duration: RateLimitWindow5h, + want: true, + }, + { + name: "exactly at boundary (started 5h ago, 5h window)", + start: rateLimitTimePtr(now.Add(-5 * time.Hour)), + duration: RateLimitWindow5h, + want: true, + }, + { + name: "active 1d window (started 12h ago)", + start: rateLimitTimePtr(now.Add(-12 * time.Hour)), + duration: RateLimitWindow1d, + want: false, + }, + { + name: "expired 1d window (started 25h ago)", + start: rateLimitTimePtr(now.Add(-25 * time.Hour)), + duration: RateLimitWindow1d, + want: true, + }, + { + name: "active 7d window (started 3d ago)", + start: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)), + duration: RateLimitWindow7d, + want: false, + }, + { + name: "expired 7d window (started 8d ago)", + start: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)), + duration: RateLimitWindow7d, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsWindowExpired(tt.start, tt.duration) + if got != tt.want { + t.Errorf("IsWindowExpired() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAPIKey_EffectiveUsage(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + key APIKey + want5h float64 + want1d float64 + want7d float64 + }{ + { + name: "all windows active", + key: APIKey{ + Usage5h: 5.0, + Usage1d: 10.0, + Usage7d: 50.0, + Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)), + Window7dStart: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)), + }, + want5h: 5.0, + want1d: 10.0, + want7d: 50.0, + }, + { + name: "all windows expired", + key: APIKey{ + Usage5h: 5.0, + Usage1d: 10.0, + Usage7d: 50.0, + Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-25 * time.Hour)), + Window7dStart: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)), + }, + want5h: 0, + want1d: 0, + want7d: 0, + }, + { + name: "nil window starts return raw usage", + key: APIKey{ + Usage5h: 5.0, + Usage1d: 10.0, + Usage7d: 50.0, + Window5hStart: nil, + Window1dStart: nil, + Window7dStart: nil, + }, + want5h: 5.0, + want1d: 10.0, + want7d: 50.0, + }, + { + name: "mixed: 5h expired, 1d active, 7d nil", + key: APIKey{ + Usage5h: 5.0, + Usage1d: 10.0, + Usage7d: 50.0, + Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)), + Window7dStart: nil, + }, + want5h: 0, + want1d: 10.0, + want7d: 50.0, + }, + { + name: "zero usage with active windows", + key: APIKey{ + Usage5h: 0, + Usage1d: 0, + Usage7d: 0, + Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)), + Window7dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)), + }, + want5h: 0, + want1d: 0, + want7d: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.key.EffectiveUsage5h(); got != tt.want5h { + t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h) + } + if got := tt.key.EffectiveUsage1d(); got != tt.want1d { + t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d) + } + if got := tt.key.EffectiveUsage7d(); got != tt.want7d { + t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d) + } + }) + } +} + +func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + data APIKeyRateLimitData + want5h float64 + want1d float64 + want7d float64 + }{ + { + name: "all windows active", + data: APIKeyRateLimitData{ + Usage5h: 3.0, + Usage1d: 8.0, + Usage7d: 40.0, + Window5hStart: rateLimitTimePtr(now.Add(-2 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-10 * time.Hour)), + Window7dStart: rateLimitTimePtr(now.Add(-2 * 24 * time.Hour)), + }, + want5h: 3.0, + want1d: 8.0, + want7d: 40.0, + }, + { + name: "all windows expired", + data: APIKeyRateLimitData{ + Usage5h: 3.0, + Usage1d: 8.0, + Usage7d: 40.0, + Window5hStart: rateLimitTimePtr(now.Add(-10 * time.Hour)), + Window1dStart: rateLimitTimePtr(now.Add(-48 * time.Hour)), + Window7dStart: rateLimitTimePtr(now.Add(-10 * 24 * time.Hour)), + }, + want5h: 0, + want1d: 0, + want7d: 0, + }, + { + name: "nil window starts return raw usage", + data: APIKeyRateLimitData{ + Usage5h: 3.0, + Usage1d: 8.0, + Usage7d: 40.0, + Window5hStart: nil, + Window1dStart: nil, + Window7dStart: nil, + }, + want5h: 3.0, + want1d: 8.0, + want7d: 40.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.data.EffectiveUsage5h(); got != tt.want5h { + t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h) + } + if got := tt.data.EffectiveUsage1d(); got != tt.want1d { + t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d) + } + if got := tt.data.EffectiveUsage7d(); got != tt.want7d { + t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d) + } + }) + } +} + +func rateLimitTimePtr(t time.Time) *time.Time { + return &t +} diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index b32a1d67..17c5b486 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -86,6 +86,30 @@ type APIKeyRateLimitData struct { Window7dStart *time.Time } +// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired. +func (d *APIKeyRateLimitData) EffectiveUsage5h() float64 { + if IsWindowExpired(d.Window5hStart, RateLimitWindow5h) { + return 0 + } + return d.Usage5h +} + +// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired. +func (d *APIKeyRateLimitData) EffectiveUsage1d() float64 { + if IsWindowExpired(d.Window1dStart, RateLimitWindow1d) { + return 0 + } + return d.Usage1d +} + +// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired. +func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 { + if IsWindowExpired(d.Window7dStart, RateLimitWindow7d) { + return 0 + } + return d.Usage7d +} + // APIKeyCache defines cache operations for API key service type APIKeyCache interface { GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index e055c0f7..f2ad0a3d 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -565,15 +565,15 @@ func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *AP needsReset := false // Reset expired windows in-memory for check purposes - if w5h != nil && time.Since(*w5h) >= 5*time.Hour { + if IsWindowExpired(w5h, RateLimitWindow5h) { usage5h = 0 needsReset = true } - if w1d != nil && time.Since(*w1d) >= 24*time.Hour { + if IsWindowExpired(w1d, RateLimitWindow1d) { usage1d = 0 needsReset = true } - if w7d != nil && time.Since(*w7d) >= 7*24*time.Hour { + if IsWindowExpired(w7d, RateLimitWindow7d) { usage7d = 0 needsReset = true } @@ -589,12 +589,16 @@ func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *AP if loader, ok := s.apiKeyRateLimitLoader.(interface { ResetRateLimitWindows(ctx context.Context, id int64) error }); ok { - _ = loader.ResetRateLimitWindows(resetCtx, keyID) + if err := loader.ResetRateLimitWindows(resetCtx, keyID); err != nil { + logger.LegacyPrintf("service.billing_cache", "Warning: reset rate limit windows failed for api key %d: %v", keyID, err) + } } } // Invalidate cache so next request loads fresh data if s.cache != nil { - _ = s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID) + if err := s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID); err != nil { + logger.LegacyPrintf("service.billing_cache", "Warning: invalidate rate limit cache failed for api key %d: %v", keyID, err) + } } }() }