refactor: Introduce pre-consume quota and unify relay handlers

This commit introduces a major architectural refactoring to improve quota management, centralize logging, and streamline the relay handling logic.

Key changes:
- **Pre-consume Quota:** Implements a new mechanism to check and reserve user quota *before* making the request to the upstream provider. This ensures more accurate quota deduction and prevents users from exceeding their limits due to concurrent requests.

- **Unified Relay Handlers:** Refactors the relay logic to use generic handlers (e.g., `ChatHandler`, `ImageHandler`) instead of provider-specific implementations. This significantly reduces code duplication and simplifies adding new channels.

- **Centralized Logger:** A new dedicated `logger` package is introduced, and all system logging calls are migrated to use it, moving this responsibility out of the `common` package.

- **Code Reorganization:** DTOs are generalized (e.g., `dalle.go` -> `openai_image.go`) and utility code is moved to more appropriate packages (e.g., `common/http.go` -> `service/http.go`) for better code structure.
This commit is contained in:
CaIon
2025-08-14 20:05:06 +08:00
parent 17bab355e4
commit e2037ad756
113 changed files with 3095 additions and 2518 deletions

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"one-api/common"
"one-api/logger"
"strings"
"sync"
@@ -294,13 +295,13 @@ func FixAbility() (int, int, error) {
if common.UsingSQLite {
err := DB.Exec("DELETE FROM abilities").Error
if err != nil {
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
return 0, 0, err
}
} else {
err := DB.Exec("TRUNCATE TABLE abilities").Error
if err != nil {
common.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
logger.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
return 0, 0, err
}
}
@@ -320,7 +321,7 @@ func FixAbility() (int, int, error) {
// Delete all abilities of this channel
err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
failCount += len(chunk)
continue
}
@@ -328,7 +329,7 @@ func FixAbility() (int, int, error) {
for _, channel := range chunk {
err = channel.AddAbilities(nil)
if err != nil {
common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
logger.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
failCount++
} else {
successCount++

View File

@@ -9,6 +9,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/types"
"strings"
"sync"
@@ -209,7 +210,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} {
if channel.OtherInfo != "" {
err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
if err != nil {
common.SysError("failed to unmarshal other info: " + err.Error())
logger.SysError("failed to unmarshal other info: " + err.Error())
}
}
return otherInfo
@@ -218,7 +219,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} {
func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
otherInfoBytes, err := json.Marshal(otherInfo)
if err != nil {
common.SysError("failed to marshal other info: " + err.Error())
logger.SysError("failed to marshal other info: " + err.Error())
return
}
channel.OtherInfo = string(otherInfoBytes)
@@ -488,7 +489,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
ResponseTime: int(responseTime),
}).Error
if err != nil {
common.SysError("failed to update response time: " + err.Error())
logger.SysError("failed to update response time: " + err.Error())
}
}
@@ -498,7 +499,7 @@ func (channel *Channel) UpdateBalance(balance float64) {
Balance: balance,
}).Error
if err != nil {
common.SysError("failed to update balance: " + err.Error())
logger.SysError("failed to update balance: " + err.Error())
}
}
@@ -614,7 +615,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
if shouldUpdateAbilities {
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
if err != nil {
common.SysError("failed to update ability status: " + err.Error())
logger.SysError("failed to update ability status: " + err.Error())
}
}
}()
@@ -642,7 +643,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
}
err = channel.Save()
if err != nil {
common.SysError("failed to update channel status: " + err.Error())
logger.SysError("failed to update channel status: " + err.Error())
return false
}
}
@@ -704,7 +705,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *
for _, channel := range channels {
err = channel.UpdateAbilities(nil)
if err != nil {
common.SysError("failed to update abilities: " + err.Error())
logger.SysError("failed to update abilities: " + err.Error())
}
}
}
@@ -728,7 +729,7 @@ func UpdateChannelUsedQuota(id int, quota int) {
func updateChannelUsedQuota(id int, quota int) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
if err != nil {
common.SysError("failed to update channel used quota: " + err.Error())
logger.SysError("failed to update channel used quota: " + err.Error())
}
}
@@ -821,7 +822,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
if channel.Setting != nil && *channel.Setting != "" {
err := common.Unmarshal([]byte(*channel.Setting), &setting)
if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error())
logger.SysError("failed to unmarshal setting: " + err.Error())
channel.Setting = nil // 清空设置以避免后续错误
_ = channel.Save() // 保存修改
}
@@ -832,7 +833,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
settingBytes, err := common.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
logger.SysError("failed to marshal setting: " + err.Error())
return
}
channel.Setting = common.GetPointer[string](string(settingBytes))
@@ -843,7 +844,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
if channel.OtherSettings != "" {
err := common.UnmarshalJsonStr(channel.OtherSettings, &setting)
if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error())
logger.SysError("failed to unmarshal setting: " + err.Error())
channel.OtherSettings = "{}" // 清空设置以避免后续错误
_ = channel.Save() // 保存修改
}
@@ -854,7 +855,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) {
settingBytes, err := common.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
logger.SysError("failed to marshal setting: " + err.Error())
return
}
channel.OtherSettings = string(settingBytes)
@@ -865,7 +866,7 @@ func (channel *Channel) GetParamOverride() map[string]interface{} {
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
err := common.Unmarshal([]byte(*channel.ParamOverride), &paramOverride)
if err != nil {
common.SysError("failed to unmarshal param override: " + err.Error())
logger.SysError("failed to unmarshal param override: " + err.Error())
}
}
return paramOverride

View File

@@ -6,6 +6,7 @@ import (
"math/rand"
"one-api/common"
"one-api/constant"
"one-api/logger"
"one-api/setting"
"one-api/setting/ratio_setting"
"sort"
@@ -84,13 +85,13 @@ func InitChannelCache() {
}
channelsIDM = newChannelId2channel
channelSyncLock.Unlock()
common.SysLog("channels synced from database")
logger.SysLog("channels synced from database")
}
func SyncChannelCache(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("syncing channels from database")
logger.SysLog("syncing channels from database")
InitChannelCache()
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"one-api/common"
"one-api/logger"
"os"
"strings"
"time"
@@ -87,13 +88,13 @@ func RecordLog(userId int, logType int, content string) {
}
err := LOG_DB.Create(log).Error
if err != nil {
common.SysError("failed to record log: " + err.Error())
logger.SysError("failed to record log: " + err.Error())
}
}
func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
// 判断是否需要记录 IP
@@ -129,7 +130,7 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
}
err := LOG_DB.Create(log).Error
if err != nil {
common.LogError(c, "failed to record log: "+err.Error())
logger.LogError(c, "failed to record log: "+err.Error())
}
}
@@ -142,7 +143,6 @@ type RecordConsumeLogParams struct {
Quota int `json:"quota"`
Content string `json:"content"`
TokenId int `json:"token_id"`
UserQuota int `json:"user_quota"`
UseTimeSeconds int `json:"use_time_seconds"`
IsStream bool `json:"is_stream"`
Group string `json:"group"`
@@ -150,7 +150,7 @@ type RecordConsumeLogParams struct {
}
func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
if !common.LogConsumeEnabled {
return
}
@@ -189,7 +189,7 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
}
err := LOG_DB.Create(log).Error
if err != nil {
common.LogError(c, "failed to record log: "+err.Error())
logger.LogError(c, "failed to record log: "+err.Error())
}
if common.DataExportEnabled {
gopool.Go(func() {

View File

@@ -5,6 +5,7 @@ import (
"log"
"one-api/common"
"one-api/constant"
"one-api/logger"
"os"
"strings"
"sync"
@@ -84,7 +85,7 @@ func createRootAccountIfNeed() error {
var user User
//if user.Status != common.UserStatusEnabled {
if err := DB.First(&user).Error; err != nil {
common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
logger.SysLog("no user exists, create a root user for you: username is root, password is 123456")
hashedPassword, err := common.Password2Hash("123456")
if err != nil {
return err
@@ -108,7 +109,7 @@ func CheckSetup() {
if setup == nil {
// No setup record exists, check if we have a root user
if RootUserExists() {
common.SysLog("system is not initialized, but root user exists")
logger.SysLog("system is not initialized, but root user exists")
// Create setup record
newSetup := Setup{
Version: common.Version,
@@ -116,16 +117,16 @@ func CheckSetup() {
}
err := DB.Create(&newSetup).Error
if err != nil {
common.SysLog("failed to create setup record: " + err.Error())
logger.SysLog("failed to create setup record: " + err.Error())
}
constant.Setup = true
} else {
common.SysLog("system is not initialized and no root user exists")
logger.SysLog("system is not initialized and no root user exists")
constant.Setup = false
}
} else {
// Setup record exists, system is initialized
common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
logger.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
constant.Setup = true
}
}
@@ -138,7 +139,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
if dsn != "" {
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
logger.SysLog("using PostgreSQL as database")
if !isLog {
common.UsingPostgreSQL = true
} else {
@@ -152,7 +153,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
})
}
if strings.HasPrefix(dsn, "local") {
common.SysLog("SQL_DSN not set, using SQLite as database")
logger.SysLog("SQL_DSN not set, using SQLite as database")
if !isLog {
common.UsingSQLite = true
} else {
@@ -163,7 +164,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
})
}
// Use MySQL
common.SysLog("using MySQL as database")
logger.SysLog("using MySQL as database")
// check parseTime
if !strings.Contains(dsn, "parseTime") {
if strings.Contains(dsn, "?") {
@@ -182,7 +183,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
})
}
// Use SQLite
common.SysLog("SQL_DSN not set, using SQLite as database")
logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
@@ -216,11 +217,11 @@ func InitDB() (err error) {
if common.UsingMySQL {
//_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
}
common.SysLog("database migration started")
logger.SysLog("database migration started")
err = migrateDB()
return err
} else {
common.FatalLog(err)
logger.FatalLog(err)
}
return err
}
@@ -253,11 +254,11 @@ func InitLogDB() (err error) {
if !common.IsMasterNode {
return nil
}
common.SysLog("database migration started")
logger.SysLog("database migration started")
err = migrateLOGDB()
return err
} else {
common.FatalLog(err)
logger.FatalLog(err)
}
return err
}
@@ -354,7 +355,7 @@ func migrateDBFast() error {
return err
}
}
common.SysLog("database migrated")
logger.SysLog("database migrated")
return nil
}
@@ -503,6 +504,6 @@ func PingDB() error {
}
lastPingTime = time.Now()
common.SysLog("Database pinged successfully")
logger.SysLog("Database pinged successfully")
return nil
}

View File

@@ -2,6 +2,7 @@ package model
import (
"one-api/common"
"one-api/logger"
"one-api/setting"
"one-api/setting/config"
"one-api/setting/operation_setting"
@@ -150,7 +151,7 @@ func loadOptionsFromDatabase() {
for _, option := range options {
err := updateOptionMap(option.Key, option.Value)
if err != nil {
common.SysError("failed to update option map: " + err.Error())
logger.SysError("failed to update option map: " + err.Error())
}
}
}
@@ -158,7 +159,7 @@ func loadOptionsFromDatabase() {
func SyncOptions(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("syncing options from database")
logger.SysLog("syncing options from database")
loadOptionsFromDatabase()
}
}

View File

@@ -3,6 +3,7 @@ package model
import (
"encoding/json"
"fmt"
"one-api/logger"
"strings"
"one-api/common"
@@ -92,7 +93,7 @@ func updatePricing() {
//modelRatios := common.GetModelRatios()
enableAbilities, err := GetAllEnableAbilityWithChannels()
if err != nil {
common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
logger.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
return
}
// 预加载模型元数据与供应商一次,避免循环查询

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"one-api/common"
"one-api/logger"
"strconv"
"gorm.io/gorm"
@@ -148,7 +149,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil {
return 0, errors.New("兑换失败," + err.Error())
}
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id))
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s兑换码ID %d", logger.LogQuota(redemption.Quota), redemption.Id))
return redemption.Quota, nil
}

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"one-api/common"
"one-api/logger"
"strings"
"github.com/bytedance/gopkg/util/gopool"
@@ -91,7 +92,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
token.Status = common.TokenStatusExpired
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token status" + err.Error())
logger.SysError("failed to update token status" + err.Error())
}
}
return token, errors.New("该令牌已过期")
@@ -102,7 +103,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
token.Status = common.TokenStatusExhausted
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token status" + err.Error())
logger.SysError("failed to update token status" + err.Error())
}
}
keyPrefix := key[:3]
@@ -134,7 +135,7 @@ func GetTokenById(id int) (*Token, error) {
if shouldUpdateRedis(true, err) {
gopool.Go(func() {
if err := cacheSetToken(token); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
logger.SysError("failed to update user status cache: " + err.Error())
}
})
}
@@ -147,7 +148,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
if shouldUpdateRedis(fromDB, err) && token != nil {
gopool.Go(func() {
if err := cacheSetToken(*token); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
logger.SysError("failed to update user status cache: " + err.Error())
}
})
}
@@ -178,7 +179,7 @@ func (token *Token) Update() (err error) {
gopool.Go(func() {
err := cacheSetToken(*token)
if err != nil {
common.SysError("failed to update token cache: " + err.Error())
logger.SysError("failed to update token cache: " + err.Error())
}
})
}
@@ -194,7 +195,7 @@ func (token *Token) SelectUpdate() (err error) {
gopool.Go(func() {
err := cacheSetToken(*token)
if err != nil {
common.SysError("failed to update token cache: " + err.Error())
logger.SysError("failed to update token cache: " + err.Error())
}
})
}
@@ -209,7 +210,7 @@ func (token *Token) Delete() (err error) {
gopool.Go(func() {
err := cacheDeleteToken(token.Key)
if err != nil {
common.SysError("failed to delete token cache: " + err.Error())
logger.SysError("failed to delete token cache: " + err.Error())
}
})
}
@@ -269,7 +270,7 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) {
gopool.Go(func() {
err := cacheIncrTokenQuota(key, int64(quota))
if err != nil {
common.SysError("failed to increase token quota: " + err.Error())
logger.SysError("failed to increase token quota: " + err.Error())
}
})
}
@@ -299,7 +300,7 @@ func DecreaseTokenQuota(id int, key string, quota int) (err error) {
gopool.Go(func() {
err := cacheDecrTokenQuota(key, int64(quota))
if err != nil {
common.SysError("failed to decrease token quota: " + err.Error())
logger.SysError("failed to decrease token quota: " + err.Error())
}
})
}

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"one-api/common"
"one-api/logger"
"gorm.io/gorm"
)
@@ -94,7 +95,7 @@ func Recharge(referenceId string, customerId string) (err error) {
return errors.New("充值失败," + err.Error())
}
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%d", common.FormatQuota(int(quota)), topUp.Amount))
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%d", logger.FormatQuota(int(quota)), topUp.Amount))
return nil
}

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"one-api/common"
"one-api/logger"
"time"
"gorm.io/gorm"
@@ -243,7 +244,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
if !common.ValidateTOTPCode(t.Secret, code) {
// 增加失败次数
if err := t.IncrementFailedAttempts(); err != nil {
common.SysError("更新2FA失败次数失败: " + err.Error())
logger.SysError("更新2FA失败次数失败: " + err.Error())
}
return false, nil
}
@@ -255,7 +256,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
t.LastUsedAt = &now
if err := t.Update(); err != nil {
common.SysError("更新2FA使用记录失败: " + err.Error())
logger.SysError("更新2FA使用记录失败: " + err.Error())
}
return true, nil
@@ -277,7 +278,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
if !valid {
// 增加失败次数
if err := t.IncrementFailedAttempts(); err != nil {
common.SysError("更新2FA失败次数失败: " + err.Error())
logger.SysError("更新2FA失败次数失败: " + err.Error())
}
return false, nil
}
@@ -289,7 +290,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
t.LastUsedAt = &now
if err := t.Update(); err != nil {
common.SysError("更新2FA使用记录失败: " + err.Error())
logger.SysError("更新2FA使用记录失败: " + err.Error())
}
return true, nil

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"gorm.io/gorm"
"one-api/common"
"one-api/logger"
"sync"
"time"
)
@@ -24,12 +25,12 @@ func UpdateQuotaData() {
// recover
defer func() {
if r := recover(); r != nil {
common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
logger.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
}
}()
for {
if common.DataExportEnabled {
common.SysLog("正在更新数据看板数据...")
logger.SysLog("正在更新数据看板数据...")
SaveQuotaDataCache()
}
time.Sleep(time.Duration(common.DataExportInterval) * time.Minute)
@@ -91,7 +92,7 @@ func SaveQuotaDataCache() {
}
}
CacheQuotaData = make(map[string]*QuotaData)
common.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size))
logger.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size))
}
func increaseQuotaData(userId int, username string, modelName string, count int, quota int, createdAt int64, tokenUsed int) {
@@ -102,7 +103,7 @@ func increaseQuotaData(userId int, username string, modelName string, count int,
"token_used": gorm.Expr("token_used + ?", tokenUsed),
}).Error
if err != nil {
common.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err))
logger.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err))
}
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"one-api/common"
"one-api/dto"
"one-api/logger"
"strconv"
"strings"
@@ -75,7 +76,7 @@ func (user *User) GetSetting() dto.UserSetting {
if user.Setting != "" {
err := json.Unmarshal([]byte(user.Setting), &setting)
if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error())
logger.SysError("failed to unmarshal setting: " + err.Error())
}
}
return setting
@@ -84,7 +85,7 @@ func (user *User) GetSetting() dto.UserSetting {
func (user *User) SetSetting(setting dto.UserSetting) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
logger.SysError("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
@@ -274,7 +275,7 @@ func inviteUser(inviterId int) (err error) {
func (user *User) TransferAffQuotaToQuota(quota int) error {
// 检查quota是否小于最小额度
if float64(quota) < common.QuotaPerUnit {
return fmt.Errorf("转移额度最小为%s", common.LogQuota(int(common.QuotaPerUnit)))
return fmt.Errorf("转移额度最小为%s", logger.LogQuota(int(common.QuotaPerUnit)))
}
// 开始数据库事务
@@ -324,16 +325,16 @@ func (user *User) Insert(inviterId int) error {
return result.Error
}
if common.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
}
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
//_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
_ = inviteUser(inviterId)
}
}
@@ -517,7 +518,7 @@ func IsAdmin(userId int) bool {
var user User
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
if err != nil {
common.SysError("no such user " + err.Error())
logger.SysError("no such user " + err.Error())
return false
}
return user.Role >= common.RoleAdminUser
@@ -572,7 +573,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserQuotaCache(id, quota); err != nil {
common.SysError("failed to update user quota cache: " + err.Error())
logger.SysError("failed to update user quota cache: " + err.Error())
}
})
}
@@ -610,7 +611,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserGroupCache(id, group); err != nil {
common.SysError("failed to update user group cache: " + err.Error())
logger.SysError("failed to update user group cache: " + err.Error())
}
})
}
@@ -639,7 +640,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error)
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserSettingCache(id, setting); err != nil {
common.SysError("failed to update user setting cache: " + err.Error())
logger.SysError("failed to update user setting cache: " + err.Error())
}
})
}
@@ -669,7 +670,7 @@ func IncreaseUserQuota(id int, quota int, db bool) (err error) {
gopool.Go(func() {
err := cacheIncrUserQuota(id, int64(quota))
if err != nil {
common.SysError("failed to increase user quota: " + err.Error())
logger.SysError("failed to increase user quota: " + err.Error())
}
})
if !db && common.BatchUpdateEnabled {
@@ -694,7 +695,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
gopool.Go(func() {
err := cacheDecrUserQuota(id, int64(quota))
if err != nil {
common.SysError("failed to decrease user quota: " + err.Error())
logger.SysError("failed to decrease user quota: " + err.Error())
}
})
if common.BatchUpdateEnabled {
@@ -750,7 +751,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
},
).Error
if err != nil {
common.SysError("failed to update user used quota and request count: " + err.Error())
logger.SysError("failed to update user used quota and request count: " + err.Error())
return
}
@@ -767,14 +768,14 @@ func updateUserUsedQuota(id int, quota int) {
},
).Error
if err != nil {
common.SysError("failed to update user used quota: " + err.Error())
logger.SysError("failed to update user used quota: " + err.Error())
}
}
func updateUserRequestCount(id int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
if err != nil {
common.SysError("failed to update user request count: " + err.Error())
logger.SysError("failed to update user request count: " + err.Error())
}
}
@@ -785,7 +786,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) {
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserNameCache(id, username); err != nil {
common.SysError("failed to update user name cache: " + err.Error())
logger.SysError("failed to update user name cache: " + err.Error())
}
})
}

View File

@@ -5,6 +5,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
"time"
"github.com/gin-gonic/gin"
@@ -37,7 +38,7 @@ func (user *UserBase) GetSetting() dto.UserSetting {
if user.Setting != "" {
err := common.Unmarshal([]byte(user.Setting), &setting)
if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error())
logger.SysError("failed to unmarshal setting: " + err.Error())
}
}
return setting
@@ -78,7 +79,7 @@ func GetUserCache(userId int) (userCache *UserBase, err error) {
if shouldUpdateRedis(fromDB, err) && user != nil {
gopool.Go(func() {
if err := updateUserCache(*user); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
logger.SysError("failed to update user status cache: " + err.Error())
}
})
}

View File

@@ -3,6 +3,7 @@ package model
import (
"errors"
"one-api/common"
"one-api/logger"
"sync"
"time"
@@ -65,7 +66,7 @@ func batchUpdate() {
return
}
common.SysLog("batch update started")
logger.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
store := batchUpdateStores[i]
@@ -77,12 +78,12 @@ func batchUpdate() {
case BatchUpdateTypeUserQuota:
err := increaseUserQuota(key, value)
if err != nil {
common.SysError("failed to batch update user quota: " + err.Error())
logger.SysError("failed to batch update user quota: " + err.Error())
}
case BatchUpdateTypeTokenQuota:
err := increaseTokenQuota(key, value)
if err != nil {
common.SysError("failed to batch update token quota: " + err.Error())
logger.SysError("failed to batch update token quota: " + err.Error())
}
case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value)
@@ -93,7 +94,7 @@ func batchUpdate() {
}
}
}
common.SysLog("batch update finished")
logger.SysLog("batch update finished")
}
func RecordExist(err error) (bool, error) {