refactor: Introduce pre-consume quota and unify relay handlers
This commit introduces a major architectural refactoring to improve quota management, centralize logging, and streamline the relay handling logic. Key changes: - **Pre-consume Quota:** Implements a new mechanism to check and reserve user quota *before* making the request to the upstream provider. This ensures more accurate quota deduction and prevents users from exceeding their limits due to concurrent requests. - **Unified Relay Handlers:** Refactors the relay logic to use generic handlers (e.g., `ChatHandler`, `ImageHandler`) instead of provider-specific implementations. This significantly reduces code duplication and simplifies adding new channels. - **Centralized Logger:** A new dedicated `logger` package is introduced, and all system logging calls are migrated to use it, moving this responsibility out of the `common` package. - **Code Reorganization:** DTOs are generalized (e.g., `dalle.go` -> `openai_image.go`) and utility code is moved to more appropriate packages (e.g., `common/http.go` -> `service/http.go`) for better code structure.
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -294,13 +295,13 @@ func FixAbility() (int, int, error) {
|
||||
if common.UsingSQLite {
|
||||
err := DB.Exec("DELETE FROM abilities").Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
||||
logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
||||
return 0, 0, err
|
||||
}
|
||||
} else {
|
||||
err := DB.Exec("TRUNCATE TABLE abilities").Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
|
||||
logger.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
|
||||
return 0, 0, err
|
||||
}
|
||||
}
|
||||
@@ -320,7 +321,7 @@ func FixAbility() (int, int, error) {
|
||||
// Delete all abilities of this channel
|
||||
err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
||||
logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
||||
failCount += len(chunk)
|
||||
continue
|
||||
}
|
||||
@@ -328,7 +329,7 @@ func FixAbility() (int, int, error) {
|
||||
for _, channel := range chunk {
|
||||
err = channel.AddAbilities(nil)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
|
||||
logger.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
|
||||
failCount++
|
||||
} else {
|
||||
successCount++
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -209,7 +210,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
||||
if channel.OtherInfo != "" {
|
||||
err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal other info: " + err.Error())
|
||||
logger.SysError("failed to unmarshal other info: " + err.Error())
|
||||
}
|
||||
}
|
||||
return otherInfo
|
||||
@@ -218,7 +219,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
||||
func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
|
||||
otherInfoBytes, err := json.Marshal(otherInfo)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal other info: " + err.Error())
|
||||
logger.SysError("failed to marshal other info: " + err.Error())
|
||||
return
|
||||
}
|
||||
channel.OtherInfo = string(otherInfoBytes)
|
||||
@@ -488,7 +489,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
|
||||
ResponseTime: int(responseTime),
|
||||
}).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update response time: " + err.Error())
|
||||
logger.SysError("failed to update response time: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -498,7 +499,7 @@ func (channel *Channel) UpdateBalance(balance float64) {
|
||||
Balance: balance,
|
||||
}).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update balance: " + err.Error())
|
||||
logger.SysError("failed to update balance: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -614,7 +615,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
||||
if shouldUpdateAbilities {
|
||||
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
|
||||
if err != nil {
|
||||
common.SysError("failed to update ability status: " + err.Error())
|
||||
logger.SysError("failed to update ability status: " + err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -642,7 +643,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
||||
}
|
||||
err = channel.Save()
|
||||
if err != nil {
|
||||
common.SysError("failed to update channel status: " + err.Error())
|
||||
logger.SysError("failed to update channel status: " + err.Error())
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -704,7 +705,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *
|
||||
for _, channel := range channels {
|
||||
err = channel.UpdateAbilities(nil)
|
||||
if err != nil {
|
||||
common.SysError("failed to update abilities: " + err.Error())
|
||||
logger.SysError("failed to update abilities: " + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -728,7 +729,7 @@ func UpdateChannelUsedQuota(id int, quota int) {
|
||||
func updateChannelUsedQuota(id int, quota int) {
|
||||
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update channel used quota: " + err.Error())
|
||||
logger.SysError("failed to update channel used quota: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -821,7 +822,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := common.Unmarshal([]byte(*channel.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
logger.SysError("failed to unmarshal setting: " + err.Error())
|
||||
channel.Setting = nil // 清空设置以避免后续错误
|
||||
_ = channel.Save() // 保存修改
|
||||
}
|
||||
@@ -832,7 +833,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
|
||||
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
||||
settingBytes, err := common.Marshal(setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal setting: " + err.Error())
|
||||
logger.SysError("failed to marshal setting: " + err.Error())
|
||||
return
|
||||
}
|
||||
channel.Setting = common.GetPointer[string](string(settingBytes))
|
||||
@@ -843,7 +844,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
|
||||
if channel.OtherSettings != "" {
|
||||
err := common.UnmarshalJsonStr(channel.OtherSettings, &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
logger.SysError("failed to unmarshal setting: " + err.Error())
|
||||
channel.OtherSettings = "{}" // 清空设置以避免后续错误
|
||||
_ = channel.Save() // 保存修改
|
||||
}
|
||||
@@ -854,7 +855,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
|
||||
func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) {
|
||||
settingBytes, err := common.Marshal(setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal setting: " + err.Error())
|
||||
logger.SysError("failed to marshal setting: " + err.Error())
|
||||
return
|
||||
}
|
||||
channel.OtherSettings = string(settingBytes)
|
||||
@@ -865,7 +866,7 @@ func (channel *Channel) GetParamOverride() map[string]interface{} {
|
||||
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
|
||||
err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal param override: " + err.Error())
|
||||
logger.SysError("failed to unmarshal param override: " + err.Error())
|
||||
}
|
||||
}
|
||||
return paramOverride
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/logger"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"sort"
|
||||
@@ -84,13 +85,13 @@ func InitChannelCache() {
|
||||
}
|
||||
channelsIDM = newChannelId2channel
|
||||
channelSyncLock.Unlock()
|
||||
common.SysLog("channels synced from database")
|
||||
logger.SysLog("channels synced from database")
|
||||
}
|
||||
|
||||
func SyncChannelCache(frequency int) {
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Second)
|
||||
common.SysLog("syncing channels from database")
|
||||
logger.SysLog("syncing channels from database")
|
||||
InitChannelCache()
|
||||
}
|
||||
}
|
||||
|
||||
12
model/log.go
12
model/log.go
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -87,13 +88,13 @@ func RecordLog(userId int, logType int, content string) {
|
||||
}
|
||||
err := LOG_DB.Create(log).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to record log: " + err.Error())
|
||||
logger.SysError("failed to record log: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
|
||||
isStream bool, group string, other map[string]interface{}) {
|
||||
common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
|
||||
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
|
||||
username := c.GetString("username")
|
||||
otherStr := common.MapToJsonStr(other)
|
||||
// 判断是否需要记录 IP
|
||||
@@ -129,7 +130,7 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
|
||||
}
|
||||
err := LOG_DB.Create(log).Error
|
||||
if err != nil {
|
||||
common.LogError(c, "failed to record log: "+err.Error())
|
||||
logger.LogError(c, "failed to record log: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,7 +143,6 @@ type RecordConsumeLogParams struct {
|
||||
Quota int `json:"quota"`
|
||||
Content string `json:"content"`
|
||||
TokenId int `json:"token_id"`
|
||||
UserQuota int `json:"user_quota"`
|
||||
UseTimeSeconds int `json:"use_time_seconds"`
|
||||
IsStream bool `json:"is_stream"`
|
||||
Group string `json:"group"`
|
||||
@@ -150,7 +150,7 @@ type RecordConsumeLogParams struct {
|
||||
}
|
||||
|
||||
func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
|
||||
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
|
||||
logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
|
||||
if !common.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
@@ -189,7 +189,7 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
|
||||
}
|
||||
err := LOG_DB.Create(log).Error
|
||||
if err != nil {
|
||||
common.LogError(c, "failed to record log: "+err.Error())
|
||||
logger.LogError(c, "failed to record log: "+err.Error())
|
||||
}
|
||||
if common.DataExportEnabled {
|
||||
gopool.Go(func() {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"log"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/logger"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -84,7 +85,7 @@ func createRootAccountIfNeed() error {
|
||||
var user User
|
||||
//if user.Status != common.UserStatusEnabled {
|
||||
if err := DB.First(&user).Error; err != nil {
|
||||
common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
|
||||
logger.SysLog("no user exists, create a root user for you: username is root, password is 123456")
|
||||
hashedPassword, err := common.Password2Hash("123456")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -108,7 +109,7 @@ func CheckSetup() {
|
||||
if setup == nil {
|
||||
// No setup record exists, check if we have a root user
|
||||
if RootUserExists() {
|
||||
common.SysLog("system is not initialized, but root user exists")
|
||||
logger.SysLog("system is not initialized, but root user exists")
|
||||
// Create setup record
|
||||
newSetup := Setup{
|
||||
Version: common.Version,
|
||||
@@ -116,16 +117,16 @@ func CheckSetup() {
|
||||
}
|
||||
err := DB.Create(&newSetup).Error
|
||||
if err != nil {
|
||||
common.SysLog("failed to create setup record: " + err.Error())
|
||||
logger.SysLog("failed to create setup record: " + err.Error())
|
||||
}
|
||||
constant.Setup = true
|
||||
} else {
|
||||
common.SysLog("system is not initialized and no root user exists")
|
||||
logger.SysLog("system is not initialized and no root user exists")
|
||||
constant.Setup = false
|
||||
}
|
||||
} else {
|
||||
// Setup record exists, system is initialized
|
||||
common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
|
||||
logger.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
|
||||
constant.Setup = true
|
||||
}
|
||||
}
|
||||
@@ -138,7 +139,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
|
||||
if dsn != "" {
|
||||
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
|
||||
// Use PostgreSQL
|
||||
common.SysLog("using PostgreSQL as database")
|
||||
logger.SysLog("using PostgreSQL as database")
|
||||
if !isLog {
|
||||
common.UsingPostgreSQL = true
|
||||
} else {
|
||||
@@ -152,7 +153,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
|
||||
})
|
||||
}
|
||||
if strings.HasPrefix(dsn, "local") {
|
||||
common.SysLog("SQL_DSN not set, using SQLite as database")
|
||||
logger.SysLog("SQL_DSN not set, using SQLite as database")
|
||||
if !isLog {
|
||||
common.UsingSQLite = true
|
||||
} else {
|
||||
@@ -163,7 +164,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
|
||||
})
|
||||
}
|
||||
// Use MySQL
|
||||
common.SysLog("using MySQL as database")
|
||||
logger.SysLog("using MySQL as database")
|
||||
// check parseTime
|
||||
if !strings.Contains(dsn, "parseTime") {
|
||||
if strings.Contains(dsn, "?") {
|
||||
@@ -182,7 +183,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
|
||||
})
|
||||
}
|
||||
// Use SQLite
|
||||
common.SysLog("SQL_DSN not set, using SQLite as database")
|
||||
logger.SysLog("SQL_DSN not set, using SQLite as database")
|
||||
common.UsingSQLite = true
|
||||
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
@@ -216,11 +217,11 @@ func InitDB() (err error) {
|
||||
if common.UsingMySQL {
|
||||
//_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
|
||||
}
|
||||
common.SysLog("database migration started")
|
||||
logger.SysLog("database migration started")
|
||||
err = migrateDB()
|
||||
return err
|
||||
} else {
|
||||
common.FatalLog(err)
|
||||
logger.FatalLog(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -253,11 +254,11 @@ func InitLogDB() (err error) {
|
||||
if !common.IsMasterNode {
|
||||
return nil
|
||||
}
|
||||
common.SysLog("database migration started")
|
||||
logger.SysLog("database migration started")
|
||||
err = migrateLOGDB()
|
||||
return err
|
||||
} else {
|
||||
common.FatalLog(err)
|
||||
logger.FatalLog(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -354,7 +355,7 @@ func migrateDBFast() error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
common.SysLog("database migrated")
|
||||
logger.SysLog("database migrated")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -503,6 +504,6 @@ func PingDB() error {
|
||||
}
|
||||
|
||||
lastPingTime = time.Now()
|
||||
common.SysLog("Database pinged successfully")
|
||||
logger.SysLog("Database pinged successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
"one-api/setting"
|
||||
"one-api/setting/config"
|
||||
"one-api/setting/operation_setting"
|
||||
@@ -150,7 +151,7 @@ func loadOptionsFromDatabase() {
|
||||
for _, option := range options {
|
||||
err := updateOptionMap(option.Key, option.Value)
|
||||
if err != nil {
|
||||
common.SysError("failed to update option map: " + err.Error())
|
||||
logger.SysError("failed to update option map: " + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -158,7 +159,7 @@ func loadOptionsFromDatabase() {
|
||||
func SyncOptions(frequency int) {
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Second)
|
||||
common.SysLog("syncing options from database")
|
||||
logger.SysLog("syncing options from database")
|
||||
loadOptionsFromDatabase()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package model
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/logger"
|
||||
"strings"
|
||||
|
||||
"one-api/common"
|
||||
@@ -92,7 +93,7 @@ func updatePricing() {
|
||||
//modelRatios := common.GetModelRatios()
|
||||
enableAbilities, err := GetAllEnableAbilityWithChannels()
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
|
||||
logger.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
|
||||
return
|
||||
}
|
||||
// 预加载模型元数据与供应商一次,避免循环查询
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
"strconv"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -148,7 +149,7 @@ func Redeem(key string, userId int) (quota int, err error) {
|
||||
if err != nil {
|
||||
return 0, errors.New("兑换失败," + err.Error())
|
||||
}
|
||||
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id))
|
||||
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", logger.LogQuota(redemption.Quota), redemption.Id))
|
||||
return redemption.Quota, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
@@ -91,7 +92,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
||||
token.Status = common.TokenStatusExpired
|
||||
err := token.SelectUpdate()
|
||||
if err != nil {
|
||||
common.SysError("failed to update token status" + err.Error())
|
||||
logger.SysError("failed to update token status" + err.Error())
|
||||
}
|
||||
}
|
||||
return token, errors.New("该令牌已过期")
|
||||
@@ -102,7 +103,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
||||
token.Status = common.TokenStatusExhausted
|
||||
err := token.SelectUpdate()
|
||||
if err != nil {
|
||||
common.SysError("failed to update token status" + err.Error())
|
||||
logger.SysError("failed to update token status" + err.Error())
|
||||
}
|
||||
}
|
||||
keyPrefix := key[:3]
|
||||
@@ -134,7 +135,7 @@ func GetTokenById(id int) (*Token, error) {
|
||||
if shouldUpdateRedis(true, err) {
|
||||
gopool.Go(func() {
|
||||
if err := cacheSetToken(token); err != nil {
|
||||
common.SysError("failed to update user status cache: " + err.Error())
|
||||
logger.SysError("failed to update user status cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -147,7 +148,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
|
||||
if shouldUpdateRedis(fromDB, err) && token != nil {
|
||||
gopool.Go(func() {
|
||||
if err := cacheSetToken(*token); err != nil {
|
||||
common.SysError("failed to update user status cache: " + err.Error())
|
||||
logger.SysError("failed to update user status cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -178,7 +179,7 @@ func (token *Token) Update() (err error) {
|
||||
gopool.Go(func() {
|
||||
err := cacheSetToken(*token)
|
||||
if err != nil {
|
||||
common.SysError("failed to update token cache: " + err.Error())
|
||||
logger.SysError("failed to update token cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -194,7 +195,7 @@ func (token *Token) SelectUpdate() (err error) {
|
||||
gopool.Go(func() {
|
||||
err := cacheSetToken(*token)
|
||||
if err != nil {
|
||||
common.SysError("failed to update token cache: " + err.Error())
|
||||
logger.SysError("failed to update token cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -209,7 +210,7 @@ func (token *Token) Delete() (err error) {
|
||||
gopool.Go(func() {
|
||||
err := cacheDeleteToken(token.Key)
|
||||
if err != nil {
|
||||
common.SysError("failed to delete token cache: " + err.Error())
|
||||
logger.SysError("failed to delete token cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -269,7 +270,7 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) {
|
||||
gopool.Go(func() {
|
||||
err := cacheIncrTokenQuota(key, int64(quota))
|
||||
if err != nil {
|
||||
common.SysError("failed to increase token quota: " + err.Error())
|
||||
logger.SysError("failed to increase token quota: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -299,7 +300,7 @@ func DecreaseTokenQuota(id int, key string, quota int) (err error) {
|
||||
gopool.Go(func() {
|
||||
err := cacheDecrTokenQuota(key, int64(quota))
|
||||
if err != nil {
|
||||
common.SysError("failed to decrease token quota: " + err.Error())
|
||||
logger.SysError("failed to decrease token quota: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -94,7 +95,7 @@ func Recharge(referenceId string, customerId string) (err error) {
|
||||
return errors.New("充值失败," + err.Error())
|
||||
}
|
||||
|
||||
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", common.FormatQuota(int(quota)), topUp.Amount))
|
||||
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", logger.FormatQuota(int(quota)), topUp.Amount))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -243,7 +244,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
|
||||
if !common.ValidateTOTPCode(t.Secret, code) {
|
||||
// 增加失败次数
|
||||
if err := t.IncrementFailedAttempts(); err != nil {
|
||||
common.SysError("更新2FA失败次数失败: " + err.Error())
|
||||
logger.SysError("更新2FA失败次数失败: " + err.Error())
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
@@ -255,7 +256,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
|
||||
t.LastUsedAt = &now
|
||||
|
||||
if err := t.Update(); err != nil {
|
||||
common.SysError("更新2FA使用记录失败: " + err.Error())
|
||||
logger.SysError("更新2FA使用记录失败: " + err.Error())
|
||||
}
|
||||
|
||||
return true, nil
|
||||
@@ -277,7 +278,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
|
||||
if !valid {
|
||||
// 增加失败次数
|
||||
if err := t.IncrementFailedAttempts(); err != nil {
|
||||
common.SysError("更新2FA失败次数失败: " + err.Error())
|
||||
logger.SysError("更新2FA失败次数失败: " + err.Error())
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
@@ -289,7 +290,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
|
||||
t.LastUsedAt = &now
|
||||
|
||||
if err := t.Update(); err != nil {
|
||||
common.SysError("更新2FA使用记录失败: " + err.Error())
|
||||
logger.SysError("更新2FA使用记录失败: " + err.Error())
|
||||
}
|
||||
|
||||
return true, nil
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -24,12 +25,12 @@ func UpdateQuotaData() {
|
||||
// recover
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
|
||||
logger.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
|
||||
}
|
||||
}()
|
||||
for {
|
||||
if common.DataExportEnabled {
|
||||
common.SysLog("正在更新数据看板数据...")
|
||||
logger.SysLog("正在更新数据看板数据...")
|
||||
SaveQuotaDataCache()
|
||||
}
|
||||
time.Sleep(time.Duration(common.DataExportInterval) * time.Minute)
|
||||
@@ -91,7 +92,7 @@ func SaveQuotaDataCache() {
|
||||
}
|
||||
}
|
||||
CacheQuotaData = make(map[string]*QuotaData)
|
||||
common.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size))
|
||||
logger.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size))
|
||||
}
|
||||
|
||||
func increaseQuotaData(userId int, username string, modelName string, count int, quota int, createdAt int64, tokenUsed int) {
|
||||
@@ -102,7 +103,7 @@ func increaseQuotaData(userId int, username string, modelName string, count int,
|
||||
"token_used": gorm.Expr("token_used + ?", tokenUsed),
|
||||
}).Error
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err))
|
||||
logger.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -75,7 +76,7 @@ func (user *User) GetSetting() dto.UserSetting {
|
||||
if user.Setting != "" {
|
||||
err := json.Unmarshal([]byte(user.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
logger.SysError("failed to unmarshal setting: " + err.Error())
|
||||
}
|
||||
}
|
||||
return setting
|
||||
@@ -84,7 +85,7 @@ func (user *User) GetSetting() dto.UserSetting {
|
||||
func (user *User) SetSetting(setting dto.UserSetting) {
|
||||
settingBytes, err := json.Marshal(setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal setting: " + err.Error())
|
||||
logger.SysError("failed to marshal setting: " + err.Error())
|
||||
return
|
||||
}
|
||||
user.Setting = string(settingBytes)
|
||||
@@ -274,7 +275,7 @@ func inviteUser(inviterId int) (err error) {
|
||||
func (user *User) TransferAffQuotaToQuota(quota int) error {
|
||||
// 检查quota是否小于最小额度
|
||||
if float64(quota) < common.QuotaPerUnit {
|
||||
return fmt.Errorf("转移额度最小为%s!", common.LogQuota(int(common.QuotaPerUnit)))
|
||||
return fmt.Errorf("转移额度最小为%s!", logger.LogQuota(int(common.QuotaPerUnit)))
|
||||
}
|
||||
|
||||
// 开始数据库事务
|
||||
@@ -324,16 +325,16 @@ func (user *User) Insert(inviterId int) error {
|
||||
return result.Error
|
||||
}
|
||||
if common.QuotaForNewUser > 0 {
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
|
||||
}
|
||||
if inviterId != 0 {
|
||||
if common.QuotaForInvitee > 0 {
|
||||
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
|
||||
}
|
||||
if common.QuotaForInviter > 0 {
|
||||
//_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
|
||||
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
|
||||
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
|
||||
_ = inviteUser(inviterId)
|
||||
}
|
||||
}
|
||||
@@ -517,7 +518,7 @@ func IsAdmin(userId int) bool {
|
||||
var user User
|
||||
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
|
||||
if err != nil {
|
||||
common.SysError("no such user " + err.Error())
|
||||
logger.SysError("no such user " + err.Error())
|
||||
return false
|
||||
}
|
||||
return user.Role >= common.RoleAdminUser
|
||||
@@ -572,7 +573,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
|
||||
if shouldUpdateRedis(fromDB, err) {
|
||||
gopool.Go(func() {
|
||||
if err := updateUserQuotaCache(id, quota); err != nil {
|
||||
common.SysError("failed to update user quota cache: " + err.Error())
|
||||
logger.SysError("failed to update user quota cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -610,7 +611,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
|
||||
if shouldUpdateRedis(fromDB, err) {
|
||||
gopool.Go(func() {
|
||||
if err := updateUserGroupCache(id, group); err != nil {
|
||||
common.SysError("failed to update user group cache: " + err.Error())
|
||||
logger.SysError("failed to update user group cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -639,7 +640,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error)
|
||||
if shouldUpdateRedis(fromDB, err) {
|
||||
gopool.Go(func() {
|
||||
if err := updateUserSettingCache(id, setting); err != nil {
|
||||
common.SysError("failed to update user setting cache: " + err.Error())
|
||||
logger.SysError("failed to update user setting cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -669,7 +670,7 @@ func IncreaseUserQuota(id int, quota int, db bool) (err error) {
|
||||
gopool.Go(func() {
|
||||
err := cacheIncrUserQuota(id, int64(quota))
|
||||
if err != nil {
|
||||
common.SysError("failed to increase user quota: " + err.Error())
|
||||
logger.SysError("failed to increase user quota: " + err.Error())
|
||||
}
|
||||
})
|
||||
if !db && common.BatchUpdateEnabled {
|
||||
@@ -694,7 +695,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
|
||||
gopool.Go(func() {
|
||||
err := cacheDecrUserQuota(id, int64(quota))
|
||||
if err != nil {
|
||||
common.SysError("failed to decrease user quota: " + err.Error())
|
||||
logger.SysError("failed to decrease user quota: " + err.Error())
|
||||
}
|
||||
})
|
||||
if common.BatchUpdateEnabled {
|
||||
@@ -750,7 +751,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
||||
},
|
||||
).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update user used quota and request count: " + err.Error())
|
||||
logger.SysError("failed to update user used quota and request count: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -767,14 +768,14 @@ func updateUserUsedQuota(id int, quota int) {
|
||||
},
|
||||
).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update user used quota: " + err.Error())
|
||||
logger.SysError("failed to update user used quota: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func updateUserRequestCount(id int, count int) {
|
||||
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update user request count: " + err.Error())
|
||||
logger.SysError("failed to update user request count: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -785,7 +786,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) {
|
||||
if shouldUpdateRedis(fromDB, err) {
|
||||
gopool.Go(func() {
|
||||
if err := updateUserNameCache(id, username); err != nil {
|
||||
common.SysError("failed to update user name cache: " + err.Error())
|
||||
logger.SysError("failed to update user name cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -37,7 +38,7 @@ func (user *UserBase) GetSetting() dto.UserSetting {
|
||||
if user.Setting != "" {
|
||||
err := common.Unmarshal([]byte(user.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
logger.SysError("failed to unmarshal setting: " + err.Error())
|
||||
}
|
||||
}
|
||||
return setting
|
||||
@@ -78,7 +79,7 @@ func GetUserCache(userId int) (userCache *UserBase, err error) {
|
||||
if shouldUpdateRedis(fromDB, err) && user != nil {
|
||||
gopool.Go(func() {
|
||||
if err := updateUserCache(*user); err != nil {
|
||||
common.SysError("failed to update user status cache: " + err.Error())
|
||||
logger.SysError("failed to update user status cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package model
|
||||
import (
|
||||
"errors"
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -65,7 +66,7 @@ func batchUpdate() {
|
||||
return
|
||||
}
|
||||
|
||||
common.SysLog("batch update started")
|
||||
logger.SysLog("batch update started")
|
||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||
batchUpdateLocks[i].Lock()
|
||||
store := batchUpdateStores[i]
|
||||
@@ -77,12 +78,12 @@ func batchUpdate() {
|
||||
case BatchUpdateTypeUserQuota:
|
||||
err := increaseUserQuota(key, value)
|
||||
if err != nil {
|
||||
common.SysError("failed to batch update user quota: " + err.Error())
|
||||
logger.SysError("failed to batch update user quota: " + err.Error())
|
||||
}
|
||||
case BatchUpdateTypeTokenQuota:
|
||||
err := increaseTokenQuota(key, value)
|
||||
if err != nil {
|
||||
common.SysError("failed to batch update token quota: " + err.Error())
|
||||
logger.SysError("failed to batch update token quota: " + err.Error())
|
||||
}
|
||||
case BatchUpdateTypeUsedQuota:
|
||||
updateUserUsedQuota(key, value)
|
||||
@@ -93,7 +94,7 @@ func batchUpdate() {
|
||||
}
|
||||
}
|
||||
}
|
||||
common.SysLog("batch update finished")
|
||||
logger.SysLog("batch update finished")
|
||||
}
|
||||
|
||||
func RecordExist(err error) (bool, error) {
|
||||
|
||||
Reference in New Issue
Block a user