refactor: Simplify root user notification and remove global email variable
- Remove global `RootUserEmail` variable - Modify channel testing and user notification methods to use `GetRootUser()` - Update user cache and notification service to use more consistent user base type - Add new channel test notification type
This commit is contained in:
@@ -101,7 +101,7 @@ var PreConsumedQuota = 500
|
|||||||
|
|
||||||
var RetryTimes = 0
|
var RetryTimes = 0
|
||||||
|
|
||||||
var RootUserEmail = ""
|
//var RootUserEmail = ""
|
||||||
|
|
||||||
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
||||||
|
|
||||||
|
|||||||
@@ -238,9 +238,7 @@ var testAllChannelsLock sync.Mutex
|
|||||||
var testAllChannelsRunning bool = false
|
var testAllChannelsRunning bool = false
|
||||||
|
|
||||||
func testAllChannels(notify bool) error {
|
func testAllChannels(notify bool) error {
|
||||||
if common.RootUserEmail == "" {
|
|
||||||
common.RootUserEmail = model.GetRootUserEmail()
|
|
||||||
}
|
|
||||||
testAllChannelsLock.Lock()
|
testAllChannelsLock.Lock()
|
||||||
if testAllChannelsRunning {
|
if testAllChannelsRunning {
|
||||||
testAllChannelsLock.Unlock()
|
testAllChannelsLock.Unlock()
|
||||||
@@ -295,10 +293,7 @@ func testAllChannels(notify bool) error {
|
|||||||
testAllChannelsRunning = false
|
testAllChannelsRunning = false
|
||||||
testAllChannelsLock.Unlock()
|
testAllChannelsLock.Unlock()
|
||||||
if notify {
|
if notify {
|
||||||
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
|
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -870,9 +870,6 @@ func EmailBind(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if user.Role == common.RoleRootUser {
|
|
||||||
common.RootUserEmail = email
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ const ContentValueParam = "{{value}}"
|
|||||||
const (
|
const (
|
||||||
NotifyTypeQuotaExceed = "quota_exceed"
|
NotifyTypeQuotaExceed = "quota_exceed"
|
||||||
NotifyTypeChannelUpdate = "channel_update"
|
NotifyTypeChannelUpdate = "channel_update"
|
||||||
|
NotifyTypeChannelTest = "channel_test"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewNotify(t string, title string, content string, values []interface{}) Notify {
|
func NewNotify(t string, title string, content string, values []interface{}) Notify {
|
||||||
|
|||||||
@@ -42,6 +42,19 @@ type User struct {
|
|||||||
Setting string `json:"setting" gorm:"type:text;column:setting"`
|
Setting string `json:"setting" gorm:"type:text;column:setting"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (user *User) ToBaseUser() UserBase {
|
||||||
|
cache := UserBase{
|
||||||
|
Id: user.Id,
|
||||||
|
Group: user.Group,
|
||||||
|
Quota: user.Quota,
|
||||||
|
Status: user.Status,
|
||||||
|
Username: user.Username,
|
||||||
|
Setting: user.Setting,
|
||||||
|
Email: user.Email,
|
||||||
|
}
|
||||||
|
return cache
|
||||||
|
}
|
||||||
|
|
||||||
func (user *User) GetAccessToken() string {
|
func (user *User) GetAccessToken() string {
|
||||||
if user.AccessToken == nil {
|
if user.AccessToken == nil {
|
||||||
return ""
|
return ""
|
||||||
@@ -687,9 +700,14 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRootUserEmail() (email string) {
|
//func GetRootUserEmail() (email string) {
|
||||||
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
|
// DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
|
||||||
return email
|
// return email
|
||||||
|
//}
|
||||||
|
|
||||||
|
func GetRootUser() (user *User) {
|
||||||
|
DB.Where("role = ?", common.RoleRootUser).First(&user)
|
||||||
|
return user
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import (
|
|||||||
"github.com/bytedance/gopkg/util/gopool"
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UserCache struct remains the same as it represents the cached data structure
|
// UserBase struct remains the same as it represents the cached data structure
|
||||||
type UserCache struct {
|
type UserBase struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
Group string `json:"group"`
|
Group string `json:"group"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
@@ -21,14 +21,14 @@ type UserCache struct {
|
|||||||
Setting string `json:"setting"`
|
Setting string `json:"setting"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *UserCache) GetSetting() map[string]interface{} {
|
func (user *UserBase) GetSetting() map[string]interface{} {
|
||||||
if user.Setting == "" {
|
if user.Setting == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return common.StrToMap(user.Setting)
|
return common.StrToMap(user.Setting)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *UserCache) SetSetting(setting map[string]interface{}) {
|
func (user *UserBase) SetSetting(setting map[string]interface{}) {
|
||||||
settingBytes, err := json.Marshal(setting)
|
settingBytes, err := json.Marshal(setting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to marshal setting: " + err.Error())
|
common.SysError("failed to marshal setting: " + err.Error())
|
||||||
@@ -56,25 +56,15 @@ func updateUserCache(user User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cache := &UserCache{
|
|
||||||
Id: user.Id,
|
|
||||||
Group: user.Group,
|
|
||||||
Quota: user.Quota,
|
|
||||||
Status: user.Status,
|
|
||||||
Username: user.Username,
|
|
||||||
Setting: user.Setting,
|
|
||||||
Email: user.Email,
|
|
||||||
}
|
|
||||||
|
|
||||||
return common.RedisHSetObj(
|
return common.RedisHSetObj(
|
||||||
getUserCacheKey(user.Id),
|
getUserCacheKey(user.Id),
|
||||||
cache,
|
user.ToBaseUser(),
|
||||||
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserCache gets complete user cache from hash
|
// GetUserCache gets complete user cache from hash
|
||||||
func GetUserCache(userId int) (userCache *UserCache, err error) {
|
func GetUserCache(userId int) (userCache *UserBase, err error) {
|
||||||
var user *User
|
var user *User
|
||||||
var fromDB bool
|
var fromDB bool
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -102,7 +92,7 @@ func GetUserCache(userId int) (userCache *UserCache, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create cache object from user data
|
// Create cache object from user data
|
||||||
userCache = &UserCache{
|
userCache = &UserBase{
|
||||||
Id: user.Id,
|
Id: user.Id,
|
||||||
Group: user.Group,
|
Group: user.Group,
|
||||||
Quota: user.Quota,
|
Quota: user.Quota,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
relaymodel "one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -15,17 +15,17 @@ func DisableChannel(channelId int, channelName string, reason string) {
|
|||||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
|
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
|
||||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||||
notifyRootUser(subject, content)
|
NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
|
||||||
}
|
}
|
||||||
|
|
||||||
func EnableChannel(channelId int, channelName string) {
|
func EnableChannel(channelId int, channelName string) {
|
||||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
|
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
|
||||||
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||||
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||||
notifyRootUser(subject, content)
|
NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatusCode) bool {
|
func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool {
|
||||||
if !common.AutomaticDisableChannelEnabled {
|
if !common.AutomaticDisableChannelEnabled {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -75,7 +75,7 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool {
|
func ShouldEnableChannel(err error, openaiWithStatusErr *dto.OpenAIErrorWithStatusCode, status int) bool {
|
||||||
if !common.AutomaticEnableChannelEnabled {
|
if !common.AutomaticEnableChannelEnabled {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,17 +9,12 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func notifyRootUser(subject string, content string) {
|
func NotifyRootUser(t string, subject string, content string) {
|
||||||
if common.RootUserEmail == "" {
|
user := model.GetRootUser().ToBaseUser()
|
||||||
common.RootUserEmail = model.GetRootUserEmail()
|
_ = NotifyUser(&user, dto.NewNotify(t, subject, content, nil))
|
||||||
}
|
|
||||||
err := common.SendEmail(subject, common.RootUserEmail, content)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NotifyUser(user *model.UserCache, data dto.Notify) error {
|
func NotifyUser(user *model.UserBase, data dto.Notify) error {
|
||||||
userSetting := user.GetSetting()
|
userSetting := user.GetSetting()
|
||||||
notifyType, ok := userSetting[constant.UserSettingNotifyType]
|
notifyType, ok := userSetting[constant.UserSettingNotifyType]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|||||||
Reference in New Issue
Block a user