refactor: user cache logic

This commit is contained in:
CalciumIon
2024-12-29 16:50:26 +08:00
parent a1b864bc5e
commit ed435e5c8f
20 changed files with 548 additions and 225 deletions

View File

@@ -2,9 +2,11 @@ package common
import ( import (
"context" "context"
"github.com/go-redis/redis/v8" "fmt"
"os" "os"
"time" "time"
"github.com/go-redis/redis/v8"
) )
var RDB *redis.Client var RDB *redis.Client
@@ -104,3 +106,21 @@ func RedisDecrease(key string, value int64) error {
} }
return nil return nil
} }
// RedisIncr Add this function to handle atomic increments
func RedisIncr(key string, delta int) error {
ctx := context.Background()
// 检查键是否存在
exists, err := RDB.Exists(ctx, key).Result()
if err != nil {
return err
}
if exists == 0 {
return fmt.Errorf("key does not exist") // 键不存在,返回错误
}
// 键存在执行INCRBY操作
result := RDB.IncrBy(ctx, key, int64(delta))
return result.Err()
}

18
constant/cache_key.go Normal file
View File

@@ -0,0 +1,18 @@
package constant
import "one-api/common"
var (
TokenCacheSeconds = common.SyncFrequency
UserId2GroupCacheSeconds = common.SyncFrequency
UserId2QuotaCacheSeconds = common.SyncFrequency
UserId2StatusCacheSeconds = common.SyncFrequency
)
const (
// Cache keys
UserGroupKeyFmt = "user_group:%d"
UserQuotaKeyFmt = "user_quota:%d"
UserEnabledKeyFmt = "user_enabled:%d"
UserUsernameKeyFmt = "user_name:%d"
)

View File

@@ -21,7 +21,7 @@ func GetSubscription(c *gin.Context) {
usedQuota = token.UsedQuota usedQuota = token.UsedQuota
} else { } else {
userId := c.GetInt("id") userId := c.GetInt("id")
remainQuota, err = model.GetUserQuota(userId) remainQuota, err = model.GetUserQuota(userId, false)
usedQuota, err = model.GetUserUsedQuota(userId) usedQuota, err = model.GetUserUsedQuota(userId)
} }
if expiredTime <= 0 { if expiredTime <= 0 {

View File

@@ -23,7 +23,7 @@ func GetUserGroups(c *gin.Context) {
usableGroups := make(map[string]string) usableGroups := make(map[string]string)
userGroup := "" userGroup := ""
userId := c.GetInt("id") userId := c.GetInt("id")
userGroup, _ = model.CacheGetUserGroup(userId) userGroup, _ = model.GetUserGroup(userId, false)
for groupName, _ := range setting.GetGroupRatioCopy() { for groupName, _ := range setting.GetGroupRatioCopy() {
// UserUsableGroups contains the groups that the user can use // UserUsableGroups contains the groups that the user can use
userUsableGroups := setting.GetUserUsableGroups(userGroup) userUsableGroups := setting.GetUserUsableGroups(userGroup)

View File

@@ -166,7 +166,7 @@ func ListModels(c *gin.Context) {
} }
} else { } else {
userId := c.GetInt("id") userId := c.GetInt("id")
userGroup, err := model.GetUserGroup(userId) userGroup, err := model.GetUserGroup(userId, true)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,

View File

@@ -153,7 +153,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
task.Progress = "100%" task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId) //err = model.CacheUpdateUserQuota(task.UserId) ?
if err != nil { if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error()) common.LogError(ctx, "error update user quota cache: "+err.Error())
} else { } else {

View File

@@ -75,7 +75,7 @@ func RequestEpay(c *gin.Context) {
} }
id := c.GetInt("id") id := c.GetInt("id")
group, err := model.CacheGetUserGroup(id) group, err := model.GetUserGroup(id, true)
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"}) c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
return return
@@ -236,7 +236,7 @@ func RequestAmount(c *gin.Context) {
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
group, err := model.CacheGetUserGroup(id) group, err := model.GetUserGroup(id, true)
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"}) c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
return return

View File

@@ -201,7 +201,7 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return return
} }
userEnabled, err := model.CacheIsUserEnabled(token.UserId) userEnabled, err := model.IsUserEnabled(token.UserId, false)
if err != nil { if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error()) abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return return

View File

@@ -40,7 +40,7 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error()) abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return return
} }
userGroup, _ := model.CacheGetUserGroup(userId) userGroup, _ := model.GetUserGroup(userId, false)
tokenGroup := c.GetString("token_group") tokenGroup := c.GetString("token_group")
if tokenGroup != "" { if tokenGroup != "" {
// check common.UserUsableGroups[userGroup] // check common.UserUsableGroups[userGroup]

View File

@@ -6,20 +6,13 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"one-api/common" "one-api/common"
"one-api/constant"
"sort" "sort"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
) )
var (
TokenCacheSeconds = common.SyncFrequency
UserId2GroupCacheSeconds = common.SyncFrequency
UserId2QuotaCacheSeconds = common.SyncFrequency
UserId2StatusCacheSeconds = common.SyncFrequency
)
// 仅用于定时同步缓存 // 仅用于定时同步缓存
var token2UserId = make(map[string]int) var token2UserId = make(map[string]int)
var token2UserIdLock sync.RWMutex var token2UserIdLock sync.RWMutex
@@ -29,7 +22,7 @@ func cacheSetToken(token *Token) error {
if err != nil { if err != nil {
return err return err
} }
err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(constant.TokenCacheSeconds)*time.Second)
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error())) common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error()))
return err return err
@@ -57,7 +50,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
return token, nil return token, nil
} }
// 如果缓存中存在,则续期时间 // 如果缓存中存在,则续期时间
err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(TokenCacheSeconds)*time.Second) err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(constant.TokenCacheSeconds)*time.Second)
err = json.Unmarshal([]byte(tokenObjectString), &token) err = json.Unmarshal([]byte(tokenObjectString), &token)
return token, err return token, err
} }
@@ -101,109 +94,105 @@ func SyncTokenCache(frequency int) {
} }
} }
func CacheGetUserGroup(id int) (group string, err error) { //func CacheGetUserGroup(id int) (group string, err error) {
if !common.RedisEnabled { // if !common.RedisEnabled {
return GetUserGroup(id) // return GetUserGroup(id)
} // }
group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id)) // group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
if err != nil { // if err != nil {
group, err = GetUserGroup(id) // group, err = GetUserGroup(id)
if err != nil { // if err != nil {
return "", err // return "", err
} // }
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) // err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
if err != nil { // if err != nil {
common.SysError("Redis set user group error: " + err.Error()) // common.SysError("Redis set user group error: " + err.Error())
} // }
} // }
return group, err // return group, err
} //}
//
func CacheGetUsername(id int) (username string, err error) { //func CacheGetUsername(id int) (username string, err error) {
if !common.RedisEnabled { // if !common.RedisEnabled {
return GetUsernameById(id) // return GetUsernameById(id)
} // }
username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id)) // username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
if err != nil { // if err != nil {
username, err = GetUsernameById(id) // username, err = GetUsernameById(id)
if err != nil { // if err != nil {
return "", err // return "", err
} // }
err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(UserId2GroupCacheSeconds)*time.Second) // err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
if err != nil { // if err != nil {
common.SysError("Redis set user group error: " + err.Error()) // common.SysError("Redis set user group error: " + err.Error())
} // }
} // }
return username, err // return username, err
} //}
//
func CacheGetUserQuota(id int) (quota int, err error) { //func CacheGetUserQuota(id int) (quota int, err error) {
if !common.RedisEnabled { // if !common.RedisEnabled {
return GetUserQuota(id) // return GetUserQuota(id)
} // }
quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) // quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
if err != nil { // if err != nil {
quota, err = GetUserQuota(id) // quota, err = GetUserQuota(id)
if err != nil { // if err != nil {
return 0, err // return 0, err
} // }
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) // return quota, nil
if err != nil { // }
common.SysError("Redis set user quota error: " + err.Error()) // quota, err = strconv.Atoi(quotaString)
} // return quota, nil
return quota, err //}
} //
quota, err = strconv.Atoi(quotaString) //func CacheUpdateUserQuota(id int) error {
return quota, err // if !common.RedisEnabled {
} // return nil
// }
func CacheUpdateUserQuota(id int) error { // quota, err := GetUserQuota(id)
if !common.RedisEnabled { // if err != nil {
return nil // return err
} // }
quota, err := GetUserQuota(id) // return cacheSetUserQuota(id, quota)
if err != nil { //}
return err //
} //func cacheSetUserQuota(id int, quota int) error {
return cacheSetUserQuota(id, quota) // err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second)
} // return err
//}
func cacheSetUserQuota(id int, quota int) error { //
err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) //func CacheDecreaseUserQuota(id int, quota int) error {
return err // if !common.RedisEnabled {
} // return nil
// }
func CacheDecreaseUserQuota(id int, quota int) error { // err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
if !common.RedisEnabled { // return err
return nil //}
} //
err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota)) //func CacheIsUserEnabled(userId int) (bool, error) {
return err // if !common.RedisEnabled {
} // return IsUserEnabled(userId)
// }
func CacheIsUserEnabled(userId int) (bool, error) { // enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
if !common.RedisEnabled { // if err == nil {
return IsUserEnabled(userId) // return enabled == "1", nil
} // }
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) //
if err == nil { // userEnabled, err := IsUserEnabled(userId)
return enabled == "1", nil // if err != nil {
} // return false, err
// }
userEnabled, err := IsUserEnabled(userId) // enabled = "0"
if err != nil { // if userEnabled {
return false, err // enabled = "1"
} // }
enabled = "0" // err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(constant.UserId2StatusCacheSeconds)*time.Second)
if userEnabled { // if err != nil {
enabled = "1" // common.SysError("Redis set user enabled error: " + err.Error())
} // }
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) // return userEnabled, err
if err != nil { //}
common.SysError("Redis set user enabled error: " + err.Error())
}
return userEnabled, err
}
var group2model2channels map[string]map[string][]*Channel var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel var channelsIDM map[int]*Channel
@@ -344,12 +333,12 @@ func CacheGetChannel(id int) (*Channel, error) {
} }
func CacheUpdateChannelStatus(id int, status int) { func CacheUpdateChannelStatus(id int, status int) {
if (!common.MemoryCacheEnabled) { if !common.MemoryCacheEnabled {
return return
} }
channelSyncLock.Lock() channelSyncLock.Lock()
defer channelSyncLock.Unlock() defer channelSyncLock.Unlock()
if channel, ok := channelsIDM[id]; ok { if channel, ok := channelsIDM[id]; ok {
channel.Status = status channel.Status = status
} }
} }

View File

@@ -81,7 +81,7 @@ func RecordLog(userId int, logType int, content string) {
if logType == LogTypeConsume && !common.LogConsumeEnabled { if logType == LogTypeConsume && !common.LogConsumeEnabled {
return return
} }
username, _ := CacheGetUsername(userId) username, _ := GetUsernameById(userId, false)
log := &Log{ log := &Log{
UserId: userId, UserId: userId,
Username: username, Username: username,
@@ -102,7 +102,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
if !common.LogConsumeEnabled { if !common.LogConsumeEnabled {
return return
} }
username, _ := CacheGetUsername(userId) username, _ := GetUsernameById(userId, false)
otherStr := common.MapToJsonStr(other) otherStr := common.MapToJsonStr(other)
log := &Log{ log := &Log{
UserId: userId, UserId: userId,

View File

@@ -258,37 +258,29 @@ func decreaseTokenQuota(id int, quota int) (err error) {
return err return err
} }
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) (userQuota int, err error) { func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
if quota < 0 { if quota < 0 {
return 0, errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if !relayInfo.IsPlayground { if !relayInfo.IsPlayground {
token, err := GetTokenById(relayInfo.TokenId) token, err := GetTokenById(relayInfo.TokenId)
if err != nil { if err != nil {
return 0, err return err
} }
if !token.UnlimitedQuota && token.RemainQuota < quota { if !token.UnlimitedQuota && token.RemainQuota < quota {
return 0, errors.New("令牌额度不足") return errors.New("令牌额度不足")
} }
} }
userQuota, err = GetUserQuota(relayInfo.UserId)
if err != nil {
return 0, err
}
if userQuota < quota {
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
}
if !relayInfo.IsPlayground { if !relayInfo.IsPlayground {
err = DecreaseTokenQuota(relayInfo.TokenId, quota) err := DecreaseTokenQuota(relayInfo.TokenId, quota)
if err != nil { if err != nil {
return 0, err return err
} }
} }
err = DecreaseUserQuota(relayInfo.UserId, quota) return nil
return userQuota - quota, err
} }
func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) { func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
if quota > 0 { if quota > 0 {
err = DecreaseUserQuota(relayInfo.UserId, quota) err = DecreaseUserQuota(relayInfo.UserId, quota)

View File

@@ -6,7 +6,8 @@ import (
"one-api/common" "one-api/common"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -107,7 +108,7 @@ func SearchUsers(keyword string, group string) ([]*User, error) {
return users, err return users, err
} }
} }
err = nil err = nil
query := DB.Unscoped().Omit("password") query := DB.Unscoped().Omit("password")
@@ -251,14 +252,12 @@ func (user *User) Update(updatePassword bool) error {
} }
newUser := *user newUser := *user
DB.First(&user, user.Id) DB.First(&user, user.Id)
err = DB.Model(user).Updates(newUser).Error if err = DB.Model(user).Updates(newUser).Error; err != nil {
if err == nil { return err
if common.RedisEnabled {
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
_ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
}
} }
return err
// 更新缓存
return updateUserCache(user)
} }
func (user *User) Edit(updatePassword bool) error { func (user *User) Edit(updatePassword bool) error {
@@ -269,6 +268,7 @@ func (user *User) Edit(updatePassword bool) error {
return err return err
} }
} }
newUser := *user newUser := *user
updates := map[string]interface{}{ updates := map[string]interface{}{
"username": newUser.Username, "username": newUser.Username,
@@ -279,23 +279,26 @@ func (user *User) Edit(updatePassword bool) error {
if updatePassword { if updatePassword {
updates["password"] = newUser.Password updates["password"] = newUser.Password
} }
DB.First(&user, user.Id) DB.First(&user, user.Id)
err = DB.Model(user).Updates(updates).Error if err = DB.Model(user).Updates(updates).Error; err != nil {
if err == nil { return err
if common.RedisEnabled {
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
_ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
}
} }
return err
// 更新缓存
return updateUserCache(user)
} }
func (user *User) Delete() error { func (user *User) Delete() error {
if user.Id == 0 { if user.Id == 0 {
return errors.New("id 为空!") return errors.New("id 为空!")
} }
err := DB.Delete(user).Error if err := DB.Delete(user).Error; err != nil {
return err return err
}
// 清除缓存
return invalidateUserCache(user.Id)
} }
func (user *User) HardDelete() error { func (user *User) HardDelete() error {
@@ -409,15 +412,33 @@ func IsAdmin(userId int) bool {
return user.Role >= common.RoleAdminUser return user.Role >= common.RoleAdminUser
} }
func IsUserEnabled(userId int) (bool, error) { // IsUserEnabled checks user status from Redis first, falls back to DB if needed
if userId == 0 { func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
return false, errors.New("user id is empty") defer func() {
// Update Redis cache asynchronously on successful DB read
if common.RedisEnabled {
gopool.Go(func() {
if err := updateUserStatusCache(id, status); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
// Try Redis first
status, err := getUserStatusCache(id)
if err == nil {
return status == common.UserStatusEnabled, nil
}
// Don't return error - fall through to DB
} }
var user User var user User
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error err = DB.Where("id = ?", id).Select("status").Find(&user).Error
if err != nil { if err != nil {
return false, err return false, err
} }
return user.Status == common.UserStatusEnabled, nil return user.Status == common.UserStatusEnabled, nil
} }
@@ -433,14 +454,33 @@ func ValidateAccessToken(token string) (user *User) {
return nil return nil
} }
func GetUserQuota(id int) (quota int, err error) { // GetUserQuota gets quota from Redis first, falls back to DB if needed
func GetUserQuota(id int, fromDB bool) (quota int, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if common.RedisEnabled && err == nil {
gopool.Go(func() {
if err := updateUserQuotaCache(id, quota); err != nil {
common.SysError("failed to update user quota cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
quota, err := getUserQuotaCache(id)
if err == nil {
return quota, nil
}
// Don't return error - fall through to DB
//common.SysError("failed to get user quota from cache: " + err.Error())
}
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
if err != nil { if err != nil {
if common.RedisEnabled { return 0, err
go cacheSetUserQuota(id, quota)
}
} }
return quota, err
return quota, nil
} }
func GetUserUsedQuota(id int) (quota int, err error) { func GetUserUsedQuota(id int) (quota int, err error) {
@@ -453,20 +493,49 @@ func GetUserEmail(id int) (email string, err error) {
return email, err return email, err
} }
func GetUserGroup(id int) (group string, err error) { // GetUserGroup gets group from Redis first, falls back to DB if needed
func GetUserGroup(id int, fromDB bool) (group string, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if common.RedisEnabled && err == nil {
gopool.Go(func() {
if err := updateUserGroupCache(id, group); err != nil {
common.SysError("failed to update user group cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
group, err := getUserGroupCache(id)
if err == nil {
return group, nil
}
// Don't return error - fall through to DB
}
groupCol := "`group`" groupCol := "`group`"
if common.UsingPostgreSQL { if common.UsingPostgreSQL {
groupCol = `"group"` groupCol = `"group"`
} }
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
return group, err if err != nil {
return "", err
}
return group, nil
} }
func IncreaseUserQuota(id int, quota int) (err error) { func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
gopool.Go(func() {
err := cacheIncrUserQuota(id, quota)
if err != nil {
common.SysError("failed to increase user quota: " + err.Error())
}
})
if common.BatchUpdateEnabled { if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota) addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil return nil
@@ -476,6 +545,9 @@ func IncreaseUserQuota(id int, quota int) (err error) {
func increaseUserQuota(id int, quota int) (err error) { func increaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
if err != nil {
return err
}
return err return err
} }
@@ -483,6 +555,12 @@ func DecreaseUserQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
gopool.Go(func() {
err := cacheDecrUserQuota(id, quota)
if err != nil {
common.SysError("failed to decrease user quota: " + err.Error())
}
})
if common.BatchUpdateEnabled { if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, -quota) addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
return nil return nil
@@ -492,9 +570,23 @@ func DecreaseUserQuota(id int, quota int) (err error) {
func decreaseUserQuota(id int, quota int) (err error) { func decreaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
if err != nil {
return err
}
return err return err
} }
func DeltaUpdateUserQuota(id int, delta int) (err error) {
if delta == 0 {
return nil
}
if delta > 0 {
return IncreaseUserQuota(id, delta)
} else {
return DecreaseUserQuota(id, -delta)
}
}
func GetRootUserEmail() (email string) { func GetRootUserEmail() (email string) {
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
return email return email
@@ -518,7 +610,13 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
).Error ).Error
if err != nil { if err != nil {
common.SysError("failed to update user used quota and request count: " + err.Error()) common.SysError("failed to update user used quota and request count: " + err.Error())
return
} }
//// 更新缓存
//if err := invalidateUserCache(id); err != nil {
// common.SysError("failed to invalidate user cache: " + err.Error())
//}
} }
func updateUserUsedQuota(id int, quota int) { func updateUserUsedQuota(id int, quota int) {
@@ -539,9 +637,32 @@ func updateUserRequestCount(id int, count int) {
} }
} }
func GetUsernameById(id int) (username string, err error) { // GetUsernameById gets username from Redis first, falls back to DB if needed
func GetUsernameById(id int, fromDB bool) (username string, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if common.RedisEnabled && err == nil {
gopool.Go(func() {
if err := updateUserNameCache(id, username); err != nil {
common.SysError("failed to update user name cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
username, err := getUserNameCache(id)
if err == nil {
return username, nil
}
// Don't return error - fall through to DB
}
err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error
return username, err if err != nil {
return "", err
}
return username, nil
} }
func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool { func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {

206
model/user_cache.go Normal file
View File

@@ -0,0 +1,206 @@
package model
import (
"fmt"
"one-api/common"
"one-api/constant"
"strconv"
"time"
)
// Change UserCache struct to userCache
type userCache struct {
Id int `json:"id"`
Group string `json:"group"`
Quota int `json:"quota"`
Status int `json:"status"`
Role int `json:"role"`
Username string `json:"username"`
}
// Rename all exported functions to private ones
// invalidateUserCache clears all user related cache
func invalidateUserCache(userId int) error {
if !common.RedisEnabled {
return nil
}
keys := []string{
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
}
for _, key := range keys {
if err := common.RedisDel(key); err != nil {
return fmt.Errorf("failed to delete cache key %s: %w", key, err)
}
}
return nil
}
// updateUserGroupCache updates user group cache
func updateUserGroupCache(userId int, group string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
group,
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserQuotaCache updates user quota cache
func updateUserQuotaCache(userId int, quota int) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
fmt.Sprintf("%d", quota),
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserStatusCache updates user status cache
func updateUserStatusCache(userId int, userEnabled bool) error {
if !common.RedisEnabled {
return nil
}
enabled := "0"
if userEnabled {
enabled = "1"
}
return common.RedisSet(
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
enabled,
time.Duration(constant.UserId2StatusCacheSeconds)*time.Second,
)
}
// updateUserNameCache updates username cache
func updateUserNameCache(userId int, username string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
username,
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserCache updates all user cache fields
func updateUserCache(user *User) error {
if !common.RedisEnabled {
return nil
}
if err := updateUserGroupCache(user.Id, user.Group); err != nil {
return fmt.Errorf("update group cache: %w", err)
}
if err := updateUserQuotaCache(user.Id, user.Quota); err != nil {
return fmt.Errorf("update quota cache: %w", err)
}
if err := updateUserStatusCache(user.Id, user.Status == common.UserStatusEnabled); err != nil {
return fmt.Errorf("update status cache: %w", err)
}
if err := updateUserNameCache(user.Id, user.Username); err != nil {
return fmt.Errorf("update username cache: %w", err)
}
return nil
}
// getUserGroupCache gets user group from cache
func getUserGroupCache(userId int) (string, error) {
if !common.RedisEnabled {
return "", nil
}
return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId))
}
// getUserQuotaCache gets user quota from cache
func getUserQuotaCache(userId int) (int, error) {
if !common.RedisEnabled {
return 0, nil
}
quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId))
if err != nil {
return 0, err
}
return strconv.Atoi(quotaStr)
}
// getUserStatusCache gets user status from cache
func getUserStatusCache(userId int) (int, error) {
if !common.RedisEnabled {
return 0, nil
}
statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId))
if err != nil {
return 0, err
}
return strconv.Atoi(statusStr)
}
// getUserNameCache gets username from cache
func getUserNameCache(userId int) (string, error) {
if !common.RedisEnabled {
return "", nil
}
return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId))
}
// getUserCache gets complete user cache
func getUserCache(userId int) (*userCache, error) {
if !common.RedisEnabled {
return nil, nil
}
group, err := getUserGroupCache(userId)
if err != nil {
return nil, fmt.Errorf("get group cache: %w", err)
}
quota, err := getUserQuotaCache(userId)
if err != nil {
return nil, fmt.Errorf("get quota cache: %w", err)
}
status, err := getUserStatusCache(userId)
if err != nil {
return nil, fmt.Errorf("get status cache: %w", err)
}
username, err := getUserNameCache(userId)
if err != nil {
return nil, fmt.Errorf("get username cache: %w", err)
}
return &userCache{
Id: userId,
Group: group,
Quota: quota,
Status: status,
Username: username,
}, nil
}
// Add atomic quota operations
func cacheIncrUserQuota(userId int, delta int) error {
if !common.RedisEnabled {
return nil
}
key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId)
return common.RedisIncr(key, delta)
}
func cacheDecrUserQuota(userId int, delta int) error {
return cacheIncrUserQuota(userId, -delta)
}

View File

@@ -77,24 +77,20 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
groupRatio := setting.GetGroupRatio(relayInfo.Group) groupRatio := setting.GetGroupRatio(relayInfo.Group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio) preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
} }
if userQuota-preConsumedQuota < 0 { if userQuota-preConsumedQuota < 0 {
return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("audio pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("audio pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
} }
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota { if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota // in this case, we do not pre-consume quota
// because the user has enough quota // because the user has enough quota
preConsumedQuota = 0 preConsumedQuota = 0
} }
if preConsumedQuota > 0 { if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota) err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
} }

View File

@@ -100,7 +100,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
} }
groupRatio := setting.GetGroupRatio(relayInfo.Group) groupRatio := setting.GetGroupRatio(relayInfo.Group)
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
sizeRatio := 1.0 sizeRatio := 1.0
// Size // Size

View File

@@ -170,7 +170,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
} }
groupRatio := setting.GetGroupRatio(group) groupRatio := setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.GetUserQuota(userId, false)
if err != nil { if err != nil {
return &dto.MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
@@ -194,11 +194,11 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
} }
defer func(ctx context.Context) { defer func(ctx context.Context) {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true) err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())
} }
err = model.CacheUpdateUserQuota(userId) //err = model.CacheUpdateUserQuota(userId)
if err != nil { if err != nil {
common.SysError("error update user quota cache: " + err.Error()) common.SysError("error update user quota cache: " + err.Error())
} }
@@ -476,7 +476,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
} }
groupRatio := setting.GetGroupRatio(group) groupRatio := setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.GetUserQuota(userId, false)
if err != nil { if err != nil {
return &dto.MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
@@ -500,14 +500,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
defer func(ctx context.Context) { defer func(ctx context.Context) {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 { if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true) err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())
} }
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 { if quota != 0 {
tokenName := c.GetString("token_name") tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %sID %s", modelPrice, groupRatio, midjRequest.Action, midjResponse.Result) logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %sID %s", modelPrice, groupRatio, midjRequest.Action, midjResponse.Result)

View File

@@ -262,7 +262,7 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom
// 预扣费并返回用户剩余配额 // 预扣费并返回用户剩余配额
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) { func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) {
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil { if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
} }
@@ -272,10 +272,6 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if userQuota-preConsumedQuota < 0 { if userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest) return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
} }
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota { if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足 // 用户额度充足,判断令牌额度是否充足
if !relayInfo.TokenUnlimited { if !relayInfo.TokenUnlimited {
@@ -293,8 +289,13 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota)) common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
} }
} }
if preConsumedQuota > 0 { if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota) err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil { if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
} }
@@ -307,7 +308,7 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us
go func() { go func() {
relayInfoCopy := *relayInfo relayInfoCopy := *relayInfo
err := model.PostConsumeTokenQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false) err := model.PostConsumeQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
if err != nil { if err != nil {
common.SysError("error return pre-consumed quota: " + err.Error()) common.SysError("error return pre-consumed quota: " + err.Error())
} }
@@ -365,15 +366,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
//} //}
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 { if quotaDelta != 0 {
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil { if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error()) common.LogError(ctx, "error consuming token remain quota: "+err.Error())
} }
} }
err := model.CacheUpdateUserQuota(relayInfo.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
} }

View File

@@ -51,7 +51,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
// 预扣 // 预扣
groupRatio := setting.GetGroupRatio(relayInfo.Group) groupRatio := setting.GetGroupRatio(relayInfo.Group)
ratio := modelPrice * groupRatio ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil { if err != nil {
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
return return
@@ -113,14 +113,10 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
// release quota // release quota
if relayInfo.ConsumeQuota && taskErr == nil { if relayInfo.ConsumeQuota && taskErr == nil {
err := model.PostConsumeTokenQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true) err := model.PostConsumeQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())
} }
err = model.CacheUpdateUserQuota(relayInfo.UserId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 { if quota != 0 {
tokenName := c.GetString("token_name") tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action) logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action)

View File

@@ -18,7 +18,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
if relayInfo.UsePrice { if relayInfo.UsePrice {
return nil return nil
} }
userQuota, err := model.GetUserQuota(relayInfo.UserId) userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil { if err != nil {
return err return err
} }
@@ -58,15 +58,11 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota)) return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
} }
err = model.PostConsumeTokenQuota(relayInfo, 0, quota, 0, false) err = model.PostConsumeQuota(relayInfo, 0, quota, 0, false)
if err != nil { if err != nil {
return err return err
} }
common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota)) common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
err = model.CacheUpdateUserQuota(relayInfo.UserId)
if err != nil {
return err
}
return nil return nil
} }
@@ -120,7 +116,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
//} //}
//quotaDelta := quota - preConsumedQuota //quotaDelta := quota - preConsumedQuota
//if quotaDelta != 0 { //if quotaDelta != 0 {
// err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) // err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
// if err != nil { // if err != nil {
// common.LogError(ctx, "error consuming token remain quota: "+err.Error()) // common.LogError(ctx, "error consuming token remain quota: "+err.Error())
// } // }
@@ -190,15 +186,11 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
} else { } else {
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 { if quotaDelta != 0 {
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil { if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error()) common.LogError(ctx, "error consuming token remain quota: "+err.Error())
} }
} }
err := model.CacheUpdateUserQuota(relayInfo.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
} }