From 069f2672c1d0defde9e464e9e71b35a59f97372c Mon Sep 17 00:00:00 2001
From: "1808837298@qq.com" <1808837298@qq.com>
Date: Tue, 25 Feb 2025 20:56:16 +0800
Subject: [PATCH] refactor: Enhance user context and quota management
- Add new context keys for user-related information
- Modify user cache and authentication middleware to populate context
- Refactor quota and notification services to use context-based user data
- Remove redundant database queries by leveraging context information
- Update various components to use new context-based user retrieval methods
---
constant/context_key.go | 5 +++
controller/midjourney.go | 2 +-
controller/task.go | 2 +-
controller/topup.go | 2 +-
middleware/auth.go | 6 +++-
middleware/distributor.go | 3 +-
model/log.go | 10 +++---
model/user.go | 66 ++++++++++++++++++------------------
model/user_cache.go | 10 ++++++
relay/channel/api_request.go | 2 +-
relay/common/relay_info.go | 62 +++++----------------------------
relay/relay-mj.go | 13 ++++---
relay/relay-text.go | 3 +-
relay/relay_task.go | 9 +++--
service/quota.go | 18 ++++------
service/user_notify.go | 16 ++++-----
16 files changed, 97 insertions(+), 132 deletions(-)
diff --git a/constant/context_key.go b/constant/context_key.go
index b02f2d43..4b4d5cae 100644
--- a/constant/context_key.go
+++ b/constant/context_key.go
@@ -2,4 +2,9 @@ package constant
const (
ContextKeyRequestStartTime = "request_start_time"
+ ContextKeyUserSetting = "user_setting"
+ ContextKeyUserQuota = "user_quota"
+ ContextKeyUserStatus = "user_status"
+ ContextKeyUserEmail = "user_email"
+ ContextKeyUserGroup = "user_group"
)
diff --git a/controller/midjourney.go b/controller/midjourney.go
index 2e351535..21027d8f 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -159,7 +159,7 @@ func UpdateMidjourneyTaskBulk() {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else {
if shouldReturnQuota {
- err = model.IncreaseUserQuota(task.UserId, task.Quota)
+ err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
diff --git a/controller/task.go b/controller/task.go
index 928f7ed7..65f79ead 100644
--- a/controller/task.go
+++ b/controller/task.go
@@ -159,7 +159,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
} else {
quota := task.Quota
if quota != 0 {
- err = model.IncreaseUserQuota(task.UserId, quota)
+ err = model.IncreaseUserQuota(task.UserId, quota, false)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
diff --git a/controller/topup.go b/controller/topup.go
index fb51c545..a342ec3a 100644
--- a/controller/topup.go
+++ b/controller/topup.go
@@ -210,7 +210,7 @@ func EpayNotify(c *gin.Context) {
}
//user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000
- err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
+ err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit), true)
if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp)
return
diff --git a/middleware/auth.go b/middleware/auth.go
index 4d879a6c..a589f52c 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -199,15 +199,19 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
}
- userEnabled, err := model.IsUserEnabled(token.UserId, false)
+ userCache, err := model.GetUserCache(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return
}
+ userEnabled := userCache.Status == common.UserStatusEnabled
if !userEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
+
+ userCache.WriteContext(c)
+
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_key", token.Key)
diff --git a/middleware/distributor.go b/middleware/distributor.go
index e0f9342a..49fcf59b 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -32,7 +32,6 @@ func Distribute() func(c *gin.Context) {
return
}
}
- userId := c.GetInt("id")
var channel *model.Channel
channelId, ok := c.Get("specific_channel_id")
modelRequest, shouldSelectChannel, err := getModelRequest(c)
@@ -40,7 +39,7 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
- userGroup, _ := model.GetUserGroup(userId, false)
+ userGroup := c.GetString(constant.ContextKeyUserGroup)
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
diff --git a/model/log.go b/model/log.go
index 82278c60..ed7ec2c7 100644
--- a/model/log.go
+++ b/model/log.go
@@ -1,8 +1,8 @@
package model
import (
- "context"
"fmt"
+ "github.com/gin-gonic/gin"
"one-api/common"
"os"
"strings"
@@ -87,14 +87,14 @@ func RecordLog(userId int, logType int, content string) {
}
}
-func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int,
+func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
- common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
+ common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
return
}
- username, _ := GetUsernameById(userId, false)
+ username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
log := &Log{
UserId: userId,
@@ -116,7 +116,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
}
err := LOG_DB.Create(log).Error
if err != nil {
- common.LogError(ctx, "failed to record log: "+err.Error())
+ common.LogError(c, "failed to record log: "+err.Error())
}
if common.DataExportEnabled {
gopool.Go(func() {
diff --git a/model/user.go b/model/user.go
index 427b0625..524f56b6 100644
--- a/model/user.go
+++ b/model/user.go
@@ -320,7 +320,7 @@ func (user *User) Insert(inviterId int) error {
}
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
- _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
+ _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
@@ -502,35 +502,35 @@ func IsAdmin(userId int) bool {
return user.Role >= common.RoleAdminUser
}
-// IsUserEnabled checks user status from Redis first, falls back to DB if needed
-func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
- defer func() {
- // Update Redis cache asynchronously on successful DB read
- if shouldUpdateRedis(fromDB, err) {
- gopool.Go(func() {
- if err := updateUserStatusCache(id, status); err != nil {
- common.SysError("failed to update user status cache: " + err.Error())
- }
- })
- }
- }()
- if !fromDB && common.RedisEnabled {
- // Try Redis first
- status, err := getUserStatusCache(id)
- if err == nil {
- return status == common.UserStatusEnabled, nil
- }
- // Don't return error - fall through to DB
- }
- fromDB = true
- var user User
- err = DB.Where("id = ?", id).Select("status").Find(&user).Error
- if err != nil {
- return false, err
- }
-
- return user.Status == common.UserStatusEnabled, nil
-}
+//// IsUserEnabled checks user status from Redis first, falls back to DB if needed
+//func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
+// defer func() {
+// // Update Redis cache asynchronously on successful DB read
+// if shouldUpdateRedis(fromDB, err) {
+// gopool.Go(func() {
+// if err := updateUserStatusCache(id, status); err != nil {
+// common.SysError("failed to update user status cache: " + err.Error())
+// }
+// })
+// }
+// }()
+// if !fromDB && common.RedisEnabled {
+// // Try Redis first
+// status, err := getUserStatusCache(id)
+// if err == nil {
+// return status == common.UserStatusEnabled, nil
+// }
+// // Don't return error - fall through to DB
+// }
+// fromDB = true
+// var user User
+// err = DB.Where("id = ?", id).Select("status").Find(&user).Error
+// if err != nil {
+// return false, err
+// }
+//
+// return user.Status == common.UserStatusEnabled, nil
+//}
func ValidateAccessToken(token string) (user *User) {
if token == "" {
@@ -639,7 +639,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
return common.StrToMap(setting), nil
}
-func IncreaseUserQuota(id int, quota int) (err error) {
+func IncreaseUserQuota(id int, quota int, db bool) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
@@ -649,7 +649,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
common.SysError("failed to increase user quota: " + err.Error())
}
})
- if common.BatchUpdateEnabled {
+ if !db && common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil
}
@@ -694,7 +694,7 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
return nil
}
if delta > 0 {
- return IncreaseUserQuota(id, delta)
+ return IncreaseUserQuota(id, delta, false)
} else {
return DecreaseUserQuota(id, -delta)
}
diff --git a/model/user_cache.go b/model/user_cache.go
index cc08288d..bc412e77 100644
--- a/model/user_cache.go
+++ b/model/user_cache.go
@@ -3,6 +3,7 @@ package model
import (
"encoding/json"
"fmt"
+ "github.com/gin-gonic/gin"
"one-api/common"
"one-api/constant"
"time"
@@ -21,6 +22,15 @@ type UserBase struct {
Setting string `json:"setting"`
}
+func (user *UserBase) WriteContext(c *gin.Context) {
+ c.Set(constant.ContextKeyUserGroup, user.Group)
+ c.Set(constant.ContextKeyUserQuota, user.Quota)
+ c.Set(constant.ContextKeyUserStatus, user.Status)
+ c.Set(constant.ContextKeyUserEmail, user.Email)
+ c.Set("username", user.Username)
+ c.Set(constant.ContextKeyUserSetting, user.GetSetting())
+}
+
func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go
index cd1b5153..a60bc6f1 100644
--- a/relay/channel/api_request.go
+++ b/relay/channel/api_request.go
@@ -130,7 +130,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo,
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
- resp, err := doRequest(c, req, info.ToRelayInfo())
+ resp, err := doRequest(c, req, info.RelayInfo)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go
index e1ecd83a..022ab628 100644
--- a/relay/common/relay_info.go
+++ b/relay/common/relay_info.go
@@ -50,6 +50,9 @@ type RelayInfo struct {
AudioUsage bool
ReasoningEffort string
ChannelSetting map[string]interface{}
+ UserSetting map[string]interface{}
+ UserEmail string
+ UserQuota int
}
// 定义支持流式选项的通道类型
@@ -89,6 +92,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &RelayInfo{
+ UserQuota: c.GetInt(constant.ContextKeyUserQuota),
+ UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
+ UserEmail: c.GetString(constant.ContextKeyUserEmail),
IsFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
@@ -148,19 +154,7 @@ func (info *RelayInfo) SetFirstResponseTime() {
}
type TaskRelayInfo struct {
- ChannelType int
- ChannelId int
- TokenId int
- UserId int
- Group string
- StartTime time.Time
- ApiType int
- RelayMode int
- UpstreamModelName string
- RequestURLPath string
- ApiKey string
- BaseUrl string
-
+ *RelayInfo
Action string
OriginTaskID string
@@ -168,48 +162,8 @@ type TaskRelayInfo struct {
}
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
- channelType := c.GetInt("channel_type")
- channelId := c.GetInt("channel_id")
-
- tokenId := c.GetInt("token_id")
- userId := c.GetInt("id")
- group := c.GetString("group")
- startTime := time.Now()
-
- apiType, _ := relayconstant.ChannelType2APIType(channelType)
-
info := &TaskRelayInfo{
- RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
- BaseUrl: c.GetString("base_url"),
- RequestURLPath: c.Request.URL.String(),
- ChannelType: channelType,
- ChannelId: channelId,
- TokenId: tokenId,
- UserId: userId,
- Group: group,
- StartTime: startTime,
- ApiType: apiType,
- ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
- }
- if info.BaseUrl == "" {
- info.BaseUrl = common.ChannelBaseURLs[channelType]
+ RelayInfo: GenRelayInfo(c),
}
return info
}
-
-func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
- return &RelayInfo{
- ChannelType: info.ChannelType,
- ChannelId: info.ChannelId,
- TokenId: info.TokenId,
- UserId: info.UserId,
- Group: info.Group,
- StartTime: info.StartTime,
- ApiType: info.ApiType,
- RelayMode: info.RelayMode,
- UpstreamModelName: info.UpstreamModelName,
- RequestURLPath: info.RequestURLPath,
- ApiKey: info.ApiKey,
- BaseUrl: info.BaseUrl,
- }
-}
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index 766064cb..57de8d10 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -2,7 +2,6 @@ package relay
import (
"bytes"
- "context"
"encoding/json"
"fmt"
"io"
@@ -192,7 +191,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
if err != nil {
return &mjResp.Response
}
- defer func(ctx context.Context) {
+ defer func() {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
@@ -208,14 +207,14 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
- model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
+ model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
- }(c.Request.Context())
+ }()
midjResponse := &mjResp.Response
midjourneyTask := &model.Midjourney{
UserId: userId,
@@ -498,7 +497,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
midjResponse := &midjResponseWithStatus.Response
- defer func(ctx context.Context) {
+ defer func() {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
@@ -510,14 +509,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
- model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
+ model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
- }(c.Request.Context())
+ }()
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
//1-提交成功
diff --git a/relay/relay-text.go b/relay/relay-text.go
index bfd91cdf..9dd72b4e 100644
--- a/relay/relay-text.go
+++ b/relay/relay-text.go
@@ -248,6 +248,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
}
+ relayInfo.UserQuota = userQuota
if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
if !relayInfo.TokenUnlimited {
@@ -267,7 +268,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
}
if preConsumedQuota > 0 {
- err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
+ err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
diff --git a/relay/relay_task.go b/relay/relay_task.go
index f03fcb2d..591ad3bb 100644
--- a/relay/relay_task.go
+++ b/relay/relay_task.go
@@ -2,7 +2,6 @@ package relay
import (
"bytes"
- "context"
"encoding/json"
"errors"
"fmt"
@@ -109,11 +108,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
return
}
- defer func(ctx context.Context) {
+ defer func() {
// release quota
if relayInfo.ConsumeQuota && taskErr == nil {
- err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true)
+ err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
@@ -123,13 +122,13 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
- model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
+ model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
}
- }(c.Request.Context())
+ }()
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
if taskErr != nil {
diff --git a/service/quota.go b/service/quota.go
index 2cae93de..294bc552 100644
--- a/service/quota.go
+++ b/service/quota.go
@@ -276,7 +276,7 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu
if quota > 0 {
err = model.DecreaseUserQuota(relayInfo.UserId, quota)
} else {
- err = model.IncreaseUserQuota(relayInfo.UserId, -quota)
+ err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false)
}
if err != nil {
return err
@@ -295,20 +295,16 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu
if sendEmail {
if (quota + preConsumedQuota) != 0 {
- checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota)
+ checkAndSendQuotaNotify(relayInfo, quota, preConsumedQuota)
}
}
return nil
}
-func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
+func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int) {
gopool.Go(func() {
- userCache, err := model.GetUserCache(userId)
- if err != nil {
- common.SysError("failed to get user cache: " + err.Error())
- }
- userSetting := userCache.GetSetting()
+ userSetting := relayInfo.UserSetting
threshold := common.QuotaRemindThreshold
if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
threshold = int(userCustomThreshold.(float64))
@@ -317,16 +313,16 @@ func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
quotaTooLow := false
consumeQuota := quota + preConsumedQuota
- if userCache.Quota-consumeQuota < threshold {
+ if relayInfo.UserQuota-consumeQuota < threshold {
quotaTooLow = true
}
if quotaTooLow {
prompt := "您的额度即将用尽"
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}"
- err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink}))
+ err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
if err != nil {
- common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error()))
+ common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
}
}
})
diff --git a/service/user_notify.go b/service/user_notify.go
index e01b7aa9..db291f0f 100644
--- a/service/user_notify.go
+++ b/service/user_notify.go
@@ -11,47 +11,45 @@ import (
func NotifyRootUser(t string, subject string, content string) {
user := model.GetRootUser().ToBaseUser()
- _ = NotifyUser(user, dto.NewNotify(t, subject, content, nil))
+ _ = NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
}
-func NotifyUser(user *model.UserBase, data dto.Notify) error {
- userSetting := user.GetSetting()
+func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
notifyType, ok := userSetting[constant.UserSettingNotifyType]
if !ok {
notifyType = constant.NotifyTypeEmail
}
// Check notification limit
- canSend, err := CheckNotificationLimit(user.Id, data.Type)
+ canSend, err := CheckNotificationLimit(userId, data.Type)
if err != nil {
common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
return err
}
if !canSend {
- return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType)
+ return fmt.Errorf("notification limit exceeded for user %d with type %s", userId, notifyType)
}
switch notifyType {
case constant.NotifyTypeEmail:
- userEmail := user.Email
// check setting email
if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
userEmail = settingEmail.(string)
}
if userEmail == "" {
- common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id))
+ common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
return nil
}
return sendEmailNotify(userEmail, data)
case constant.NotifyTypeWebhook:
webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
if !ok {
- common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id))
+ common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
return nil
}
webhookURLStr, ok := webhookURL.(string)
if !ok {
- common.SysError(fmt.Sprintf("user %d webhook url is not string type", user.Id))
+ common.SysError(fmt.Sprintf("user %d webhook url is not string type", userId))
return nil
}