From 56f6b2ab56089dbe18a419a18a483b90a0973eba Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Tue, 18 Feb 2025 15:30:43 +0800 Subject: [PATCH] feat: Implement notification rate limiting mechanism - Add in-memory and Redis-based notification rate limiting - Create configurable hourly notification limits - Implement notification limit checking for user notifications - Add environment variables for customizing notification limits --- common/model-ratio.go | 7 +++ constant/env.go | 3 ++ service/notify-limit.go | 116 ++++++++++++++++++++++++++++++++++++++++ service/user_notify.go | 13 ++++- 4 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 service/notify-limit.go diff --git a/common/model-ratio.go b/common/model-ratio.go index 4b64c79f..542cd93c 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -356,6 +356,13 @@ func CompletionRatio2JSONString() string { return string(jsonBytes) } +func UpdateCompletionRatioByJSONString(jsonStr string) error { + CompletionRatioMutex.Lock() + defer CompletionRatioMutex.Unlock() + CompletionRatio = make(map[string]float64) + return json.Unmarshal([]byte(jsonStr), &CompletionRatio) +} + func GetCompletionRatio(name string) float64 { GetCompletionRatioMap() diff --git a/constant/env.go b/constant/env.go index c0ff5d10..2102bb7c 100644 --- a/constant/env.go +++ b/constant/env.go @@ -29,6 +29,9 @@ var GeminiModelMap = map[string]string{ var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) +var DefaultNotifyHourlyLimit = common.GetEnvOrDefault("NOTIFY_HOURLY_LIMIT", 2) +var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10) + func InitEnv() { modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP")) if modelVersionMapStr == "" { diff --git a/service/notify-limit.go b/service/notify-limit.go new file mode 100644 index 00000000..d99f49cc --- /dev/null +++ b/service/notify-limit.go @@ -0,0 +1,116 @@ +package service + +import ( + "fmt" + "one-api/common" + "one-api/constant" + "strconv" + "sync" + "time" +) + +// notifyLimitStore is used for in-memory rate limiting when Redis is disabled +var ( + notifyLimitStore sync.Map + cleanupOnce sync.Once +) + +type limitCount struct { + Count int + Timestamp time.Time +} + +func getDuration() time.Duration { + minute := constant.NotificationLimitDurationMinute + return time.Duration(minute) * time.Minute +} + +// startCleanupTask starts a background task to clean up expired entries +func startCleanupTask() { + go func() { + for { + time.Sleep(time.Hour) + now := time.Now() + notifyLimitStore.Range(func(key, value interface{}) bool { + if limit, ok := value.(limitCount); ok { + if now.Sub(limit.Timestamp) >= getDuration() { + notifyLimitStore.Delete(key) + } + } + return true + }) + } + }() +} + +// CheckNotificationLimit checks if the user has exceeded their notification limit +// Returns true if the user can send notification, false if limit exceeded +func CheckNotificationLimit(userId int, notifyType string) (bool, error) { + if common.RedisEnabled { + return checkRedisLimit(userId, notifyType) + } + return checkMemoryLimit(userId, notifyType) +} + +func checkRedisLimit(userId int, notifyType string) (bool, error) { + key := fmt.Sprintf("notify_limit:%d:%s:%s", userId, notifyType, time.Now().Format("2006010215")) + + // Get current count + count, err := common.RedisGet(key) + if err != nil && err.Error() != "redis: nil" { + return false, fmt.Errorf("failed to get notification count: %w", err) + } + + // If key doesn't exist, initialize it + if count == "" { + err = common.RedisSet(key, "1", getDuration()) + return true, err + } + + currentCount, _ := strconv.Atoi(count) + limit := constant.DefaultNotifyHourlyLimit + + // Check if limit is already reached + if currentCount >= limit { + return false, nil + } + + // Only increment if under limit + err = common.RedisIncr(key, 1) + if err != nil { + return false, fmt.Errorf("failed to increment notification count: %w", err) + } + + return true, nil +} + +func checkMemoryLimit(userId int, notifyType string) (bool, error) { + // Ensure cleanup task is started + cleanupOnce.Do(startCleanupTask) + + key := fmt.Sprintf("%d:%s:%s", userId, notifyType, time.Now().Format("2006010215")) + now := time.Now() + + // Get current limit count or initialize new one + var currentLimit limitCount + if value, ok := notifyLimitStore.Load(key); ok { + currentLimit = value.(limitCount) + // Check if the entry has expired + if now.Sub(currentLimit.Timestamp) >= getDuration() { + currentLimit = limitCount{Count: 0, Timestamp: now} + } + } else { + currentLimit = limitCount{Count: 0, Timestamp: now} + } + + // Increment count + currentLimit.Count++ + + // Check against limits + limit := constant.DefaultNotifyHourlyLimit + + // Store updated count + notifyLimitStore.Store(key, currentLimit) + + return currentLimit.Count <= limit, nil +} diff --git a/service/user_notify.go b/service/user_notify.go index dd6f5606..829fa3e8 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -25,6 +25,17 @@ func NotifyUser(user *model.UserCache, data dto.Notify) error { if !ok { notifyType = constant.NotifyTypeEmail } + + // Check notification limit + canSend, err := CheckNotificationLimit(user.Id, data.Type) + if err != nil { + common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error())) + return err + } + if !canSend { + return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType) + } + switch notifyType { case constant.NotifyTypeEmail: userEmail := user.Email @@ -46,7 +57,7 @@ func NotifyUser(user *model.UserCache, data dto.Notify) error { // TODO: 实现webhook通知 _ = webhookURL // 临时处理未使用警告,等待webhook实现 } - return nil // 添加缺失的return + return nil } func sendEmailNotify(userEmail string, data dto.Notify) error {