From 0907a078b4e0f2861840a78d9f529d357017a2d5 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Tue, 18 Feb 2025 15:59:17 +0800 Subject: [PATCH] refactor: Simplify root user notification and remove global email variable - Remove global `RootUserEmail` variable - Modify channel testing and user notification methods to use `GetRootUser()` - Update user cache and notification service to use more consistent user base type - Add new channel test notification type --- common/constants.go | 2 +- controller/channel-test.go | 9 ++------- controller/user.go | 3 --- dto/notify.go | 1 + model/user.go | 24 +++++++++++++++++++++--- model/user_cache.go | 24 +++++++----------------- service/channel.go | 10 +++++----- service/user_notify.go | 13 ++++--------- 8 files changed, 41 insertions(+), 45 deletions(-) diff --git a/common/constants.go b/common/constants.go index f967d066..04fb1b9a 100644 --- a/common/constants.go +++ b/common/constants.go @@ -101,7 +101,7 @@ var PreConsumedQuota = 500 var RetryTimes = 0 -var RootUserEmail = "" +//var RootUserEmail = "" var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" diff --git a/controller/channel-test.go b/controller/channel-test.go index 7e74bec2..4b0cc169 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -238,9 +238,7 @@ var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false func testAllChannels(notify bool) error { - if common.RootUserEmail == "" { - common.RootUserEmail = model.GetRootUserEmail() - } + testAllChannelsLock.Lock() if testAllChannelsRunning { testAllChannelsLock.Unlock() @@ -295,10 +293,7 @@ func testAllChannels(notify bool) error { testAllChannelsRunning = false testAllChannelsLock.Unlock() if notify { - err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") - if err != nil { - common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) - } + service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成") } }) return nil diff --git a/controller/user.go b/controller/user.go index f8ce0354..ac2cc839 100644 --- a/controller/user.go +++ b/controller/user.go @@ -870,9 +870,6 @@ func EmailBind(c *gin.Context) { }) return } - if user.Role == common.RoleRootUser { - common.RootUserEmail = email - } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", diff --git a/dto/notify.go b/dto/notify.go index 8594cd32..b75cec70 100644 --- a/dto/notify.go +++ b/dto/notify.go @@ -12,6 +12,7 @@ const ContentValueParam = "{{value}}" const ( NotifyTypeQuotaExceed = "quota_exceed" NotifyTypeChannelUpdate = "channel_update" + NotifyTypeChannelTest = "channel_test" ) func NewNotify(t string, title string, content string, values []interface{}) Notify { diff --git a/model/user.go b/model/user.go index f21c2a24..5aa0bdd3 100644 --- a/model/user.go +++ b/model/user.go @@ -42,6 +42,19 @@ type User struct { Setting string `json:"setting" gorm:"type:text;column:setting"` } +func (user *User) ToBaseUser() UserBase { + cache := UserBase{ + Id: user.Id, + Group: user.Group, + Quota: user.Quota, + Status: user.Status, + Username: user.Username, + Setting: user.Setting, + Email: user.Email, + } + return cache +} + func (user *User) GetAccessToken() string { if user.AccessToken == nil { return "" @@ -687,9 +700,14 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) { } } -func GetRootUserEmail() (email string) { - DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) - return email +//func GetRootUserEmail() (email string) { +// DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) +// return email +//} + +func GetRootUser() (user *User) { + DB.Where("role = ?", common.RoleRootUser).First(&user) + return user } func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { diff --git a/model/user_cache.go b/model/user_cache.go index 18fd3d2f..38ae0397 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -10,8 +10,8 @@ import ( "github.com/bytedance/gopkg/util/gopool" ) -// UserCache struct remains the same as it represents the cached data structure -type UserCache struct { +// UserBase struct remains the same as it represents the cached data structure +type UserBase struct { Id int `json:"id"` Group string `json:"group"` Email string `json:"email"` @@ -21,14 +21,14 @@ type UserCache struct { Setting string `json:"setting"` } -func (user *UserCache) GetSetting() map[string]interface{} { +func (user *UserBase) GetSetting() map[string]interface{} { if user.Setting == "" { return nil } return common.StrToMap(user.Setting) } -func (user *UserCache) SetSetting(setting map[string]interface{}) { +func (user *UserBase) SetSetting(setting map[string]interface{}) { settingBytes, err := json.Marshal(setting) if err != nil { common.SysError("failed to marshal setting: " + err.Error()) @@ -56,25 +56,15 @@ func updateUserCache(user User) error { return nil } - cache := &UserCache{ - Id: user.Id, - Group: user.Group, - Quota: user.Quota, - Status: user.Status, - Username: user.Username, - Setting: user.Setting, - Email: user.Email, - } - return common.RedisHSetObj( getUserCacheKey(user.Id), - cache, + user.ToBaseUser(), time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, ) } // GetUserCache gets complete user cache from hash -func GetUserCache(userId int) (userCache *UserCache, err error) { +func GetUserCache(userId int) (userCache *UserBase, err error) { var user *User var fromDB bool defer func() { @@ -102,7 +92,7 @@ func GetUserCache(userId int) (userCache *UserCache, err error) { } // Create cache object from user data - userCache = &UserCache{ + userCache = &UserBase{ Id: user.Id, Group: user.Group, Quota: user.Quota, diff --git a/service/channel.go b/service/channel.go index 73545b1e..76bcacf1 100644 --- a/service/channel.go +++ b/service/channel.go @@ -4,7 +4,7 @@ import ( "fmt" "net/http" "one-api/common" - relaymodel "one-api/dto" + "one-api/dto" "one-api/model" "one-api/setting" "strings" @@ -15,17 +15,17 @@ func DisableChannel(channelId int, channelName string, reason string) { model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason) subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) - notifyRootUser(subject, content) + NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate) } func EnableChannel(channelId int, channelName string) { model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "") subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - notifyRootUser(subject, content) + NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate) } -func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatusCode) bool { +func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool { if !common.AutomaticDisableChannelEnabled { return false } @@ -75,7 +75,7 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus return false } -func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool { +func ShouldEnableChannel(err error, openaiWithStatusErr *dto.OpenAIErrorWithStatusCode, status int) bool { if !common.AutomaticEnableChannelEnabled { return false } diff --git a/service/user_notify.go b/service/user_notify.go index 829fa3e8..d8b3c939 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -9,17 +9,12 @@ import ( "strings" ) -func notifyRootUser(subject string, content string) { - if common.RootUserEmail == "" { - common.RootUserEmail = model.GetRootUserEmail() - } - err := common.SendEmail(subject, common.RootUserEmail, content) - if err != nil { - common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) - } +func NotifyRootUser(t string, subject string, content string) { + user := model.GetRootUser().ToBaseUser() + _ = NotifyUser(&user, dto.NewNotify(t, subject, content, nil)) } -func NotifyUser(user *model.UserCache, data dto.Notify) error { +func NotifyUser(user *model.UserBase, data dto.Notify) error { userSetting := user.GetSetting() notifyType, ok := userSetting[constant.UserSettingNotifyType] if !ok {