first commit: one-api base code + SAAS plan document
This commit is contained in:
112
model/ability.go
Normal file
112
model/ability.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/utils"
|
||||
)
|
||||
|
||||
type Ability struct {
|
||||
Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"`
|
||||
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
|
||||
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
|
||||
}
|
||||
|
||||
func GetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
|
||||
ability := Ability{}
|
||||
groupCol := "`group`"
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
trueVal = "true"
|
||||
}
|
||||
|
||||
var err error = nil
|
||||
var channelQuery *gorm.DB
|
||||
if ignoreFirstPriority {
|
||||
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
||||
} else {
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
||||
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
||||
}
|
||||
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||
err = channelQuery.Order("RANDOM()").First(&ability).Error
|
||||
} else {
|
||||
err = channelQuery.Order("RAND()").First(&ability).Error
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
channel := Channel{}
|
||||
channel.Id = ability.ChannelId
|
||||
err = DB.First(&channel, "id = ?", ability.ChannelId).Error
|
||||
return &channel, err
|
||||
}
|
||||
|
||||
func (channel *Channel) AddAbilities() error {
|
||||
models_ := strings.Split(channel.Models, ",")
|
||||
models_ = utils.DeDuplication(models_)
|
||||
groups_ := strings.Split(channel.Group, ",")
|
||||
abilities := make([]Ability, 0, len(models_))
|
||||
for _, model := range models_ {
|
||||
for _, group := range groups_ {
|
||||
ability := Ability{
|
||||
Group: group,
|
||||
Model: model,
|
||||
ChannelId: channel.Id,
|
||||
Enabled: channel.Status == ChannelStatusEnabled,
|
||||
Priority: channel.Priority,
|
||||
}
|
||||
abilities = append(abilities, ability)
|
||||
}
|
||||
}
|
||||
return DB.Create(&abilities).Error
|
||||
}
|
||||
|
||||
func (channel *Channel) DeleteAbilities() error {
|
||||
return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
|
||||
}
|
||||
|
||||
// UpdateAbilities updates abilities of this channel.
|
||||
// Make sure the channel is completed before calling this function.
|
||||
func (channel *Channel) UpdateAbilities() error {
|
||||
// A quick and dirty way to update abilities
|
||||
// First delete all abilities of this channel
|
||||
err := channel.DeleteAbilities()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Then add new abilities
|
||||
err = channel.AddAbilities()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateAbilityStatus(channelId int, status bool) error {
|
||||
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
|
||||
}
|
||||
|
||||
func GetGroupModels(ctx context.Context, group string) ([]string, error) {
|
||||
groupCol := "`group`"
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
trueVal = "true"
|
||||
}
|
||||
var models []string
|
||||
err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Strings(models)
|
||||
return models, err
|
||||
}
|
||||
255
model/cache.go
Normal file
255
model/cache.go
Normal file
@@ -0,0 +1,255 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
TokenCacheSeconds = config.SyncFrequency
|
||||
UserId2GroupCacheSeconds = config.SyncFrequency
|
||||
UserId2QuotaCacheSeconds = config.SyncFrequency
|
||||
UserId2StatusCacheSeconds = config.SyncFrequency
|
||||
GroupModelsCacheSeconds = config.SyncFrequency
|
||||
)
|
||||
|
||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||
keyCol := "`key`"
|
||||
if common.UsingPostgreSQL {
|
||||
keyCol = `"key"`
|
||||
}
|
||||
var token Token
|
||||
if !common.RedisEnabled {
|
||||
err := DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||
return &token, err
|
||||
}
|
||||
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
|
||||
if err != nil {
|
||||
err := DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jsonBytes, err := json.Marshal(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
logger.SysError("Redis set token error: " + err.Error())
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
err = json.Unmarshal([]byte(tokenObjectString), &token)
|
||||
return &token, err
|
||||
}
|
||||
|
||||
func CacheGetUserGroup(id int) (group string, err error) {
|
||||
if !common.RedisEnabled {
|
||||
return GetUserGroup(id)
|
||||
}
|
||||
group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
|
||||
if err != nil {
|
||||
group, err = GetUserGroup(id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
logger.SysError("Redis set user group error: " + err.Error())
|
||||
}
|
||||
}
|
||||
return group, err
|
||||
}
|
||||
|
||||
func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) {
|
||||
quota, err = GetUserQuota(id)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "Redis set user quota error: "+err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) {
|
||||
if !common.RedisEnabled {
|
||||
return GetUserQuota(id)
|
||||
}
|
||||
quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
|
||||
if err != nil {
|
||||
return fetchAndUpdateUserQuota(ctx, id)
|
||||
}
|
||||
quota, err = strconv.ParseInt(quotaString, 10, 64)
|
||||
if err != nil {
|
||||
return 0, nil
|
||||
}
|
||||
if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db
|
||||
logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id)
|
||||
return fetchAndUpdateUserQuota(ctx, id)
|
||||
}
|
||||
return quota, nil
|
||||
}
|
||||
|
||||
func CacheUpdateUserQuota(ctx context.Context, id int) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
quota, err := CacheGetUserQuota(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
return err
|
||||
}
|
||||
|
||||
func CacheDecreaseUserQuota(id int, quota int64) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
|
||||
return err
|
||||
}
|
||||
|
||||
func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
if !common.RedisEnabled {
|
||||
return IsUserEnabled(userId)
|
||||
}
|
||||
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
|
||||
if err == nil {
|
||||
return enabled == "1", nil
|
||||
}
|
||||
|
||||
userEnabled, err := IsUserEnabled(userId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
enabled = "0"
|
||||
if userEnabled {
|
||||
enabled = "1"
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
logger.SysError("Redis set user enabled error: " + err.Error())
|
||||
}
|
||||
return userEnabled, err
|
||||
}
|
||||
|
||||
func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) {
|
||||
if !common.RedisEnabled {
|
||||
return GetGroupModels(ctx, group)
|
||||
}
|
||||
modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group))
|
||||
if err == nil {
|
||||
return strings.Split(modelsStr, ","), nil
|
||||
}
|
||||
models, err := GetGroupModels(ctx, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
logger.SysError("Redis set group models error: " + err.Error())
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
var channelSyncLock sync.RWMutex
|
||||
|
||||
func InitChannelCache() {
|
||||
newChannelId2channel := make(map[int]*Channel)
|
||||
var channels []*Channel
|
||||
DB.Where("status = ?", ChannelStatusEnabled).Find(&channels)
|
||||
for _, channel := range channels {
|
||||
newChannelId2channel[channel.Id] = channel
|
||||
}
|
||||
var abilities []*Ability
|
||||
DB.Find(&abilities)
|
||||
groups := make(map[string]bool)
|
||||
for _, ability := range abilities {
|
||||
groups[ability.Group] = true
|
||||
}
|
||||
newGroup2model2channels := make(map[string]map[string][]*Channel)
|
||||
for group := range groups {
|
||||
newGroup2model2channels[group] = make(map[string][]*Channel)
|
||||
}
|
||||
for _, channel := range channels {
|
||||
groups := strings.Split(channel.Group, ",")
|
||||
for _, group := range groups {
|
||||
models := strings.Split(channel.Models, ",")
|
||||
for _, model := range models {
|
||||
if _, ok := newGroup2model2channels[group][model]; !ok {
|
||||
newGroup2model2channels[group][model] = make([]*Channel, 0)
|
||||
}
|
||||
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sort by priority
|
||||
for group, model2channels := range newGroup2model2channels {
|
||||
for model, channels := range model2channels {
|
||||
sort.Slice(channels, func(i, j int) bool {
|
||||
return channels[i].GetPriority() > channels[j].GetPriority()
|
||||
})
|
||||
newGroup2model2channels[group][model] = channels
|
||||
}
|
||||
}
|
||||
|
||||
channelSyncLock.Lock()
|
||||
group2model2channels = newGroup2model2channels
|
||||
channelSyncLock.Unlock()
|
||||
logger.SysLog("channels synced from database")
|
||||
}
|
||||
|
||||
func SyncChannelCache(frequency int) {
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Second)
|
||||
logger.SysLog("syncing channels from database")
|
||||
InitChannelCache()
|
||||
}
|
||||
}
|
||||
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
|
||||
if !config.MemoryCacheEnabled {
|
||||
return GetRandomSatisfiedChannel(group, model, ignoreFirstPriority)
|
||||
}
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
channels := group2model2channels[group][model]
|
||||
if len(channels) == 0 {
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
endIdx := len(channels)
|
||||
// choose by priority
|
||||
firstChannel := channels[0]
|
||||
if firstChannel.GetPriority() > 0 {
|
||||
for i := range channels {
|
||||
if channels[i].GetPriority() != firstChannel.GetPriority() {
|
||||
endIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
idx := rand.Intn(endIdx)
|
||||
if ignoreFirstPriority {
|
||||
if endIdx < len(channels) { // which means there are more than one priority
|
||||
idx = random.RandRange(endIdx, len(channels))
|
||||
}
|
||||
}
|
||||
return channels[idx], nil
|
||||
}
|
||||
224
model/channel.go
Normal file
224
model/channel.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ChannelStatusUnknown = 0
|
||||
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
||||
ChannelStatusAutoDisabled = 3
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
Id int `json:"id"`
|
||||
Type int `json:"type" gorm:"default:0"`
|
||||
Key string `json:"key" gorm:"type:text"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index"`
|
||||
Weight *uint `json:"weight" gorm:"default:0"`
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
TestTime int64 `json:"test_time" gorm:"bigint"`
|
||||
ResponseTime int `json:"response_time"` // in milliseconds
|
||||
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
|
||||
Other *string `json:"other"` // DEPRECATED: please save config to field Config
|
||||
Balance float64 `json:"balance"` // in USD
|
||||
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
||||
Models string `json:"models"`
|
||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||||
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||
Config string `json:"config"`
|
||||
SystemPrompt *string `json:"system_prompt" gorm:"type:text"`
|
||||
}
|
||||
|
||||
type ChannelConfig struct {
|
||||
Region string `json:"region,omitempty"`
|
||||
SK string `json:"sk,omitempty"`
|
||||
AK string `json:"ak,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
APIVersion string `json:"api_version,omitempty"`
|
||||
LibraryID string `json:"library_id,omitempty"`
|
||||
Plugin string `json:"plugin,omitempty"`
|
||||
VertexAIProjectID string `json:"vertex_ai_project_id,omitempty"`
|
||||
VertexAIADC string `json:"vertex_ai_adc,omitempty"`
|
||||
}
|
||||
|
||||
func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {
|
||||
var channels []*Channel
|
||||
var err error
|
||||
switch scope {
|
||||
case "all":
|
||||
err = DB.Order("id desc").Find(&channels).Error
|
||||
case "disabled":
|
||||
err = DB.Order("id desc").Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Find(&channels).Error
|
||||
default:
|
||||
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
||||
}
|
||||
return channels, err
|
||||
}
|
||||
|
||||
func SearchChannels(keyword string) (channels []*Channel, err error) {
|
||||
err = DB.Omit("key").Where("id = ? or name LIKE ?", helper.String2Int(keyword), keyword+"%").Find(&channels).Error
|
||||
return channels, err
|
||||
}
|
||||
|
||||
func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
||||
channel := Channel{Id: id}
|
||||
var err error = nil
|
||||
if selectAll {
|
||||
err = DB.First(&channel, "id = ?", id).Error
|
||||
} else {
|
||||
err = DB.Omit("key").First(&channel, "id = ?", id).Error
|
||||
}
|
||||
return &channel, err
|
||||
}
|
||||
|
||||
func BatchInsertChannels(channels []Channel) error {
|
||||
var err error
|
||||
err = DB.Create(&channels).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, channel_ := range channels {
|
||||
err = channel_.AddAbilities()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (channel *Channel) GetPriority() int64 {
|
||||
if channel.Priority == nil {
|
||||
return 0
|
||||
}
|
||||
return *channel.Priority
|
||||
}
|
||||
|
||||
func (channel *Channel) GetBaseURL() string {
|
||||
if channel.BaseURL == nil {
|
||||
return ""
|
||||
}
|
||||
return *channel.BaseURL
|
||||
}
|
||||
|
||||
func (channel *Channel) GetModelMapping() map[string]string {
|
||||
if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" {
|
||||
return nil
|
||||
}
|
||||
modelMapping := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping)
|
||||
if err != nil {
|
||||
logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error()))
|
||||
return nil
|
||||
}
|
||||
return modelMapping
|
||||
}
|
||||
|
||||
func (channel *Channel) Insert() error {
|
||||
var err error
|
||||
err = DB.Create(channel).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = channel.AddAbilities()
|
||||
return err
|
||||
}
|
||||
|
||||
func (channel *Channel) Update() error {
|
||||
var err error
|
||||
err = DB.Model(channel).Updates(channel).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
DB.Model(channel).First(channel, "id = ?", channel.Id)
|
||||
err = channel.UpdateAbilities()
|
||||
return err
|
||||
}
|
||||
|
||||
func (channel *Channel) UpdateResponseTime(responseTime int64) {
|
||||
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
|
||||
TestTime: helper.GetTimestamp(),
|
||||
ResponseTime: int(responseTime),
|
||||
}).Error
|
||||
if err != nil {
|
||||
logger.SysError("failed to update response time: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (channel *Channel) UpdateBalance(balance float64) {
|
||||
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
|
||||
BalanceUpdatedTime: helper.GetTimestamp(),
|
||||
Balance: balance,
|
||||
}).Error
|
||||
if err != nil {
|
||||
logger.SysError("failed to update balance: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (channel *Channel) Delete() error {
|
||||
var err error
|
||||
err = DB.Delete(channel).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = channel.DeleteAbilities()
|
||||
return err
|
||||
}
|
||||
|
||||
func (channel *Channel) LoadConfig() (ChannelConfig, error) {
|
||||
var cfg ChannelConfig
|
||||
if channel.Config == "" {
|
||||
return cfg, nil
|
||||
}
|
||||
err := json.Unmarshal([]byte(channel.Config), &cfg)
|
||||
if err != nil {
|
||||
return cfg, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func UpdateChannelStatusById(id int, status int) {
|
||||
err := UpdateAbilityStatus(id, status == ChannelStatusEnabled)
|
||||
if err != nil {
|
||||
logger.SysError("failed to update ability status: " + err.Error())
|
||||
}
|
||||
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
|
||||
if err != nil {
|
||||
logger.SysError("failed to update channel status: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateChannelUsedQuota(id int, quota int64) {
|
||||
if config.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
|
||||
return
|
||||
}
|
||||
updateChannelUsedQuota(id, quota)
|
||||
}
|
||||
|
||||
func updateChannelUsedQuota(id int, quota int64) {
|
||||
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
||||
if err != nil {
|
||||
logger.SysError("failed to update channel used quota: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteChannelByStatus(status int64) (int64, error) {
|
||||
result := DB.Where("status = ?", status).Delete(&Channel{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func DeleteDisabledChannel() (int64, error) {
|
||||
result := DB.Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Delete(&Channel{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
251
model/log.go
Normal file
251
model/log.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
)
|
||||
|
||||
type Log struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_type"`
|
||||
Type int `json:"type" gorm:"index:idx_created_at_type"`
|
||||
Content string `json:"content"`
|
||||
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
|
||||
TokenName string `json:"token_name" gorm:"index;default:''"`
|
||||
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
|
||||
Quota int `json:"quota" gorm:"default:0"`
|
||||
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
||||
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
||||
ChannelId int `json:"channel" gorm:"index"`
|
||||
RequestId string `json:"request_id" gorm:"default:''"`
|
||||
ElapsedTime int64 `json:"elapsed_time" gorm:"default:0"` // unit is ms
|
||||
IsStream bool `json:"is_stream" gorm:"default:false"`
|
||||
SystemPromptReset bool `json:"system_prompt_reset" gorm:"default:false"`
|
||||
}
|
||||
|
||||
const (
|
||||
LogTypeUnknown = iota
|
||||
LogTypeTopup
|
||||
LogTypeConsume
|
||||
LogTypeManage
|
||||
LogTypeSystem
|
||||
LogTypeTest
|
||||
)
|
||||
|
||||
func recordLogHelper(ctx context.Context, log *Log) {
|
||||
requestId := helper.GetRequestID(ctx)
|
||||
log.RequestId = requestId
|
||||
err := LOG_DB.Create(log).Error
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to record log: "+err.Error())
|
||||
return
|
||||
}
|
||||
logger.Infof(ctx, "record log: %+v", log)
|
||||
}
|
||||
|
||||
func RecordLog(ctx context.Context, userId int, logType int, content string) {
|
||||
if logType == LogTypeConsume && !config.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
log := &Log{
|
||||
UserId: userId,
|
||||
Username: GetUsernameById(userId),
|
||||
CreatedAt: helper.GetTimestamp(),
|
||||
Type: logType,
|
||||
Content: content,
|
||||
}
|
||||
recordLogHelper(ctx, log)
|
||||
}
|
||||
|
||||
func RecordTopupLog(ctx context.Context, userId int, content string, quota int) {
|
||||
log := &Log{
|
||||
UserId: userId,
|
||||
Username: GetUsernameById(userId),
|
||||
CreatedAt: helper.GetTimestamp(),
|
||||
Type: LogTypeTopup,
|
||||
Content: content,
|
||||
Quota: quota,
|
||||
}
|
||||
recordLogHelper(ctx, log)
|
||||
}
|
||||
|
||||
func RecordConsumeLog(ctx context.Context, log *Log) {
|
||||
if !config.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
log.Username = GetUsernameById(log.UserId)
|
||||
log.CreatedAt = helper.GetTimestamp()
|
||||
log.Type = LogTypeConsume
|
||||
recordLogHelper(ctx, log)
|
||||
}
|
||||
|
||||
func RecordTestLog(ctx context.Context, log *Log) {
|
||||
log.CreatedAt = helper.GetTimestamp()
|
||||
log.Type = LogTypeTest
|
||||
recordLogHelper(ctx, log)
|
||||
}
|
||||
|
||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
|
||||
var tx *gorm.DB
|
||||
if logType == LogTypeUnknown {
|
||||
tx = LOG_DB
|
||||
} else {
|
||||
tx = LOG_DB.Where("type = ?", logType)
|
||||
}
|
||||
if modelName != "" {
|
||||
tx = tx.Where("model_name = ?", modelName)
|
||||
}
|
||||
if username != "" {
|
||||
tx = tx.Where("username = ?", username)
|
||||
}
|
||||
if tokenName != "" {
|
||||
tx = tx.Where("token_name = ?", tokenName)
|
||||
}
|
||||
if startTimestamp != 0 {
|
||||
tx = tx.Where("created_at >= ?", startTimestamp)
|
||||
}
|
||||
if endTimestamp != 0 {
|
||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||
}
|
||||
if channel != 0 {
|
||||
tx = tx.Where("channel_id = ?", channel)
|
||||
}
|
||||
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
||||
return logs, err
|
||||
}
|
||||
|
||||
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
|
||||
var tx *gorm.DB
|
||||
if logType == LogTypeUnknown {
|
||||
tx = LOG_DB.Where("user_id = ?", userId)
|
||||
} else {
|
||||
tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType)
|
||||
}
|
||||
if modelName != "" {
|
||||
tx = tx.Where("model_name = ?", modelName)
|
||||
}
|
||||
if tokenName != "" {
|
||||
tx = tx.Where("token_name = ?", tokenName)
|
||||
}
|
||||
if startTimestamp != 0 {
|
||||
tx = tx.Where("created_at >= ?", startTimestamp)
|
||||
}
|
||||
if endTimestamp != 0 {
|
||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||
}
|
||||
err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error
|
||||
return logs, err
|
||||
}
|
||||
|
||||
func SearchAllLogs(keyword string) (logs []*Log, err error) {
|
||||
err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error
|
||||
return logs, err
|
||||
}
|
||||
|
||||
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
|
||||
err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error
|
||||
return logs, err
|
||||
}
|
||||
|
||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
|
||||
ifnull := "ifnull"
|
||||
if common.UsingPostgreSQL {
|
||||
ifnull = "COALESCE"
|
||||
}
|
||||
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull))
|
||||
if username != "" {
|
||||
tx = tx.Where("username = ?", username)
|
||||
}
|
||||
if tokenName != "" {
|
||||
tx = tx.Where("token_name = ?", tokenName)
|
||||
}
|
||||
if startTimestamp != 0 {
|
||||
tx = tx.Where("created_at >= ?", startTimestamp)
|
||||
}
|
||||
if endTimestamp != 0 {
|
||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||
}
|
||||
if modelName != "" {
|
||||
tx = tx.Where("model_name = ?", modelName)
|
||||
}
|
||||
if channel != 0 {
|
||||
tx = tx.Where("channel_id = ?", channel)
|
||||
}
|
||||
tx.Where("type = ?", LogTypeConsume).Scan("a)
|
||||
return quota
|
||||
}
|
||||
|
||||
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
|
||||
ifnull := "ifnull"
|
||||
if common.UsingPostgreSQL {
|
||||
ifnull = "COALESCE"
|
||||
}
|
||||
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull))
|
||||
if username != "" {
|
||||
tx = tx.Where("username = ?", username)
|
||||
}
|
||||
if tokenName != "" {
|
||||
tx = tx.Where("token_name = ?", tokenName)
|
||||
}
|
||||
if startTimestamp != 0 {
|
||||
tx = tx.Where("created_at >= ?", startTimestamp)
|
||||
}
|
||||
if endTimestamp != 0 {
|
||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||
}
|
||||
if modelName != "" {
|
||||
tx = tx.Where("model_name = ?", modelName)
|
||||
}
|
||||
tx.Where("type = ?", LogTypeConsume).Scan(&token)
|
||||
return token
|
||||
}
|
||||
|
||||
func DeleteOldLog(targetTimestamp int64) (int64, error) {
|
||||
result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
type LogStatistic struct {
|
||||
Day string `gorm:"column:day"`
|
||||
ModelName string `gorm:"column:model_name"`
|
||||
RequestCount int `gorm:"column:request_count"`
|
||||
Quota int `gorm:"column:quota"`
|
||||
PromptTokens int `gorm:"column:prompt_tokens"`
|
||||
CompletionTokens int `gorm:"column:completion_tokens"`
|
||||
}
|
||||
|
||||
func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatistic, err error) {
|
||||
groupSelect := "DATE_FORMAT(FROM_UNIXTIME(created_at), '%Y-%m-%d') as day"
|
||||
|
||||
if common.UsingPostgreSQL {
|
||||
groupSelect = "TO_CHAR(date_trunc('day', to_timestamp(created_at)), 'YYYY-MM-DD') as day"
|
||||
}
|
||||
|
||||
if common.UsingSQLite {
|
||||
groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day"
|
||||
}
|
||||
|
||||
err = LOG_DB.Raw(`
|
||||
SELECT `+groupSelect+`,
|
||||
model_name, count(1) as request_count,
|
||||
sum(quota) as quota,
|
||||
sum(prompt_tokens) as prompt_tokens,
|
||||
sum(completion_tokens) as completion_tokens
|
||||
FROM logs
|
||||
WHERE type=2
|
||||
AND user_id= ?
|
||||
AND created_at BETWEEN ? AND ?
|
||||
GROUP BY day, model_name
|
||||
ORDER BY day, model_name
|
||||
`, userId, start, end).Scan(&LogStatistics).Error
|
||||
|
||||
return LogStatistics, err
|
||||
}
|
||||
237
model/main.go
Normal file
237
model/main.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/env"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
var LOG_DB *gorm.DB
|
||||
|
||||
func CreateRootAccountIfNeed() error {
|
||||
var user User
|
||||
//if user.Status != util.UserStatusEnabled {
|
||||
if err := DB.First(&user).Error; err != nil {
|
||||
logger.SysLog("no user exists, creating a root user for you: username is root, password is 123456")
|
||||
hashedPassword, err := common.Password2Hash("123456")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
accessToken := random.GetUUID()
|
||||
if config.InitialRootAccessToken != "" {
|
||||
accessToken = config.InitialRootAccessToken
|
||||
}
|
||||
rootUser := User{
|
||||
Username: "root",
|
||||
Password: hashedPassword,
|
||||
Role: RoleRootUser,
|
||||
Status: UserStatusEnabled,
|
||||
DisplayName: "Root User",
|
||||
AccessToken: accessToken,
|
||||
Quota: 500000000000000,
|
||||
}
|
||||
DB.Create(&rootUser)
|
||||
if config.InitialRootToken != "" {
|
||||
logger.SysLog("creating initial root token as requested")
|
||||
token := Token{
|
||||
Id: 1,
|
||||
UserId: rootUser.Id,
|
||||
Key: config.InitialRootToken,
|
||||
Status: TokenStatusEnabled,
|
||||
Name: "Initial Root Token",
|
||||
CreatedTime: helper.GetTimestamp(),
|
||||
AccessedTime: helper.GetTimestamp(),
|
||||
ExpiredTime: -1,
|
||||
RemainQuota: 500000000000000,
|
||||
UnlimitedQuota: true,
|
||||
}
|
||||
DB.Create(&token)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func chooseDB(envName string) (*gorm.DB, error) {
|
||||
dsn := os.Getenv(envName)
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(dsn, "postgres://"):
|
||||
// Use PostgreSQL
|
||||
return openPostgreSQL(dsn)
|
||||
case dsn != "":
|
||||
// Use MySQL
|
||||
return openMySQL(dsn)
|
||||
default:
|
||||
// Use SQLite
|
||||
return openSQLite()
|
||||
}
|
||||
}
|
||||
|
||||
func openPostgreSQL(dsn string) (*gorm.DB, error) {
|
||||
logger.SysLog("using PostgreSQL as database")
|
||||
common.UsingPostgreSQL = true
|
||||
return gorm.Open(postgres.New(postgres.Config{
|
||||
DSN: dsn,
|
||||
PreferSimpleProtocol: true, // disables implicit prepared statement usage
|
||||
}), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
}
|
||||
|
||||
func openMySQL(dsn string) (*gorm.DB, error) {
|
||||
logger.SysLog("using MySQL as database")
|
||||
common.UsingMySQL = true
|
||||
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
}
|
||||
|
||||
func openSQLite() (*gorm.DB, error) {
|
||||
logger.SysLog("SQL_DSN not set, using SQLite as database")
|
||||
common.UsingSQLite = true
|
||||
dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout)
|
||||
return gorm.Open(sqlite.Open(dsn), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
}
|
||||
|
||||
func InitDB() {
|
||||
var err error
|
||||
DB, err = chooseDB("SQL_DSN")
|
||||
if err != nil {
|
||||
logger.FatalLog("failed to initialize database: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sqlDB := setDBConns(DB)
|
||||
|
||||
if !config.IsMasterNode {
|
||||
return
|
||||
}
|
||||
|
||||
if common.UsingMySQL {
|
||||
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
|
||||
}
|
||||
|
||||
logger.SysLog("database migration started")
|
||||
if err = migrateDB(); err != nil {
|
||||
logger.FatalLog("failed to migrate database: " + err.Error())
|
||||
return
|
||||
}
|
||||
logger.SysLog("database migrated")
|
||||
}
|
||||
|
||||
func migrateDB() error {
|
||||
var err error
|
||||
if err = DB.AutoMigrate(&Channel{}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = DB.AutoMigrate(&Token{}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = DB.AutoMigrate(&User{}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = DB.AutoMigrate(&Option{}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = DB.AutoMigrate(&Redemption{}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = DB.AutoMigrate(&Ability{}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = DB.AutoMigrate(&Log{}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = DB.AutoMigrate(&Channel{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func InitLogDB() {
|
||||
if os.Getenv("LOG_SQL_DSN") == "" {
|
||||
LOG_DB = DB
|
||||
return
|
||||
}
|
||||
|
||||
logger.SysLog("using secondary database for table logs")
|
||||
var err error
|
||||
LOG_DB, err = chooseDB("LOG_SQL_DSN")
|
||||
if err != nil {
|
||||
logger.FatalLog("failed to initialize secondary database: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
setDBConns(LOG_DB)
|
||||
|
||||
if !config.IsMasterNode {
|
||||
return
|
||||
}
|
||||
|
||||
logger.SysLog("secondary database migration started")
|
||||
err = migrateLOGDB()
|
||||
if err != nil {
|
||||
logger.FatalLog("failed to migrate secondary database: " + err.Error())
|
||||
return
|
||||
}
|
||||
logger.SysLog("secondary database migrated")
|
||||
}
|
||||
|
||||
func migrateLOGDB() error {
|
||||
var err error
|
||||
if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setDBConns(db *gorm.DB) *sql.DB {
|
||||
if config.DebugSQLEnabled {
|
||||
db = db.Debug()
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
logger.FatalLog("failed to connect database: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
|
||||
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
|
||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
|
||||
return sqlDB
|
||||
}
|
||||
|
||||
func closeDB(db *gorm.DB) error {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = sqlDB.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func CloseDB() error {
|
||||
if LOG_DB != DB {
|
||||
err := closeDB(LOG_DB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return closeDB(DB)
|
||||
}
|
||||
244
model/option.go
Normal file
244
model/option.go
Normal file
@@ -0,0 +1,244 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Option struct {
|
||||
Key string `json:"key" gorm:"primaryKey"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
func AllOption() ([]*Option, error) {
|
||||
var options []*Option
|
||||
var err error
|
||||
err = DB.Find(&options).Error
|
||||
return options, err
|
||||
}
|
||||
|
||||
func InitOptionMap() {
|
||||
config.OptionMapRWMutex.Lock()
|
||||
config.OptionMap = make(map[string]string)
|
||||
config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled)
|
||||
config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
|
||||
config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
|
||||
config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
|
||||
config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled)
|
||||
config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
|
||||
config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
|
||||
config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
|
||||
config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled)
|
||||
config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled)
|
||||
config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled)
|
||||
config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled)
|
||||
config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled)
|
||||
config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled)
|
||||
config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64)
|
||||
config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled)
|
||||
config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",")
|
||||
config.OptionMap["SMTPServer"] = ""
|
||||
config.OptionMap["SMTPFrom"] = ""
|
||||
config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort)
|
||||
config.OptionMap["SMTPAccount"] = ""
|
||||
config.OptionMap["SMTPToken"] = ""
|
||||
config.OptionMap["Notice"] = ""
|
||||
config.OptionMap["About"] = ""
|
||||
config.OptionMap["HomePageContent"] = ""
|
||||
config.OptionMap["Footer"] = config.Footer
|
||||
config.OptionMap["SystemName"] = config.SystemName
|
||||
config.OptionMap["Logo"] = config.Logo
|
||||
config.OptionMap["ServerAddress"] = ""
|
||||
config.OptionMap["GitHubClientId"] = ""
|
||||
config.OptionMap["GitHubClientSecret"] = ""
|
||||
config.OptionMap["WeChatServerAddress"] = ""
|
||||
config.OptionMap["WeChatServerToken"] = ""
|
||||
config.OptionMap["WeChatAccountQRCodeImageURL"] = ""
|
||||
config.OptionMap["MessagePusherAddress"] = ""
|
||||
config.OptionMap["MessagePusherToken"] = ""
|
||||
config.OptionMap["TurnstileSiteKey"] = ""
|
||||
config.OptionMap["TurnstileSecretKey"] = ""
|
||||
config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10)
|
||||
config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10)
|
||||
config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10)
|
||||
config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10)
|
||||
config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10)
|
||||
config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString()
|
||||
config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString()
|
||||
config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString()
|
||||
config.OptionMap["TopUpLink"] = config.TopUpLink
|
||||
config.OptionMap["ChatLink"] = config.ChatLink
|
||||
config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
|
||||
config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes)
|
||||
config.OptionMap["Theme"] = config.Theme
|
||||
config.OptionMapRWMutex.Unlock()
|
||||
loadOptionsFromDatabase()
|
||||
}
|
||||
|
||||
func loadOptionsFromDatabase() {
|
||||
options, _ := AllOption()
|
||||
for _, option := range options {
|
||||
if option.Key == "ModelRatio" {
|
||||
option.Value = billingratio.AddNewMissingRatio(option.Value)
|
||||
}
|
||||
err := updateOptionMap(option.Key, option.Value)
|
||||
if err != nil {
|
||||
logger.SysError("failed to update option map: " + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func SyncOptions(frequency int) {
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Second)
|
||||
logger.SysLog("syncing options from database")
|
||||
loadOptionsFromDatabase()
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateOption(key string, value string) error {
|
||||
// Save to database first
|
||||
option := Option{
|
||||
Key: key,
|
||||
}
|
||||
// https://gorm.io/docs/update.html#Save-All-Fields
|
||||
DB.FirstOrCreate(&option, Option{Key: key})
|
||||
option.Value = value
|
||||
// Save is a combination function.
|
||||
// If save value does not contain primary key, it will execute Create,
|
||||
// otherwise it will execute Update (with all fields).
|
||||
DB.Save(&option)
|
||||
// Update OptionMap
|
||||
return updateOptionMap(key, value)
|
||||
}
|
||||
|
||||
func updateOptionMap(key string, value string) (err error) {
|
||||
config.OptionMapRWMutex.Lock()
|
||||
defer config.OptionMapRWMutex.Unlock()
|
||||
config.OptionMap[key] = value
|
||||
if strings.HasSuffix(key, "Enabled") {
|
||||
boolValue := value == "true"
|
||||
switch key {
|
||||
case "PasswordRegisterEnabled":
|
||||
config.PasswordRegisterEnabled = boolValue
|
||||
case "PasswordLoginEnabled":
|
||||
config.PasswordLoginEnabled = boolValue
|
||||
case "EmailVerificationEnabled":
|
||||
config.EmailVerificationEnabled = boolValue
|
||||
case "GitHubOAuthEnabled":
|
||||
config.GitHubOAuthEnabled = boolValue
|
||||
case "OidcEnabled":
|
||||
config.OidcEnabled = boolValue
|
||||
case "WeChatAuthEnabled":
|
||||
config.WeChatAuthEnabled = boolValue
|
||||
case "TurnstileCheckEnabled":
|
||||
config.TurnstileCheckEnabled = boolValue
|
||||
case "RegisterEnabled":
|
||||
config.RegisterEnabled = boolValue
|
||||
case "EmailDomainRestrictionEnabled":
|
||||
config.EmailDomainRestrictionEnabled = boolValue
|
||||
case "AutomaticDisableChannelEnabled":
|
||||
config.AutomaticDisableChannelEnabled = boolValue
|
||||
case "AutomaticEnableChannelEnabled":
|
||||
config.AutomaticEnableChannelEnabled = boolValue
|
||||
case "ApproximateTokenEnabled":
|
||||
config.ApproximateTokenEnabled = boolValue
|
||||
case "LogConsumeEnabled":
|
||||
config.LogConsumeEnabled = boolValue
|
||||
case "DisplayInCurrencyEnabled":
|
||||
config.DisplayInCurrencyEnabled = boolValue
|
||||
case "DisplayTokenStatEnabled":
|
||||
config.DisplayTokenStatEnabled = boolValue
|
||||
}
|
||||
}
|
||||
switch key {
|
||||
case "EmailDomainWhitelist":
|
||||
config.EmailDomainWhitelist = strings.Split(value, ",")
|
||||
case "SMTPServer":
|
||||
config.SMTPServer = value
|
||||
case "SMTPPort":
|
||||
intValue, _ := strconv.Atoi(value)
|
||||
config.SMTPPort = intValue
|
||||
case "SMTPAccount":
|
||||
config.SMTPAccount = value
|
||||
case "SMTPFrom":
|
||||
config.SMTPFrom = value
|
||||
case "SMTPToken":
|
||||
config.SMTPToken = value
|
||||
case "ServerAddress":
|
||||
config.ServerAddress = value
|
||||
case "GitHubClientId":
|
||||
config.GitHubClientId = value
|
||||
case "GitHubClientSecret":
|
||||
config.GitHubClientSecret = value
|
||||
case "LarkClientId":
|
||||
config.LarkClientId = value
|
||||
case "LarkClientSecret":
|
||||
config.LarkClientSecret = value
|
||||
case "OidcClientId":
|
||||
config.OidcClientId = value
|
||||
case "OidcClientSecret":
|
||||
config.OidcClientSecret = value
|
||||
case "OidcWellKnown":
|
||||
config.OidcWellKnown = value
|
||||
case "OidcAuthorizationEndpoint":
|
||||
config.OidcAuthorizationEndpoint = value
|
||||
case "OidcTokenEndpoint":
|
||||
config.OidcTokenEndpoint = value
|
||||
case "OidcUserinfoEndpoint":
|
||||
config.OidcUserinfoEndpoint = value
|
||||
case "Footer":
|
||||
config.Footer = value
|
||||
case "SystemName":
|
||||
config.SystemName = value
|
||||
case "Logo":
|
||||
config.Logo = value
|
||||
case "WeChatServerAddress":
|
||||
config.WeChatServerAddress = value
|
||||
case "WeChatServerToken":
|
||||
config.WeChatServerToken = value
|
||||
case "WeChatAccountQRCodeImageURL":
|
||||
config.WeChatAccountQRCodeImageURL = value
|
||||
case "MessagePusherAddress":
|
||||
config.MessagePusherAddress = value
|
||||
case "MessagePusherToken":
|
||||
config.MessagePusherToken = value
|
||||
case "TurnstileSiteKey":
|
||||
config.TurnstileSiteKey = value
|
||||
case "TurnstileSecretKey":
|
||||
config.TurnstileSecretKey = value
|
||||
case "QuotaForNewUser":
|
||||
config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64)
|
||||
case "QuotaForInviter":
|
||||
config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64)
|
||||
case "QuotaForInvitee":
|
||||
config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64)
|
||||
case "QuotaRemindThreshold":
|
||||
config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64)
|
||||
case "PreConsumedQuota":
|
||||
config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64)
|
||||
case "RetryTimes":
|
||||
config.RetryTimes, _ = strconv.Atoi(value)
|
||||
case "ModelRatio":
|
||||
err = billingratio.UpdateModelRatioByJSONString(value)
|
||||
case "GroupRatio":
|
||||
err = billingratio.UpdateGroupRatioByJSONString(value)
|
||||
case "CompletionRatio":
|
||||
err = billingratio.UpdateCompletionRatioByJSONString(value)
|
||||
case "TopUpLink":
|
||||
config.TopUpLink = value
|
||||
case "ChatLink":
|
||||
config.ChatLink = value
|
||||
case "ChannelDisableThreshold":
|
||||
config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
|
||||
case "QuotaPerUnit":
|
||||
config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
|
||||
case "Theme":
|
||||
config.Theme = value
|
||||
}
|
||||
return err
|
||||
}
|
||||
126
model/redemption.go
Normal file
126
model/redemption.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
)
|
||||
|
||||
const (
|
||||
RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||
RedemptionCodeStatusDisabled = 2 // also don't use 0
|
||||
RedemptionCodeStatusUsed = 3 // also don't use 0
|
||||
)
|
||||
|
||||
type Redemption struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id"`
|
||||
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index"`
|
||||
Quota int64 `json:"quota" gorm:"bigint;default:100"`
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
|
||||
Count int `json:"count" gorm:"-:all"` // only for api request
|
||||
}
|
||||
|
||||
func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) {
|
||||
var redemptions []*Redemption
|
||||
var err error
|
||||
err = DB.Order("id desc").Limit(num).Offset(startIdx).Find(&redemptions).Error
|
||||
return redemptions, err
|
||||
}
|
||||
|
||||
func SearchRedemptions(keyword string) (redemptions []*Redemption, err error) {
|
||||
err = DB.Where("id = ? or name LIKE ?", keyword, keyword+"%").Find(&redemptions).Error
|
||||
return redemptions, err
|
||||
}
|
||||
|
||||
func GetRedemptionById(id int) (*Redemption, error) {
|
||||
if id == 0 {
|
||||
return nil, errors.New("id 为空!")
|
||||
}
|
||||
redemption := Redemption{Id: id}
|
||||
var err error = nil
|
||||
err = DB.First(&redemption, "id = ?", id).Error
|
||||
return &redemption, err
|
||||
}
|
||||
|
||||
func Redeem(ctx context.Context, key string, userId int) (quota int64, err error) {
|
||||
if key == "" {
|
||||
return 0, errors.New("未提供兑换码")
|
||||
}
|
||||
if userId == 0 {
|
||||
return 0, errors.New("无效的 user id")
|
||||
}
|
||||
redemption := &Redemption{}
|
||||
|
||||
keyCol := "`key`"
|
||||
if common.UsingPostgreSQL {
|
||||
keyCol = `"key"`
|
||||
}
|
||||
|
||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
|
||||
if err != nil {
|
||||
return errors.New("无效的兑换码")
|
||||
}
|
||||
if redemption.Status != RedemptionCodeStatusEnabled {
|
||||
return errors.New("该兑换码已被使用")
|
||||
}
|
||||
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
redemption.RedeemedTime = helper.GetTimestamp()
|
||||
redemption.Status = RedemptionCodeStatusUsed
|
||||
err = tx.Save(redemption).Error
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return 0, errors.New("兑换失败," + err.Error())
|
||||
}
|
||||
RecordLog(ctx, userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
|
||||
return redemption.Quota, nil
|
||||
}
|
||||
|
||||
func (redemption *Redemption) Insert() error {
|
||||
var err error
|
||||
err = DB.Create(redemption).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (redemption *Redemption) SelectUpdate() error {
|
||||
// This can update zero values
|
||||
return DB.Model(redemption).Select("redeemed_time", "status").Updates(redemption).Error
|
||||
}
|
||||
|
||||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||||
func (redemption *Redemption) Update() error {
|
||||
var err error
|
||||
err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time").Updates(redemption).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (redemption *Redemption) Delete() error {
|
||||
var err error
|
||||
err = DB.Delete(redemption).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func DeleteRedemptionById(id int) (err error) {
|
||||
if id == 0 {
|
||||
return errors.New("id 为空!")
|
||||
}
|
||||
redemption := Redemption{Id: id}
|
||||
err = DB.Where(redemption).First(&redemption).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return redemption.Delete()
|
||||
}
|
||||
303
model/token.go
Normal file
303
model/token.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/message"
|
||||
)
|
||||
|
||||
const (
|
||||
TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||
TokenStatusDisabled = 2 // also don't use 0
|
||||
TokenStatusExpired = 3
|
||||
TokenStatusExhausted = 4
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id"`
|
||||
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index" `
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
||||
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
||||
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
|
||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
|
||||
Models *string `json:"models" gorm:"type:text"` // allowed models
|
||||
Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet
|
||||
}
|
||||
|
||||
func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) {
|
||||
var tokens []*Token
|
||||
var err error
|
||||
query := DB.Where("user_id = ?", userId)
|
||||
|
||||
switch order {
|
||||
case "remain_quota":
|
||||
query = query.Order("unlimited_quota desc, remain_quota desc")
|
||||
case "used_quota":
|
||||
query = query.Order("used_quota desc")
|
||||
default:
|
||||
query = query.Order("id desc")
|
||||
}
|
||||
|
||||
err = query.Limit(num).Offset(startIdx).Find(&tokens).Error
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
func SearchUserTokens(userId int, keyword string) (tokens []*Token, err error) {
|
||||
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", keyword+"%").Find(&tokens).Error
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
func ValidateUserToken(key string) (token *Token, err error) {
|
||||
if key == "" {
|
||||
return nil, errors.New("未提供令牌")
|
||||
}
|
||||
token, err = CacheGetTokenByKey(key)
|
||||
if err != nil {
|
||||
logger.SysError("CacheGetTokenByKey failed: " + err.Error())
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("无效的令牌")
|
||||
}
|
||||
return nil, errors.New("令牌验证失败")
|
||||
}
|
||||
if token.Status == TokenStatusExhausted {
|
||||
return nil, fmt.Errorf("令牌 %s(#%d)额度已用尽", token.Name, token.Id)
|
||||
} else if token.Status == TokenStatusExpired {
|
||||
return nil, errors.New("该令牌已过期")
|
||||
}
|
||||
if token.Status != TokenStatusEnabled {
|
||||
return nil, errors.New("该令牌状态不可用")
|
||||
}
|
||||
if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() {
|
||||
if !common.RedisEnabled {
|
||||
token.Status = TokenStatusExpired
|
||||
err := token.SelectUpdate()
|
||||
if err != nil {
|
||||
logger.SysError("failed to update token status" + err.Error())
|
||||
}
|
||||
}
|
||||
return nil, errors.New("该令牌已过期")
|
||||
}
|
||||
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
|
||||
if !common.RedisEnabled {
|
||||
// in this case, we can make sure the token is exhausted
|
||||
token.Status = TokenStatusExhausted
|
||||
err := token.SelectUpdate()
|
||||
if err != nil {
|
||||
logger.SysError("failed to update token status" + err.Error())
|
||||
}
|
||||
}
|
||||
return nil, errors.New("该令牌额度已用尽")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func GetTokenByIds(id int, userId int) (*Token, error) {
|
||||
if id == 0 || userId == 0 {
|
||||
return nil, errors.New("id 或 userId 为空!")
|
||||
}
|
||||
token := Token{Id: id, UserId: userId}
|
||||
var err error = nil
|
||||
err = DB.First(&token, "id = ? and user_id = ?", id, userId).Error
|
||||
return &token, err
|
||||
}
|
||||
|
||||
func GetTokenById(id int) (*Token, error) {
|
||||
if id == 0 {
|
||||
return nil, errors.New("id 为空!")
|
||||
}
|
||||
token := Token{Id: id}
|
||||
var err error = nil
|
||||
err = DB.First(&token, "id = ?", id).Error
|
||||
return &token, err
|
||||
}
|
||||
|
||||
func (t *Token) Insert() error {
|
||||
var err error
|
||||
err = DB.Create(t).Error
|
||||
return err
|
||||
}
|
||||
|
||||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||||
func (t *Token) Update() error {
|
||||
var err error
|
||||
err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Token) SelectUpdate() error {
|
||||
// This can update zero values
|
||||
return DB.Model(t).Select("accessed_time", "status").Updates(t).Error
|
||||
}
|
||||
|
||||
func (t *Token) Delete() error {
|
||||
var err error
|
||||
err = DB.Delete(t).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Token) GetModels() string {
|
||||
if t == nil {
|
||||
return ""
|
||||
}
|
||||
if t.Models == nil {
|
||||
return ""
|
||||
}
|
||||
return *t.Models
|
||||
}
|
||||
|
||||
func DeleteTokenById(id int, userId int) (err error) {
|
||||
// Why we need userId here? In case user want to delete other's token.
|
||||
if id == 0 || userId == 0 {
|
||||
return errors.New("id 或 userId 为空!")
|
||||
}
|
||||
token := Token{Id: id, UserId: userId}
|
||||
err = DB.Where(token).First(&token).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return token.Delete()
|
||||
}
|
||||
|
||||
func IncreaseTokenQuota(id int, quota int64) (err error) {
|
||||
if quota < 0 {
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
if config.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
|
||||
return nil
|
||||
}
|
||||
return increaseTokenQuota(id, quota)
|
||||
}
|
||||
|
||||
func increaseTokenQuota(id int, quota int64) (err error) {
|
||||
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||||
map[string]interface{}{
|
||||
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
||||
"used_quota": gorm.Expr("used_quota - ?", quota),
|
||||
"accessed_time": helper.GetTimestamp(),
|
||||
},
|
||||
).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func DecreaseTokenQuota(id int, quota int64) (err error) {
|
||||
if quota < 0 {
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
if config.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
|
||||
return nil
|
||||
}
|
||||
return decreaseTokenQuota(id, quota)
|
||||
}
|
||||
|
||||
func decreaseTokenQuota(id int, quota int64) (err error) {
|
||||
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||||
map[string]interface{}{
|
||||
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||
"accessed_time": helper.GetTimestamp(),
|
||||
},
|
||||
).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
|
||||
if quota < 0 {
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
token, err := GetTokenById(tokenId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !token.UnlimitedQuota && token.RemainQuota < quota {
|
||||
return errors.New("令牌额度不足")
|
||||
}
|
||||
userQuota, err := GetUserQuota(token.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if userQuota < quota {
|
||||
return errors.New("用户额度不足")
|
||||
}
|
||||
quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold
|
||||
noMoreQuota := userQuota-quota <= 0
|
||||
if quotaTooLow || noMoreQuota {
|
||||
go func() {
|
||||
email, err := GetUserEmail(token.UserId)
|
||||
if err != nil {
|
||||
logger.SysError("failed to fetch user email: " + err.Error())
|
||||
}
|
||||
prompt := "额度提醒"
|
||||
var contentText string
|
||||
if noMoreQuota {
|
||||
contentText = "您的额度已用尽"
|
||||
} else {
|
||||
contentText = "您的额度即将用尽"
|
||||
}
|
||||
if email != "" {
|
||||
topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress)
|
||||
content := message.EmailTemplate(
|
||||
prompt,
|
||||
fmt.Sprintf(`
|
||||
<p>您好!</p>
|
||||
<p>%s,当前剩余额度为 <strong>%d</strong>。</p>
|
||||
<p>为了不影响您的使用,请及时充值。</p>
|
||||
<p style="text-align: center; margin: 30px 0;">
|
||||
<a href="%s" style="background-color: #007bff; color: white; padding: 12px 24px; text-decoration: none; border-radius: 4px; display: inline-block;">立即充值</a>
|
||||
</p>
|
||||
<p style="color: #666;">如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
|
||||
<p style="background-color: #f8f8f8; padding: 10px; border-radius: 4px; word-break: break-all;">%s</p>
|
||||
`, contentText, userQuota, topUpLink, topUpLink),
|
||||
)
|
||||
err = message.SendEmail(prompt, email, content)
|
||||
if err != nil {
|
||||
logger.SysError("failed to send email: " + err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
if !token.UnlimitedQuota {
|
||||
err = DecreaseTokenQuota(tokenId, quota)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = DecreaseUserQuota(token.UserId, quota)
|
||||
return err
|
||||
}
|
||||
|
||||
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
|
||||
token, err := GetTokenById(tokenId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if quota > 0 {
|
||||
err = DecreaseUserQuota(token.UserId, quota)
|
||||
} else {
|
||||
err = IncreaseUserQuota(token.UserId, -quota)
|
||||
}
|
||||
if !token.UnlimitedQuota {
|
||||
if quota > 0 {
|
||||
err = DecreaseTokenQuota(tokenId, quota)
|
||||
} else {
|
||||
err = IncreaseTokenQuota(tokenId, -quota)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
453
model/user.go
Normal file
453
model/user.go
Normal file
@@ -0,0 +1,453 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/blacklist"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
)
|
||||
|
||||
const (
|
||||
RoleGuestUser = 0
|
||||
RoleCommonUser = 1
|
||||
RoleAdminUser = 10
|
||||
RoleRootUser = 100
|
||||
)
|
||||
|
||||
const (
|
||||
UserStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||
UserStatusDisabled = 2 // also don't use 0
|
||||
UserStatusDeleted = 3
|
||||
)
|
||||
|
||||
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
|
||||
// Otherwise, the sensitive information will be saved on local storage in plain text!
|
||||
type User struct {
|
||||
Id int `json:"id"`
|
||||
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
|
||||
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
|
||||
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
|
||||
Role int `json:"role" gorm:"type:int;default:1"` // admin, util
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
LarkId string `json:"lark_id" gorm:"column:lark_id;index"`
|
||||
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
|
||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
||||
Quota int64 `json:"quota" gorm:"bigint;default:0"`
|
||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0;column:used_quota"` // used quota
|
||||
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
|
||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
|
||||
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
||||
}
|
||||
|
||||
func GetMaxUserId() int {
|
||||
var user User
|
||||
DB.Last(&user)
|
||||
return user.Id
|
||||
}
|
||||
|
||||
func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) {
|
||||
query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", UserStatusDeleted)
|
||||
|
||||
switch order {
|
||||
case "quota":
|
||||
query = query.Order("quota desc")
|
||||
case "used_quota":
|
||||
query = query.Order("used_quota desc")
|
||||
case "request_count":
|
||||
query = query.Order("request_count desc")
|
||||
default:
|
||||
query = query.Order("id desc")
|
||||
}
|
||||
|
||||
err = query.Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
func SearchUsers(keyword string) (users []*User, err error) {
|
||||
if !common.UsingPostgreSQL {
|
||||
err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
|
||||
} else {
|
||||
err = DB.Omit("password").Where("username LIKE ? or email LIKE ? or display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
|
||||
}
|
||||
return users, err
|
||||
}
|
||||
|
||||
func GetUserById(id int, selectAll bool) (*User, error) {
|
||||
if id == 0 {
|
||||
return nil, errors.New("id 为空!")
|
||||
}
|
||||
user := User{Id: id}
|
||||
var err error = nil
|
||||
if selectAll {
|
||||
err = DB.First(&user, "id = ?", id).Error
|
||||
} else {
|
||||
err = DB.Omit("password", "access_token").First(&user, "id = ?", id).Error
|
||||
}
|
||||
return &user, err
|
||||
}
|
||||
|
||||
func GetUserIdByAffCode(affCode string) (int, error) {
|
||||
if affCode == "" {
|
||||
return 0, errors.New("affCode 为空!")
|
||||
}
|
||||
var user User
|
||||
err := DB.Select("id").First(&user, "aff_code = ?", affCode).Error
|
||||
return user.Id, err
|
||||
}
|
||||
|
||||
func DeleteUserById(id int) (err error) {
|
||||
if id == 0 {
|
||||
return errors.New("id 为空!")
|
||||
}
|
||||
user := User{Id: id}
|
||||
return user.Delete()
|
||||
}
|
||||
|
||||
func (user *User) Insert(ctx context.Context, inviterId int) error {
|
||||
var err error
|
||||
if user.Password != "" {
|
||||
user.Password, err = common.Password2Hash(user.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
user.Quota = config.QuotaForNewUser
|
||||
user.AccessToken = random.GetUUID()
|
||||
user.AffCode = random.GetRandomString(4)
|
||||
result := DB.Create(user)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if config.QuotaForNewUser > 0 {
|
||||
RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
|
||||
}
|
||||
if inviterId != 0 {
|
||||
if config.QuotaForInvitee > 0 {
|
||||
_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
|
||||
RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
|
||||
}
|
||||
if config.QuotaForInviter > 0 {
|
||||
_ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
|
||||
RecordLog(ctx, inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
|
||||
}
|
||||
}
|
||||
// create default token
|
||||
cleanToken := Token{
|
||||
UserId: user.Id,
|
||||
Name: "default",
|
||||
Key: random.GenerateKey(),
|
||||
CreatedTime: helper.GetTimestamp(),
|
||||
AccessedTime: helper.GetTimestamp(),
|
||||
ExpiredTime: -1,
|
||||
RemainQuota: -1,
|
||||
UnlimitedQuota: true,
|
||||
}
|
||||
result.Error = cleanToken.Insert()
|
||||
if result.Error != nil {
|
||||
// do not block
|
||||
logger.SysError(fmt.Sprintf("create default token for user %d failed: %s", user.Id, result.Error.Error()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) Update(updatePassword bool) error {
|
||||
var err error
|
||||
if updatePassword {
|
||||
user.Password, err = common.Password2Hash(user.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if user.Status == UserStatusDisabled {
|
||||
blacklist.BanUser(user.Id)
|
||||
} else if user.Status == UserStatusEnabled {
|
||||
blacklist.UnbanUser(user.Id)
|
||||
}
|
||||
err = DB.Model(user).Updates(user).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (user *User) Delete() error {
|
||||
if user.Id == 0 {
|
||||
return errors.New("id 为空!")
|
||||
}
|
||||
blacklist.BanUser(user.Id)
|
||||
user.Username = fmt.Sprintf("deleted_%s", random.GetUUID())
|
||||
user.Status = UserStatusDeleted
|
||||
err := DB.Model(user).Updates(user).Error
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateAndFill check password & user status
|
||||
func (user *User) ValidateAndFill() (err error) {
|
||||
// When querying with struct, GORM will only query with non-zero fields,
|
||||
// that means if your field’s value is 0, '', false or other zero values,
|
||||
// it won’t be used to build query conditions
|
||||
password := user.Password
|
||||
if user.Username == "" || password == "" {
|
||||
return errors.New("用户名或密码为空")
|
||||
}
|
||||
err = DB.Where("username = ?", user.Username).First(user).Error
|
||||
if err != nil {
|
||||
// we must make sure check username firstly
|
||||
// consider this case: a malicious user set his username as other's email
|
||||
err := DB.Where("email = ?", user.Username).First(user).Error
|
||||
if err != nil {
|
||||
return errors.New("用户名或密码错误,或用户已被封禁")
|
||||
}
|
||||
}
|
||||
okay := common.ValidatePasswordAndHash(password, user.Password)
|
||||
if !okay || user.Status != UserStatusEnabled {
|
||||
return errors.New("用户名或密码错误,或用户已被封禁")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserById() error {
|
||||
if user.Id == 0 {
|
||||
return errors.New("id 为空!")
|
||||
}
|
||||
DB.Where(User{Id: user.Id}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByEmail() error {
|
||||
if user.Email == "" {
|
||||
return errors.New("email 为空!")
|
||||
}
|
||||
DB.Where(User{Email: user.Email}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByGitHubId() error {
|
||||
if user.GitHubId == "" {
|
||||
return errors.New("GitHub id 为空!")
|
||||
}
|
||||
DB.Where(User{GitHubId: user.GitHubId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByLarkId() error {
|
||||
if user.LarkId == "" {
|
||||
return errors.New("lark id 为空!")
|
||||
}
|
||||
DB.Where(User{LarkId: user.LarkId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByOidcId() error {
|
||||
if user.OidcId == "" {
|
||||
return errors.New("oidc id 为空!")
|
||||
}
|
||||
DB.Where(User{OidcId: user.OidcId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByWeChatId() error {
|
||||
if user.WeChatId == "" {
|
||||
return errors.New("WeChat id 为空!")
|
||||
}
|
||||
DB.Where(User{WeChatId: user.WeChatId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByUsername() error {
|
||||
if user.Username == "" {
|
||||
return errors.New("username 为空!")
|
||||
}
|
||||
DB.Where(User{Username: user.Username}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsEmailAlreadyTaken(email string) bool {
|
||||
return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsWeChatIdAlreadyTaken(wechatId string) bool {
|
||||
return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsGitHubIdAlreadyTaken(githubId string) bool {
|
||||
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsLarkIdAlreadyTaken(githubId string) bool {
|
||||
return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsOidcIdAlreadyTaken(oidcId string) bool {
|
||||
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsUsernameAlreadyTaken(username string) bool {
|
||||
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func ResetUserPasswordByEmail(email string, password string) error {
|
||||
if email == "" || password == "" {
|
||||
return errors.New("邮箱地址或密码为空!")
|
||||
}
|
||||
hashedPassword, err := common.Password2Hash(password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.Model(&User{}).Where("email = ?", email).Update("password", hashedPassword).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func IsAdmin(userId int) bool {
|
||||
if userId == 0 {
|
||||
return false
|
||||
}
|
||||
var user User
|
||||
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
|
||||
if err != nil {
|
||||
logger.SysError("no such user " + err.Error())
|
||||
return false
|
||||
}
|
||||
return user.Role >= RoleAdminUser
|
||||
}
|
||||
|
||||
func IsUserEnabled(userId int) (bool, error) {
|
||||
if userId == 0 {
|
||||
return false, errors.New("user id is empty")
|
||||
}
|
||||
var user User
|
||||
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return user.Status == UserStatusEnabled, nil
|
||||
}
|
||||
|
||||
func ValidateAccessToken(token string) (user *User) {
|
||||
if token == "" {
|
||||
return nil
|
||||
}
|
||||
token = strings.Replace(token, "Bearer ", "", 1)
|
||||
user = &User{}
|
||||
if DB.Where("access_token = ?", token).First(user).RowsAffected == 1 {
|
||||
return user
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetUserQuota(id int) (quota int64, err error) {
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
|
||||
return quota, err
|
||||
}
|
||||
|
||||
func GetUserUsedQuota(id int) (quota int64, err error) {
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error
|
||||
return quota, err
|
||||
}
|
||||
|
||||
func GetUserEmail(id int) (email string, err error) {
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select("email").Find(&email).Error
|
||||
return email, err
|
||||
}
|
||||
|
||||
func GetUserGroup(id int) (group string, err error) {
|
||||
groupCol := "`group`"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
}
|
||||
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
|
||||
return group, err
|
||||
}
|
||||
|
||||
func IncreaseUserQuota(id int, quota int64) (err error) {
|
||||
if quota < 0 {
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
if config.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
|
||||
return nil
|
||||
}
|
||||
return increaseUserQuota(id, quota)
|
||||
}
|
||||
|
||||
func increaseUserQuota(id int, quota int64) (err error) {
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func DecreaseUserQuota(id int, quota int64) (err error) {
|
||||
if quota < 0 {
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
if config.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
|
||||
return nil
|
||||
}
|
||||
return decreaseUserQuota(id, quota)
|
||||
}
|
||||
|
||||
func decreaseUserQuota(id int, quota int64) (err error) {
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func GetRootUserEmail() (email string) {
|
||||
DB.Model(&User{}).Where("role = ?", RoleRootUser).Select("email").Find(&email)
|
||||
return email
|
||||
}
|
||||
|
||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) {
|
||||
if config.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
|
||||
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
|
||||
return
|
||||
}
|
||||
updateUserUsedQuotaAndRequestCount(id, quota, 1)
|
||||
}
|
||||
|
||||
func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) {
|
||||
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
||||
map[string]interface{}{
|
||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||
"request_count": gorm.Expr("request_count + ?", count),
|
||||
},
|
||||
).Error
|
||||
if err != nil {
|
||||
logger.SysError("failed to update user used quota and request count: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func updateUserUsedQuota(id int, quota int64) {
|
||||
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
||||
map[string]interface{}{
|
||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||
},
|
||||
).Error
|
||||
if err != nil {
|
||||
logger.SysError("failed to update user used quota: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func updateUserRequestCount(id int, count int) {
|
||||
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
|
||||
if err != nil {
|
||||
logger.SysError("failed to update user request count: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func GetUsernameById(id int) (username string) {
|
||||
DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
|
||||
return username
|
||||
}
|
||||
78
model/utils.go
Normal file
78
model/utils.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
BatchUpdateTypeUserQuota = iota
|
||||
BatchUpdateTypeTokenQuota
|
||||
BatchUpdateTypeUsedQuota
|
||||
BatchUpdateTypeChannelUsedQuota
|
||||
BatchUpdateTypeRequestCount
|
||||
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
|
||||
)
|
||||
|
||||
var batchUpdateStores []map[int]int64
|
||||
var batchUpdateLocks []sync.Mutex
|
||||
|
||||
func init() {
|
||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||
batchUpdateStores = append(batchUpdateStores, make(map[int]int64))
|
||||
batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
|
||||
}
|
||||
}
|
||||
|
||||
func InitBatchUpdater() {
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second)
|
||||
batchUpdate()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func addNewRecord(type_ int, id int, value int64) {
|
||||
batchUpdateLocks[type_].Lock()
|
||||
defer batchUpdateLocks[type_].Unlock()
|
||||
if _, ok := batchUpdateStores[type_][id]; !ok {
|
||||
batchUpdateStores[type_][id] = value
|
||||
} else {
|
||||
batchUpdateStores[type_][id] += value
|
||||
}
|
||||
}
|
||||
|
||||
func batchUpdate() {
|
||||
logger.SysLog("batch update started")
|
||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||
batchUpdateLocks[i].Lock()
|
||||
store := batchUpdateStores[i]
|
||||
batchUpdateStores[i] = make(map[int]int64)
|
||||
batchUpdateLocks[i].Unlock()
|
||||
// TODO: maybe we can combine updates with same key?
|
||||
for key, value := range store {
|
||||
switch i {
|
||||
case BatchUpdateTypeUserQuota:
|
||||
err := increaseUserQuota(key, value)
|
||||
if err != nil {
|
||||
logger.SysError("failed to batch update user quota: " + err.Error())
|
||||
}
|
||||
case BatchUpdateTypeTokenQuota:
|
||||
err := increaseTokenQuota(key, value)
|
||||
if err != nil {
|
||||
logger.SysError("failed to batch update token quota: " + err.Error())
|
||||
}
|
||||
case BatchUpdateTypeUsedQuota:
|
||||
updateUserUsedQuota(key, value)
|
||||
case BatchUpdateTypeRequestCount:
|
||||
updateUserRequestCount(key, int(value))
|
||||
case BatchUpdateTypeChannelUsedQuota:
|
||||
updateChannelUsedQuota(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.SysLog("batch update finished")
|
||||
}
|
||||
Reference in New Issue
Block a user