refactor: Enhance user context and quota management
- Add new context keys for user-related information - Modify user cache and authentication middleware to populate context - Refactor quota and notification services to use context-based user data - Remove redundant database queries by leveraging context information - Update various components to use new context-based user retrieval methods
This commit is contained in:
@@ -130,7 +130,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo,
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
resp, err := doRequest(c, req, info.ToRelayInfo())
|
||||
resp, err := doRequest(c, req, info.RelayInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do request failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -50,6 +50,9 @@ type RelayInfo struct {
|
||||
AudioUsage bool
|
||||
ReasoningEffort string
|
||||
ChannelSetting map[string]interface{}
|
||||
UserSetting map[string]interface{}
|
||||
UserEmail string
|
||||
UserQuota int
|
||||
}
|
||||
|
||||
// 定义支持流式选项的通道类型
|
||||
@@ -89,6 +92,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
||||
|
||||
info := &RelayInfo{
|
||||
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
|
||||
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
|
||||
UserEmail: c.GetString(constant.ContextKeyUserEmail),
|
||||
IsFirstResponse: true,
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
@@ -148,19 +154,7 @@ func (info *RelayInfo) SetFirstResponseTime() {
|
||||
}
|
||||
|
||||
type TaskRelayInfo struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
UserId int
|
||||
Group string
|
||||
StartTime time.Time
|
||||
ApiType int
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
RequestURLPath string
|
||||
ApiKey string
|
||||
BaseUrl string
|
||||
|
||||
*RelayInfo
|
||||
Action string
|
||||
OriginTaskID string
|
||||
|
||||
@@ -168,48 +162,8 @@ type TaskRelayInfo struct {
|
||||
}
|
||||
|
||||
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
||||
channelType := c.GetInt("channel_type")
|
||||
channelId := c.GetInt("channel_id")
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
startTime := time.Now()
|
||||
|
||||
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
||||
|
||||
info := &TaskRelayInfo{
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
ChannelId: channelId,
|
||||
TokenId: tokenId,
|
||||
UserId: userId,
|
||||
Group: group,
|
||||
StartTime: startTime,
|
||||
ApiType: apiType,
|
||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
}
|
||||
if info.BaseUrl == "" {
|
||||
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
||||
RelayInfo: GenRelayInfo(c),
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
|
||||
return &RelayInfo{
|
||||
ChannelType: info.ChannelType,
|
||||
ChannelId: info.ChannelId,
|
||||
TokenId: info.TokenId,
|
||||
UserId: info.UserId,
|
||||
Group: info.Group,
|
||||
StartTime: info.StartTime,
|
||||
ApiType: info.ApiType,
|
||||
RelayMode: info.RelayMode,
|
||||
UpstreamModelName: info.UpstreamModelName,
|
||||
RequestURLPath: info.RequestURLPath,
|
||||
ApiKey: info.ApiKey,
|
||||
BaseUrl: info.BaseUrl,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -192,7 +191,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
if err != nil {
|
||||
return &mjResp.Response
|
||||
}
|
||||
defer func(ctx context.Context) {
|
||||
defer func() {
|
||||
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||||
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
|
||||
if err != nil {
|
||||
@@ -208,14 +207,14 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
|
||||
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
||||
quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}()
|
||||
midjResponse := &mjResp.Response
|
||||
midjourneyTask := &model.Midjourney{
|
||||
UserId: userId,
|
||||
@@ -498,7 +497,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
}
|
||||
midjResponse := &midjResponseWithStatus.Response
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
defer func() {
|
||||
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||||
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
|
||||
if err != nil {
|
||||
@@ -510,14 +509,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
|
||||
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
||||
quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}()
|
||||
|
||||
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
|
||||
//1-提交成功
|
||||
|
||||
@@ -248,6 +248,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
relayInfo.UserQuota = userQuota
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// 用户额度充足,判断令牌额度是否充足
|
||||
if !relayInfo.TokenUnlimited {
|
||||
@@ -267,7 +268,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
}
|
||||
|
||||
if preConsumedQuota > 0 {
|
||||
err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -109,11 +108,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
return
|
||||
}
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
defer func() {
|
||||
// release quota
|
||||
if relayInfo.ConsumeQuota && taskErr == nil {
|
||||
|
||||
err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true)
|
||||
err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
@@ -123,13 +122,13 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
|
||||
model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
|
||||
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}()
|
||||
|
||||
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
if taskErr != nil {
|
||||
|
||||
Reference in New Issue
Block a user