From 3da134489719118c598ba7fda404e6f78236b992 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Tue, 18 Feb 2025 14:54:21 +0800 Subject: [PATCH 1/8] feat: Add user notification settings with quota warning and multiple notification methods - Implement user notification settings with email and webhook options - Add new user settings for quota warning threshold and notification preferences - Create backend API and database support for user notification configuration - Enhance frontend personal settings with notification configuration UI - Support custom notification email and webhook URL - Add service layer for sending user notifications --- common/logger.go | 8 + common/model-ratio.go | 14 - constant/env.go | 2 +- constant/user_setting.go | 14 + controller/user.go | 113 ++++++ docker-compose.yml | 2 +- dto/notify.go | 24 ++ dto/openai_request.go | 1 + model/token.go | 85 +---- model/user.go | 66 +++- model/user_cache.go | 336 +++++++++--------- relay/relay-mj.go | 4 +- relay/relay-text.go | 12 +- relay/relay_task.go | 2 +- router/api-router.go | 1 + service/quota.go | 91 ++++- service/user_notify.go | 43 +++ .../{system-setting.go => system_setting.go} | 0 web/src/components/PersonalSetting.js | 176 +++++++-- web/src/helpers/render.js | 2 +- 20 files changed, 685 insertions(+), 311 deletions(-) create mode 100644 constant/user_setting.go create mode 100644 dto/notify.go rename setting/{system-setting.go => system_setting.go} (100%) diff --git a/common/logger.go b/common/logger.go index 93d557d8..e72a73af 100644 --- a/common/logger.go +++ b/common/logger.go @@ -100,6 +100,14 @@ func LogQuota(quota int) string { } } +func FormatQuota(quota int) string { + if DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit) + } else { + return fmt.Sprintf("%d", quota) + } +} + // LogJson 仅供测试使用 only for test func LogJson(ctx context.Context, msg string, obj any) { jsonStr, err := json.Marshal(obj) diff --git a/common/model-ratio.go b/common/model-ratio.go index bb94ad36..ffeda83d 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -477,20 +477,6 @@ func GetAudioCompletionRatio(name string) float64 { return 2 } -//func GetAudioPricePerMinute(name string) float64 { -// if strings.HasPrefix(name, "gpt-4o-realtime") { -// return 0.06 -// } -// return 0.06 -//} -// -//func GetAudioCompletionPricePerMinute(name string) float64 { -// if strings.HasPrefix(name, "gpt-4o-realtime") { -// return 0.24 -// } -// return 0.24 -//} - func GetCompletionRatioMap() map[string]float64 { if CompletionRatio == nil { CompletionRatio = defaultCompletionRatio diff --git a/constant/env.go b/constant/env.go index 4135e8c7..c0ff5d10 100644 --- a/constant/env.go +++ b/constant/env.go @@ -44,5 +44,5 @@ func InitEnv() { } } -// 是否生成初始令牌,默认关闭。 +// GenerateDefaultToken 是否生成初始令牌,默认关闭。 var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false) diff --git a/constant/user_setting.go b/constant/user_setting.go new file mode 100644 index 00000000..a5b921b2 --- /dev/null +++ b/constant/user_setting.go @@ -0,0 +1,14 @@ +package constant + +var ( + UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型 + UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值 + UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址 + UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥 + UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址 +) + +var ( + NotifyTypeEmail = "email" // Email 邮件 + NotifyTypeWebhook = "webhook" // Webhook +) diff --git a/controller/user.go b/controller/user.go index 7146f00e..f8ce0354 100644 --- a/controller/user.go +++ b/controller/user.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" "one-api/common" "one-api/model" "one-api/setting" @@ -913,3 +914,115 @@ func TopUp(c *gin.Context) { }) return } + +type UpdateUserSettingRequest struct { + QuotaWarningType string `json:"notify_type"` + QuotaWarningThreshold int `json:"quota_warning_threshold"` + WebhookUrl string `json:"webhook_url,omitempty"` + WebhookSecret string `json:"webhook_secret,omitempty"` + NotificationEmail string `json:"notification_email,omitempty"` +} + +func UpdateUserSetting(c *gin.Context) { + var req UpdateUserSettingRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + + // 验证预警类型 + if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的预警类型", + }) + return + } + + // 验证预警阈值 + if req.QuotaWarningThreshold <= 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "预警阈值必须大于0", + }) + return + } + + // 如果是webhook类型,验证webhook地址 + if req.QuotaWarningType == constant.NotifyTypeWebhook { + if req.WebhookUrl == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "Webhook地址不能为空", + }) + return + } + // 验证URL格式 + if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的Webhook地址", + }) + return + } + } + + // 如果是邮件类型,验证邮箱地址 + if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" { + // 验证邮箱格式 + if !strings.Contains(req.NotificationEmail, "@") { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的邮箱地址", + }) + return + } + } + + userId := c.GetInt("id") + user, err := model.GetUserById(userId, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + // 构建设置 + settings := map[string]interface{}{ + constant.UserSettingNotifyType: req.QuotaWarningType, + constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold, + } + + // 如果是webhook类型,添加webhook相关设置 + if req.QuotaWarningType == constant.NotifyTypeWebhook { + settings[constant.UserSettingWebhookUrl] = req.WebhookUrl + if req.WebhookSecret != "" { + settings[constant.UserSettingWebhookSecret] = req.WebhookSecret + } + } + + // 如果提供了通知邮箱,添加到设置中 + if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" { + settings[constant.UserSettingNotificationEmail] = req.NotificationEmail + } + + // 更新用户设置 + user.SetSetting(settings) + if err := user.Update(false); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "更新设置失败: " + err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "设置已更新", + }) +} diff --git a/docker-compose.yml b/docker-compose.yml index 640cf074..0f23cea2 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -24,7 +24,7 @@ services: - redis - mysql healthcheck: - test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ] + test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"] interval: 30s timeout: 10s retries: 3 diff --git a/dto/notify.go b/dto/notify.go new file mode 100644 index 00000000..8594cd32 --- /dev/null +++ b/dto/notify.go @@ -0,0 +1,24 @@ +package dto + +type Notify struct { + Type string `json:"type"` + Title string `json:"title"` + Content string `json:"content"` + Values []interface{} `json:"values"` +} + +const ContentValueParam = "{{value}}" + +const ( + NotifyTypeQuotaExceed = "quota_exceed" + NotifyTypeChannelUpdate = "channel_update" +) + +func NewNotify(t string, title string, content string, values []interface{}) Notify { + return Notify{ + Type: t, + Title: title, + Content: content, + Values: values, + } +} diff --git a/dto/openai_request.go b/dto/openai_request.go index 58a4ce73..a142b437 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -18,6 +18,7 @@ type GeneralOpenAIRequest struct { Model string `json:"model,omitempty"` Messages []Message `json:"messages,omitempty"` Prompt any `json:"prompt,omitempty"` + Prefix any `json:"prefix,omitempty"` Suffix any `json:"suffix,omitempty"` Stream bool `json:"stream,omitempty"` StreamOptions *StreamOptions `json:"stream_options,omitempty"` diff --git a/model/token.go b/model/token.go index 3abd22cf..8587ea62 100644 --- a/model/token.go +++ b/model/token.go @@ -3,13 +3,11 @@ package model import ( "errors" "fmt" + "one-api/common" + "strings" + "github.com/bytedance/gopkg/util/gopool" "gorm.io/gorm" - "one-api/common" - relaycommon "one-api/relay/common" - "one-api/setting" - "strconv" - "strings" ) type Token struct { @@ -322,80 +320,3 @@ func decreaseTokenQuota(id int, quota int) (err error) { ).Error return err } - -func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { - if quota < 0 { - return errors.New("quota 不能为负数!") - } - if relayInfo.IsPlayground { - return nil - } - //if relayInfo.TokenUnlimited { - // return nil - //} - token, err := GetTokenById(relayInfo.TokenId) - if err != nil { - return err - } - if !relayInfo.TokenUnlimited && token.RemainQuota < quota { - return errors.New("令牌额度不足") - } - err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) - if err != nil { - return err - } - return nil -} - -func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) { - - if quota > 0 { - err = DecreaseUserQuota(relayInfo.UserId, quota) - } else { - err = IncreaseUserQuota(relayInfo.UserId, -quota) - } - if err != nil { - return err - } - - if !relayInfo.IsPlayground { - if quota > 0 { - err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) - } else { - err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota) - } - if err != nil { - return err - } - } - - if sendEmail { - if (quota + preConsumedQuota) != 0 { - quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold - noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0 - if quotaTooLow || noMoreQuota { - go func() { - email, err := GetUserEmail(relayInfo.UserId) - if err != nil { - common.SysError("failed to fetch user email: " + err.Error()) - } - prompt := "您的额度即将用尽" - if noMoreQuota { - prompt = "您的额度已用尽" - } - if email != "" { - topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress) - err = common.SendEmail(prompt, email, - fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink)) - if err != nil { - common.SysError("failed to send email" + err.Error()) - } - common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota)) - } - }() - } - } - } - - return nil -} diff --git a/model/user.go b/model/user.go index 95123c21..f21c2a24 100644 --- a/model/user.go +++ b/model/user.go @@ -1,6 +1,7 @@ package model import ( + "encoding/json" "errors" "fmt" "one-api/common" @@ -38,6 +39,7 @@ type User struct { InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` DeletedAt gorm.DeletedAt `gorm:"index"` LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"` + Setting string `json:"setting" gorm:"type:text;column:setting"` } func (user *User) GetAccessToken() string { @@ -51,6 +53,22 @@ func (user *User) SetAccessToken(token string) { user.AccessToken = &token } +func (user *User) GetSetting() map[string]interface{} { + if user.Setting == "" { + return nil + } + return common.StrToMap(user.Setting) +} + +func (user *User) SetSetting(setting map[string]interface{}) { + settingBytes, err := json.Marshal(setting) + if err != nil { + common.SysError("failed to marshal setting: " + err.Error()) + return + } + user.Setting = string(settingBytes) +} + // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil func CheckUserExistOrDeleted(username string, email string) (bool, error) { var user User @@ -315,8 +333,8 @@ func (user *User) Update(updatePassword bool) error { return err } - // 更新缓存 - return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status) + // Update cache + return updateUserCache(*user) } func (user *User) Edit(updatePassword bool) error { @@ -344,8 +362,8 @@ func (user *User) Edit(updatePassword bool) error { return err } - // 更新缓存 - return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status) + // Update cache + return updateUserCache(*user) } func (user *User) Delete() error { @@ -371,8 +389,8 @@ func (user *User) HardDelete() error { // ValidateAndFill check password & user status func (user *User) ValidateAndFill() (err error) { // When querying with struct, GORM will only query with non-zero fields, - // that means if your field’s value is 0, '', false or other zero values, - // it won’t be used to build query conditions + // that means if your field's value is 0, '', false or other zero values, + // it won't be used to build query conditions password := user.Password username := strings.TrimSpace(user.Username) if username == "" || password == "" { @@ -531,7 +549,6 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) { return quota, nil } // Don't return error - fall through to DB - //common.SysError("failed to get user quota from cache: " + err.Error()) } fromDB = true err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error @@ -580,6 +597,35 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) { return group, nil } +// GetUserSetting gets setting from Redis first, falls back to DB if needed +func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) { + var setting string + defer func() { + // Update Redis cache asynchronously on successful DB read + if shouldUpdateRedis(fromDB, err) { + gopool.Go(func() { + if err := updateUserSettingCache(id, setting); err != nil { + common.SysError("failed to update user setting cache: " + err.Error()) + } + }) + } + }() + if !fromDB && common.RedisEnabled { + setting, err := getUserSettingCache(id) + if err == nil { + return setting, nil + } + // Don't return error - fall through to DB + } + fromDB = true + err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error + if err != nil { + return map[string]interface{}{}, err + } + + return common.StrToMap(setting), nil +} + func IncreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") @@ -725,10 +771,10 @@ func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool { return !errors.Is(err, gorm.ErrRecordNotFound) } -func (u *User) FillUserByLinuxDOId() error { - if u.LinuxDOId == "" { +func (user *User) FillUserByLinuxDOId() error { + if user.LinuxDOId == "" { return errors.New("linux do id is empty") } - err := DB.Where("linux_do_id = ?", u.LinuxDOId).First(u).Error + err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error return err } diff --git a/model/user_cache.go b/model/user_cache.go index 9dc7e899..18fd3d2f 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -1,206 +1,210 @@ package model import ( + "encoding/json" "fmt" "one-api/common" "one-api/constant" - "strconv" "time" + + "github.com/bytedance/gopkg/util/gopool" ) -// Change UserCache struct to userCache -type userCache struct { +// UserCache struct remains the same as it represents the cached data structure +type UserCache struct { Id int `json:"id"` Group string `json:"group"` + Email string `json:"email"` Quota int `json:"quota"` Status int `json:"status"` - Role int `json:"role"` Username string `json:"username"` + Setting string `json:"setting"` } -// Rename all exported functions to private ones -// invalidateUserCache clears all user related cache +func (user *UserCache) GetSetting() map[string]interface{} { + if user.Setting == "" { + return nil + } + return common.StrToMap(user.Setting) +} + +func (user *UserCache) SetSetting(setting map[string]interface{}) { + settingBytes, err := json.Marshal(setting) + if err != nil { + common.SysError("failed to marshal setting: " + err.Error()) + return + } + user.Setting = string(settingBytes) +} + +// getUserCacheKey returns the key for user cache +func getUserCacheKey(userId int) string { + return fmt.Sprintf("user:%d", userId) +} + +// invalidateUserCache clears user cache func invalidateUserCache(userId int) error { if !common.RedisEnabled { return nil } + return common.RedisHDelObj(getUserCacheKey(userId)) +} - keys := []string{ - fmt.Sprintf(constant.UserGroupKeyFmt, userId), - fmt.Sprintf(constant.UserQuotaKeyFmt, userId), - fmt.Sprintf(constant.UserEnabledKeyFmt, userId), - fmt.Sprintf(constant.UserUsernameKeyFmt, userId), +// updateUserCache updates all user cache fields using hash +func updateUserCache(user User) error { + if !common.RedisEnabled { + return nil } - for _, key := range keys { - if err := common.RedisDel(key); err != nil { - return fmt.Errorf("failed to delete cache key %s: %w", key, err) + cache := &UserCache{ + Id: user.Id, + Group: user.Group, + Quota: user.Quota, + Status: user.Status, + Username: user.Username, + Setting: user.Setting, + Email: user.Email, + } + + return common.RedisHSetObj( + getUserCacheKey(user.Id), + cache, + time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, + ) +} + +// GetUserCache gets complete user cache from hash +func GetUserCache(userId int) (userCache *UserCache, err error) { + var user *User + var fromDB bool + defer func() { + // Update Redis cache asynchronously on successful DB read + if shouldUpdateRedis(fromDB, err) && user != nil { + gopool.Go(func() { + if err := updateUserCache(*user); err != nil { + common.SysError("failed to update user status cache: " + err.Error()) + } + }) } - } - return nil -} + }() -// updateUserGroupCache updates user group cache -func updateUserGroupCache(userId int, group string) error { - if !common.RedisEnabled { - return nil - } - return common.RedisSet( - fmt.Sprintf(constant.UserGroupKeyFmt, userId), - group, - time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, - ) -} - -// updateUserQuotaCache updates user quota cache -func updateUserQuotaCache(userId int, quota int) error { - if !common.RedisEnabled { - return nil - } - return common.RedisSet( - fmt.Sprintf(constant.UserQuotaKeyFmt, userId), - fmt.Sprintf("%d", quota), - time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, - ) -} - -// updateUserStatusCache updates user status cache -func updateUserStatusCache(userId int, userEnabled bool) error { - if !common.RedisEnabled { - return nil - } - enabled := "0" - if userEnabled { - enabled = "1" - } - return common.RedisSet( - fmt.Sprintf(constant.UserEnabledKeyFmt, userId), - enabled, - time.Duration(constant.UserId2StatusCacheSeconds)*time.Second, - ) -} - -// updateUserNameCache updates username cache -func updateUserNameCache(userId int, username string) error { - if !common.RedisEnabled { - return nil - } - return common.RedisSet( - fmt.Sprintf(constant.UserUsernameKeyFmt, userId), - username, - time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, - ) -} - -// updateUserCache updates all user cache fields -func updateUserCache(userId int, username string, userGroup string, quota int, status int) error { - if !common.RedisEnabled { - return nil + // Try getting from Redis first + err = common.RedisHGetObj(getUserCacheKey(userId), &userCache) + if err == nil { + return userCache, nil } - if err := updateUserGroupCache(userId, userGroup); err != nil { - return fmt.Errorf("update group cache: %w", err) - } - - if err := updateUserQuotaCache(userId, quota); err != nil { - return fmt.Errorf("update quota cache: %w", err) - } - - if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil { - return fmt.Errorf("update status cache: %w", err) - } - - if err := updateUserNameCache(userId, username); err != nil { - return fmt.Errorf("update username cache: %w", err) - } - - return nil -} - -// getUserGroupCache gets user group from cache -func getUserGroupCache(userId int) (string, error) { - if !common.RedisEnabled { - return "", nil - } - return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId)) -} - -// getUserQuotaCache gets user quota from cache -func getUserQuotaCache(userId int) (int, error) { - if !common.RedisEnabled { - return 0, nil - } - quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId)) + // If Redis fails, get from DB + fromDB = true + user, err = GetUserById(userId, false) if err != nil { - return 0, err + return nil, err // Return nil and error if DB lookup fails } - return strconv.Atoi(quotaStr) + + // Create cache object from user data + userCache = &UserCache{ + Id: user.Id, + Group: user.Group, + Quota: user.Quota, + Status: user.Status, + Username: user.Username, + Setting: user.Setting, + Email: user.Email, + } + + return userCache, nil } -// getUserStatusCache gets user status from cache -func getUserStatusCache(userId int) (int, error) { - if !common.RedisEnabled { - return 0, nil - } - statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId)) - if err != nil { - return 0, err - } - return strconv.Atoi(statusStr) -} - -// getUserNameCache gets username from cache -func getUserNameCache(userId int) (string, error) { - if !common.RedisEnabled { - return "", nil - } - return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId)) -} - -// getUserCache gets complete user cache -func getUserCache(userId int) (*userCache, error) { - if !common.RedisEnabled { - return nil, nil - } - - group, err := getUserGroupCache(userId) - if err != nil { - return nil, fmt.Errorf("get group cache: %w", err) - } - - quota, err := getUserQuotaCache(userId) - if err != nil { - return nil, fmt.Errorf("get quota cache: %w", err) - } - - status, err := getUserStatusCache(userId) - if err != nil { - return nil, fmt.Errorf("get status cache: %w", err) - } - - username, err := getUserNameCache(userId) - if err != nil { - return nil, fmt.Errorf("get username cache: %w", err) - } - - return &userCache{ - Id: userId, - Group: group, - Quota: quota, - Status: status, - Username: username, - }, nil -} - -// Add atomic quota operations +// Add atomic quota operations using hash fields func cacheIncrUserQuota(userId int, delta int64) error { if !common.RedisEnabled { return nil } - key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId) - return common.RedisIncr(key, delta) + return common.RedisHIncrBy(getUserCacheKey(userId), "Quota", delta) } func cacheDecrUserQuota(userId int, delta int64) error { return cacheIncrUserQuota(userId, -delta) } + +// Helper functions to get individual fields if needed +func getUserGroupCache(userId int) (string, error) { + cache, err := GetUserCache(userId) + if err != nil { + return "", err + } + return cache.Group, nil +} + +func getUserQuotaCache(userId int) (int, error) { + cache, err := GetUserCache(userId) + if err != nil { + return 0, err + } + return cache.Quota, nil +} + +func getUserStatusCache(userId int) (int, error) { + cache, err := GetUserCache(userId) + if err != nil { + return 0, err + } + return cache.Status, nil +} + +func getUserNameCache(userId int) (string, error) { + cache, err := GetUserCache(userId) + if err != nil { + return "", err + } + return cache.Username, nil +} + +func getUserSettingCache(userId int) (map[string]interface{}, error) { + setting := make(map[string]interface{}) + cache, err := GetUserCache(userId) + if err != nil { + return setting, err + } + return cache.GetSetting(), nil +} + +// New functions for individual field updates +func updateUserStatusCache(userId int, status bool) error { + if !common.RedisEnabled { + return nil + } + statusInt := common.UserStatusEnabled + if !status { + statusInt = common.UserStatusDisabled + } + return common.RedisHSetField(getUserCacheKey(userId), "Status", fmt.Sprintf("%d", statusInt)) +} + +func updateUserQuotaCache(userId int, quota int) error { + if !common.RedisEnabled { + return nil + } + return common.RedisHSetField(getUserCacheKey(userId), "Quota", fmt.Sprintf("%d", quota)) +} + +func updateUserGroupCache(userId int, group string) error { + if !common.RedisEnabled { + return nil + } + return common.RedisHSetField(getUserCacheKey(userId), "Group", group) +} + +func updateUserNameCache(userId int, username string) error { + if !common.RedisEnabled { + return nil + } + return common.RedisHSetField(getUserCacheKey(userId), "Username", username) +} + +func updateUserSettingCache(userId int, setting string) error { + if !common.RedisEnabled { + return nil + } + return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting) +} diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 0facecab..766064cb 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -194,7 +194,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { } defer func(ctx context.Context) { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { - err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true) + err := service.PostConsumeQuota(relayInfo, quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } @@ -500,7 +500,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons defer func(ctx context.Context) { if consumeQuota && midjResponseWithStatus.StatusCode == 200 { - err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true) + err := service.PostConsumeQuota(relayInfo, quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } diff --git a/relay/relay-text.go b/relay/relay-text.go index f303ff6a..f9d1bd03 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -272,7 +272,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } if userQuota-preConsumedQuota < 0 { - return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest) + return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %d", common.FormatQuota(userQuota), preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest) } if userQuota > 100*preConsumedQuota { // 用户额度充足,判断令牌额度是否充足 @@ -282,18 +282,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if tokenQuota > 100*preConsumedQuota { // 令牌额度充足,信任令牌 preConsumedQuota = 0 - common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota)) + common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota)) } } else { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 - common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota)) + common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota))) } } if preConsumedQuota > 0 { - err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota) + err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota) if err != nil { return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) } @@ -310,7 +310,7 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us go func() { relayInfoCopy := *relayInfo - err := model.PostConsumeQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false) + err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) if err != nil { common.SysError("error return pre-consumed quota: " + err.Error()) } @@ -368,7 +368,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN //} quotaDelta := quota - preConsumedQuota if quotaDelta != 0 { - err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) + err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) if err != nil { common.LogError(ctx, "error consuming token remain quota: "+err.Error()) } diff --git a/relay/relay_task.go b/relay/relay_task.go index 61577faf..f03fcb2d 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -113,7 +113,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { // release quota if relayInfo.ConsumeQuota && taskErr == nil { - err := model.PostConsumeQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true) + err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } diff --git a/router/api-router.go b/router/api-router.go index b00595af..bf88449a 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -56,6 +56,7 @@ func SetApiRouter(router *gin.Engine) { selfRoute.POST("/pay", controller.RequestEpay) selfRoute.POST("/amount", controller.RequestAmount) selfRoute.POST("/aff_transfer", controller.TransferAffQuota) + selfRoute.PUT("/setting", controller.UpdateUserSetting) } adminRoute := userRoute.Group("/") diff --git a/service/quota.go b/service/quota.go index ab048008..13ce9763 100644 --- a/service/quota.go +++ b/service/quota.go @@ -3,8 +3,10 @@ package service import ( "errors" "fmt" + "github.com/bytedance/gopkg/util/gopool" "math" "one-api/common" + constant2 "one-api/constant" "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" @@ -99,7 +101,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota)) } - err = model.PostConsumeQuota(relayInfo, 0, quota, 0, false) + err = PostConsumeQuota(relayInfo, quota, 0, false) if err != nil { return err } @@ -222,7 +224,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } else { quotaDelta := quota - preConsumedQuota if quotaDelta != 0 { - err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) + err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) if err != nil { common.LogError(ctx, "error consuming token remain quota: "+err.Error()) } @@ -239,3 +241,88 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) } + +func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { + if quota < 0 { + return errors.New("quota 不能为负数!") + } + if relayInfo.IsPlayground { + return nil + } + //if relayInfo.TokenUnlimited { + // return nil + //} + token, err := model.GetTokenById(relayInfo.TokenId) + if err != nil { + return err + } + if !relayInfo.TokenUnlimited && token.RemainQuota < quota { + return errors.New("令牌额度不足") + } + err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) + if err != nil { + return err + } + return nil +} + +func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) { + + if quota > 0 { + err = model.DecreaseUserQuota(relayInfo.UserId, quota) + } else { + err = model.IncreaseUserQuota(relayInfo.UserId, -quota) + } + if err != nil { + return err + } + + if !relayInfo.IsPlayground { + if quota > 0 { + err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) + } else { + err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota) + } + if err != nil { + return err + } + } + + if sendEmail { + if (quota + preConsumedQuota) != 0 { + checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota) + } + } + + return nil +} + +func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) { + gopool.Go(func() { + userCache, err := model.GetUserCache(userId) + if err != nil { + common.SysError("failed to get user cache: " + err.Error()) + } + userSetting := userCache.GetSetting() + threshold := common.QuotaRemindThreshold + if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok { + threshold = int(userCustomThreshold.(float64)) + } + + //noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0 + quotaTooLow := false + consumeQuota := quota + preConsumedQuota + if userCache.Quota-consumeQuota < threshold { + quotaTooLow = true + } + if quotaTooLow { + prompt := "您的额度即将用尽" + topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress) + content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}" + err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink})) + if err != nil { + common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error())) + } + } + }) +} diff --git a/service/user_notify.go b/service/user_notify.go index 7ae9062b..dd6f5606 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -3,7 +3,10 @@ package service import ( "fmt" "one-api/common" + "one-api/constant" + "one-api/dto" "one-api/model" + "strings" ) func notifyRootUser(subject string, content string) { @@ -15,3 +18,43 @@ func notifyRootUser(subject string, content string) { common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } } + +func NotifyUser(user *model.UserCache, data dto.Notify) error { + userSetting := user.GetSetting() + notifyType, ok := userSetting[constant.UserSettingNotifyType] + if !ok { + notifyType = constant.NotifyTypeEmail + } + switch notifyType { + case constant.NotifyTypeEmail: + userEmail := user.Email + // check setting email + if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok { + userEmail = settingEmail.(string) + } + if userEmail == "" { + common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id)) + return nil + } + return sendEmailNotify(userEmail, data) + case constant.NotifyTypeWebhook: + webhookURL, ok := userSetting[constant.UserSettingWebhookUrl] + if !ok { + common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id)) + return nil + } + // TODO: 实现webhook通知 + _ = webhookURL // 临时处理未使用警告,等待webhook实现 + } + return nil // 添加缺失的return +} + +func sendEmailNotify(userEmail string, data dto.Notify) error { + // make email content + content := data.Content + // 处理占位符 + for _, value := range data.Values { + content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1) + } + return common.SendEmail(data.Title, userEmail, content) +} diff --git a/setting/system-setting.go b/setting/system_setting.go similarity index 100% rename from setting/system-setting.go rename to setting/system_setting.go diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js index 2f112c37..66e2bb26 100644 --- a/web/src/components/PersonalSetting.js +++ b/web/src/components/PersonalSetting.js @@ -26,6 +26,10 @@ import { Tag, Typography, Collapsible, + Select, + Radio, + RadioGroup, + AutoComplete, } from '@douyinfe/semi-ui'; import { getQuotaPerUnit, @@ -67,14 +71,15 @@ const PersonalSetting = () => { const [transferAmount, setTransferAmount] = useState(0); const [isModelsExpanded, setIsModelsExpanded] = useState(false); const MODELS_DISPLAY_COUNT = 10; // 默认显示的模型数量 + const [notificationSettings, setNotificationSettings] = useState({ + warningType: 'email', + warningThreshold: 100000, + webhookUrl: '', + webhookSecret: '', + notificationEmail: '' + }); useEffect(() => { - // let user = localStorage.getItem('user'); - // if (user) { - // userDispatch({ type: 'login', payload: user }); - // } - // console.log(localStorage.getItem('user')) - let status = localStorage.getItem('status'); if (status) { status = JSON.parse(status); @@ -105,6 +110,19 @@ const PersonalSetting = () => { return () => clearInterval(countdownInterval); // Clean up on unmount }, [disableButton, countdown]); + useEffect(() => { + if (userState?.user?.setting) { + const settings = JSON.parse(userState.user.setting); + setNotificationSettings({ + warningType: settings.notify_type || 'email', + warningThreshold: settings.quota_warning_threshold || 500000, + webhookUrl: settings.webhook_url || '', + webhookSecret: settings.webhook_secret || '', + notificationEmail: settings.notification_email || '' + }); + } + }, [userState?.user?.setting]); + const handleInputChange = (name, value) => { setInputs((inputs) => ({...inputs, [name]: value})); }; @@ -300,7 +318,36 @@ const PersonalSetting = () => { } }; + const handleNotificationSettingChange = (type, value) => { + setNotificationSettings(prev => ({ + ...prev, + [type]: value.target ? value.target.value : value // 处理 Radio 事件对象 + })); + }; + + const saveNotificationSettings = async () => { + try { + const res = await API.put('/api/user/setting', { + notify_type: notificationSettings.warningType, + quota_warning_threshold: notificationSettings.warningThreshold, + webhook_url: notificationSettings.webhookUrl, + webhook_secret: notificationSettings.webhookSecret, + notification_email: notificationSettings.notificationEmail + }); + + if (res.data.success) { + showSuccess(t('通知设置已更新')); + await getUserData(); + } else { + showError(res.data.message); + } + } catch (error) { + showError(t('更新通知设置失败')); + } + }; + return ( +
@@ -526,9 +573,7 @@ const PersonalSetting = () => {
{t('微信')} -
+
{
@@ -672,18 +721,8 @@ const PersonalSetting = () => { style={{marginTop: '10px'}} /> )} - {status.wechat_login && ( - - )} setShowWeChatBindModal(false)} - // onOpen={() => setShowWeChatBindModal(true)} visible={showWeChatBindModal} size={'small'} > @@ -707,9 +746,96 @@ const PersonalSetting = () => {
+ + {t('通知设置')} +
+ {t('通知方式')} +
+ handleNotificationSettingChange('warningType', value)} + > + {t('邮件通知')} + {t('Webhook通知')} + +
+
+ {notificationSettings.warningType === 'webhook' && ( + <> +
+ {t('Webhook地址')} +
+ handleNotificationSettingChange('webhookUrl', val)} + placeholder={t('请输入Webhook地址,例如: https://example.com/webhook')} + /> + + {t('系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求')} + +
+
+
+ {t('接口凭证(可选)')} +
+ handleNotificationSettingChange('webhookSecret', val)} + placeholder={t('请输入密钥')} + /> + + {t('密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性')} + + + {t('Authorization: Bearer your-secret-key')} + +
+
+ + )} + {notificationSettings.warningType === 'email' && ( +
+ {t('通知邮箱')} +
+ handleNotificationSettingChange('notificationEmail', val)} + placeholder={t('留空则使用账号绑定的邮箱')} + /> + + {t('设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱')} + +
+
+ )} +
+ {t('额度预警阈值')} {renderQuotaWithPrompt(notificationSettings.warningThreshold)} +
+ handleNotificationSettingChange('warningThreshold', val)} + style={{width: 200}} + placeholder={t('请输入预警额度')} + data={[ + { value: 100000, label: '0.2$' }, + { value: 500000, label: '1$' }, + { value: 1000000, label: '5$' }, + { value: 5000000, label: '10$' } + ]} + /> +
+ + {t('当剩余额度低于此数值时,系统将通过选择的方式发送通知')} + +
+
+ +
+
setShowEmailBindModal(false)} - // onOpen={() => setShowEmailBindModal(true)} onOk={bindEmail} visible={showEmailBindModal} size={'small'} diff --git a/web/src/helpers/render.js b/web/src/helpers/render.js index 1b164037..310fc3ea 100644 --- a/web/src/helpers/render.js +++ b/web/src/helpers/render.js @@ -386,7 +386,7 @@ export function renderQuotaWithPrompt(quota, digits) { let displayInCurrency = localStorage.getItem('display_in_currency'); displayInCurrency = displayInCurrency === 'true'; if (displayInCurrency) { - return '|' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + ''; + return ' | ' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + ''; } return ''; } From 9d9c461c48876f0bb503c6446f8b31b4c4a2d9dc Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Tue, 18 Feb 2025 15:01:43 +0800 Subject: [PATCH 2/8] refactor: Improve CompletionRatio handling with thread-safe access and initialization --- common/model-ratio.go | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index ffeda83d..4b64c79f 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -233,7 +233,11 @@ var ( modelRatioMapMutex = sync.RWMutex{} ) -var CompletionRatio map[string]float64 = nil +var ( + CompletionRatio map[string]float64 = nil + CompletionRatioMutex = sync.RWMutex{} +) + var defaultCompletionRatio = map[string]float64{ "gpt-4-gizmo-*": 2, "gpt-4o-gizmo-*": 3, @@ -334,10 +338,17 @@ func GetDefaultModelRatioMap() map[string]float64 { return defaultModelRatio } -func CompletionRatio2JSONString() string { +func GetCompletionRatioMap() map[string]float64 { + CompletionRatioMutex.Lock() + defer CompletionRatioMutex.Unlock() if CompletionRatio == nil { CompletionRatio = defaultCompletionRatio } + return CompletionRatio +} + +func CompletionRatio2JSONString() string { + GetCompletionRatioMap() jsonBytes, err := json.Marshal(CompletionRatio) if err != nil { SysError("error marshalling completion ratio: " + err.Error()) @@ -345,12 +356,9 @@ func CompletionRatio2JSONString() string { return string(jsonBytes) } -func UpdateCompletionRatioByJSONString(jsonStr string) error { - CompletionRatio = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &CompletionRatio) -} - func GetCompletionRatio(name string) float64 { + GetCompletionRatioMap() + if strings.Contains(name, "/") { if ratio, ok := CompletionRatio[name]; ok { return ratio @@ -476,10 +484,3 @@ func GetAudioCompletionRatio(name string) float64 { } return 2 } - -func GetCompletionRatioMap() map[string]float64 { - if CompletionRatio == nil { - CompletionRatio = defaultCompletionRatio - } - return CompletionRatio -} From 56f6b2ab56089dbe18a419a18a483b90a0973eba Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Tue, 18 Feb 2025 15:30:43 +0800 Subject: [PATCH 3/8] feat: Implement notification rate limiting mechanism - Add in-memory and Redis-based notification rate limiting - Create configurable hourly notification limits - Implement notification limit checking for user notifications - Add environment variables for customizing notification limits --- common/model-ratio.go | 7 +++ constant/env.go | 3 ++ service/notify-limit.go | 116 ++++++++++++++++++++++++++++++++++++++++ service/user_notify.go | 13 ++++- 4 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 service/notify-limit.go diff --git a/common/model-ratio.go b/common/model-ratio.go index 4b64c79f..542cd93c 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -356,6 +356,13 @@ func CompletionRatio2JSONString() string { return string(jsonBytes) } +func UpdateCompletionRatioByJSONString(jsonStr string) error { + CompletionRatioMutex.Lock() + defer CompletionRatioMutex.Unlock() + CompletionRatio = make(map[string]float64) + return json.Unmarshal([]byte(jsonStr), &CompletionRatio) +} + func GetCompletionRatio(name string) float64 { GetCompletionRatioMap() diff --git a/constant/env.go b/constant/env.go index c0ff5d10..2102bb7c 100644 --- a/constant/env.go +++ b/constant/env.go @@ -29,6 +29,9 @@ var GeminiModelMap = map[string]string{ var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) +var DefaultNotifyHourlyLimit = common.GetEnvOrDefault("NOTIFY_HOURLY_LIMIT", 2) +var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10) + func InitEnv() { modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP")) if modelVersionMapStr == "" { diff --git a/service/notify-limit.go b/service/notify-limit.go new file mode 100644 index 00000000..d99f49cc --- /dev/null +++ b/service/notify-limit.go @@ -0,0 +1,116 @@ +package service + +import ( + "fmt" + "one-api/common" + "one-api/constant" + "strconv" + "sync" + "time" +) + +// notifyLimitStore is used for in-memory rate limiting when Redis is disabled +var ( + notifyLimitStore sync.Map + cleanupOnce sync.Once +) + +type limitCount struct { + Count int + Timestamp time.Time +} + +func getDuration() time.Duration { + minute := constant.NotificationLimitDurationMinute + return time.Duration(minute) * time.Minute +} + +// startCleanupTask starts a background task to clean up expired entries +func startCleanupTask() { + go func() { + for { + time.Sleep(time.Hour) + now := time.Now() + notifyLimitStore.Range(func(key, value interface{}) bool { + if limit, ok := value.(limitCount); ok { + if now.Sub(limit.Timestamp) >= getDuration() { + notifyLimitStore.Delete(key) + } + } + return true + }) + } + }() +} + +// CheckNotificationLimit checks if the user has exceeded their notification limit +// Returns true if the user can send notification, false if limit exceeded +func CheckNotificationLimit(userId int, notifyType string) (bool, error) { + if common.RedisEnabled { + return checkRedisLimit(userId, notifyType) + } + return checkMemoryLimit(userId, notifyType) +} + +func checkRedisLimit(userId int, notifyType string) (bool, error) { + key := fmt.Sprintf("notify_limit:%d:%s:%s", userId, notifyType, time.Now().Format("2006010215")) + + // Get current count + count, err := common.RedisGet(key) + if err != nil && err.Error() != "redis: nil" { + return false, fmt.Errorf("failed to get notification count: %w", err) + } + + // If key doesn't exist, initialize it + if count == "" { + err = common.RedisSet(key, "1", getDuration()) + return true, err + } + + currentCount, _ := strconv.Atoi(count) + limit := constant.DefaultNotifyHourlyLimit + + // Check if limit is already reached + if currentCount >= limit { + return false, nil + } + + // Only increment if under limit + err = common.RedisIncr(key, 1) + if err != nil { + return false, fmt.Errorf("failed to increment notification count: %w", err) + } + + return true, nil +} + +func checkMemoryLimit(userId int, notifyType string) (bool, error) { + // Ensure cleanup task is started + cleanupOnce.Do(startCleanupTask) + + key := fmt.Sprintf("%d:%s:%s", userId, notifyType, time.Now().Format("2006010215")) + now := time.Now() + + // Get current limit count or initialize new one + var currentLimit limitCount + if value, ok := notifyLimitStore.Load(key); ok { + currentLimit = value.(limitCount) + // Check if the entry has expired + if now.Sub(currentLimit.Timestamp) >= getDuration() { + currentLimit = limitCount{Count: 0, Timestamp: now} + } + } else { + currentLimit = limitCount{Count: 0, Timestamp: now} + } + + // Increment count + currentLimit.Count++ + + // Check against limits + limit := constant.DefaultNotifyHourlyLimit + + // Store updated count + notifyLimitStore.Store(key, currentLimit) + + return currentLimit.Count <= limit, nil +} diff --git a/service/user_notify.go b/service/user_notify.go index dd6f5606..829fa3e8 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -25,6 +25,17 @@ func NotifyUser(user *model.UserCache, data dto.Notify) error { if !ok { notifyType = constant.NotifyTypeEmail } + + // Check notification limit + canSend, err := CheckNotificationLimit(user.Id, data.Type) + if err != nil { + common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error())) + return err + } + if !canSend { + return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType) + } + switch notifyType { case constant.NotifyTypeEmail: userEmail := user.Email @@ -46,7 +57,7 @@ func NotifyUser(user *model.UserCache, data dto.Notify) error { // TODO: 实现webhook通知 _ = webhookURL // 临时处理未使用警告,等待webhook实现 } - return nil // 添加缺失的return + return nil } func sendEmailNotify(userEmail string, data dto.Notify) error { From 0907a078b4e0f2861840a78d9f529d357017a2d5 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Tue, 18 Feb 2025 15:59:17 +0800 Subject: [PATCH 4/8] refactor: Simplify root user notification and remove global email variable - Remove global `RootUserEmail` variable - Modify channel testing and user notification methods to use `GetRootUser()` - Update user cache and notification service to use more consistent user base type - Add new channel test notification type --- common/constants.go | 2 +- controller/channel-test.go | 9 ++------- controller/user.go | 3 --- dto/notify.go | 1 + model/user.go | 24 +++++++++++++++++++++--- model/user_cache.go | 24 +++++++----------------- service/channel.go | 10 +++++----- service/user_notify.go | 13 ++++--------- 8 files changed, 41 insertions(+), 45 deletions(-) diff --git a/common/constants.go b/common/constants.go index f967d066..04fb1b9a 100644 --- a/common/constants.go +++ b/common/constants.go @@ -101,7 +101,7 @@ var PreConsumedQuota = 500 var RetryTimes = 0 -var RootUserEmail = "" +//var RootUserEmail = "" var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" diff --git a/controller/channel-test.go b/controller/channel-test.go index 7e74bec2..4b0cc169 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -238,9 +238,7 @@ var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false func testAllChannels(notify bool) error { - if common.RootUserEmail == "" { - common.RootUserEmail = model.GetRootUserEmail() - } + testAllChannelsLock.Lock() if testAllChannelsRunning { testAllChannelsLock.Unlock() @@ -295,10 +293,7 @@ func testAllChannels(notify bool) error { testAllChannelsRunning = false testAllChannelsLock.Unlock() if notify { - err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") - if err != nil { - common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) - } + service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成") } }) return nil diff --git a/controller/user.go b/controller/user.go index f8ce0354..ac2cc839 100644 --- a/controller/user.go +++ b/controller/user.go @@ -870,9 +870,6 @@ func EmailBind(c *gin.Context) { }) return } - if user.Role == common.RoleRootUser { - common.RootUserEmail = email - } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", diff --git a/dto/notify.go b/dto/notify.go index 8594cd32..b75cec70 100644 --- a/dto/notify.go +++ b/dto/notify.go @@ -12,6 +12,7 @@ const ContentValueParam = "{{value}}" const ( NotifyTypeQuotaExceed = "quota_exceed" NotifyTypeChannelUpdate = "channel_update" + NotifyTypeChannelTest = "channel_test" ) func NewNotify(t string, title string, content string, values []interface{}) Notify { diff --git a/model/user.go b/model/user.go index f21c2a24..5aa0bdd3 100644 --- a/model/user.go +++ b/model/user.go @@ -42,6 +42,19 @@ type User struct { Setting string `json:"setting" gorm:"type:text;column:setting"` } +func (user *User) ToBaseUser() UserBase { + cache := UserBase{ + Id: user.Id, + Group: user.Group, + Quota: user.Quota, + Status: user.Status, + Username: user.Username, + Setting: user.Setting, + Email: user.Email, + } + return cache +} + func (user *User) GetAccessToken() string { if user.AccessToken == nil { return "" @@ -687,9 +700,14 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) { } } -func GetRootUserEmail() (email string) { - DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) - return email +//func GetRootUserEmail() (email string) { +// DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) +// return email +//} + +func GetRootUser() (user *User) { + DB.Where("role = ?", common.RoleRootUser).First(&user) + return user } func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { diff --git a/model/user_cache.go b/model/user_cache.go index 18fd3d2f..38ae0397 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -10,8 +10,8 @@ import ( "github.com/bytedance/gopkg/util/gopool" ) -// UserCache struct remains the same as it represents the cached data structure -type UserCache struct { +// UserBase struct remains the same as it represents the cached data structure +type UserBase struct { Id int `json:"id"` Group string `json:"group"` Email string `json:"email"` @@ -21,14 +21,14 @@ type UserCache struct { Setting string `json:"setting"` } -func (user *UserCache) GetSetting() map[string]interface{} { +func (user *UserBase) GetSetting() map[string]interface{} { if user.Setting == "" { return nil } return common.StrToMap(user.Setting) } -func (user *UserCache) SetSetting(setting map[string]interface{}) { +func (user *UserBase) SetSetting(setting map[string]interface{}) { settingBytes, err := json.Marshal(setting) if err != nil { common.SysError("failed to marshal setting: " + err.Error()) @@ -56,25 +56,15 @@ func updateUserCache(user User) error { return nil } - cache := &UserCache{ - Id: user.Id, - Group: user.Group, - Quota: user.Quota, - Status: user.Status, - Username: user.Username, - Setting: user.Setting, - Email: user.Email, - } - return common.RedisHSetObj( getUserCacheKey(user.Id), - cache, + user.ToBaseUser(), time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, ) } // GetUserCache gets complete user cache from hash -func GetUserCache(userId int) (userCache *UserCache, err error) { +func GetUserCache(userId int) (userCache *UserBase, err error) { var user *User var fromDB bool defer func() { @@ -102,7 +92,7 @@ func GetUserCache(userId int) (userCache *UserCache, err error) { } // Create cache object from user data - userCache = &UserCache{ + userCache = &UserBase{ Id: user.Id, Group: user.Group, Quota: user.Quota, diff --git a/service/channel.go b/service/channel.go index 73545b1e..76bcacf1 100644 --- a/service/channel.go +++ b/service/channel.go @@ -4,7 +4,7 @@ import ( "fmt" "net/http" "one-api/common" - relaymodel "one-api/dto" + "one-api/dto" "one-api/model" "one-api/setting" "strings" @@ -15,17 +15,17 @@ func DisableChannel(channelId int, channelName string, reason string) { model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason) subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) - notifyRootUser(subject, content) + NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate) } func EnableChannel(channelId int, channelName string) { model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "") subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - notifyRootUser(subject, content) + NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate) } -func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatusCode) bool { +func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool { if !common.AutomaticDisableChannelEnabled { return false } @@ -75,7 +75,7 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus return false } -func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool { +func ShouldEnableChannel(err error, openaiWithStatusErr *dto.OpenAIErrorWithStatusCode, status int) bool { if !common.AutomaticEnableChannelEnabled { return false } diff --git a/service/user_notify.go b/service/user_notify.go index 829fa3e8..d8b3c939 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -9,17 +9,12 @@ import ( "strings" ) -func notifyRootUser(subject string, content string) { - if common.RootUserEmail == "" { - common.RootUserEmail = model.GetRootUserEmail() - } - err := common.SendEmail(subject, common.RootUserEmail, content) - if err != nil { - common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) - } +func NotifyRootUser(t string, subject string, content string) { + user := model.GetRootUser().ToBaseUser() + _ = NotifyUser(&user, dto.NewNotify(t, subject, content, nil)) } -func NotifyUser(user *model.UserCache, data dto.Notify) error { +func NotifyUser(user *model.UserBase, data dto.Notify) error { userSetting := user.GetSetting() notifyType, ok := userSetting[constant.UserSettingNotifyType] if !ok { From b1847509a4e5120f55fb308851acb65f823ca8c3 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Wed, 19 Feb 2025 15:12:26 +0800 Subject: [PATCH 5/8] refactor: Optimize user caching and token retrieval methods --- controller/pricing.go | 2 +- controller/user.go | 2 +- model/token_cache.go | 2 +- model/user.go | 4 ++-- model/user_cache.go | 15 ++++++++++++++- service/quota.go | 2 +- service/user_notify.go | 2 +- 7 files changed, 21 insertions(+), 8 deletions(-) diff --git a/controller/pricing.go b/controller/pricing.go index 36caff9d..d7af5a4c 100644 --- a/controller/pricing.go +++ b/controller/pricing.go @@ -17,7 +17,7 @@ func GetPricing(c *gin.Context) { } var group string if exists { - user, err := model.GetUserById(userId.(int), false) + user, err := model.GetUserCache(userId.(int)) if err == nil { group = user.Group } diff --git a/controller/user.go b/controller/user.go index ac2cc839..3002a613 100644 --- a/controller/user.go +++ b/controller/user.go @@ -472,7 +472,7 @@ func GetUserModels(c *gin.Context) { if err != nil { id = c.GetInt("id") } - user, err := model.GetUserById(id, true) + user, err := model.GetUserCache(id) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/model/token_cache.go b/model/token_cache.go index 99b762f5..0fe02fea 100644 --- a/model/token_cache.go +++ b/model/token_cache.go @@ -52,7 +52,7 @@ func cacheSetTokenField(key string, field string, value string) error { func cacheGetTokenByKey(key string) (*Token, error) { hmacKey := common.GenerateHMAC(key) if !common.RedisEnabled { - return nil, nil + return nil, fmt.Errorf("redis is not enabled") } var token Token err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token) diff --git a/model/user.go b/model/user.go index 5aa0bdd3..427b0625 100644 --- a/model/user.go +++ b/model/user.go @@ -42,8 +42,8 @@ type User struct { Setting string `json:"setting" gorm:"type:text;column:setting"` } -func (user *User) ToBaseUser() UserBase { - cache := UserBase{ +func (user *User) ToBaseUser() *UserBase { + cache := &UserBase{ Id: user.Id, Group: user.Group, Quota: user.Quota, diff --git a/model/user_cache.go b/model/user_cache.go index 38ae0397..cc08288d 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -79,7 +79,7 @@ func GetUserCache(userId int) (userCache *UserBase, err error) { }() // Try getting from Redis first - err = common.RedisHGetObj(getUserCacheKey(userId), &userCache) + userCache, err = cacheGetUserBase(userId) if err == nil { return userCache, nil } @@ -105,6 +105,19 @@ func GetUserCache(userId int) (userCache *UserBase, err error) { return userCache, nil } +func cacheGetUserBase(userId int) (*UserBase, error) { + if !common.RedisEnabled { + return nil, fmt.Errorf("redis is not enabled") + } + var userCache UserBase + // Try getting from Redis first + err := common.RedisHGetObj(getUserCacheKey(userId), &userCache) + if err != nil { + return nil, err + } + return &userCache, nil +} + // Add atomic quota operations using hash fields func cacheIncrUserQuota(userId int, delta int64) error { if !common.RedisEnabled { diff --git a/service/quota.go b/service/quota.go index 13ce9763..2ec04fe0 100644 --- a/service/quota.go +++ b/service/quota.go @@ -252,7 +252,7 @@ func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { //if relayInfo.TokenUnlimited { // return nil //} - token, err := model.GetTokenById(relayInfo.TokenId) + token, err := model.GetTokenByKey(relayInfo.TokenKey, false) if err != nil { return err } diff --git a/service/user_notify.go b/service/user_notify.go index d8b3c939..d51bbcec 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -11,7 +11,7 @@ import ( func NotifyRootUser(t string, subject string, content string) { user := model.GetRootUser().ToBaseUser() - _ = NotifyUser(&user, dto.NewNotify(t, subject, content, nil)) + _ = NotifyUser(user, dto.NewNotify(t, subject, content, nil)) } func NotifyUser(user *model.UserBase, data dto.Notify) error { From 4e871507cf810666c20dc025d9b78e7f80e4f085 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Wed, 19 Feb 2025 15:40:54 +0800 Subject: [PATCH 6/8] feat: Implement comprehensive webhook notification system --- relay/channel/cloudflare/adaptor.go | 3 +- service/cf_worker.go | 46 ++++++++-- service/user_notify.go | 15 +++- service/webhook.go | 118 ++++++++++++++++++++++++++ web/src/components/PersonalSetting.js | 28 +++++- 5 files changed, 197 insertions(+), 13 deletions(-) create mode 100644 service/webhook.go diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 75400098..5c2eadc2 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -4,13 +4,14 @@ import ( "bytes" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/relay/constant" + + "github.com/gin-gonic/gin" ) type Adaptor struct { diff --git a/service/cf_worker.go b/service/cf_worker.go index afe65411..40a1e294 100644 --- a/service/cf_worker.go +++ b/service/cf_worker.go @@ -2,6 +2,7 @@ package service import ( "bytes" + "encoding/json" "fmt" "net/http" "one-api/common" @@ -9,19 +10,46 @@ import ( "strings" ) +// WorkerRequest Worker请求的数据结构 +type WorkerRequest struct { + URL string `json:"url"` + Key string `json:"key"` + Method string `json:"method,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + Body json.RawMessage `json:"body,omitempty"` +} + +// DoWorkerRequest 通过Worker发送请求 +func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { + if !setting.EnableWorker() { + return nil, fmt.Errorf("worker not enabled") + } + if !strings.HasPrefix(req.URL, "https") { + return nil, fmt.Errorf("only support https url") + } + + workerUrl := setting.WorkerUrl + if !strings.HasSuffix(workerUrl, "/") { + workerUrl += "/" + } + + // 序列化worker请求数据 + workerPayload, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal worker payload: %v", err) + } + + return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload)) +} + func DoDownloadRequest(originUrl string) (resp *http.Response, err error) { if setting.EnableWorker() { common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl)) - if !strings.HasPrefix(originUrl, "https") { - return nil, fmt.Errorf("only support https url") + req := &WorkerRequest{ + URL: originUrl, + Key: setting.WorkerValidKey, } - workerUrl := setting.WorkerUrl - if !strings.HasSuffix(workerUrl, "/") { - workerUrl += "/" - } - // post request to worker - data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`) - return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data)) + return DoWorkerRequest(req) } else { common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl)) return http.Get(originUrl) diff --git a/service/user_notify.go b/service/user_notify.go index d51bbcec..e01b7aa9 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -49,8 +49,19 @@ func NotifyUser(user *model.UserBase, data dto.Notify) error { common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id)) return nil } - // TODO: 实现webhook通知 - _ = webhookURL // 临时处理未使用警告,等待webhook实现 + webhookURLStr, ok := webhookURL.(string) + if !ok { + common.SysError(fmt.Sprintf("user %d webhook url is not string type", user.Id)) + return nil + } + + // 获取 webhook secret + var webhookSecret string + if secret, ok := userSetting[constant.UserSettingWebhookSecret]; ok { + webhookSecret, _ = secret.(string) + } + + return SendWebhookNotify(webhookURLStr, webhookSecret, data) } return nil } diff --git a/service/webhook.go b/service/webhook.go new file mode 100644 index 00000000..ad2967eb --- /dev/null +++ b/service/webhook.go @@ -0,0 +1,118 @@ +package service + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "one-api/dto" + "one-api/setting" + "time" +) + +// WebhookPayload webhook 通知的负载数据 +type WebhookPayload struct { + Type string `json:"type"` + Title string `json:"title"` + Content string `json:"content"` + Values []interface{} `json:"values,omitempty"` + Timestamp int64 `json:"timestamp"` +} + +// generateSignature 生成 webhook 签名 +func generateSignature(secret string, payload []byte) string { + h := hmac.New(sha256.New, []byte(secret)) + h.Write(payload) + return hex.EncodeToString(h.Sum(nil)) +} + +// SendWebhookNotify 发送 webhook 通知 +func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error { + // 处理占位符 + content := data.Content + for _, value := range data.Values { + content = fmt.Sprintf(content, value) + } + + // 构建 webhook 负载 + payload := WebhookPayload{ + Type: data.Type, + Title: data.Title, + Content: content, + Values: data.Values, + Timestamp: time.Now().Unix(), + } + + // 序列化负载 + payloadBytes, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal webhook payload: %v", err) + } + + // 创建 HTTP 请求 + var req *http.Request + var resp *http.Response + + if setting.EnableWorker() { + // 构建worker请求数据 + workerReq := &WorkerRequest{ + URL: webhookURL, + Key: setting.WorkerValidKey, + Method: http.MethodPost, + Headers: map[string]string{ + "Content-Type": "application/json", + }, + Body: payloadBytes, + } + + // 如果有secret,添加签名到headers + if secret != "" { + signature := generateSignature(secret, payloadBytes) + workerReq.Headers["X-Webhook-Signature"] = signature + workerReq.Headers["Authorization"] = "Bearer " + secret + } + + resp, err = DoWorkerRequest(workerReq) + if err != nil { + return fmt.Errorf("failed to send webhook request through worker: %v", err) + } + defer resp.Body.Close() + + // 检查响应状态 + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode) + } + } else { + req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes)) + if err != nil { + return fmt.Errorf("failed to create webhook request: %v", err) + } + + // 设置请求头 + req.Header.Set("Content-Type", "application/json") + + // 如果有 secret,生成签名 + if secret != "" { + signature := generateSignature(secret, payloadBytes) + req.Header.Set("X-Webhook-Signature", signature) + } + + // 发送请求 + client := GetImpatientHttpClient() + resp, err = client.Do(req) + if err != nil { + return fmt.Errorf("failed to send webhook request: %v", err) + } + defer resp.Body.Close() + + // 检查响应状态 + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode) + } + } + + return nil +} diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js index 66e2bb26..777cf042 100644 --- a/web/src/components/PersonalSetting.js +++ b/web/src/components/PersonalSetting.js @@ -78,6 +78,7 @@ const PersonalSetting = () => { webhookSecret: '', notificationEmail: '' }); + const [showWebhookDocs, setShowWebhookDocs] = useState(false); useEffect(() => { let status = localStorage.getItem('status'); @@ -771,7 +772,32 @@ const PersonalSetting = () => { placeholder={t('请输入Webhook地址,例如: https://example.com/webhook')} /> - {t('系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求')} + {t('只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求')} + + +
setShowWebhookDocs(!showWebhookDocs)}> + {t('Webhook请求结构')} {showWebhookDocs ? '▼' : '▶'} +
+ +
+{`{
+    "type": "quota_exceed",      // 通知类型
+    "title": "标题",             // 通知标题
+    "content": "通知内容",       // 通知内容,支持 {{value}} 变量占位符
+    "values": ["值1", "值2"],    // 按顺序替换content中的 {{value}} 占位符
+    "timestamp": 1739950503      // 时间戳
+}
+
+示例:
+{
+    "type": "quota_exceed",
+    "title": "额度预警通知",
+    "content": "您的额度即将用尽,当前剩余额度为 {{value}}",
+    "values": ["$0.99"],
+    "timestamp": 1739950503
+}`}
+                                                    
+
From 585c19fc7033ad5af8f8bbe937c7ac634e57b6ae Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Wed, 19 Feb 2025 15:45:09 +0800 Subject: [PATCH 7/8] docs: Add proxy usage information note in SystemSetting component --- web/src/components/SystemSetting.js | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/web/src/components/SystemSetting.js b/web/src/components/SystemSetting.js index 1c953f6b..3149f91e 100644 --- a/web/src/components/SystemSetting.js +++ b/web/src/components/SystemSetting.js @@ -368,6 +368,17 @@ const SystemSetting = () => { ) + + 注意:代理功能仅对图片请求和 Webhook 请求生效,不会影响其他 API 请求。如需配置 API 请求代理,请参考 + + {' '}API 代理设置文档 + + 。 + Date: Wed, 19 Feb 2025 15:54:33 +0800 Subject: [PATCH 8/8] chore: update env name and README --- README.en.md | 2 ++ README.md | 3 +++ constant/env.go | 2 +- service/notify-limit.go | 4 ++-- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/README.en.md b/README.en.md index feb4b0bb..446c88f6 100644 --- a/README.en.md +++ b/README.en.md @@ -89,6 +89,8 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model - `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20` - `CRYPTO_SECRET`: Encryption key for encrypting database content - `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview` +- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10` +- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2` ## Deployment diff --git a/README.md b/README.md index cecefca6..e678832d 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,9 @@ - `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`。 - `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。 - `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,如果渠道设置中未指定API版本,则使用此版本,默认为 `2024-12-01-preview` +- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`。 +- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`。 + ## 部署 > [!TIP] diff --git a/constant/env.go b/constant/env.go index 2102bb7c..bffbfeea 100644 --- a/constant/env.go +++ b/constant/env.go @@ -29,7 +29,7 @@ var GeminiModelMap = map[string]string{ var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) -var DefaultNotifyHourlyLimit = common.GetEnvOrDefault("NOTIFY_HOURLY_LIMIT", 2) +var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2) var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10) func InitEnv() { diff --git a/service/notify-limit.go b/service/notify-limit.go index d99f49cc..7bb62f62 100644 --- a/service/notify-limit.go +++ b/service/notify-limit.go @@ -68,7 +68,7 @@ func checkRedisLimit(userId int, notifyType string) (bool, error) { } currentCount, _ := strconv.Atoi(count) - limit := constant.DefaultNotifyHourlyLimit + limit := constant.NotifyLimitCount // Check if limit is already reached if currentCount >= limit { @@ -107,7 +107,7 @@ func checkMemoryLimit(userId int, notifyType string) (bool, error) { currentLimit.Count++ // Check against limits - limit := constant.DefaultNotifyHourlyLimit + limit := constant.NotifyLimitCount // Store updated count notifyLimitStore.Store(key, currentLimit)