refactor: Enhance user context and quota management

- Add new context keys for user-related information
- Modify user cache and authentication middleware to populate context
- Refactor quota and notification services to use context-based user data
- Remove redundant database queries by leveraging context information
- Update various components to use new context-based user retrieval methods
This commit is contained in:
1808837298@qq.com
2025-02-25 20:56:16 +08:00
parent ccf13d445f
commit 069f2672c1
16 changed files with 97 additions and 132 deletions

View File

@@ -2,4 +2,9 @@ package constant
const ( const (
ContextKeyRequestStartTime = "request_start_time" ContextKeyRequestStartTime = "request_start_time"
ContextKeyUserSetting = "user_setting"
ContextKeyUserQuota = "user_quota"
ContextKeyUserStatus = "user_status"
ContextKeyUserEmail = "user_email"
ContextKeyUserGroup = "user_group"
) )

View File

@@ -159,7 +159,7 @@ func UpdateMidjourneyTaskBulk() {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else { } else {
if shouldReturnQuota { if shouldReturnQuota {
err = model.IncreaseUserQuota(task.UserId, task.Quota) err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
if err != nil { if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error()) common.LogError(ctx, "fail to increase user quota: "+err.Error())
} }

View File

@@ -159,7 +159,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
} else { } else {
quota := task.Quota quota := task.Quota
if quota != 0 { if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota) err = model.IncreaseUserQuota(task.UserId, quota, false)
if err != nil { if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error()) common.LogError(ctx, "fail to increase user quota: "+err.Error())
} }

View File

@@ -210,7 +210,7 @@ func EpayNotify(c *gin.Context) {
} }
//user, _ := model.GetUserById(topUp.UserId, false) //user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000 //user.Quota += topUp.Amount * 500000
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit)) err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit), true)
if err != nil { if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp) log.Printf("易支付回调更新用户失败: %v", topUp)
return return

View File

@@ -199,15 +199,19 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return return
} }
userEnabled, err := model.IsUserEnabled(token.UserId, false) userCache, err := model.GetUserCache(token.UserId)
if err != nil { if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error()) abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return return
} }
userEnabled := userCache.Status == common.UserStatusEnabled
if !userEnabled { if !userEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁") abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return return
} }
userCache.WriteContext(c)
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
c.Set("token_key", token.Key) c.Set("token_key", token.Key)

View File

@@ -32,7 +32,6 @@ func Distribute() func(c *gin.Context) {
return return
} }
} }
userId := c.GetInt("id")
var channel *model.Channel var channel *model.Channel
channelId, ok := c.Get("specific_channel_id") channelId, ok := c.Get("specific_channel_id")
modelRequest, shouldSelectChannel, err := getModelRequest(c) modelRequest, shouldSelectChannel, err := getModelRequest(c)
@@ -40,7 +39,7 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error()) abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return return
} }
userGroup, _ := model.GetUserGroup(userId, false) userGroup := c.GetString(constant.ContextKeyUserGroup)
tokenGroup := c.GetString("token_group") tokenGroup := c.GetString("token_group")
if tokenGroup != "" { if tokenGroup != "" {
// check common.UserUsableGroups[userGroup] // check common.UserUsableGroups[userGroup]

View File

@@ -1,8 +1,8 @@
package model package model
import ( import (
"context"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
"os" "os"
"strings" "strings"
@@ -87,14 +87,14 @@ func RecordLog(userId int, logType int, content string) {
} }
} }
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) { isStream bool, group string, other map[string]interface{}) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled { if !common.LogConsumeEnabled {
return return
} }
username, _ := GetUsernameById(userId, false) username := c.GetString("username")
otherStr := common.MapToJsonStr(other) otherStr := common.MapToJsonStr(other)
log := &Log{ log := &Log{
UserId: userId, UserId: userId,
@@ -116,7 +116,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
} }
err := LOG_DB.Create(log).Error err := LOG_DB.Create(log).Error
if err != nil { if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error()) common.LogError(c, "failed to record log: "+err.Error())
} }
if common.DataExportEnabled { if common.DataExportEnabled {
gopool.Go(func() { gopool.Go(func() {

View File

@@ -320,7 +320,7 @@ func (user *User) Insert(inviterId int) error {
} }
if inviterId != 0 { if inviterId != 0 {
if common.QuotaForInvitee > 0 { if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee) _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
} }
if common.QuotaForInviter > 0 { if common.QuotaForInviter > 0 {
@@ -502,35 +502,35 @@ func IsAdmin(userId int) bool {
return user.Role >= common.RoleAdminUser return user.Role >= common.RoleAdminUser
} }
// IsUserEnabled checks user status from Redis first, falls back to DB if needed //// IsUserEnabled checks user status from Redis first, falls back to DB if needed
func IsUserEnabled(id int, fromDB bool) (status bool, err error) { //func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
defer func() { // defer func() {
// Update Redis cache asynchronously on successful DB read // // Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) { // if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() { // gopool.Go(func() {
if err := updateUserStatusCache(id, status); err != nil { // if err := updateUserStatusCache(id, status); err != nil {
common.SysError("failed to update user status cache: " + err.Error()) // common.SysError("failed to update user status cache: " + err.Error())
} // }
}) // })
} // }
}() // }()
if !fromDB && common.RedisEnabled { // if !fromDB && common.RedisEnabled {
// Try Redis first // // Try Redis first
status, err := getUserStatusCache(id) // status, err := getUserStatusCache(id)
if err == nil { // if err == nil {
return status == common.UserStatusEnabled, nil // return status == common.UserStatusEnabled, nil
} // }
// Don't return error - fall through to DB // // Don't return error - fall through to DB
} // }
fromDB = true // fromDB = true
var user User // var user User
err = DB.Where("id = ?", id).Select("status").Find(&user).Error // err = DB.Where("id = ?", id).Select("status").Find(&user).Error
if err != nil { // if err != nil {
return false, err // return false, err
} // }
//
return user.Status == common.UserStatusEnabled, nil // return user.Status == common.UserStatusEnabled, nil
} //}
func ValidateAccessToken(token string) (user *User) { func ValidateAccessToken(token string) (user *User) {
if token == "" { if token == "" {
@@ -639,7 +639,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
return common.StrToMap(setting), nil return common.StrToMap(setting), nil
} }
func IncreaseUserQuota(id int, quota int) (err error) { func IncreaseUserQuota(id int, quota int, db bool) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
@@ -649,7 +649,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
common.SysError("failed to increase user quota: " + err.Error()) common.SysError("failed to increase user quota: " + err.Error())
} }
}) })
if common.BatchUpdateEnabled { if !db && common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota) addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil return nil
} }
@@ -694,7 +694,7 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
return nil return nil
} }
if delta > 0 { if delta > 0 {
return IncreaseUserQuota(id, delta) return IncreaseUserQuota(id, delta, false)
} else { } else {
return DecreaseUserQuota(id, -delta) return DecreaseUserQuota(id, -delta)
} }

View File

@@ -3,6 +3,7 @@ package model
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"time" "time"
@@ -21,6 +22,15 @@ type UserBase struct {
Setting string `json:"setting"` Setting string `json:"setting"`
} }
func (user *UserBase) WriteContext(c *gin.Context) {
c.Set(constant.ContextKeyUserGroup, user.Group)
c.Set(constant.ContextKeyUserQuota, user.Quota)
c.Set(constant.ContextKeyUserStatus, user.Status)
c.Set(constant.ContextKeyUserEmail, user.Email)
c.Set("username", user.Username)
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
}
func (user *UserBase) GetSetting() map[string]interface{} { func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" { if user.Setting == "" {
return nil return nil

View File

@@ -130,7 +130,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo,
if err != nil { if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err) return nil, fmt.Errorf("setup request header failed: %w", err)
} }
resp, err := doRequest(c, req, info.ToRelayInfo()) resp, err := doRequest(c, req, info.RelayInfo)
if err != nil { if err != nil {
return nil, fmt.Errorf("do request failed: %w", err) return nil, fmt.Errorf("do request failed: %w", err)
} }

View File

@@ -50,6 +50,9 @@ type RelayInfo struct {
AudioUsage bool AudioUsage bool
ReasoningEffort string ReasoningEffort string
ChannelSetting map[string]interface{} ChannelSetting map[string]interface{}
UserSetting map[string]interface{}
UserEmail string
UserQuota int
} }
// 定义支持流式选项的通道类型 // 定义支持流式选项的通道类型
@@ -89,6 +92,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
apiType, _ := relayconstant.ChannelType2APIType(channelType) apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &RelayInfo{ info := &RelayInfo{
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
UserEmail: c.GetString(constant.ContextKeyUserEmail),
IsFirstResponse: true, IsFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"), BaseUrl: c.GetString("base_url"),
@@ -148,19 +154,7 @@ func (info *RelayInfo) SetFirstResponseTime() {
} }
type TaskRelayInfo struct { type TaskRelayInfo struct {
ChannelType int *RelayInfo
ChannelId int
TokenId int
UserId int
Group string
StartTime time.Time
ApiType int
RelayMode int
UpstreamModelName string
RequestURLPath string
ApiKey string
BaseUrl string
Action string Action string
OriginTaskID string OriginTaskID string
@@ -168,48 +162,8 @@ type TaskRelayInfo struct {
} }
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo { func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
startTime := time.Now()
apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &TaskRelayInfo{ info := &TaskRelayInfo{
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), RelayInfo: GenRelayInfo(c),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,
TokenId: tokenId,
UserId: userId,
Group: group,
StartTime: startTime,
ApiType: apiType,
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
}
if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType]
} }
return info return info
} }
func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
return &RelayInfo{
ChannelType: info.ChannelType,
ChannelId: info.ChannelId,
TokenId: info.TokenId,
UserId: info.UserId,
Group: info.Group,
StartTime: info.StartTime,
ApiType: info.ApiType,
RelayMode: info.RelayMode,
UpstreamModelName: info.UpstreamModelName,
RequestURLPath: info.RequestURLPath,
ApiKey: info.ApiKey,
BaseUrl: info.BaseUrl,
}
}

View File

@@ -2,7 +2,6 @@ package relay
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -192,7 +191,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
if err != nil { if err != nil {
return &mjResp.Response return &mjResp.Response
} }
defer func(ctx context.Context) { defer func() {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := service.PostConsumeQuota(relayInfo, quota, 0, true) err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil { if err != nil {
@@ -208,14 +207,14 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
other := make(map[string]interface{}) other := make(map[string]interface{})
other["model_price"] = modelPrice other["model_price"] = modelPrice
other["group_ratio"] = groupRatio other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
quota, logContent, tokenId, userQuota, 0, false, group, other) quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota) model.UpdateChannelUsedQuota(channelId, quota)
} }
} }
}(c.Request.Context()) }()
midjResponse := &mjResp.Response midjResponse := &mjResp.Response
midjourneyTask := &model.Midjourney{ midjourneyTask := &model.Midjourney{
UserId: userId, UserId: userId,
@@ -498,7 +497,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
} }
midjResponse := &midjResponseWithStatus.Response midjResponse := &midjResponseWithStatus.Response
defer func(ctx context.Context) { defer func() {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 { if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := service.PostConsumeQuota(relayInfo, quota, 0, true) err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil { if err != nil {
@@ -510,14 +509,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
other := make(map[string]interface{}) other := make(map[string]interface{})
other["model_price"] = modelPrice other["model_price"] = modelPrice
other["group_ratio"] = groupRatio other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
quota, logContent, tokenId, userQuota, 0, false, group, other) quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota) model.UpdateChannelUsedQuota(channelId, quota)
} }
} }
}(c.Request.Context()) }()
// 文档https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md // 文档https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
//1-提交成功 //1-提交成功

View File

@@ -248,6 +248,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if userQuota-preConsumedQuota < 0 { if userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden) return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
} }
relayInfo.UserQuota = userQuota
if userQuota > 100*preConsumedQuota { if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足 // 用户额度充足,判断令牌额度是否充足
if !relayInfo.TokenUnlimited { if !relayInfo.TokenUnlimited {
@@ -267,7 +268,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
} }
if preConsumedQuota > 0 { if preConsumedQuota > 0 {
err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota) err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil { if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
} }

View File

@@ -2,7 +2,6 @@ package relay
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -109,11 +108,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
return return
} }
defer func(ctx context.Context) { defer func() {
// release quota // release quota
if relayInfo.ConsumeQuota && taskErr == nil { if relayInfo.ConsumeQuota && taskErr == nil {
err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true) err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())
} }
@@ -123,13 +122,13 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
other := make(map[string]interface{}) other := make(map[string]interface{})
other["model_price"] = modelPrice other["model_price"] = modelPrice
other["group_ratio"] = groupRatio other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0, model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other) modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
} }
} }
}(c.Request.Context()) }()
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo) taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
if taskErr != nil { if taskErr != nil {

View File

@@ -276,7 +276,7 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu
if quota > 0 { if quota > 0 {
err = model.DecreaseUserQuota(relayInfo.UserId, quota) err = model.DecreaseUserQuota(relayInfo.UserId, quota)
} else { } else {
err = model.IncreaseUserQuota(relayInfo.UserId, -quota) err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false)
} }
if err != nil { if err != nil {
return err return err
@@ -295,20 +295,16 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu
if sendEmail { if sendEmail {
if (quota + preConsumedQuota) != 0 { if (quota + preConsumedQuota) != 0 {
checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota) checkAndSendQuotaNotify(relayInfo, quota, preConsumedQuota)
} }
} }
return nil return nil
} }
func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) { func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int) {
gopool.Go(func() { gopool.Go(func() {
userCache, err := model.GetUserCache(userId) userSetting := relayInfo.UserSetting
if err != nil {
common.SysError("failed to get user cache: " + err.Error())
}
userSetting := userCache.GetSetting()
threshold := common.QuotaRemindThreshold threshold := common.QuotaRemindThreshold
if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok { if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
threshold = int(userCustomThreshold.(float64)) threshold = int(userCustomThreshold.(float64))
@@ -317,16 +313,16 @@ func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0 //noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
quotaTooLow := false quotaTooLow := false
consumeQuota := quota + preConsumedQuota consumeQuota := quota + preConsumedQuota
if userCache.Quota-consumeQuota < threshold { if relayInfo.UserQuota-consumeQuota < threshold {
quotaTooLow = true quotaTooLow = true
} }
if quotaTooLow { if quotaTooLow {
prompt := "您的额度即将用尽" prompt := "您的额度即将用尽"
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress) topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>" content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink})) err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error())) common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
} }
} }
}) })

View File

@@ -11,47 +11,45 @@ import (
func NotifyRootUser(t string, subject string, content string) { func NotifyRootUser(t string, subject string, content string) {
user := model.GetRootUser().ToBaseUser() user := model.GetRootUser().ToBaseUser()
_ = NotifyUser(user, dto.NewNotify(t, subject, content, nil)) _ = NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
} }
func NotifyUser(user *model.UserBase, data dto.Notify) error { func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
userSetting := user.GetSetting()
notifyType, ok := userSetting[constant.UserSettingNotifyType] notifyType, ok := userSetting[constant.UserSettingNotifyType]
if !ok { if !ok {
notifyType = constant.NotifyTypeEmail notifyType = constant.NotifyTypeEmail
} }
// Check notification limit // Check notification limit
canSend, err := CheckNotificationLimit(user.Id, data.Type) canSend, err := CheckNotificationLimit(userId, data.Type)
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error())) common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
return err return err
} }
if !canSend { if !canSend {
return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType) return fmt.Errorf("notification limit exceeded for user %d with type %s", userId, notifyType)
} }
switch notifyType { switch notifyType {
case constant.NotifyTypeEmail: case constant.NotifyTypeEmail:
userEmail := user.Email
// check setting email // check setting email
if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok { if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
userEmail = settingEmail.(string) userEmail = settingEmail.(string)
} }
if userEmail == "" { if userEmail == "" {
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id)) common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
return nil return nil
} }
return sendEmailNotify(userEmail, data) return sendEmailNotify(userEmail, data)
case constant.NotifyTypeWebhook: case constant.NotifyTypeWebhook:
webhookURL, ok := userSetting[constant.UserSettingWebhookUrl] webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
if !ok { if !ok {
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id)) common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
return nil return nil
} }
webhookURLStr, ok := webhookURL.(string) webhookURLStr, ok := webhookURL.(string)
if !ok { if !ok {
common.SysError(fmt.Sprintf("user %d webhook url is not string type", user.Id)) common.SysError(fmt.Sprintf("user %d webhook url is not string type", userId))
return nil return nil
} }