diff --git a/common/redis.go b/common/redis.go index 6a352356..cc8035af 100644 --- a/common/redis.go +++ b/common/redis.go @@ -2,9 +2,11 @@ package common import ( "context" - "github.com/go-redis/redis/v8" + "fmt" "os" "time" + + "github.com/go-redis/redis/v8" ) var RDB *redis.Client @@ -104,3 +106,21 @@ func RedisDecrease(key string, value int64) error { } return nil } + +// RedisIncr Add this function to handle atomic increments +func RedisIncr(key string, delta int) error { + ctx := context.Background() + + // 检查键是否存在 + exists, err := RDB.Exists(ctx, key).Result() + if err != nil { + return err + } + if exists == 0 { + return fmt.Errorf("key does not exist") // 键不存在,返回错误 + } + + // 键存在,执行INCRBY操作 + result := RDB.IncrBy(ctx, key, int64(delta)) + return result.Err() +} diff --git a/constant/cache_key.go b/constant/cache_key.go new file mode 100644 index 00000000..d5a2c5ac --- /dev/null +++ b/constant/cache_key.go @@ -0,0 +1,18 @@ +package constant + +import "one-api/common" + +var ( + TokenCacheSeconds = common.SyncFrequency + UserId2GroupCacheSeconds = common.SyncFrequency + UserId2QuotaCacheSeconds = common.SyncFrequency + UserId2StatusCacheSeconds = common.SyncFrequency +) + +const ( + // Cache keys + UserGroupKeyFmt = "user_group:%d" + UserQuotaKeyFmt = "user_quota:%d" + UserEnabledKeyFmt = "user_enabled:%d" + UserUsernameKeyFmt = "user_name:%d" +) diff --git a/controller/billing.go b/controller/billing.go index 02fb8bd2..1fb83633 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -21,7 +21,7 @@ func GetSubscription(c *gin.Context) { usedQuota = token.UsedQuota } else { userId := c.GetInt("id") - remainQuota, err = model.GetUserQuota(userId) + remainQuota, err = model.GetUserQuota(userId, false) usedQuota, err = model.GetUserUsedQuota(userId) } if expiredTime <= 0 { diff --git a/controller/group.go b/controller/group.go index c5fde769..b700fc96 100644 --- a/controller/group.go +++ b/controller/group.go @@ -23,7 +23,7 @@ func GetUserGroups(c *gin.Context) { usableGroups := make(map[string]string) userGroup := "" userId := c.GetInt("id") - userGroup, _ = model.CacheGetUserGroup(userId) + userGroup, _ = model.GetUserGroup(userId, false) for groupName, _ := range setting.GetGroupRatioCopy() { // UserUsableGroups contains the groups that the user can use userUsableGroups := setting.GetUserUsableGroups(userGroup) diff --git a/controller/model.go b/controller/model.go index 3d207023..8ec2c7c9 100644 --- a/controller/model.go +++ b/controller/model.go @@ -166,7 +166,7 @@ func ListModels(c *gin.Context) { } } else { userId := c.GetInt("id") - userGroup, err := model.GetUserGroup(userId) + userGroup, err := model.GetUserGroup(userId, true) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/task.go b/controller/task.go index fce9e7f0..928f7ed7 100644 --- a/controller/task.go +++ b/controller/task.go @@ -153,7 +153,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) task.Progress = "100%" - err = model.CacheUpdateUserQuota(task.UserId) + //err = model.CacheUpdateUserQuota(task.UserId) ? if err != nil { common.LogError(ctx, "error update user quota cache: "+err.Error()) } else { diff --git a/controller/topup.go b/controller/topup.go index 85a12a1a..fb51c545 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -75,7 +75,7 @@ func RequestEpay(c *gin.Context) { } id := c.GetInt("id") - group, err := model.CacheGetUserGroup(id) + group, err := model.GetUserGroup(id, true) if err != nil { c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"}) return @@ -236,7 +236,7 @@ func RequestAmount(c *gin.Context) { return } id := c.GetInt("id") - group, err := model.CacheGetUserGroup(id) + group, err := model.GetUserGroup(id, true) if err != nil { c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"}) return diff --git a/middleware/auth.go b/middleware/auth.go index cb55cac9..64d1895c 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -201,7 +201,7 @@ func TokenAuth() func(c *gin.Context) { abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) return } - userEnabled, err := model.CacheIsUserEnabled(token.UserId) + userEnabled, err := model.IsUserEnabled(token.UserId, false) if err != nil { abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error()) return diff --git a/middleware/distributor.go b/middleware/distributor.go index 0d5a8cac..49cca260 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -40,7 +40,7 @@ func Distribute() func(c *gin.Context) { abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error()) return } - userGroup, _ := model.CacheGetUserGroup(userId) + userGroup, _ := model.GetUserGroup(userId, false) tokenGroup := c.GetString("token_group") if tokenGroup != "" { // check common.UserUsableGroups[userGroup] diff --git a/model/cache.go b/model/cache.go index a4ef47cd..0d87d1e1 100644 --- a/model/cache.go +++ b/model/cache.go @@ -6,20 +6,13 @@ import ( "fmt" "math/rand" "one-api/common" + "one-api/constant" "sort" - "strconv" "strings" "sync" "time" ) -var ( - TokenCacheSeconds = common.SyncFrequency - UserId2GroupCacheSeconds = common.SyncFrequency - UserId2QuotaCacheSeconds = common.SyncFrequency - UserId2StatusCacheSeconds = common.SyncFrequency -) - // 仅用于定时同步缓存 var token2UserId = make(map[string]int) var token2UserIdLock sync.RWMutex @@ -29,7 +22,7 @@ func cacheSetToken(token *Token) error { if err != nil { return err } - err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) + err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(constant.TokenCacheSeconds)*time.Second) if err != nil { common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error())) return err @@ -57,7 +50,7 @@ func CacheGetTokenByKey(key string) (*Token, error) { return token, nil } // 如果缓存中存在,则续期时间 - err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(TokenCacheSeconds)*time.Second) + err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(constant.TokenCacheSeconds)*time.Second) err = json.Unmarshal([]byte(tokenObjectString), &token) return token, err } @@ -101,109 +94,105 @@ func SyncTokenCache(frequency int) { } } -func CacheGetUserGroup(id int) (group string, err error) { - if !common.RedisEnabled { - return GetUserGroup(id) - } - group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id)) - if err != nil { - group, err = GetUserGroup(id) - if err != nil { - return "", err - } - err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) - if err != nil { - common.SysError("Redis set user group error: " + err.Error()) - } - } - return group, err -} - -func CacheGetUsername(id int) (username string, err error) { - if !common.RedisEnabled { - return GetUsernameById(id) - } - username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id)) - if err != nil { - username, err = GetUsernameById(id) - if err != nil { - return "", err - } - err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(UserId2GroupCacheSeconds)*time.Second) - if err != nil { - common.SysError("Redis set user group error: " + err.Error()) - } - } - return username, err -} - -func CacheGetUserQuota(id int) (quota int, err error) { - if !common.RedisEnabled { - return GetUserQuota(id) - } - quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) - if err != nil { - quota, err = GetUserQuota(id) - if err != nil { - return 0, err - } - err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) - if err != nil { - common.SysError("Redis set user quota error: " + err.Error()) - } - return quota, err - } - quota, err = strconv.Atoi(quotaString) - return quota, err -} - -func CacheUpdateUserQuota(id int) error { - if !common.RedisEnabled { - return nil - } - quota, err := GetUserQuota(id) - if err != nil { - return err - } - return cacheSetUserQuota(id, quota) -} - -func cacheSetUserQuota(id int, quota int) error { - err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) - return err -} - -func CacheDecreaseUserQuota(id int, quota int) error { - if !common.RedisEnabled { - return nil - } - err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota)) - return err -} - -func CacheIsUserEnabled(userId int) (bool, error) { - if !common.RedisEnabled { - return IsUserEnabled(userId) - } - enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) - if err == nil { - return enabled == "1", nil - } - - userEnabled, err := IsUserEnabled(userId) - if err != nil { - return false, err - } - enabled = "0" - if userEnabled { - enabled = "1" - } - err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) - if err != nil { - common.SysError("Redis set user enabled error: " + err.Error()) - } - return userEnabled, err -} +//func CacheGetUserGroup(id int) (group string, err error) { +// if !common.RedisEnabled { +// return GetUserGroup(id) +// } +// group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id)) +// if err != nil { +// group, err = GetUserGroup(id) +// if err != nil { +// return "", err +// } +// err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second) +// if err != nil { +// common.SysError("Redis set user group error: " + err.Error()) +// } +// } +// return group, err +//} +// +//func CacheGetUsername(id int) (username string, err error) { +// if !common.RedisEnabled { +// return GetUsernameById(id) +// } +// username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id)) +// if err != nil { +// username, err = GetUsernameById(id) +// if err != nil { +// return "", err +// } +// err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second) +// if err != nil { +// common.SysError("Redis set user group error: " + err.Error()) +// } +// } +// return username, err +//} +// +//func CacheGetUserQuota(id int) (quota int, err error) { +// if !common.RedisEnabled { +// return GetUserQuota(id) +// } +// quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) +// if err != nil { +// quota, err = GetUserQuota(id) +// if err != nil { +// return 0, err +// } +// return quota, nil +// } +// quota, err = strconv.Atoi(quotaString) +// return quota, nil +//} +// +//func CacheUpdateUserQuota(id int) error { +// if !common.RedisEnabled { +// return nil +// } +// quota, err := GetUserQuota(id) +// if err != nil { +// return err +// } +// return cacheSetUserQuota(id, quota) +//} +// +//func cacheSetUserQuota(id int, quota int) error { +// err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second) +// return err +//} +// +//func CacheDecreaseUserQuota(id int, quota int) error { +// if !common.RedisEnabled { +// return nil +// } +// err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota)) +// return err +//} +// +//func CacheIsUserEnabled(userId int) (bool, error) { +// if !common.RedisEnabled { +// return IsUserEnabled(userId) +// } +// enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) +// if err == nil { +// return enabled == "1", nil +// } +// +// userEnabled, err := IsUserEnabled(userId) +// if err != nil { +// return false, err +// } +// enabled = "0" +// if userEnabled { +// enabled = "1" +// } +// err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(constant.UserId2StatusCacheSeconds)*time.Second) +// if err != nil { +// common.SysError("Redis set user enabled error: " + err.Error()) +// } +// return userEnabled, err +//} var group2model2channels map[string]map[string][]*Channel var channelsIDM map[int]*Channel @@ -344,12 +333,12 @@ func CacheGetChannel(id int) (*Channel, error) { } func CacheUpdateChannelStatus(id int, status int) { - if (!common.MemoryCacheEnabled) { - return - } - channelSyncLock.Lock() - defer channelSyncLock.Unlock() - if channel, ok := channelsIDM[id]; ok { - channel.Status = status - } + if !common.MemoryCacheEnabled { + return + } + channelSyncLock.Lock() + defer channelSyncLock.Unlock() + if channel, ok := channelsIDM[id]; ok { + channel.Status = status + } } diff --git a/model/log.go b/model/log.go index 56cb54cb..06abeafa 100644 --- a/model/log.go +++ b/model/log.go @@ -81,7 +81,7 @@ func RecordLog(userId int, logType int, content string) { if logType == LogTypeConsume && !common.LogConsumeEnabled { return } - username, _ := CacheGetUsername(userId) + username, _ := GetUsernameById(userId, false) log := &Log{ UserId: userId, Username: username, @@ -102,7 +102,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke if !common.LogConsumeEnabled { return } - username, _ := CacheGetUsername(userId) + username, _ := GetUsernameById(userId, false) otherStr := common.MapToJsonStr(other) log := &Log{ UserId: userId, diff --git a/model/token.go b/model/token.go index 0f5a87cc..4d52bf03 100644 --- a/model/token.go +++ b/model/token.go @@ -258,37 +258,29 @@ func decreaseTokenQuota(id int, quota int) (err error) { return err } -func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) (userQuota int, err error) { +func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { if quota < 0 { - return 0, errors.New("quota 不能为负数!") + return errors.New("quota 不能为负数!") } if !relayInfo.IsPlayground { token, err := GetTokenById(relayInfo.TokenId) if err != nil { - return 0, err + return err } if !token.UnlimitedQuota && token.RemainQuota < quota { - return 0, errors.New("令牌额度不足") + return errors.New("令牌额度不足") } } - userQuota, err = GetUserQuota(relayInfo.UserId) - if err != nil { - return 0, err - } - if userQuota < quota { - return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota)) - } if !relayInfo.IsPlayground { - err = DecreaseTokenQuota(relayInfo.TokenId, quota) + err := DecreaseTokenQuota(relayInfo.TokenId, quota) if err != nil { - return 0, err + return err } } - err = DecreaseUserQuota(relayInfo.UserId, quota) - return userQuota - quota, err + return nil } -func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) { +func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) { if quota > 0 { err = DecreaseUserQuota(relayInfo.UserId, quota) diff --git a/model/user.go b/model/user.go index 36f18005..ff95b173 100644 --- a/model/user.go +++ b/model/user.go @@ -6,7 +6,8 @@ import ( "one-api/common" "strconv" "strings" - "time" + + "github.com/bytedance/gopkg/util/gopool" "gorm.io/gorm" ) @@ -107,7 +108,7 @@ func SearchUsers(keyword string, group string) ([]*User, error) { return users, err } } - + err = nil query := DB.Unscoped().Omit("password") @@ -251,14 +252,12 @@ func (user *User) Update(updatePassword bool) error { } newUser := *user DB.First(&user, user.Id) - err = DB.Model(user).Updates(newUser).Error - if err == nil { - if common.RedisEnabled { - _ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second) - _ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) - } + if err = DB.Model(user).Updates(newUser).Error; err != nil { + return err } - return err + + // 更新缓存 + return updateUserCache(user) } func (user *User) Edit(updatePassword bool) error { @@ -269,6 +268,7 @@ func (user *User) Edit(updatePassword bool) error { return err } } + newUser := *user updates := map[string]interface{}{ "username": newUser.Username, @@ -279,23 +279,26 @@ func (user *User) Edit(updatePassword bool) error { if updatePassword { updates["password"] = newUser.Password } + DB.First(&user, user.Id) - err = DB.Model(user).Updates(updates).Error - if err == nil { - if common.RedisEnabled { - _ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second) - _ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) - } + if err = DB.Model(user).Updates(updates).Error; err != nil { + return err } - return err + + // 更新缓存 + return updateUserCache(user) } func (user *User) Delete() error { if user.Id == 0 { return errors.New("id 为空!") } - err := DB.Delete(user).Error - return err + if err := DB.Delete(user).Error; err != nil { + return err + } + + // 清除缓存 + return invalidateUserCache(user.Id) } func (user *User) HardDelete() error { @@ -409,15 +412,33 @@ func IsAdmin(userId int) bool { return user.Role >= common.RoleAdminUser } -func IsUserEnabled(userId int) (bool, error) { - if userId == 0 { - return false, errors.New("user id is empty") +// IsUserEnabled checks user status from Redis first, falls back to DB if needed +func IsUserEnabled(id int, fromDB bool) (status bool, err error) { + defer func() { + // Update Redis cache asynchronously on successful DB read + if common.RedisEnabled { + gopool.Go(func() { + if err := updateUserStatusCache(id, status); err != nil { + common.SysError("failed to update user status cache: " + err.Error()) + } + }) + } + }() + if !fromDB && common.RedisEnabled { + // Try Redis first + status, err := getUserStatusCache(id) + if err == nil { + return status == common.UserStatusEnabled, nil + } + // Don't return error - fall through to DB } + var user User - err := DB.Where("id = ?", userId).Select("status").Find(&user).Error + err = DB.Where("id = ?", id).Select("status").Find(&user).Error if err != nil { return false, err } + return user.Status == common.UserStatusEnabled, nil } @@ -433,14 +454,33 @@ func ValidateAccessToken(token string) (user *User) { return nil } -func GetUserQuota(id int) (quota int, err error) { +// GetUserQuota gets quota from Redis first, falls back to DB if needed +func GetUserQuota(id int, fromDB bool) (quota int, err error) { + defer func() { + // Update Redis cache asynchronously on successful DB read + if common.RedisEnabled && err == nil { + gopool.Go(func() { + if err := updateUserQuotaCache(id, quota); err != nil { + common.SysError("failed to update user quota cache: " + err.Error()) + } + }) + } + }() + if !fromDB && common.RedisEnabled { + quota, err := getUserQuotaCache(id) + if err == nil { + return quota, nil + } + // Don't return error - fall through to DB + //common.SysError("failed to get user quota from cache: " + err.Error()) + } + err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error if err != nil { - if common.RedisEnabled { - go cacheSetUserQuota(id, quota) - } + return 0, err } - return quota, err + + return quota, nil } func GetUserUsedQuota(id int) (quota int, err error) { @@ -453,20 +493,49 @@ func GetUserEmail(id int) (email string, err error) { return email, err } -func GetUserGroup(id int) (group string, err error) { +// GetUserGroup gets group from Redis first, falls back to DB if needed +func GetUserGroup(id int, fromDB bool) (group string, err error) { + defer func() { + // Update Redis cache asynchronously on successful DB read + if common.RedisEnabled && err == nil { + gopool.Go(func() { + if err := updateUserGroupCache(id, group); err != nil { + common.SysError("failed to update user group cache: " + err.Error()) + } + }) + } + }() + if !fromDB && common.RedisEnabled { + group, err := getUserGroupCache(id) + if err == nil { + return group, nil + } + // Don't return error - fall through to DB + } + groupCol := "`group`" if common.UsingPostgreSQL { groupCol = `"group"` } err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error - return group, err + if err != nil { + return "", err + } + + return group, nil } func IncreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + gopool.Go(func() { + err := cacheIncrUserQuota(id, quota) + if err != nil { + common.SysError("failed to increase user quota: " + err.Error()) + } + }) if common.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUserQuota, id, quota) return nil @@ -476,6 +545,9 @@ func IncreaseUserQuota(id int, quota int) (err error) { func increaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error + if err != nil { + return err + } return err } @@ -483,6 +555,12 @@ func DecreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + gopool.Go(func() { + err := cacheDecrUserQuota(id, quota) + if err != nil { + common.SysError("failed to decrease user quota: " + err.Error()) + } + }) if common.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUserQuota, id, -quota) return nil @@ -492,9 +570,23 @@ func DecreaseUserQuota(id int, quota int) (err error) { func decreaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error + if err != nil { + return err + } return err } +func DeltaUpdateUserQuota(id int, delta int) (err error) { + if delta == 0 { + return nil + } + if delta > 0 { + return IncreaseUserQuota(id, delta) + } else { + return DecreaseUserQuota(id, -delta) + } +} + func GetRootUserEmail() (email string) { DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) return email @@ -518,7 +610,13 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { ).Error if err != nil { common.SysError("failed to update user used quota and request count: " + err.Error()) + return } + + //// 更新缓存 + //if err := invalidateUserCache(id); err != nil { + // common.SysError("failed to invalidate user cache: " + err.Error()) + //} } func updateUserUsedQuota(id int, quota int) { @@ -539,9 +637,32 @@ func updateUserRequestCount(id int, count int) { } } -func GetUsernameById(id int) (username string, err error) { +// GetUsernameById gets username from Redis first, falls back to DB if needed +func GetUsernameById(id int, fromDB bool) (username string, err error) { + defer func() { + // Update Redis cache asynchronously on successful DB read + if common.RedisEnabled && err == nil { + gopool.Go(func() { + if err := updateUserNameCache(id, username); err != nil { + common.SysError("failed to update user name cache: " + err.Error()) + } + }) + } + }() + if !fromDB && common.RedisEnabled { + username, err := getUserNameCache(id) + if err == nil { + return username, nil + } + // Don't return error - fall through to DB + } + err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error - return username, err + if err != nil { + return "", err + } + + return username, nil } func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool { diff --git a/model/user_cache.go b/model/user_cache.go new file mode 100644 index 00000000..8c112939 --- /dev/null +++ b/model/user_cache.go @@ -0,0 +1,206 @@ +package model + +import ( + "fmt" + "one-api/common" + "one-api/constant" + "strconv" + "time" +) + +// Change UserCache struct to userCache +type userCache struct { + Id int `json:"id"` + Group string `json:"group"` + Quota int `json:"quota"` + Status int `json:"status"` + Role int `json:"role"` + Username string `json:"username"` +} + +// Rename all exported functions to private ones +// invalidateUserCache clears all user related cache +func invalidateUserCache(userId int) error { + if !common.RedisEnabled { + return nil + } + + keys := []string{ + fmt.Sprintf(constant.UserGroupKeyFmt, userId), + fmt.Sprintf(constant.UserQuotaKeyFmt, userId), + fmt.Sprintf(constant.UserEnabledKeyFmt, userId), + fmt.Sprintf(constant.UserUsernameKeyFmt, userId), + } + + for _, key := range keys { + if err := common.RedisDel(key); err != nil { + return fmt.Errorf("failed to delete cache key %s: %w", key, err) + } + } + return nil +} + +// updateUserGroupCache updates user group cache +func updateUserGroupCache(userId int, group string) error { + if !common.RedisEnabled { + return nil + } + return common.RedisSet( + fmt.Sprintf(constant.UserGroupKeyFmt, userId), + group, + time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, + ) +} + +// updateUserQuotaCache updates user quota cache +func updateUserQuotaCache(userId int, quota int) error { + if !common.RedisEnabled { + return nil + } + return common.RedisSet( + fmt.Sprintf(constant.UserQuotaKeyFmt, userId), + fmt.Sprintf("%d", quota), + time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, + ) +} + +// updateUserStatusCache updates user status cache +func updateUserStatusCache(userId int, userEnabled bool) error { + if !common.RedisEnabled { + return nil + } + enabled := "0" + if userEnabled { + enabled = "1" + } + return common.RedisSet( + fmt.Sprintf(constant.UserEnabledKeyFmt, userId), + enabled, + time.Duration(constant.UserId2StatusCacheSeconds)*time.Second, + ) +} + +// updateUserNameCache updates username cache +func updateUserNameCache(userId int, username string) error { + if !common.RedisEnabled { + return nil + } + return common.RedisSet( + fmt.Sprintf(constant.UserUsernameKeyFmt, userId), + username, + time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, + ) +} + +// updateUserCache updates all user cache fields +func updateUserCache(user *User) error { + if !common.RedisEnabled { + return nil + } + + if err := updateUserGroupCache(user.Id, user.Group); err != nil { + return fmt.Errorf("update group cache: %w", err) + } + + if err := updateUserQuotaCache(user.Id, user.Quota); err != nil { + return fmt.Errorf("update quota cache: %w", err) + } + + if err := updateUserStatusCache(user.Id, user.Status == common.UserStatusEnabled); err != nil { + return fmt.Errorf("update status cache: %w", err) + } + + if err := updateUserNameCache(user.Id, user.Username); err != nil { + return fmt.Errorf("update username cache: %w", err) + } + + return nil +} + +// getUserGroupCache gets user group from cache +func getUserGroupCache(userId int) (string, error) { + if !common.RedisEnabled { + return "", nil + } + return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId)) +} + +// getUserQuotaCache gets user quota from cache +func getUserQuotaCache(userId int) (int, error) { + if !common.RedisEnabled { + return 0, nil + } + quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId)) + if err != nil { + return 0, err + } + return strconv.Atoi(quotaStr) +} + +// getUserStatusCache gets user status from cache +func getUserStatusCache(userId int) (int, error) { + if !common.RedisEnabled { + return 0, nil + } + statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId)) + if err != nil { + return 0, err + } + return strconv.Atoi(statusStr) +} + +// getUserNameCache gets username from cache +func getUserNameCache(userId int) (string, error) { + if !common.RedisEnabled { + return "", nil + } + return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId)) +} + +// getUserCache gets complete user cache +func getUserCache(userId int) (*userCache, error) { + if !common.RedisEnabled { + return nil, nil + } + + group, err := getUserGroupCache(userId) + if err != nil { + return nil, fmt.Errorf("get group cache: %w", err) + } + + quota, err := getUserQuotaCache(userId) + if err != nil { + return nil, fmt.Errorf("get quota cache: %w", err) + } + + status, err := getUserStatusCache(userId) + if err != nil { + return nil, fmt.Errorf("get status cache: %w", err) + } + + username, err := getUserNameCache(userId) + if err != nil { + return nil, fmt.Errorf("get username cache: %w", err) + } + + return &userCache{ + Id: userId, + Group: group, + Quota: quota, + Status: status, + Username: username, + }, nil +} + +// Add atomic quota operations +func cacheIncrUserQuota(userId int, delta int) error { + if !common.RedisEnabled { + return nil + } + key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId) + return common.RedisIncr(key, delta) +} + +func cacheDecrUserQuota(userId int, delta int) error { + return cacheIncrUserQuota(userId, -delta) +} diff --git a/relay/relay-audio.go b/relay/relay-audio.go index c9f54f82..a2943457 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -77,24 +77,20 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { groupRatio := setting.GetGroupRatio(relayInfo.Group) ratio := modelRatio * groupRatio preConsumedQuota := int(float64(preConsumedTokens) * ratio) - userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) + userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) } if userQuota-preConsumedQuota < 0 { return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("audio pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest) } - err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } if userQuota > 100*preConsumedQuota { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 } if preConsumedQuota > 0 { - userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota) + err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota) if err != nil { return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) } diff --git a/relay/relay-image.go b/relay/relay-image.go index 5ec71611..207350da 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -100,7 +100,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { } groupRatio := setting.GetGroupRatio(relayInfo.Group) - userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) + userQuota, err := model.GetUserQuota(relayInfo.UserId, false) sizeRatio := 1.0 // Size diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 8bc5c93a..0facecab 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -170,7 +170,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { } groupRatio := setting.GetGroupRatio(group) ratio := modelPrice * groupRatio - userQuota, err := model.CacheGetUserQuota(userId) + userQuota, err := model.GetUserQuota(userId, false) if err != nil { return &dto.MidjourneyResponse{ Code: 4, @@ -194,11 +194,11 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { } defer func(ctx context.Context) { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { - err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true) + err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } - err = model.CacheUpdateUserQuota(userId) + //err = model.CacheUpdateUserQuota(userId) if err != nil { common.SysError("error update user quota cache: " + err.Error()) } @@ -476,7 +476,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } groupRatio := setting.GetGroupRatio(group) ratio := modelPrice * groupRatio - userQuota, err := model.CacheGetUserQuota(userId) + userQuota, err := model.GetUserQuota(userId, false) if err != nil { return &dto.MidjourneyResponse{ Code: 4, @@ -500,14 +500,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons defer func(ctx context.Context) { if consumeQuota && midjResponseWithStatus.StatusCode == 200 { - err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true) + err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", modelPrice, groupRatio, midjRequest.Action, midjResponse.Result) diff --git a/relay/relay-text.go b/relay/relay-text.go index c3e449be..6f251f6d 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -262,7 +262,7 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom // 预扣费并返回用户剩余配额 func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) { - userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) + userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) } @@ -272,10 +272,6 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if userQuota-preConsumedQuota < 0 { return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest) } - err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota) - if err != nil { - return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } if userQuota > 100*preConsumedQuota { // 用户额度充足,判断令牌额度是否充足 if !relayInfo.TokenUnlimited { @@ -293,8 +289,13 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota)) } } + if preConsumedQuota > 0 { - userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota) + err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) + if err != nil { + return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota) if err != nil { return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) } @@ -307,7 +308,7 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us go func() { relayInfoCopy := *relayInfo - err := model.PostConsumeTokenQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false) + err := model.PostConsumeQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false) if err != nil { common.SysError("error return pre-consumed quota: " + err.Error()) } @@ -365,15 +366,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN //} quotaDelta := quota - preConsumedQuota if quotaDelta != 0 { - err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) + err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) if err != nil { common.LogError(ctx, "error consuming token remain quota: "+err.Error()) } } - err := model.CacheUpdateUserQuota(relayInfo.UserId) - if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) - } model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } diff --git a/relay/relay_task.go b/relay/relay_task.go index 7b694a81..61577faf 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -51,7 +51,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { // 预扣 groupRatio := setting.GetGroupRatio(relayInfo.Group) ratio := modelPrice * groupRatio - userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) + userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) return @@ -113,14 +113,10 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { // release quota if relayInfo.ConsumeQuota && taskErr == nil { - err := model.PostConsumeTokenQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true) + err := model.PostConsumeQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } - err = model.CacheUpdateUserQuota(relayInfo.UserId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action) diff --git a/service/quota.go b/service/quota.go index 2e0cd4fb..820dcce5 100644 --- a/service/quota.go +++ b/service/quota.go @@ -18,7 +18,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag if relayInfo.UsePrice { return nil } - userQuota, err := model.GetUserQuota(relayInfo.UserId) + userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { return err } @@ -58,15 +58,11 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota)) } - err = model.PostConsumeTokenQuota(relayInfo, 0, quota, 0, false) + err = model.PostConsumeQuota(relayInfo, 0, quota, 0, false) if err != nil { return err } common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota)) - err = model.CacheUpdateUserQuota(relayInfo.UserId) - if err != nil { - return err - } return nil } @@ -120,7 +116,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod //} //quotaDelta := quota - preConsumedQuota //if quotaDelta != 0 { - // err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) + // err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) // if err != nil { // common.LogError(ctx, "error consuming token remain quota: "+err.Error()) // } @@ -190,15 +186,11 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } else { quotaDelta := quota - preConsumedQuota if quotaDelta != 0 { - err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) + err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) if err != nil { common.LogError(ctx, "error consuming token remain quota: "+err.Error()) } } - err := model.CacheUpdateUserQuota(relayInfo.UserId) - if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) - } model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) }