refactor: token cache logic

This commit is contained in:
CalciumIon
2024-12-30 17:10:48 +08:00
parent ca8b7ed1c3
commit bb5e032dd2
15 changed files with 417 additions and 196 deletions

View File

@@ -1,99 +1,16 @@
package model
import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"one-api/common"
"one-api/constant"
"sort"
"strings"
"sync"
"time"
)
// 仅用于定时同步缓存
var token2UserId = make(map[string]int)
var token2UserIdLock sync.RWMutex
func cacheSetToken(token *Token) error {
jsonBytes, err := json.Marshal(token)
if err != nil {
return err
}
err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(constant.TokenCacheSeconds)*time.Second)
if err != nil {
common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error()))
return err
}
token2UserIdLock.Lock()
defer token2UserIdLock.Unlock()
token2UserId[token.Key] = token.UserId
return nil
}
// CacheGetTokenByKey 从缓存中获取 token 并续期时间,如果缓存中不存在,则从数据库中获取
func CacheGetTokenByKey(key string) (*Token, error) {
if !common.RedisEnabled {
return GetTokenByKey(key)
}
var token *Token
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
// 如果缓存中不存在,则从数据库中获取
token, err = GetTokenByKey(key)
if err != nil {
return nil, err
}
err = cacheSetToken(token)
return token, nil
}
// 如果缓存中存在,则续期时间
err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(constant.TokenCacheSeconds)*time.Second)
err = json.Unmarshal([]byte(tokenObjectString), &token)
return token, err
}
func SyncTokenCache(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("syncing tokens from database")
token2UserIdLock.Lock()
// 从token2UserId中获取所有的key
var copyToken2UserId = make(map[string]int)
for s, i := range token2UserId {
copyToken2UserId[s] = i
}
token2UserId = make(map[string]int)
token2UserIdLock.Unlock()
for key := range copyToken2UserId {
token, err := GetTokenByKey(key)
if err != nil {
// 如果数据库中不存在,则删除缓存
common.SysError(fmt.Sprintf("failed to get token %s from database: %s", key, err.Error()))
//delete redis
err := common.RedisDel(fmt.Sprintf("token:%s", key))
if err != nil {
common.SysError(fmt.Sprintf("failed to delete token %s from redis: %s", key, err.Error()))
}
} else {
// 如果数据库中存在先检查redis
_, err = common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
// 如果redis中不存在则跳过
continue
}
err = cacheSetToken(token)
if err != nil {
common.SysError(fmt.Sprintf("failed to update token %s to redis: %s", key, err.Error()))
}
}
}
}
}
//func CacheGetUserGroup(id int) (group string, err error) {
// if !common.RedisEnabled {
// return GetUserGroup(id)

View File

@@ -12,16 +12,6 @@ import (
"gorm.io/gorm"
)
var groupCol string
func init() {
if common.UsingPostgreSQL {
groupCol = `"group"`
} else {
groupCol = "`group`"
}
}
type Log struct {
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1"`
UserId int `json:"user_id" gorm:"index"`

View File

@@ -13,6 +13,20 @@ import (
"time"
)
var groupCol string
var keyCol string
func init() {
if common.UsingPostgreSQL {
groupCol = `"group"`
keyCol = `"key"`
} else {
groupCol = "`group`"
keyCol = "`key`"
}
}
var DB *gorm.DB
var LOG_DB *gorm.DB

View File

@@ -3,6 +3,7 @@ package model
import (
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
relaycommon "one-api/relay/common"
@@ -30,6 +31,10 @@ type Token struct {
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (token *Token) Clean() {
token.Key = ""
}
func (token *Token) GetIpLimitsMap() map[string]any {
// delete empty spaces
//split with \n
@@ -71,7 +76,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
if key == "" {
return nil, errors.New("未提供令牌")
}
token, err = CacheGetTokenByKey(key)
token, err = GetTokenByKey(key, false)
if err == nil {
if token.Status == common.TokenStatusExhausted {
keyPrefix := key[:3]
@@ -129,21 +134,37 @@ func GetTokenById(id int) (*Token, error) {
var err error = nil
err = DB.First(&token, "id = ?", id).Error
if err != nil {
if common.RedisEnabled {
go cacheSetToken(&token)
}
gopool.Go(func() {
if err := cacheSetToken(token); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
}
})
}
return &token, err
}
func GetTokenByKey(key string) (*Token, error) {
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) && token != nil {
gopool.Go(func() {
if err := cacheSetToken(*token); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
// Try Redis first
token, err := cacheGetTokenByKey(key)
if err == nil {
return token, nil
}
// Don't return error - fall through to DB
}
var token Token
err := DB.Where(keyCol+" = ?", key).First(&token).Error
return &token, err
fromDB = true
err = DB.Where(keyCol+" = ?", key).First(&token).Error
return token, err
}
func (token *Token) Insert() error {
@@ -153,20 +174,48 @@ func (token *Token) Insert() error {
}
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
var err error
func (token *Token) Update() (err error) {
defer func() {
if common.RedisEnabled && err == nil {
gopool.Go(func() {
err := cacheSetToken(*token)
if err != nil {
common.SysError("failed to update token cache: " + err.Error())
}
})
}
}()
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
"model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
return err
}
func (token *Token) SelectUpdate() error {
func (token *Token) SelectUpdate() (err error) {
defer func() {
if common.RedisEnabled && err == nil {
gopool.Go(func() {
err := cacheSetToken(*token)
if err != nil {
common.SysError("failed to update token cache: " + err.Error())
}
})
}
}()
// This can update zero values
return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
}
func (token *Token) Delete() error {
var err error
func (token *Token) Delete() (err error) {
defer func() {
if common.RedisEnabled && err == nil {
gopool.Go(func() {
err := cacheDeleteToken(token.Key)
if err != nil {
common.SysError("failed to delete token cache: " + err.Error())
}
})
}
}()
err = DB.Delete(token).Error
return err
}
@@ -214,10 +263,16 @@ func DeleteTokenById(id int, userId int) (err error) {
return token.Delete()
}
func IncreaseTokenQuota(id int, quota int) (err error) {
func IncreaseTokenQuota(id int, key string, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
gopool.Go(func() {
err := cacheIncrTokenQuota(key, int64(quota))
if err != nil {
common.SysError("failed to increase token quota: " + err.Error())
}
})
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
return nil
@@ -236,10 +291,16 @@ func increaseTokenQuota(id int, quota int) (err error) {
return err
}
func DecreaseTokenQuota(id int, quota int) (err error) {
func DecreaseTokenQuota(id int, key string, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
gopool.Go(func() {
err := cacheDecrTokenQuota(key, int64(quota))
if err != nil {
common.SysError("failed to decrease token quota: " + err.Error())
}
})
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
return nil
@@ -262,20 +323,22 @@ func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if !relayInfo.IsPlayground {
token, err := GetTokenById(relayInfo.TokenId)
if err != nil {
return err
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
return errors.New("令牌额度不足")
}
if relayInfo.IsPlayground {
return nil
}
if !relayInfo.IsPlayground {
err := DecreaseTokenQuota(relayInfo.TokenId, quota)
if err != nil {
return err
}
//if relayInfo.TokenUnlimited {
// return nil
//}
token, err := GetTokenById(relayInfo.TokenId)
if err != nil {
return err
}
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
return errors.New("令牌额度不足")
}
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
if err != nil {
return err
}
return nil
}
@@ -293,9 +356,9 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int
if !relayInfo.IsPlayground {
if quota > 0 {
err = DecreaseTokenQuota(relayInfo.TokenId, quota)
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
} else {
err = IncreaseTokenQuota(relayInfo.TokenId, -quota)
err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
}
if err != nil {
return err

64
model/token_cache.go Normal file
View File

@@ -0,0 +1,64 @@
package model
import (
"fmt"
"one-api/common"
"one-api/constant"
"time"
)
func cacheSetToken(token Token) error {
key := common.GenerateHMAC(token.Key)
token.Clean()
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second)
if err != nil {
return err
}
return nil
}
func cacheDeleteToken(key string) error {
key = common.GenerateHMAC(key)
err := common.RedisHDelObj(fmt.Sprintf("token:%s", key))
if err != nil {
return err
}
return nil
}
func cacheIncrTokenQuota(key string, increment int64) error {
key = common.GenerateHMAC(key)
err := common.RedisHIncrBy(fmt.Sprintf("token:%s", key), constant.TokenFiledRemainQuota, increment)
if err != nil {
return err
}
return nil
}
func cacheDecrTokenQuota(key string, decrement int64) error {
return cacheIncrTokenQuota(key, -decrement)
}
func cacheSetTokenField(key string, field string, value string) error {
key = common.GenerateHMAC(key)
err := common.RedisHSetField(fmt.Sprintf("token:%s", key), field, value)
if err != nil {
return err
}
return nil
}
// CacheGetTokenByKey 从缓存中获取 token如果缓存中不存在则从数据库中获取
func cacheGetTokenByKey(key string) (*Token, error) {
hmacKey := common.GenerateHMAC(key)
if !common.RedisEnabled {
return nil, nil
}
var token Token
err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
if err != nil {
return nil, err
}
token.Key = key
return &token, nil
}

View File

@@ -252,7 +252,7 @@ func (user *User) Update(updatePassword bool) error {
}
// 更新缓存
return updateUserCache(user)
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
}
func (user *User) Edit(updatePassword bool) error {
@@ -281,7 +281,7 @@ func (user *User) Edit(updatePassword bool) error {
}
// 更新缓存
return updateUserCache(user)
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
}
func (user *User) Delete() error {
@@ -411,7 +411,7 @@ func IsAdmin(userId int) bool {
func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if common.RedisEnabled {
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserStatusCache(id, status); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
@@ -427,7 +427,7 @@ func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
}
// Don't return error - fall through to DB
}
fromDB = true
var user User
err = DB.Where("id = ?", id).Select("status").Find(&user).Error
if err != nil {
@@ -453,7 +453,7 @@ func ValidateAccessToken(token string) (user *User) {
func GetUserQuota(id int, fromDB bool) (quota int, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if common.RedisEnabled && err == nil {
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserQuotaCache(id, quota); err != nil {
common.SysError("failed to update user quota cache: " + err.Error())
@@ -469,7 +469,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
// Don't return error - fall through to DB
//common.SysError("failed to get user quota from cache: " + err.Error())
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
if err != nil {
return 0, err
@@ -492,7 +492,7 @@ func GetUserEmail(id int) (email string, err error) {
func GetUserGroup(id int, fromDB bool) (group string, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if common.RedisEnabled && err == nil {
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserGroupCache(id, group); err != nil {
common.SysError("failed to update user group cache: " + err.Error())
@@ -507,7 +507,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
}
// Don't return error - fall through to DB
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
if err != nil {
return "", err
@@ -521,7 +521,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
return errors.New("quota 不能为负数!")
}
gopool.Go(func() {
err := cacheIncrUserQuota(id, quota)
err := cacheIncrUserQuota(id, int64(quota))
if err != nil {
common.SysError("failed to increase user quota: " + err.Error())
}
@@ -546,7 +546,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
return errors.New("quota 不能为负数!")
}
gopool.Go(func() {
err := cacheDecrUserQuota(id, quota)
err := cacheDecrUserQuota(id, int64(quota))
if err != nil {
common.SysError("failed to decrease user quota: " + err.Error())
}
@@ -631,7 +631,7 @@ func updateUserRequestCount(id int, count int) {
func GetUsernameById(id int, fromDB bool) (username string, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if common.RedisEnabled && err == nil {
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserNameCache(id, username); err != nil {
common.SysError("failed to update user name cache: " + err.Error())
@@ -646,7 +646,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) {
}
// Don't return error - fall through to DB
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error
if err != nil {
return "", err

View File

@@ -93,24 +93,24 @@ func updateUserNameCache(userId int, username string) error {
}
// updateUserCache updates all user cache fields
func updateUserCache(user *User) error {
func updateUserCache(userId int, username string, userGroup string, quota int, status int) error {
if !common.RedisEnabled {
return nil
}
if err := updateUserGroupCache(user.Id, user.Group); err != nil {
if err := updateUserGroupCache(userId, userGroup); err != nil {
return fmt.Errorf("update group cache: %w", err)
}
if err := updateUserQuotaCache(user.Id, user.Quota); err != nil {
if err := updateUserQuotaCache(userId, quota); err != nil {
return fmt.Errorf("update quota cache: %w", err)
}
if err := updateUserStatusCache(user.Id, user.Status == common.UserStatusEnabled); err != nil {
if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil {
return fmt.Errorf("update status cache: %w", err)
}
if err := updateUserNameCache(user.Id, user.Username); err != nil {
if err := updateUserNameCache(userId, username); err != nil {
return fmt.Errorf("update username cache: %w", err)
}
@@ -193,7 +193,7 @@ func getUserCache(userId int) (*userCache, error) {
}
// Add atomic quota operations
func cacheIncrUserQuota(userId int, delta int) error {
func cacheIncrUserQuota(userId int, delta int64) error {
if !common.RedisEnabled {
return nil
}
@@ -201,6 +201,6 @@ func cacheIncrUserQuota(userId int, delta int) error {
return common.RedisIncr(key, delta)
}
func cacheDecrUserQuota(userId int, delta int) error {
func cacheDecrUserQuota(userId int, delta int64) error {
return cacheIncrUserQuota(userId, -delta)
}

View File

@@ -88,3 +88,7 @@ func RecordExist(err error) (bool, error) {
}
return false, err
}
func shouldUpdateRedis(fromDB bool, err error) bool {
return common.RedisEnabled && fromDB && err == nil
}