From 3da134489719118c598ba7fda404e6f78236b992 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Tue, 18 Feb 2025 14:54:21 +0800 Subject: [PATCH] feat: Add user notification settings with quota warning and multiple notification methods - Implement user notification settings with email and webhook options - Add new user settings for quota warning threshold and notification preferences - Create backend API and database support for user notification configuration - Enhance frontend personal settings with notification configuration UI - Support custom notification email and webhook URL - Add service layer for sending user notifications --- common/logger.go | 8 + common/model-ratio.go | 14 - constant/env.go | 2 +- constant/user_setting.go | 14 + controller/user.go | 113 ++++++ docker-compose.yml | 2 +- dto/notify.go | 24 ++ dto/openai_request.go | 1 + model/token.go | 85 +---- model/user.go | 66 +++- model/user_cache.go | 336 +++++++++--------- relay/relay-mj.go | 4 +- relay/relay-text.go | 12 +- relay/relay_task.go | 2 +- router/api-router.go | 1 + service/quota.go | 91 ++++- service/user_notify.go | 43 +++ .../{system-setting.go => system_setting.go} | 0 web/src/components/PersonalSetting.js | 176 +++++++-- web/src/helpers/render.js | 2 +- 20 files changed, 685 insertions(+), 311 deletions(-) create mode 100644 constant/user_setting.go create mode 100644 dto/notify.go rename setting/{system-setting.go => system_setting.go} (100%) diff --git a/common/logger.go b/common/logger.go index 93d557d8..e72a73af 100644 --- a/common/logger.go +++ b/common/logger.go @@ -100,6 +100,14 @@ func LogQuota(quota int) string { } } +func FormatQuota(quota int) string { + if DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit) + } else { + return fmt.Sprintf("%d", quota) + } +} + // LogJson 仅供测试使用 only for test func LogJson(ctx context.Context, msg string, obj any) { jsonStr, err := json.Marshal(obj) diff --git a/common/model-ratio.go b/common/model-ratio.go index bb94ad36..ffeda83d 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -477,20 +477,6 @@ func GetAudioCompletionRatio(name string) float64 { return 2 } -//func GetAudioPricePerMinute(name string) float64 { -// if strings.HasPrefix(name, "gpt-4o-realtime") { -// return 0.06 -// } -// return 0.06 -//} -// -//func GetAudioCompletionPricePerMinute(name string) float64 { -// if strings.HasPrefix(name, "gpt-4o-realtime") { -// return 0.24 -// } -// return 0.24 -//} - func GetCompletionRatioMap() map[string]float64 { if CompletionRatio == nil { CompletionRatio = defaultCompletionRatio diff --git a/constant/env.go b/constant/env.go index 4135e8c7..c0ff5d10 100644 --- a/constant/env.go +++ b/constant/env.go @@ -44,5 +44,5 @@ func InitEnv() { } } -// 是否生成初始令牌,默认关闭。 +// GenerateDefaultToken 是否生成初始令牌,默认关闭。 var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false) diff --git a/constant/user_setting.go b/constant/user_setting.go new file mode 100644 index 00000000..a5b921b2 --- /dev/null +++ b/constant/user_setting.go @@ -0,0 +1,14 @@ +package constant + +var ( + UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型 + UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值 + UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址 + UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥 + UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址 +) + +var ( + NotifyTypeEmail = "email" // Email 邮件 + NotifyTypeWebhook = "webhook" // Webhook +) diff --git a/controller/user.go b/controller/user.go index 7146f00e..f8ce0354 100644 --- a/controller/user.go +++ b/controller/user.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" "one-api/common" "one-api/model" "one-api/setting" @@ -913,3 +914,115 @@ func TopUp(c *gin.Context) { }) return } + +type UpdateUserSettingRequest struct { + QuotaWarningType string `json:"notify_type"` + QuotaWarningThreshold int `json:"quota_warning_threshold"` + WebhookUrl string `json:"webhook_url,omitempty"` + WebhookSecret string `json:"webhook_secret,omitempty"` + NotificationEmail string `json:"notification_email,omitempty"` +} + +func UpdateUserSetting(c *gin.Context) { + var req UpdateUserSettingRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + + // 验证预警类型 + if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的预警类型", + }) + return + } + + // 验证预警阈值 + if req.QuotaWarningThreshold <= 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "预警阈值必须大于0", + }) + return + } + + // 如果是webhook类型,验证webhook地址 + if req.QuotaWarningType == constant.NotifyTypeWebhook { + if req.WebhookUrl == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "Webhook地址不能为空", + }) + return + } + // 验证URL格式 + if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的Webhook地址", + }) + return + } + } + + // 如果是邮件类型,验证邮箱地址 + if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" { + // 验证邮箱格式 + if !strings.Contains(req.NotificationEmail, "@") { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的邮箱地址", + }) + return + } + } + + userId := c.GetInt("id") + user, err := model.GetUserById(userId, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + // 构建设置 + settings := map[string]interface{}{ + constant.UserSettingNotifyType: req.QuotaWarningType, + constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold, + } + + // 如果是webhook类型,添加webhook相关设置 + if req.QuotaWarningType == constant.NotifyTypeWebhook { + settings[constant.UserSettingWebhookUrl] = req.WebhookUrl + if req.WebhookSecret != "" { + settings[constant.UserSettingWebhookSecret] = req.WebhookSecret + } + } + + // 如果提供了通知邮箱,添加到设置中 + if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" { + settings[constant.UserSettingNotificationEmail] = req.NotificationEmail + } + + // 更新用户设置 + user.SetSetting(settings) + if err := user.Update(false); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "更新设置失败: " + err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "设置已更新", + }) +} diff --git a/docker-compose.yml b/docker-compose.yml index 640cf074..0f23cea2 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -24,7 +24,7 @@ services: - redis - mysql healthcheck: - test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ] + test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"] interval: 30s timeout: 10s retries: 3 diff --git a/dto/notify.go b/dto/notify.go new file mode 100644 index 00000000..8594cd32 --- /dev/null +++ b/dto/notify.go @@ -0,0 +1,24 @@ +package dto + +type Notify struct { + Type string `json:"type"` + Title string `json:"title"` + Content string `json:"content"` + Values []interface{} `json:"values"` +} + +const ContentValueParam = "{{value}}" + +const ( + NotifyTypeQuotaExceed = "quota_exceed" + NotifyTypeChannelUpdate = "channel_update" +) + +func NewNotify(t string, title string, content string, values []interface{}) Notify { + return Notify{ + Type: t, + Title: title, + Content: content, + Values: values, + } +} diff --git a/dto/openai_request.go b/dto/openai_request.go index 58a4ce73..a142b437 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -18,6 +18,7 @@ type GeneralOpenAIRequest struct { Model string `json:"model,omitempty"` Messages []Message `json:"messages,omitempty"` Prompt any `json:"prompt,omitempty"` + Prefix any `json:"prefix,omitempty"` Suffix any `json:"suffix,omitempty"` Stream bool `json:"stream,omitempty"` StreamOptions *StreamOptions `json:"stream_options,omitempty"` diff --git a/model/token.go b/model/token.go index 3abd22cf..8587ea62 100644 --- a/model/token.go +++ b/model/token.go @@ -3,13 +3,11 @@ package model import ( "errors" "fmt" + "one-api/common" + "strings" + "github.com/bytedance/gopkg/util/gopool" "gorm.io/gorm" - "one-api/common" - relaycommon "one-api/relay/common" - "one-api/setting" - "strconv" - "strings" ) type Token struct { @@ -322,80 +320,3 @@ func decreaseTokenQuota(id int, quota int) (err error) { ).Error return err } - -func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { - if quota < 0 { - return errors.New("quota 不能为负数!") - } - if relayInfo.IsPlayground { - return nil - } - //if relayInfo.TokenUnlimited { - // return nil - //} - token, err := GetTokenById(relayInfo.TokenId) - if err != nil { - return err - } - if !relayInfo.TokenUnlimited && token.RemainQuota < quota { - return errors.New("令牌额度不足") - } - err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) - if err != nil { - return err - } - return nil -} - -func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) { - - if quota > 0 { - err = DecreaseUserQuota(relayInfo.UserId, quota) - } else { - err = IncreaseUserQuota(relayInfo.UserId, -quota) - } - if err != nil { - return err - } - - if !relayInfo.IsPlayground { - if quota > 0 { - err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) - } else { - err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota) - } - if err != nil { - return err - } - } - - if sendEmail { - if (quota + preConsumedQuota) != 0 { - quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold - noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0 - if quotaTooLow || noMoreQuota { - go func() { - email, err := GetUserEmail(relayInfo.UserId) - if err != nil { - common.SysError("failed to fetch user email: " + err.Error()) - } - prompt := "您的额度即将用尽" - if noMoreQuota { - prompt = "您的额度已用尽" - } - if email != "" { - topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress) - err = common.SendEmail(prompt, email, - fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink)) - if err != nil { - common.SysError("failed to send email" + err.Error()) - } - common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota)) - } - }() - } - } - } - - return nil -} diff --git a/model/user.go b/model/user.go index 95123c21..f21c2a24 100644 --- a/model/user.go +++ b/model/user.go @@ -1,6 +1,7 @@ package model import ( + "encoding/json" "errors" "fmt" "one-api/common" @@ -38,6 +39,7 @@ type User struct { InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` DeletedAt gorm.DeletedAt `gorm:"index"` LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"` + Setting string `json:"setting" gorm:"type:text;column:setting"` } func (user *User) GetAccessToken() string { @@ -51,6 +53,22 @@ func (user *User) SetAccessToken(token string) { user.AccessToken = &token } +func (user *User) GetSetting() map[string]interface{} { + if user.Setting == "" { + return nil + } + return common.StrToMap(user.Setting) +} + +func (user *User) SetSetting(setting map[string]interface{}) { + settingBytes, err := json.Marshal(setting) + if err != nil { + common.SysError("failed to marshal setting: " + err.Error()) + return + } + user.Setting = string(settingBytes) +} + // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil func CheckUserExistOrDeleted(username string, email string) (bool, error) { var user User @@ -315,8 +333,8 @@ func (user *User) Update(updatePassword bool) error { return err } - // 更新缓存 - return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status) + // Update cache + return updateUserCache(*user) } func (user *User) Edit(updatePassword bool) error { @@ -344,8 +362,8 @@ func (user *User) Edit(updatePassword bool) error { return err } - // 更新缓存 - return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status) + // Update cache + return updateUserCache(*user) } func (user *User) Delete() error { @@ -371,8 +389,8 @@ func (user *User) HardDelete() error { // ValidateAndFill check password & user status func (user *User) ValidateAndFill() (err error) { // When querying with struct, GORM will only query with non-zero fields, - // that means if your field’s value is 0, '', false or other zero values, - // it won’t be used to build query conditions + // that means if your field's value is 0, '', false or other zero values, + // it won't be used to build query conditions password := user.Password username := strings.TrimSpace(user.Username) if username == "" || password == "" { @@ -531,7 +549,6 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) { return quota, nil } // Don't return error - fall through to DB - //common.SysError("failed to get user quota from cache: " + err.Error()) } fromDB = true err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error @@ -580,6 +597,35 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) { return group, nil } +// GetUserSetting gets setting from Redis first, falls back to DB if needed +func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) { + var setting string + defer func() { + // Update Redis cache asynchronously on successful DB read + if shouldUpdateRedis(fromDB, err) { + gopool.Go(func() { + if err := updateUserSettingCache(id, setting); err != nil { + common.SysError("failed to update user setting cache: " + err.Error()) + } + }) + } + }() + if !fromDB && common.RedisEnabled { + setting, err := getUserSettingCache(id) + if err == nil { + return setting, nil + } + // Don't return error - fall through to DB + } + fromDB = true + err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error + if err != nil { + return map[string]interface{}{}, err + } + + return common.StrToMap(setting), nil +} + func IncreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") @@ -725,10 +771,10 @@ func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool { return !errors.Is(err, gorm.ErrRecordNotFound) } -func (u *User) FillUserByLinuxDOId() error { - if u.LinuxDOId == "" { +func (user *User) FillUserByLinuxDOId() error { + if user.LinuxDOId == "" { return errors.New("linux do id is empty") } - err := DB.Where("linux_do_id = ?", u.LinuxDOId).First(u).Error + err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error return err } diff --git a/model/user_cache.go b/model/user_cache.go index 9dc7e899..18fd3d2f 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -1,206 +1,210 @@ package model import ( + "encoding/json" "fmt" "one-api/common" "one-api/constant" - "strconv" "time" + + "github.com/bytedance/gopkg/util/gopool" ) -// Change UserCache struct to userCache -type userCache struct { +// UserCache struct remains the same as it represents the cached data structure +type UserCache struct { Id int `json:"id"` Group string `json:"group"` + Email string `json:"email"` Quota int `json:"quota"` Status int `json:"status"` - Role int `json:"role"` Username string `json:"username"` + Setting string `json:"setting"` } -// Rename all exported functions to private ones -// invalidateUserCache clears all user related cache +func (user *UserCache) GetSetting() map[string]interface{} { + if user.Setting == "" { + return nil + } + return common.StrToMap(user.Setting) +} + +func (user *UserCache) SetSetting(setting map[string]interface{}) { + settingBytes, err := json.Marshal(setting) + if err != nil { + common.SysError("failed to marshal setting: " + err.Error()) + return + } + user.Setting = string(settingBytes) +} + +// getUserCacheKey returns the key for user cache +func getUserCacheKey(userId int) string { + return fmt.Sprintf("user:%d", userId) +} + +// invalidateUserCache clears user cache func invalidateUserCache(userId int) error { if !common.RedisEnabled { return nil } + return common.RedisHDelObj(getUserCacheKey(userId)) +} - keys := []string{ - fmt.Sprintf(constant.UserGroupKeyFmt, userId), - fmt.Sprintf(constant.UserQuotaKeyFmt, userId), - fmt.Sprintf(constant.UserEnabledKeyFmt, userId), - fmt.Sprintf(constant.UserUsernameKeyFmt, userId), +// updateUserCache updates all user cache fields using hash +func updateUserCache(user User) error { + if !common.RedisEnabled { + return nil } - for _, key := range keys { - if err := common.RedisDel(key); err != nil { - return fmt.Errorf("failed to delete cache key %s: %w", key, err) + cache := &UserCache{ + Id: user.Id, + Group: user.Group, + Quota: user.Quota, + Status: user.Status, + Username: user.Username, + Setting: user.Setting, + Email: user.Email, + } + + return common.RedisHSetObj( + getUserCacheKey(user.Id), + cache, + time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, + ) +} + +// GetUserCache gets complete user cache from hash +func GetUserCache(userId int) (userCache *UserCache, err error) { + var user *User + var fromDB bool + defer func() { + // Update Redis cache asynchronously on successful DB read + if shouldUpdateRedis(fromDB, err) && user != nil { + gopool.Go(func() { + if err := updateUserCache(*user); err != nil { + common.SysError("failed to update user status cache: " + err.Error()) + } + }) } - } - return nil -} + }() -// updateUserGroupCache updates user group cache -func updateUserGroupCache(userId int, group string) error { - if !common.RedisEnabled { - return nil - } - return common.RedisSet( - fmt.Sprintf(constant.UserGroupKeyFmt, userId), - group, - time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, - ) -} - -// updateUserQuotaCache updates user quota cache -func updateUserQuotaCache(userId int, quota int) error { - if !common.RedisEnabled { - return nil - } - return common.RedisSet( - fmt.Sprintf(constant.UserQuotaKeyFmt, userId), - fmt.Sprintf("%d", quota), - time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, - ) -} - -// updateUserStatusCache updates user status cache -func updateUserStatusCache(userId int, userEnabled bool) error { - if !common.RedisEnabled { - return nil - } - enabled := "0" - if userEnabled { - enabled = "1" - } - return common.RedisSet( - fmt.Sprintf(constant.UserEnabledKeyFmt, userId), - enabled, - time.Duration(constant.UserId2StatusCacheSeconds)*time.Second, - ) -} - -// updateUserNameCache updates username cache -func updateUserNameCache(userId int, username string) error { - if !common.RedisEnabled { - return nil - } - return common.RedisSet( - fmt.Sprintf(constant.UserUsernameKeyFmt, userId), - username, - time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, - ) -} - -// updateUserCache updates all user cache fields -func updateUserCache(userId int, username string, userGroup string, quota int, status int) error { - if !common.RedisEnabled { - return nil + // Try getting from Redis first + err = common.RedisHGetObj(getUserCacheKey(userId), &userCache) + if err == nil { + return userCache, nil } - if err := updateUserGroupCache(userId, userGroup); err != nil { - return fmt.Errorf("update group cache: %w", err) - } - - if err := updateUserQuotaCache(userId, quota); err != nil { - return fmt.Errorf("update quota cache: %w", err) - } - - if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil { - return fmt.Errorf("update status cache: %w", err) - } - - if err := updateUserNameCache(userId, username); err != nil { - return fmt.Errorf("update username cache: %w", err) - } - - return nil -} - -// getUserGroupCache gets user group from cache -func getUserGroupCache(userId int) (string, error) { - if !common.RedisEnabled { - return "", nil - } - return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId)) -} - -// getUserQuotaCache gets user quota from cache -func getUserQuotaCache(userId int) (int, error) { - if !common.RedisEnabled { - return 0, nil - } - quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId)) + // If Redis fails, get from DB + fromDB = true + user, err = GetUserById(userId, false) if err != nil { - return 0, err + return nil, err // Return nil and error if DB lookup fails } - return strconv.Atoi(quotaStr) + + // Create cache object from user data + userCache = &UserCache{ + Id: user.Id, + Group: user.Group, + Quota: user.Quota, + Status: user.Status, + Username: user.Username, + Setting: user.Setting, + Email: user.Email, + } + + return userCache, nil } -// getUserStatusCache gets user status from cache -func getUserStatusCache(userId int) (int, error) { - if !common.RedisEnabled { - return 0, nil - } - statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId)) - if err != nil { - return 0, err - } - return strconv.Atoi(statusStr) -} - -// getUserNameCache gets username from cache -func getUserNameCache(userId int) (string, error) { - if !common.RedisEnabled { - return "", nil - } - return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId)) -} - -// getUserCache gets complete user cache -func getUserCache(userId int) (*userCache, error) { - if !common.RedisEnabled { - return nil, nil - } - - group, err := getUserGroupCache(userId) - if err != nil { - return nil, fmt.Errorf("get group cache: %w", err) - } - - quota, err := getUserQuotaCache(userId) - if err != nil { - return nil, fmt.Errorf("get quota cache: %w", err) - } - - status, err := getUserStatusCache(userId) - if err != nil { - return nil, fmt.Errorf("get status cache: %w", err) - } - - username, err := getUserNameCache(userId) - if err != nil { - return nil, fmt.Errorf("get username cache: %w", err) - } - - return &userCache{ - Id: userId, - Group: group, - Quota: quota, - Status: status, - Username: username, - }, nil -} - -// Add atomic quota operations +// Add atomic quota operations using hash fields func cacheIncrUserQuota(userId int, delta int64) error { if !common.RedisEnabled { return nil } - key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId) - return common.RedisIncr(key, delta) + return common.RedisHIncrBy(getUserCacheKey(userId), "Quota", delta) } func cacheDecrUserQuota(userId int, delta int64) error { return cacheIncrUserQuota(userId, -delta) } + +// Helper functions to get individual fields if needed +func getUserGroupCache(userId int) (string, error) { + cache, err := GetUserCache(userId) + if err != nil { + return "", err + } + return cache.Group, nil +} + +func getUserQuotaCache(userId int) (int, error) { + cache, err := GetUserCache(userId) + if err != nil { + return 0, err + } + return cache.Quota, nil +} + +func getUserStatusCache(userId int) (int, error) { + cache, err := GetUserCache(userId) + if err != nil { + return 0, err + } + return cache.Status, nil +} + +func getUserNameCache(userId int) (string, error) { + cache, err := GetUserCache(userId) + if err != nil { + return "", err + } + return cache.Username, nil +} + +func getUserSettingCache(userId int) (map[string]interface{}, error) { + setting := make(map[string]interface{}) + cache, err := GetUserCache(userId) + if err != nil { + return setting, err + } + return cache.GetSetting(), nil +} + +// New functions for individual field updates +func updateUserStatusCache(userId int, status bool) error { + if !common.RedisEnabled { + return nil + } + statusInt := common.UserStatusEnabled + if !status { + statusInt = common.UserStatusDisabled + } + return common.RedisHSetField(getUserCacheKey(userId), "Status", fmt.Sprintf("%d", statusInt)) +} + +func updateUserQuotaCache(userId int, quota int) error { + if !common.RedisEnabled { + return nil + } + return common.RedisHSetField(getUserCacheKey(userId), "Quota", fmt.Sprintf("%d", quota)) +} + +func updateUserGroupCache(userId int, group string) error { + if !common.RedisEnabled { + return nil + } + return common.RedisHSetField(getUserCacheKey(userId), "Group", group) +} + +func updateUserNameCache(userId int, username string) error { + if !common.RedisEnabled { + return nil + } + return common.RedisHSetField(getUserCacheKey(userId), "Username", username) +} + +func updateUserSettingCache(userId int, setting string) error { + if !common.RedisEnabled { + return nil + } + return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting) +} diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 0facecab..766064cb 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -194,7 +194,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { } defer func(ctx context.Context) { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { - err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true) + err := service.PostConsumeQuota(relayInfo, quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } @@ -500,7 +500,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons defer func(ctx context.Context) { if consumeQuota && midjResponseWithStatus.StatusCode == 200 { - err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true) + err := service.PostConsumeQuota(relayInfo, quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } diff --git a/relay/relay-text.go b/relay/relay-text.go index f303ff6a..f9d1bd03 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -272,7 +272,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } if userQuota-preConsumedQuota < 0 { - return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest) + return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %d", common.FormatQuota(userQuota), preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest) } if userQuota > 100*preConsumedQuota { // 用户额度充足,判断令牌额度是否充足 @@ -282,18 +282,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if tokenQuota > 100*preConsumedQuota { // 令牌额度充足,信任令牌 preConsumedQuota = 0 - common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota)) + common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota)) } } else { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 - common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota)) + common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota))) } } if preConsumedQuota > 0 { - err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota) + err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota) if err != nil { return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) } @@ -310,7 +310,7 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us go func() { relayInfoCopy := *relayInfo - err := model.PostConsumeQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false) + err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) if err != nil { common.SysError("error return pre-consumed quota: " + err.Error()) } @@ -368,7 +368,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN //} quotaDelta := quota - preConsumedQuota if quotaDelta != 0 { - err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) + err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) if err != nil { common.LogError(ctx, "error consuming token remain quota: "+err.Error()) } diff --git a/relay/relay_task.go b/relay/relay_task.go index 61577faf..f03fcb2d 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -113,7 +113,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { // release quota if relayInfo.ConsumeQuota && taskErr == nil { - err := model.PostConsumeQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true) + err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } diff --git a/router/api-router.go b/router/api-router.go index b00595af..bf88449a 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -56,6 +56,7 @@ func SetApiRouter(router *gin.Engine) { selfRoute.POST("/pay", controller.RequestEpay) selfRoute.POST("/amount", controller.RequestAmount) selfRoute.POST("/aff_transfer", controller.TransferAffQuota) + selfRoute.PUT("/setting", controller.UpdateUserSetting) } adminRoute := userRoute.Group("/") diff --git a/service/quota.go b/service/quota.go index ab048008..13ce9763 100644 --- a/service/quota.go +++ b/service/quota.go @@ -3,8 +3,10 @@ package service import ( "errors" "fmt" + "github.com/bytedance/gopkg/util/gopool" "math" "one-api/common" + constant2 "one-api/constant" "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" @@ -99,7 +101,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota)) } - err = model.PostConsumeQuota(relayInfo, 0, quota, 0, false) + err = PostConsumeQuota(relayInfo, quota, 0, false) if err != nil { return err } @@ -222,7 +224,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } else { quotaDelta := quota - preConsumedQuota if quotaDelta != 0 { - err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) + err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) if err != nil { common.LogError(ctx, "error consuming token remain quota: "+err.Error()) } @@ -239,3 +241,88 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) } + +func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { + if quota < 0 { + return errors.New("quota 不能为负数!") + } + if relayInfo.IsPlayground { + return nil + } + //if relayInfo.TokenUnlimited { + // return nil + //} + token, err := model.GetTokenById(relayInfo.TokenId) + if err != nil { + return err + } + if !relayInfo.TokenUnlimited && token.RemainQuota < quota { + return errors.New("令牌额度不足") + } + err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) + if err != nil { + return err + } + return nil +} + +func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) { + + if quota > 0 { + err = model.DecreaseUserQuota(relayInfo.UserId, quota) + } else { + err = model.IncreaseUserQuota(relayInfo.UserId, -quota) + } + if err != nil { + return err + } + + if !relayInfo.IsPlayground { + if quota > 0 { + err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) + } else { + err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota) + } + if err != nil { + return err + } + } + + if sendEmail { + if (quota + preConsumedQuota) != 0 { + checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota) + } + } + + return nil +} + +func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) { + gopool.Go(func() { + userCache, err := model.GetUserCache(userId) + if err != nil { + common.SysError("failed to get user cache: " + err.Error()) + } + userSetting := userCache.GetSetting() + threshold := common.QuotaRemindThreshold + if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok { + threshold = int(userCustomThreshold.(float64)) + } + + //noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0 + quotaTooLow := false + consumeQuota := quota + preConsumedQuota + if userCache.Quota-consumeQuota < threshold { + quotaTooLow = true + } + if quotaTooLow { + prompt := "您的额度即将用尽" + topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress) + content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}" + err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink})) + if err != nil { + common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error())) + } + } + }) +} diff --git a/service/user_notify.go b/service/user_notify.go index 7ae9062b..dd6f5606 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -3,7 +3,10 @@ package service import ( "fmt" "one-api/common" + "one-api/constant" + "one-api/dto" "one-api/model" + "strings" ) func notifyRootUser(subject string, content string) { @@ -15,3 +18,43 @@ func notifyRootUser(subject string, content string) { common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } } + +func NotifyUser(user *model.UserCache, data dto.Notify) error { + userSetting := user.GetSetting() + notifyType, ok := userSetting[constant.UserSettingNotifyType] + if !ok { + notifyType = constant.NotifyTypeEmail + } + switch notifyType { + case constant.NotifyTypeEmail: + userEmail := user.Email + // check setting email + if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok { + userEmail = settingEmail.(string) + } + if userEmail == "" { + common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id)) + return nil + } + return sendEmailNotify(userEmail, data) + case constant.NotifyTypeWebhook: + webhookURL, ok := userSetting[constant.UserSettingWebhookUrl] + if !ok { + common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id)) + return nil + } + // TODO: 实现webhook通知 + _ = webhookURL // 临时处理未使用警告,等待webhook实现 + } + return nil // 添加缺失的return +} + +func sendEmailNotify(userEmail string, data dto.Notify) error { + // make email content + content := data.Content + // 处理占位符 + for _, value := range data.Values { + content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1) + } + return common.SendEmail(data.Title, userEmail, content) +} diff --git a/setting/system-setting.go b/setting/system_setting.go similarity index 100% rename from setting/system-setting.go rename to setting/system_setting.go diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js index 2f112c37..66e2bb26 100644 --- a/web/src/components/PersonalSetting.js +++ b/web/src/components/PersonalSetting.js @@ -26,6 +26,10 @@ import { Tag, Typography, Collapsible, + Select, + Radio, + RadioGroup, + AutoComplete, } from '@douyinfe/semi-ui'; import { getQuotaPerUnit, @@ -67,14 +71,15 @@ const PersonalSetting = () => { const [transferAmount, setTransferAmount] = useState(0); const [isModelsExpanded, setIsModelsExpanded] = useState(false); const MODELS_DISPLAY_COUNT = 10; // 默认显示的模型数量 + const [notificationSettings, setNotificationSettings] = useState({ + warningType: 'email', + warningThreshold: 100000, + webhookUrl: '', + webhookSecret: '', + notificationEmail: '' + }); useEffect(() => { - // let user = localStorage.getItem('user'); - // if (user) { - // userDispatch({ type: 'login', payload: user }); - // } - // console.log(localStorage.getItem('user')) - let status = localStorage.getItem('status'); if (status) { status = JSON.parse(status); @@ -105,6 +110,19 @@ const PersonalSetting = () => { return () => clearInterval(countdownInterval); // Clean up on unmount }, [disableButton, countdown]); + useEffect(() => { + if (userState?.user?.setting) { + const settings = JSON.parse(userState.user.setting); + setNotificationSettings({ + warningType: settings.notify_type || 'email', + warningThreshold: settings.quota_warning_threshold || 500000, + webhookUrl: settings.webhook_url || '', + webhookSecret: settings.webhook_secret || '', + notificationEmail: settings.notification_email || '' + }); + } + }, [userState?.user?.setting]); + const handleInputChange = (name, value) => { setInputs((inputs) => ({...inputs, [name]: value})); }; @@ -300,7 +318,36 @@ const PersonalSetting = () => { } }; + const handleNotificationSettingChange = (type, value) => { + setNotificationSettings(prev => ({ + ...prev, + [type]: value.target ? value.target.value : value // 处理 Radio 事件对象 + })); + }; + + const saveNotificationSettings = async () => { + try { + const res = await API.put('/api/user/setting', { + notify_type: notificationSettings.warningType, + quota_warning_threshold: notificationSettings.warningThreshold, + webhook_url: notificationSettings.webhookUrl, + webhook_secret: notificationSettings.webhookSecret, + notification_email: notificationSettings.notificationEmail + }); + + if (res.data.success) { + showSuccess(t('通知设置已更新')); + await getUserData(); + } else { + showError(res.data.message); + } + } catch (error) { + showError(t('更新通知设置失败')); + } + }; + return ( +
@@ -526,9 +573,7 @@ const PersonalSetting = () => {
{t('微信')} -
+
{
@@ -672,18 +721,8 @@ const PersonalSetting = () => { style={{marginTop: '10px'}} /> )} - {status.wechat_login && ( - - )} setShowWeChatBindModal(false)} - // onOpen={() => setShowWeChatBindModal(true)} visible={showWeChatBindModal} size={'small'} > @@ -707,9 +746,96 @@ const PersonalSetting = () => {
+ + {t('通知设置')} +
+ {t('通知方式')} +
+ handleNotificationSettingChange('warningType', value)} + > + {t('邮件通知')} + {t('Webhook通知')} + +
+
+ {notificationSettings.warningType === 'webhook' && ( + <> +
+ {t('Webhook地址')} +
+ handleNotificationSettingChange('webhookUrl', val)} + placeholder={t('请输入Webhook地址,例如: https://example.com/webhook')} + /> + + {t('系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求')} + +
+
+
+ {t('接口凭证(可选)')} +
+ handleNotificationSettingChange('webhookSecret', val)} + placeholder={t('请输入密钥')} + /> + + {t('密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性')} + + + {t('Authorization: Bearer your-secret-key')} + +
+
+ + )} + {notificationSettings.warningType === 'email' && ( +
+ {t('通知邮箱')} +
+ handleNotificationSettingChange('notificationEmail', val)} + placeholder={t('留空则使用账号绑定的邮箱')} + /> + + {t('设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱')} + +
+
+ )} +
+ {t('额度预警阈值')} {renderQuotaWithPrompt(notificationSettings.warningThreshold)} +
+ handleNotificationSettingChange('warningThreshold', val)} + style={{width: 200}} + placeholder={t('请输入预警额度')} + data={[ + { value: 100000, label: '0.2$' }, + { value: 500000, label: '1$' }, + { value: 1000000, label: '5$' }, + { value: 5000000, label: '10$' } + ]} + /> +
+ + {t('当剩余额度低于此数值时,系统将通过选择的方式发送通知')} + +
+
+ +
+
setShowEmailBindModal(false)} - // onOpen={() => setShowEmailBindModal(true)} onOk={bindEmail} visible={showEmailBindModal} size={'small'} diff --git a/web/src/helpers/render.js b/web/src/helpers/render.js index 1b164037..310fc3ea 100644 --- a/web/src/helpers/render.js +++ b/web/src/helpers/render.js @@ -386,7 +386,7 @@ export function renderQuotaWithPrompt(quota, digits) { let displayInCurrency = localStorage.getItem('display_in_currency'); displayInCurrency = displayInCurrency === 'true'; if (displayInCurrency) { - return '|' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + ''; + return ' | ' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + ''; } return ''; }