refactor: token cache logic
This commit is contained in:
@@ -1,6 +1,23 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import "golang.org/x/crypto/bcrypt"
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GenerateHMACWithKey(key []byte, data string) string {
|
||||||
|
h := hmac.New(sha256.New, key)
|
||||||
|
h.Write([]byte(data))
|
||||||
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateHMAC(data string) string {
|
||||||
|
h := hmac.New(sha256.New, []byte(SessionSecret))
|
||||||
|
h.Write([]byte(data))
|
||||||
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
func Password2Hash(password string) (string, error) {
|
func Password2Hash(password string) (string, error) {
|
||||||
passwordBytes := []byte(password)
|
passwordBytes := []byte(password)
|
||||||
|
|||||||
216
common/redis.go
216
common/redis.go
@@ -2,11 +2,15 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
var RDB *redis.Client
|
var RDB *redis.Client
|
||||||
@@ -58,39 +62,167 @@ func RedisGet(key string) (string, error) {
|
|||||||
return RDB.Get(ctx, key).Result()
|
return RDB.Get(ctx, key).Result()
|
||||||
}
|
}
|
||||||
|
|
||||||
func RedisExpire(key string, expiration time.Duration) error {
|
//func RedisExpire(key string, expiration time.Duration) error {
|
||||||
ctx := context.Background()
|
// ctx := context.Background()
|
||||||
return RDB.Expire(ctx, key, expiration).Err()
|
// return RDB.Expire(ctx, key, expiration).Err()
|
||||||
}
|
//}
|
||||||
|
//
|
||||||
func RedisGetEx(key string, expiration time.Duration) (string, error) {
|
//func RedisGetEx(key string, expiration time.Duration) (string, error) {
|
||||||
ctx := context.Background()
|
// ctx := context.Background()
|
||||||
return RDB.GetSet(ctx, key, expiration).Result()
|
// return RDB.GetSet(ctx, key, expiration).Result()
|
||||||
}
|
//}
|
||||||
|
|
||||||
func RedisDel(key string) error {
|
func RedisDel(key string) error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return RDB.Del(ctx, key).Err()
|
return RDB.Del(ctx, key).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func RedisDecrease(key string, value int64) error {
|
func RedisHDelObj(key string) error {
|
||||||
|
ctx := context.Background()
|
||||||
|
return RDB.HDel(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
data := make(map[string]interface{})
|
||||||
|
|
||||||
|
// 使用反射遍历结构体字段
|
||||||
|
v := reflect.ValueOf(obj).Elem()
|
||||||
|
t := v.Type()
|
||||||
|
for i := 0; i < v.NumField(); i++ {
|
||||||
|
field := t.Field(i)
|
||||||
|
value := v.Field(i)
|
||||||
|
|
||||||
|
// Skip DeletedAt field
|
||||||
|
if field.Type.String() == "gorm.DeletedAt" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理指针类型
|
||||||
|
if value.Kind() == reflect.Ptr {
|
||||||
|
if value.IsNil() {
|
||||||
|
data[field.Name] = ""
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
value = value.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理布尔类型
|
||||||
|
if value.Kind() == reflect.Bool {
|
||||||
|
data[field.Name] = strconv.FormatBool(value.Bool())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 其他类型直接转换为字符串
|
||||||
|
data[field.Name] = fmt.Sprintf("%v", value.Interface())
|
||||||
|
}
|
||||||
|
|
||||||
|
txn := RDB.TxPipeline()
|
||||||
|
txn.HSet(ctx, key, data)
|
||||||
|
txn.Expire(ctx, key, expiration)
|
||||||
|
|
||||||
|
_, err := txn.Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute transaction: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func RedisHGetObj(key string, obj interface{}) error {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
result, err := RDB.HGetAll(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load hash from Redis: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result) == 0 {
|
||||||
|
return fmt.Errorf("key %s not found in Redis", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle both pointer and non-pointer values
|
||||||
|
val := reflect.ValueOf(obj)
|
||||||
|
if val.Kind() != reflect.Ptr {
|
||||||
|
return fmt.Errorf("obj must be a pointer to a struct, got %T", obj)
|
||||||
|
}
|
||||||
|
|
||||||
|
v := val.Elem()
|
||||||
|
if v.Kind() != reflect.Struct {
|
||||||
|
return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface())
|
||||||
|
}
|
||||||
|
|
||||||
|
t := v.Type()
|
||||||
|
for i := 0; i < v.NumField(); i++ {
|
||||||
|
field := t.Field(i)
|
||||||
|
fieldName := field.Name
|
||||||
|
if value, ok := result[fieldName]; ok {
|
||||||
|
fieldValue := v.Field(i)
|
||||||
|
|
||||||
|
// Handle pointer types
|
||||||
|
if fieldValue.Kind() == reflect.Ptr {
|
||||||
|
if value == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if fieldValue.IsNil() {
|
||||||
|
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
|
||||||
|
}
|
||||||
|
fieldValue = fieldValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enhanced type handling for Token struct
|
||||||
|
switch fieldValue.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
fieldValue.SetString(value)
|
||||||
|
case reflect.Int, reflect.Int64:
|
||||||
|
intValue, err := strconv.ParseInt(value, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse int field %s: %w", fieldName, err)
|
||||||
|
}
|
||||||
|
fieldValue.SetInt(intValue)
|
||||||
|
case reflect.Bool:
|
||||||
|
boolValue, err := strconv.ParseBool(value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err)
|
||||||
|
}
|
||||||
|
fieldValue.SetBool(boolValue)
|
||||||
|
case reflect.Struct:
|
||||||
|
// Special handling for gorm.DeletedAt
|
||||||
|
if fieldValue.Type().String() == "gorm.DeletedAt" {
|
||||||
|
if value != "" {
|
||||||
|
timeValue, err := time.Parse(time.RFC3339, value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err)
|
||||||
|
}
|
||||||
|
fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedisIncr Add this function to handle atomic increments
|
||||||
|
func RedisIncr(key string, delta int64) error {
|
||||||
// 检查键的剩余生存时间
|
// 检查键的剩余生存时间
|
||||||
ttlCmd := RDB.TTL(context.Background(), key)
|
ttlCmd := RDB.TTL(context.Background(), key)
|
||||||
ttl, err := ttlCmd.Result()
|
ttl, err := ttlCmd.Result()
|
||||||
if err != nil {
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
// 失败则尝试直接减少
|
return fmt.Errorf("failed to get TTL: %w", err)
|
||||||
return RDB.DecrBy(context.Background(), key, value).Err()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果剩余生存时间大于0,则进行减少操作
|
// 只有在 key 存在且有 TTL 时才需要特殊处理
|
||||||
if ttl > 0 {
|
if ttl > 0 {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
// 开始一个Redis事务
|
// 开始一个Redis事务
|
||||||
txn := RDB.TxPipeline()
|
txn := RDB.TxPipeline()
|
||||||
|
|
||||||
// 减少余额
|
// 减少余额
|
||||||
decrCmd := txn.DecrBy(ctx, key, value)
|
decrCmd := txn.IncrBy(ctx, key, delta)
|
||||||
if err := decrCmd.Err(); err != nil {
|
if err := decrCmd.Err(); err != nil {
|
||||||
return err // 如果减少失败,则直接返回错误
|
return err // 如果减少失败,则直接返回错误
|
||||||
}
|
}
|
||||||
@@ -101,26 +233,54 @@ func RedisDecrease(key string, value int64) error {
|
|||||||
// 执行事务
|
// 执行事务
|
||||||
_, err = txn.Exec(ctx)
|
_, err = txn.Exec(ctx)
|
||||||
return err
|
return err
|
||||||
} else {
|
|
||||||
_ = RedisDel(key)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RedisIncr Add this function to handle atomic increments
|
func RedisHIncrBy(key, field string, delta int64) error {
|
||||||
func RedisIncr(key string, delta int) error {
|
ttlCmd := RDB.TTL(context.Background(), key)
|
||||||
ctx := context.Background()
|
ttl, err := ttlCmd.Result()
|
||||||
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
|
return fmt.Errorf("failed to get TTL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// 检查键是否存在
|
if ttl > 0 {
|
||||||
exists, err := RDB.Exists(ctx, key).Result()
|
ctx := context.Background()
|
||||||
if err != nil {
|
txn := RDB.TxPipeline()
|
||||||
|
|
||||||
|
incrCmd := txn.HIncrBy(ctx, key, field, delta)
|
||||||
|
if err := incrCmd.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
txn.Expire(ctx, key, ttl)
|
||||||
|
|
||||||
|
_, err = txn.Exec(ctx)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if exists == 0 {
|
return nil
|
||||||
return fmt.Errorf("key does not exist") // 键不存在,返回错误
|
}
|
||||||
|
|
||||||
|
func RedisHSetField(key, field string, value interface{}) error {
|
||||||
|
ttlCmd := RDB.TTL(context.Background(), key)
|
||||||
|
ttl, err := ttlCmd.Result()
|
||||||
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
|
return fmt.Errorf("failed to get TTL: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 键存在,执行INCRBY操作
|
if ttl > 0 {
|
||||||
result := RDB.IncrBy(ctx, key, int64(delta))
|
ctx := context.Background()
|
||||||
return result.Err()
|
txn := RDB.TxPipeline()
|
||||||
|
|
||||||
|
hsetCmd := txn.HSet(ctx, key, field, value)
|
||||||
|
if err := hsetCmd.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
txn.Expire(ctx, key, ttl)
|
||||||
|
|
||||||
|
_, err = txn.Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,10 +9,15 @@ var (
|
|||||||
UserId2StatusCacheSeconds = common.SyncFrequency
|
UserId2StatusCacheSeconds = common.SyncFrequency
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Cache keys
|
||||||
const (
|
const (
|
||||||
// Cache keys
|
|
||||||
UserGroupKeyFmt = "user_group:%d"
|
UserGroupKeyFmt = "user_group:%d"
|
||||||
UserQuotaKeyFmt = "user_quota:%d"
|
UserQuotaKeyFmt = "user_quota:%d"
|
||||||
UserEnabledKeyFmt = "user_enabled:%d"
|
UserEnabledKeyFmt = "user_enabled:%d"
|
||||||
UserUsernameKeyFmt = "user_name:%d"
|
UserUsernameKeyFmt = "user_name:%d"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TokenFiledRemainQuota = "RemainQuota"
|
||||||
|
TokenFieldGroup = "Group"
|
||||||
|
)
|
||||||
|
|||||||
3
main.go
3
main.go
@@ -80,9 +80,6 @@ func main() {
|
|||||||
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
|
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
|
||||||
model.InitChannelCache()
|
model.InitChannelCache()
|
||||||
}
|
}
|
||||||
if common.RedisEnabled {
|
|
||||||
go model.SyncTokenCache(common.SyncFrequency)
|
|
||||||
}
|
|
||||||
if common.MemoryCacheEnabled {
|
if common.MemoryCacheEnabled {
|
||||||
go model.SyncOptions(common.SyncFrequency)
|
go model.SyncOptions(common.SyncFrequency)
|
||||||
go model.SyncChannelCache(common.SyncFrequency)
|
go model.SyncChannelCache(common.SyncFrequency)
|
||||||
|
|||||||
@@ -1,99 +1,16 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 仅用于定时同步缓存
|
|
||||||
var token2UserId = make(map[string]int)
|
|
||||||
var token2UserIdLock sync.RWMutex
|
|
||||||
|
|
||||||
func cacheSetToken(token *Token) error {
|
|
||||||
jsonBytes, err := json.Marshal(token)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(constant.TokenCacheSeconds)*time.Second)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error()))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
token2UserIdLock.Lock()
|
|
||||||
defer token2UserIdLock.Unlock()
|
|
||||||
token2UserId[token.Key] = token.UserId
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CacheGetTokenByKey 从缓存中获取 token 并续期时间,如果缓存中不存在,则从数据库中获取
|
|
||||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
|
||||||
if !common.RedisEnabled {
|
|
||||||
return GetTokenByKey(key)
|
|
||||||
}
|
|
||||||
var token *Token
|
|
||||||
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
|
|
||||||
if err != nil {
|
|
||||||
// 如果缓存中不存在,则从数据库中获取
|
|
||||||
token, err = GetTokenByKey(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = cacheSetToken(token)
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
// 如果缓存中存在,则续期时间
|
|
||||||
err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(constant.TokenCacheSeconds)*time.Second)
|
|
||||||
err = json.Unmarshal([]byte(tokenObjectString), &token)
|
|
||||||
return token, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func SyncTokenCache(frequency int) {
|
|
||||||
for {
|
|
||||||
time.Sleep(time.Duration(frequency) * time.Second)
|
|
||||||
common.SysLog("syncing tokens from database")
|
|
||||||
token2UserIdLock.Lock()
|
|
||||||
// 从token2UserId中获取所有的key
|
|
||||||
var copyToken2UserId = make(map[string]int)
|
|
||||||
for s, i := range token2UserId {
|
|
||||||
copyToken2UserId[s] = i
|
|
||||||
}
|
|
||||||
token2UserId = make(map[string]int)
|
|
||||||
token2UserIdLock.Unlock()
|
|
||||||
|
|
||||||
for key := range copyToken2UserId {
|
|
||||||
token, err := GetTokenByKey(key)
|
|
||||||
if err != nil {
|
|
||||||
// 如果数据库中不存在,则删除缓存
|
|
||||||
common.SysError(fmt.Sprintf("failed to get token %s from database: %s", key, err.Error()))
|
|
||||||
//delete redis
|
|
||||||
err := common.RedisDel(fmt.Sprintf("token:%s", key))
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("failed to delete token %s from redis: %s", key, err.Error()))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// 如果数据库中存在,先检查redis
|
|
||||||
_, err = common.RedisGet(fmt.Sprintf("token:%s", key))
|
|
||||||
if err != nil {
|
|
||||||
// 如果redis中不存在,则跳过
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
err = cacheSetToken(token)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("failed to update token %s to redis: %s", key, err.Error()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//func CacheGetUserGroup(id int) (group string, err error) {
|
//func CacheGetUserGroup(id int) (group string, err error) {
|
||||||
// if !common.RedisEnabled {
|
// if !common.RedisEnabled {
|
||||||
// return GetUserGroup(id)
|
// return GetUserGroup(id)
|
||||||
|
|||||||
10
model/log.go
10
model/log.go
@@ -12,16 +12,6 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
var groupCol string
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
if common.UsingPostgreSQL {
|
|
||||||
groupCol = `"group"`
|
|
||||||
} else {
|
|
||||||
groupCol = "`group`"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Log struct {
|
type Log struct {
|
||||||
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1"`
|
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1"`
|
||||||
UserId int `json:"user_id" gorm:"index"`
|
UserId int `json:"user_id" gorm:"index"`
|
||||||
|
|||||||
@@ -13,6 +13,20 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var groupCol string
|
||||||
|
var keyCol string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
groupCol = `"group"`
|
||||||
|
keyCol = `"key"`
|
||||||
|
|
||||||
|
} else {
|
||||||
|
groupCol = "`group`"
|
||||||
|
keyCol = "`key`"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var DB *gorm.DB
|
var DB *gorm.DB
|
||||||
|
|
||||||
var LOG_DB *gorm.DB
|
var LOG_DB *gorm.DB
|
||||||
|
|||||||
129
model/token.go
129
model/token.go
@@ -3,6 +3,7 @@ package model
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
@@ -30,6 +31,10 @@ type Token struct {
|
|||||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (token *Token) Clean() {
|
||||||
|
token.Key = ""
|
||||||
|
}
|
||||||
|
|
||||||
func (token *Token) GetIpLimitsMap() map[string]any {
|
func (token *Token) GetIpLimitsMap() map[string]any {
|
||||||
// delete empty spaces
|
// delete empty spaces
|
||||||
//split with \n
|
//split with \n
|
||||||
@@ -71,7 +76,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
|||||||
if key == "" {
|
if key == "" {
|
||||||
return nil, errors.New("未提供令牌")
|
return nil, errors.New("未提供令牌")
|
||||||
}
|
}
|
||||||
token, err = CacheGetTokenByKey(key)
|
token, err = GetTokenByKey(key, false)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if token.Status == common.TokenStatusExhausted {
|
if token.Status == common.TokenStatusExhausted {
|
||||||
keyPrefix := key[:3]
|
keyPrefix := key[:3]
|
||||||
@@ -129,21 +134,37 @@ func GetTokenById(id int) (*Token, error) {
|
|||||||
var err error = nil
|
var err error = nil
|
||||||
err = DB.First(&token, "id = ?", id).Error
|
err = DB.First(&token, "id = ?", id).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if common.RedisEnabled {
|
gopool.Go(func() {
|
||||||
go cacheSetToken(&token)
|
if err := cacheSetToken(token); err != nil {
|
||||||
}
|
common.SysError("failed to update user status cache: " + err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
return &token, err
|
return &token, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTokenByKey(key string) (*Token, error) {
|
func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
|
||||||
keyCol := "`key`"
|
defer func() {
|
||||||
if common.UsingPostgreSQL {
|
// Update Redis cache asynchronously on successful DB read
|
||||||
keyCol = `"key"`
|
if shouldUpdateRedis(fromDB, err) && token != nil {
|
||||||
|
gopool.Go(func() {
|
||||||
|
if err := cacheSetToken(*token); err != nil {
|
||||||
|
common.SysError("failed to update user status cache: " + err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if !fromDB && common.RedisEnabled {
|
||||||
|
// Try Redis first
|
||||||
|
token, err := cacheGetTokenByKey(key)
|
||||||
|
if err == nil {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
// Don't return error - fall through to DB
|
||||||
}
|
}
|
||||||
var token Token
|
fromDB = true
|
||||||
err := DB.Where(keyCol+" = ?", key).First(&token).Error
|
err = DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||||
return &token, err
|
return token, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (token *Token) Insert() error {
|
func (token *Token) Insert() error {
|
||||||
@@ -153,20 +174,48 @@ func (token *Token) Insert() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||||||
func (token *Token) Update() error {
|
func (token *Token) Update() (err error) {
|
||||||
var err error
|
defer func() {
|
||||||
|
if common.RedisEnabled && err == nil {
|
||||||
|
gopool.Go(func() {
|
||||||
|
err := cacheSetToken(*token)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to update token cache: " + err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}()
|
||||||
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
|
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
|
||||||
"model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
|
"model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (token *Token) SelectUpdate() error {
|
func (token *Token) SelectUpdate() (err error) {
|
||||||
|
defer func() {
|
||||||
|
if common.RedisEnabled && err == nil {
|
||||||
|
gopool.Go(func() {
|
||||||
|
err := cacheSetToken(*token)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to update token cache: " + err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}()
|
||||||
// This can update zero values
|
// This can update zero values
|
||||||
return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
|
return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (token *Token) Delete() error {
|
func (token *Token) Delete() (err error) {
|
||||||
var err error
|
defer func() {
|
||||||
|
if common.RedisEnabled && err == nil {
|
||||||
|
gopool.Go(func() {
|
||||||
|
err := cacheDeleteToken(token.Key)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to delete token cache: " + err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}()
|
||||||
err = DB.Delete(token).Error
|
err = DB.Delete(token).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -214,10 +263,16 @@ func DeleteTokenById(id int, userId int) (err error) {
|
|||||||
return token.Delete()
|
return token.Delete()
|
||||||
}
|
}
|
||||||
|
|
||||||
func IncreaseTokenQuota(id int, quota int) (err error) {
|
func IncreaseTokenQuota(id int, key string, quota int) (err error) {
|
||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
|
gopool.Go(func() {
|
||||||
|
err := cacheIncrTokenQuota(key, int64(quota))
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to increase token quota: " + err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
if common.BatchUpdateEnabled {
|
if common.BatchUpdateEnabled {
|
||||||
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
|
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
|
||||||
return nil
|
return nil
|
||||||
@@ -236,10 +291,16 @@ func increaseTokenQuota(id int, quota int) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecreaseTokenQuota(id int, quota int) (err error) {
|
func DecreaseTokenQuota(id int, key string, quota int) (err error) {
|
||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
|
gopool.Go(func() {
|
||||||
|
err := cacheDecrTokenQuota(key, int64(quota))
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to decrease token quota: " + err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
if common.BatchUpdateEnabled {
|
if common.BatchUpdateEnabled {
|
||||||
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
|
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
|
||||||
return nil
|
return nil
|
||||||
@@ -262,20 +323,22 @@ func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
|
|||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
if !relayInfo.IsPlayground {
|
if relayInfo.IsPlayground {
|
||||||
token, err := GetTokenById(relayInfo.TokenId)
|
return nil
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !token.UnlimitedQuota && token.RemainQuota < quota {
|
|
||||||
return errors.New("令牌额度不足")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if !relayInfo.IsPlayground {
|
//if relayInfo.TokenUnlimited {
|
||||||
err := DecreaseTokenQuota(relayInfo.TokenId, quota)
|
// return nil
|
||||||
if err != nil {
|
//}
|
||||||
return err
|
token, err := GetTokenById(relayInfo.TokenId)
|
||||||
}
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
|
||||||
|
return errors.New("令牌额度不足")
|
||||||
|
}
|
||||||
|
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -293,9 +356,9 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int
|
|||||||
|
|
||||||
if !relayInfo.IsPlayground {
|
if !relayInfo.IsPlayground {
|
||||||
if quota > 0 {
|
if quota > 0 {
|
||||||
err = DecreaseTokenQuota(relayInfo.TokenId, quota)
|
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
|
||||||
} else {
|
} else {
|
||||||
err = IncreaseTokenQuota(relayInfo.TokenId, -quota)
|
err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
64
model/token_cache.go
Normal file
64
model/token_cache.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func cacheSetToken(token Token) error {
|
||||||
|
key := common.GenerateHMAC(token.Key)
|
||||||
|
token.Clean()
|
||||||
|
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cacheDeleteToken(key string) error {
|
||||||
|
key = common.GenerateHMAC(key)
|
||||||
|
err := common.RedisHDelObj(fmt.Sprintf("token:%s", key))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cacheIncrTokenQuota(key string, increment int64) error {
|
||||||
|
key = common.GenerateHMAC(key)
|
||||||
|
err := common.RedisHIncrBy(fmt.Sprintf("token:%s", key), constant.TokenFiledRemainQuota, increment)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cacheDecrTokenQuota(key string, decrement int64) error {
|
||||||
|
return cacheIncrTokenQuota(key, -decrement)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cacheSetTokenField(key string, field string, value string) error {
|
||||||
|
key = common.GenerateHMAC(key)
|
||||||
|
err := common.RedisHSetField(fmt.Sprintf("token:%s", key), field, value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheGetTokenByKey 从缓存中获取 token,如果缓存中不存在,则从数据库中获取
|
||||||
|
func cacheGetTokenByKey(key string) (*Token, error) {
|
||||||
|
hmacKey := common.GenerateHMAC(key)
|
||||||
|
if !common.RedisEnabled {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var token Token
|
||||||
|
err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
token.Key = key
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
@@ -252,7 +252,7 @@ func (user *User) Update(updatePassword bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 更新缓存
|
// 更新缓存
|
||||||
return updateUserCache(user)
|
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) Edit(updatePassword bool) error {
|
func (user *User) Edit(updatePassword bool) error {
|
||||||
@@ -281,7 +281,7 @@ func (user *User) Edit(updatePassword bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 更新缓存
|
// 更新缓存
|
||||||
return updateUserCache(user)
|
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) Delete() error {
|
func (user *User) Delete() error {
|
||||||
@@ -411,7 +411,7 @@ func IsAdmin(userId int) bool {
|
|||||||
func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
|
func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
// Update Redis cache asynchronously on successful DB read
|
// Update Redis cache asynchronously on successful DB read
|
||||||
if common.RedisEnabled {
|
if shouldUpdateRedis(fromDB, err) {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
if err := updateUserStatusCache(id, status); err != nil {
|
if err := updateUserStatusCache(id, status); err != nil {
|
||||||
common.SysError("failed to update user status cache: " + err.Error())
|
common.SysError("failed to update user status cache: " + err.Error())
|
||||||
@@ -427,7 +427,7 @@ func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
|
|||||||
}
|
}
|
||||||
// Don't return error - fall through to DB
|
// Don't return error - fall through to DB
|
||||||
}
|
}
|
||||||
|
fromDB = true
|
||||||
var user User
|
var user User
|
||||||
err = DB.Where("id = ?", id).Select("status").Find(&user).Error
|
err = DB.Where("id = ?", id).Select("status").Find(&user).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -453,7 +453,7 @@ func ValidateAccessToken(token string) (user *User) {
|
|||||||
func GetUserQuota(id int, fromDB bool) (quota int, err error) {
|
func GetUserQuota(id int, fromDB bool) (quota int, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
// Update Redis cache asynchronously on successful DB read
|
// Update Redis cache asynchronously on successful DB read
|
||||||
if common.RedisEnabled && err == nil {
|
if shouldUpdateRedis(fromDB, err) {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
if err := updateUserQuotaCache(id, quota); err != nil {
|
if err := updateUserQuotaCache(id, quota); err != nil {
|
||||||
common.SysError("failed to update user quota cache: " + err.Error())
|
common.SysError("failed to update user quota cache: " + err.Error())
|
||||||
@@ -469,7 +469,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
|
|||||||
// Don't return error - fall through to DB
|
// Don't return error - fall through to DB
|
||||||
//common.SysError("failed to get user quota from cache: " + err.Error())
|
//common.SysError("failed to get user quota from cache: " + err.Error())
|
||||||
}
|
}
|
||||||
|
fromDB = true
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@@ -492,7 +492,7 @@ func GetUserEmail(id int) (email string, err error) {
|
|||||||
func GetUserGroup(id int, fromDB bool) (group string, err error) {
|
func GetUserGroup(id int, fromDB bool) (group string, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
// Update Redis cache asynchronously on successful DB read
|
// Update Redis cache asynchronously on successful DB read
|
||||||
if common.RedisEnabled && err == nil {
|
if shouldUpdateRedis(fromDB, err) {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
if err := updateUserGroupCache(id, group); err != nil {
|
if err := updateUserGroupCache(id, group); err != nil {
|
||||||
common.SysError("failed to update user group cache: " + err.Error())
|
common.SysError("failed to update user group cache: " + err.Error())
|
||||||
@@ -507,7 +507,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
|
|||||||
}
|
}
|
||||||
// Don't return error - fall through to DB
|
// Don't return error - fall through to DB
|
||||||
}
|
}
|
||||||
|
fromDB = true
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -521,7 +521,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
|
|||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
err := cacheIncrUserQuota(id, quota)
|
err := cacheIncrUserQuota(id, int64(quota))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to increase user quota: " + err.Error())
|
common.SysError("failed to increase user quota: " + err.Error())
|
||||||
}
|
}
|
||||||
@@ -546,7 +546,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
|
|||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
err := cacheDecrUserQuota(id, quota)
|
err := cacheDecrUserQuota(id, int64(quota))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to decrease user quota: " + err.Error())
|
common.SysError("failed to decrease user quota: " + err.Error())
|
||||||
}
|
}
|
||||||
@@ -631,7 +631,7 @@ func updateUserRequestCount(id int, count int) {
|
|||||||
func GetUsernameById(id int, fromDB bool) (username string, err error) {
|
func GetUsernameById(id int, fromDB bool) (username string, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
// Update Redis cache asynchronously on successful DB read
|
// Update Redis cache asynchronously on successful DB read
|
||||||
if common.RedisEnabled && err == nil {
|
if shouldUpdateRedis(fromDB, err) {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
if err := updateUserNameCache(id, username); err != nil {
|
if err := updateUserNameCache(id, username); err != nil {
|
||||||
common.SysError("failed to update user name cache: " + err.Error())
|
common.SysError("failed to update user name cache: " + err.Error())
|
||||||
@@ -646,7 +646,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) {
|
|||||||
}
|
}
|
||||||
// Don't return error - fall through to DB
|
// Don't return error - fall through to DB
|
||||||
}
|
}
|
||||||
|
fromDB = true
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
|||||||
@@ -93,24 +93,24 @@ func updateUserNameCache(userId int, username string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// updateUserCache updates all user cache fields
|
// updateUserCache updates all user cache fields
|
||||||
func updateUserCache(user *User) error {
|
func updateUserCache(userId int, username string, userGroup string, quota int, status int) error {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := updateUserGroupCache(user.Id, user.Group); err != nil {
|
if err := updateUserGroupCache(userId, userGroup); err != nil {
|
||||||
return fmt.Errorf("update group cache: %w", err)
|
return fmt.Errorf("update group cache: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := updateUserQuotaCache(user.Id, user.Quota); err != nil {
|
if err := updateUserQuotaCache(userId, quota); err != nil {
|
||||||
return fmt.Errorf("update quota cache: %w", err)
|
return fmt.Errorf("update quota cache: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := updateUserStatusCache(user.Id, user.Status == common.UserStatusEnabled); err != nil {
|
if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil {
|
||||||
return fmt.Errorf("update status cache: %w", err)
|
return fmt.Errorf("update status cache: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := updateUserNameCache(user.Id, user.Username); err != nil {
|
if err := updateUserNameCache(userId, username); err != nil {
|
||||||
return fmt.Errorf("update username cache: %w", err)
|
return fmt.Errorf("update username cache: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,7 +193,7 @@ func getUserCache(userId int) (*userCache, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add atomic quota operations
|
// Add atomic quota operations
|
||||||
func cacheIncrUserQuota(userId int, delta int) error {
|
func cacheIncrUserQuota(userId int, delta int64) error {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -201,6 +201,6 @@ func cacheIncrUserQuota(userId int, delta int) error {
|
|||||||
return common.RedisIncr(key, delta)
|
return common.RedisIncr(key, delta)
|
||||||
}
|
}
|
||||||
|
|
||||||
func cacheDecrUserQuota(userId int, delta int) error {
|
func cacheDecrUserQuota(userId int, delta int64) error {
|
||||||
return cacheIncrUserQuota(userId, -delta)
|
return cacheIncrUserQuota(userId, -delta)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -88,3 +88,7 @@ func RecordExist(err error) (bool, error) {
|
|||||||
}
|
}
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldUpdateRedis(fromDB bool, err error) bool {
|
||||||
|
return common.RedisEnabled && fromDB && err == nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -81,19 +81,9 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if userQuota-preConsumedQuota < 0 {
|
preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||||
return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("audio pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
|
if openaiErr != nil {
|
||||||
}
|
return openaiErr
|
||||||
if userQuota > 100*preConsumedQuota {
|
|
||||||
// in this case, we do not pre-consume quota
|
|
||||||
// because the user has enough quota
|
|
||||||
preConsumedQuota = 0
|
|
||||||
}
|
|
||||||
if preConsumedQuota > 0 {
|
|
||||||
err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if openaiErr != nil {
|
if openaiErr != nil {
|
||||||
|
|||||||
@@ -291,14 +291,14 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|||||||
}
|
}
|
||||||
|
|
||||||
if preConsumedQuota > 0 {
|
if preConsumedQuota > 0 {
|
||||||
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
|
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return preConsumedQuota, userQuota, nil
|
return preConsumedQuota, userQuota, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := model.CacheGetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"))
|
token, err := model.GetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user