refactor: Enhance user context and quota management
- Add new context keys for user-related information - Modify user cache and authentication middleware to populate context - Refactor quota and notification services to use context-based user data - Remove redundant database queries by leveraging context information - Update various components to use new context-based user retrieval methods
This commit is contained in:
@@ -2,4 +2,9 @@ package constant
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
ContextKeyRequestStartTime = "request_start_time"
|
ContextKeyRequestStartTime = "request_start_time"
|
||||||
|
ContextKeyUserSetting = "user_setting"
|
||||||
|
ContextKeyUserQuota = "user_quota"
|
||||||
|
ContextKeyUserStatus = "user_status"
|
||||||
|
ContextKeyUserEmail = "user_email"
|
||||||
|
ContextKeyUserGroup = "user_group"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||||
} else {
|
} else {
|
||||||
if shouldReturnQuota {
|
if shouldReturnQuota {
|
||||||
err = model.IncreaseUserQuota(task.UserId, task.Quota)
|
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
} else {
|
} else {
|
||||||
quota := task.Quota
|
quota := task.Quota
|
||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
err = model.IncreaseUserQuota(task.UserId, quota)
|
err = model.IncreaseUserQuota(task.UserId, quota, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -210,7 +210,7 @@ func EpayNotify(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
//user, _ := model.GetUserById(topUp.UserId, false)
|
//user, _ := model.GetUserById(topUp.UserId, false)
|
||||||
//user.Quota += topUp.Amount * 500000
|
//user.Quota += topUp.Amount * 500000
|
||||||
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
|
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("易支付回调更新用户失败: %v", topUp)
|
log.Printf("易支付回调更新用户失败: %v", topUp)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -199,15 +199,19 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userEnabled, err := model.IsUserEnabled(token.UserId, false)
|
userCache, err := model.GetUserCache(token.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
userEnabled := userCache.Status == common.UserStatusEnabled
|
||||||
if !userEnabled {
|
if !userEnabled {
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
|
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userCache.WriteContext(c)
|
||||||
|
|
||||||
c.Set("id", token.UserId)
|
c.Set("id", token.UserId)
|
||||||
c.Set("token_id", token.Id)
|
c.Set("token_id", token.Id)
|
||||||
c.Set("token_key", token.Key)
|
c.Set("token_key", token.Key)
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ func Distribute() func(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
userId := c.GetInt("id")
|
|
||||||
var channel *model.Channel
|
var channel *model.Channel
|
||||||
channelId, ok := c.Get("specific_channel_id")
|
channelId, ok := c.Get("specific_channel_id")
|
||||||
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
||||||
@@ -40,7 +39,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userGroup, _ := model.GetUserGroup(userId, false)
|
userGroup := c.GetString(constant.ContextKeyUserGroup)
|
||||||
tokenGroup := c.GetString("token_group")
|
tokenGroup := c.GetString("token_group")
|
||||||
if tokenGroup != "" {
|
if tokenGroup != "" {
|
||||||
// check common.UserUsableGroups[userGroup]
|
// check common.UserUsableGroups[userGroup]
|
||||||
|
|||||||
10
model/log.go
10
model/log.go
@@ -1,8 +1,8 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -87,14 +87,14 @@ func RecordLog(userId int, logType int, content string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int,
|
func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
|
||||||
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
|
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
|
||||||
isStream bool, group string, other map[string]interface{}) {
|
isStream bool, group string, other map[string]interface{}) {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||||
if !common.LogConsumeEnabled {
|
if !common.LogConsumeEnabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
username, _ := GetUsernameById(userId, false)
|
username := c.GetString("username")
|
||||||
otherStr := common.MapToJsonStr(other)
|
otherStr := common.MapToJsonStr(other)
|
||||||
log := &Log{
|
log := &Log{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
@@ -116,7 +116,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
|||||||
}
|
}
|
||||||
err := LOG_DB.Create(log).Error
|
err := LOG_DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "failed to record log: "+err.Error())
|
common.LogError(c, "failed to record log: "+err.Error())
|
||||||
}
|
}
|
||||||
if common.DataExportEnabled {
|
if common.DataExportEnabled {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
|
|||||||
@@ -320,7 +320,7 @@ func (user *User) Insert(inviterId int) error {
|
|||||||
}
|
}
|
||||||
if inviterId != 0 {
|
if inviterId != 0 {
|
||||||
if common.QuotaForInvitee > 0 {
|
if common.QuotaForInvitee > 0 {
|
||||||
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
|
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
|
||||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
|
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
|
||||||
}
|
}
|
||||||
if common.QuotaForInviter > 0 {
|
if common.QuotaForInviter > 0 {
|
||||||
@@ -502,35 +502,35 @@ func IsAdmin(userId int) bool {
|
|||||||
return user.Role >= common.RoleAdminUser
|
return user.Role >= common.RoleAdminUser
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsUserEnabled checks user status from Redis first, falls back to DB if needed
|
//// IsUserEnabled checks user status from Redis first, falls back to DB if needed
|
||||||
func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
|
//func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
|
||||||
defer func() {
|
// defer func() {
|
||||||
// Update Redis cache asynchronously on successful DB read
|
// // Update Redis cache asynchronously on successful DB read
|
||||||
if shouldUpdateRedis(fromDB, err) {
|
// if shouldUpdateRedis(fromDB, err) {
|
||||||
gopool.Go(func() {
|
// gopool.Go(func() {
|
||||||
if err := updateUserStatusCache(id, status); err != nil {
|
// if err := updateUserStatusCache(id, status); err != nil {
|
||||||
common.SysError("failed to update user status cache: " + err.Error())
|
// common.SysError("failed to update user status cache: " + err.Error())
|
||||||
}
|
// }
|
||||||
})
|
// })
|
||||||
}
|
// }
|
||||||
}()
|
// }()
|
||||||
if !fromDB && common.RedisEnabled {
|
// if !fromDB && common.RedisEnabled {
|
||||||
// Try Redis first
|
// // Try Redis first
|
||||||
status, err := getUserStatusCache(id)
|
// status, err := getUserStatusCache(id)
|
||||||
if err == nil {
|
// if err == nil {
|
||||||
return status == common.UserStatusEnabled, nil
|
// return status == common.UserStatusEnabled, nil
|
||||||
}
|
// }
|
||||||
// Don't return error - fall through to DB
|
// // Don't return error - fall through to DB
|
||||||
}
|
// }
|
||||||
fromDB = true
|
// fromDB = true
|
||||||
var user User
|
// var user User
|
||||||
err = DB.Where("id = ?", id).Select("status").Find(&user).Error
|
// err = DB.Where("id = ?", id).Select("status").Find(&user).Error
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return false, err
|
// return false, err
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
return user.Status == common.UserStatusEnabled, nil
|
// return user.Status == common.UserStatusEnabled, nil
|
||||||
}
|
//}
|
||||||
|
|
||||||
func ValidateAccessToken(token string) (user *User) {
|
func ValidateAccessToken(token string) (user *User) {
|
||||||
if token == "" {
|
if token == "" {
|
||||||
@@ -639,7 +639,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
|
|||||||
return common.StrToMap(setting), nil
|
return common.StrToMap(setting), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func IncreaseUserQuota(id int, quota int) (err error) {
|
func IncreaseUserQuota(id int, quota int, db bool) (err error) {
|
||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
@@ -649,7 +649,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
|
|||||||
common.SysError("failed to increase user quota: " + err.Error())
|
common.SysError("failed to increase user quota: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if common.BatchUpdateEnabled {
|
if !db && common.BatchUpdateEnabled {
|
||||||
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
|
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -694,7 +694,7 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if delta > 0 {
|
if delta > 0 {
|
||||||
return IncreaseUserQuota(id, delta)
|
return IncreaseUserQuota(id, delta, false)
|
||||||
} else {
|
} else {
|
||||||
return DecreaseUserQuota(id, -delta)
|
return DecreaseUserQuota(id, -delta)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package model
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"time"
|
"time"
|
||||||
@@ -21,6 +22,15 @@ type UserBase struct {
|
|||||||
Setting string `json:"setting"`
|
Setting string `json:"setting"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (user *UserBase) WriteContext(c *gin.Context) {
|
||||||
|
c.Set(constant.ContextKeyUserGroup, user.Group)
|
||||||
|
c.Set(constant.ContextKeyUserQuota, user.Quota)
|
||||||
|
c.Set(constant.ContextKeyUserStatus, user.Status)
|
||||||
|
c.Set(constant.ContextKeyUserEmail, user.Email)
|
||||||
|
c.Set("username", user.Username)
|
||||||
|
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
|
||||||
|
}
|
||||||
|
|
||||||
func (user *UserBase) GetSetting() map[string]interface{} {
|
func (user *UserBase) GetSetting() map[string]interface{} {
|
||||||
if user.Setting == "" {
|
if user.Setting == "" {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||||
}
|
}
|
||||||
resp, err := doRequest(c, req, info.ToRelayInfo())
|
resp, err := doRequest(c, req, info.RelayInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("do request failed: %w", err)
|
return nil, fmt.Errorf("do request failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,6 +50,9 @@ type RelayInfo struct {
|
|||||||
AudioUsage bool
|
AudioUsage bool
|
||||||
ReasoningEffort string
|
ReasoningEffort string
|
||||||
ChannelSetting map[string]interface{}
|
ChannelSetting map[string]interface{}
|
||||||
|
UserSetting map[string]interface{}
|
||||||
|
UserEmail string
|
||||||
|
UserQuota int
|
||||||
}
|
}
|
||||||
|
|
||||||
// 定义支持流式选项的通道类型
|
// 定义支持流式选项的通道类型
|
||||||
@@ -89,6 +92,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
||||||
|
|
||||||
info := &RelayInfo{
|
info := &RelayInfo{
|
||||||
|
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
|
||||||
|
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
|
||||||
|
UserEmail: c.GetString(constant.ContextKeyUserEmail),
|
||||||
IsFirstResponse: true,
|
IsFirstResponse: true,
|
||||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||||
BaseUrl: c.GetString("base_url"),
|
BaseUrl: c.GetString("base_url"),
|
||||||
@@ -148,19 +154,7 @@ func (info *RelayInfo) SetFirstResponseTime() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TaskRelayInfo struct {
|
type TaskRelayInfo struct {
|
||||||
ChannelType int
|
*RelayInfo
|
||||||
ChannelId int
|
|
||||||
TokenId int
|
|
||||||
UserId int
|
|
||||||
Group string
|
|
||||||
StartTime time.Time
|
|
||||||
ApiType int
|
|
||||||
RelayMode int
|
|
||||||
UpstreamModelName string
|
|
||||||
RequestURLPath string
|
|
||||||
ApiKey string
|
|
||||||
BaseUrl string
|
|
||||||
|
|
||||||
Action string
|
Action string
|
||||||
OriginTaskID string
|
OriginTaskID string
|
||||||
|
|
||||||
@@ -168,48 +162,8 @@ type TaskRelayInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
||||||
channelType := c.GetInt("channel_type")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
group := c.GetString("group")
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
|
||||||
|
|
||||||
info := &TaskRelayInfo{
|
info := &TaskRelayInfo{
|
||||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
RelayInfo: GenRelayInfo(c),
|
||||||
BaseUrl: c.GetString("base_url"),
|
|
||||||
RequestURLPath: c.Request.URL.String(),
|
|
||||||
ChannelType: channelType,
|
|
||||||
ChannelId: channelId,
|
|
||||||
TokenId: tokenId,
|
|
||||||
UserId: userId,
|
|
||||||
Group: group,
|
|
||||||
StartTime: startTime,
|
|
||||||
ApiType: apiType,
|
|
||||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
|
||||||
}
|
|
||||||
if info.BaseUrl == "" {
|
|
||||||
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
|
||||||
}
|
}
|
||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
|
|
||||||
return &RelayInfo{
|
|
||||||
ChannelType: info.ChannelType,
|
|
||||||
ChannelId: info.ChannelId,
|
|
||||||
TokenId: info.TokenId,
|
|
||||||
UserId: info.UserId,
|
|
||||||
Group: info.Group,
|
|
||||||
StartTime: info.StartTime,
|
|
||||||
ApiType: info.ApiType,
|
|
||||||
RelayMode: info.RelayMode,
|
|
||||||
UpstreamModelName: info.UpstreamModelName,
|
|
||||||
RequestURLPath: info.RequestURLPath,
|
|
||||||
ApiKey: info.ApiKey,
|
|
||||||
BaseUrl: info.BaseUrl,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package relay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -192,7 +191,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return &mjResp.Response
|
return &mjResp.Response
|
||||||
}
|
}
|
||||||
defer func(ctx context.Context) {
|
defer func() {
|
||||||
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||||||
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
|
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -208,14 +207,14 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|||||||
other := make(map[string]interface{})
|
other := make(map[string]interface{})
|
||||||
other["model_price"] = modelPrice
|
other["model_price"] = modelPrice
|
||||||
other["group_ratio"] = groupRatio
|
other["group_ratio"] = groupRatio
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
|
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
||||||
quota, logContent, tokenId, userQuota, 0, false, group, other)
|
quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(c.Request.Context())
|
}()
|
||||||
midjResponse := &mjResp.Response
|
midjResponse := &mjResp.Response
|
||||||
midjourneyTask := &model.Midjourney{
|
midjourneyTask := &model.Midjourney{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
@@ -498,7 +497,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
}
|
}
|
||||||
midjResponse := &midjResponseWithStatus.Response
|
midjResponse := &midjResponseWithStatus.Response
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
defer func() {
|
||||||
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||||||
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
|
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -510,14 +509,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
other := make(map[string]interface{})
|
other := make(map[string]interface{})
|
||||||
other["model_price"] = modelPrice
|
other["model_price"] = modelPrice
|
||||||
other["group_ratio"] = groupRatio
|
other["group_ratio"] = groupRatio
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
|
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
||||||
quota, logContent, tokenId, userQuota, 0, false, group, other)
|
quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(c.Request.Context())
|
}()
|
||||||
|
|
||||||
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
|
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
|
||||||
//1-提交成功
|
//1-提交成功
|
||||||
|
|||||||
@@ -248,6 +248,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|||||||
if userQuota-preConsumedQuota < 0 {
|
if userQuota-preConsumedQuota < 0 {
|
||||||
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
|
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
|
relayInfo.UserQuota = userQuota
|
||||||
if userQuota > 100*preConsumedQuota {
|
if userQuota > 100*preConsumedQuota {
|
||||||
// 用户额度充足,判断令牌额度是否充足
|
// 用户额度充足,判断令牌额度是否充足
|
||||||
if !relayInfo.TokenUnlimited {
|
if !relayInfo.TokenUnlimited {
|
||||||
@@ -267,7 +268,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|||||||
}
|
}
|
||||||
|
|
||||||
if preConsumedQuota > 0 {
|
if preConsumedQuota > 0 {
|
||||||
err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package relay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -109,11 +108,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
defer func() {
|
||||||
// release quota
|
// release quota
|
||||||
if relayInfo.ConsumeQuota && taskErr == nil {
|
if relayInfo.ConsumeQuota && taskErr == nil {
|
||||||
|
|
||||||
err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true)
|
err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
}
|
}
|
||||||
@@ -123,13 +122,13 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|||||||
other := make(map[string]interface{})
|
other := make(map[string]interface{})
|
||||||
other["model_price"] = modelPrice
|
other["model_price"] = modelPrice
|
||||||
other["group_ratio"] = groupRatio
|
other["group_ratio"] = groupRatio
|
||||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
|
model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
|
||||||
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
|
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(c.Request.Context())
|
}()
|
||||||
|
|
||||||
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
|
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||||
if taskErr != nil {
|
if taskErr != nil {
|
||||||
|
|||||||
@@ -276,7 +276,7 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu
|
|||||||
if quota > 0 {
|
if quota > 0 {
|
||||||
err = model.DecreaseUserQuota(relayInfo.UserId, quota)
|
err = model.DecreaseUserQuota(relayInfo.UserId, quota)
|
||||||
} else {
|
} else {
|
||||||
err = model.IncreaseUserQuota(relayInfo.UserId, -quota)
|
err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -295,20 +295,16 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu
|
|||||||
|
|
||||||
if sendEmail {
|
if sendEmail {
|
||||||
if (quota + preConsumedQuota) != 0 {
|
if (quota + preConsumedQuota) != 0 {
|
||||||
checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota)
|
checkAndSendQuotaNotify(relayInfo, quota, preConsumedQuota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
|
func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int) {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
userCache, err := model.GetUserCache(userId)
|
userSetting := relayInfo.UserSetting
|
||||||
if err != nil {
|
|
||||||
common.SysError("failed to get user cache: " + err.Error())
|
|
||||||
}
|
|
||||||
userSetting := userCache.GetSetting()
|
|
||||||
threshold := common.QuotaRemindThreshold
|
threshold := common.QuotaRemindThreshold
|
||||||
if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
|
if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
|
||||||
threshold = int(userCustomThreshold.(float64))
|
threshold = int(userCustomThreshold.(float64))
|
||||||
@@ -317,16 +313,16 @@ func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
|
|||||||
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
|
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
|
||||||
quotaTooLow := false
|
quotaTooLow := false
|
||||||
consumeQuota := quota + preConsumedQuota
|
consumeQuota := quota + preConsumedQuota
|
||||||
if userCache.Quota-consumeQuota < threshold {
|
if relayInfo.UserQuota-consumeQuota < threshold {
|
||||||
quotaTooLow = true
|
quotaTooLow = true
|
||||||
}
|
}
|
||||||
if quotaTooLow {
|
if quotaTooLow {
|
||||||
prompt := "您的额度即将用尽"
|
prompt := "您的额度即将用尽"
|
||||||
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
|
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
|
||||||
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
|
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
|
||||||
err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink}))
|
err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error()))
|
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -11,47 +11,45 @@ import (
|
|||||||
|
|
||||||
func NotifyRootUser(t string, subject string, content string) {
|
func NotifyRootUser(t string, subject string, content string) {
|
||||||
user := model.GetRootUser().ToBaseUser()
|
user := model.GetRootUser().ToBaseUser()
|
||||||
_ = NotifyUser(user, dto.NewNotify(t, subject, content, nil))
|
_ = NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func NotifyUser(user *model.UserBase, data dto.Notify) error {
|
func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
|
||||||
userSetting := user.GetSetting()
|
|
||||||
notifyType, ok := userSetting[constant.UserSettingNotifyType]
|
notifyType, ok := userSetting[constant.UserSettingNotifyType]
|
||||||
if !ok {
|
if !ok {
|
||||||
notifyType = constant.NotifyTypeEmail
|
notifyType = constant.NotifyTypeEmail
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check notification limit
|
// Check notification limit
|
||||||
canSend, err := CheckNotificationLimit(user.Id, data.Type)
|
canSend, err := CheckNotificationLimit(userId, data.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
|
common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !canSend {
|
if !canSend {
|
||||||
return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType)
|
return fmt.Errorf("notification limit exceeded for user %d with type %s", userId, notifyType)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch notifyType {
|
switch notifyType {
|
||||||
case constant.NotifyTypeEmail:
|
case constant.NotifyTypeEmail:
|
||||||
userEmail := user.Email
|
|
||||||
// check setting email
|
// check setting email
|
||||||
if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
|
if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
|
||||||
userEmail = settingEmail.(string)
|
userEmail = settingEmail.(string)
|
||||||
}
|
}
|
||||||
if userEmail == "" {
|
if userEmail == "" {
|
||||||
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id))
|
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return sendEmailNotify(userEmail, data)
|
return sendEmailNotify(userEmail, data)
|
||||||
case constant.NotifyTypeWebhook:
|
case constant.NotifyTypeWebhook:
|
||||||
webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
|
webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
|
||||||
if !ok {
|
if !ok {
|
||||||
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id))
|
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
webhookURLStr, ok := webhookURL.(string)
|
webhookURLStr, ok := webhookURL.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
common.SysError(fmt.Sprintf("user %d webhook url is not string type", user.Id))
|
common.SysError(fmt.Sprintf("user %d webhook url is not string type", userId))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user