Merge branch 'alpha' into refactor_error

# Conflicts:
#	controller/channel.go
#	middleware/distributor.go
#	model/channel.go
#	model/user.go
#	model/user_cache.go
#	relay/common/relay_info.go
This commit is contained in:
CaIon
2025-07-10 15:11:55 +08:00
32 changed files with 395 additions and 258 deletions

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"one-api/common"
"strings"
"sync"
"github.com/samber/lo"
"gorm.io/gorm"
@@ -272,74 +273,45 @@ func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uin
return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
}
func FixAbility() (int, error) {
var channelIds []int
count := 0
// Find all channel ids from channel table
err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error
var fixLock = sync.Mutex{}
func FixAbility() (int, int, error) {
lock := fixLock.TryLock()
if !lock {
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
}
defer fixLock.Unlock()
var channels []*Channel
// Find all channels
err := DB.Model(&Channel{}).Find(&channels).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
return 0, err
return 0, 0, err
}
// Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
if len(channelIds) > 0 {
// Process deletion in chunks to avoid "too many placeholders" error
for _, chunk := range lo.Chunk(channelIds, 100) {
err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
return 0, err
}
}
} else {
// If no channels exist, delete all abilities
err = DB.Delete(&Ability{}).Error
if len(channels) == 0 {
return 0, 0, nil
}
successCount := 0
failCount := 0
for _, chunk := range lo.Chunk(channels, 50) {
ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id })
// Delete all abilities of this channel
err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
return 0, err
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
failCount += len(chunk)
continue
}
common.SysLog("Delete all abilities successfully")
return 0, nil
}
common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
count += len(channelIds)
// Use channelIds to find channel not in abilities table
var abilityChannelIds []int
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
return count, err
}
var channels []Channel
if len(abilityChannelIds) == 0 {
err = DB.Find(&channels).Error
} else {
// Process query in chunks to avoid "too many placeholders" error
err = nil
for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
var channelsChunk []Channel
err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
// Then add new abilities
for _, channel := range chunk {
err = channel.AddAbilities()
if err != nil {
common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
return count, err
common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
failCount++
} else {
successCount++
}
channels = append(channels, channelsChunk...)
}
}
for _, channel := range channels {
err := channel.UpdateAbilities(nil)
if err != nil {
common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
} else {
common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id))
count++
}
}
InitChannelCache()
return count, nil
return successCount, failCount, nil
}

View File

@@ -7,6 +7,7 @@ import (
"math/rand"
"one-api/common"
"one-api/constant"
"one-api/dto"
"strings"
"sync"
@@ -610,8 +611,19 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
return tags, nil
}
func (channel *Channel) GetSetting() map[string]interface{} {
setting := make(map[string]interface{})
func (channel *Channel) ValidateSettings() error {
channelParams := &dto.ChannelSettings{}
if channel.Setting != nil && *channel.Setting != "" {
err := json.Unmarshal([]byte(*channel.Setting), channelParams)
if err != nil {
return err
}
}
return nil
}
func (channel *Channel) GetSetting() dto.ChannelSettings {
setting := dto.ChannelSettings{}
if channel.Setting != nil && *channel.Setting != "" {
err := json.Unmarshal([]byte(*channel.Setting), &setting)
if err != nil {
@@ -621,7 +633,7 @@ func (channel *Channel) GetSetting() map[string]interface{} {
return setting
}
func (channel *Channel) SetSetting(setting map[string]interface{}) {
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"one-api/common"
"one-api/constant"
"os"
"strings"
"time"
@@ -100,10 +99,8 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
// 判断是否需要记录 IP
needRecordIp := false
if settingMap, err := GetUserSetting(userId, false); err == nil {
if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
if vb, ok := v.(bool); ok && vb {
needRecordIp = true
}
if settingMap.RecordIpLog {
needRecordIp = true
}
}
log := &Log{
@@ -136,22 +133,34 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
}
}
func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
type RecordConsumeLogParams struct {
ChannelId int `json:"channel_id"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
ModelName string `json:"model_name"`
TokenName string `json:"token_name"`
Quota int `json:"quota"`
Content string `json:"content"`
TokenId int `json:"token_id"`
UserQuota int `json:"user_quota"`
UseTimeSeconds int `json:"use_time_seconds"`
IsStream bool `json:"is_stream"`
Group string `json:"group"`
Other map[string]interface{} `json:"other"`
}
func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
if !common.LogConsumeEnabled {
return
}
username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
otherStr := common.MapToJsonStr(params.Other)
// 判断是否需要记录 IP
needRecordIp := false
if settingMap, err := GetUserSetting(userId, false); err == nil {
if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
if vb, ok := v.(bool); ok && vb {
needRecordIp = true
}
if settingMap.RecordIpLog {
needRecordIp = true
}
}
log := &Log{
@@ -159,17 +168,17 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
Username: username,
CreatedAt: common.GetTimestamp(),
Type: LogTypeConsume,
Content: content,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TokenName: tokenName,
ModelName: modelName,
Quota: quota,
ChannelId: channelId,
TokenId: tokenId,
UseTime: useTimeSeconds,
IsStream: isStream,
Group: group,
Content: params.Content,
PromptTokens: params.PromptTokens,
CompletionTokens: params.CompletionTokens,
TokenName: params.TokenName,
ModelName: params.ModelName,
Quota: params.Quota,
ChannelId: params.ChannelId,
TokenId: params.TokenId,
UseTime: params.UseTimeSeconds,
IsStream: params.IsStream,
Group: params.Group,
Ip: func() string {
if needRecordIp {
return c.ClientIP()
@@ -184,7 +193,7 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
}
if common.DataExportEnabled {
gopool.Go(func() {
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
LogQuotaData(userId, username, params.ModelName, params.Quota, common.GetTimestamp(), params.PromptTokens+params.CompletionTokens)
})
}
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"one-api/common"
"one-api/dto"
"strconv"
"strings"
@@ -68,19 +69,18 @@ func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
func (user *User) GetSetting() (map[string]interface{}, error) {
if user.Setting == "" {
return map[string]interface{}{}, nil
func (user *User) GetSetting() dto.UserSetting {
setting := dto.UserSetting{}
if user.Setting != "" {
err := json.Unmarshal([]byte(user.Setting), &setting)
if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error())
}
}
toMap, err := common.StrToMap(user.Setting)
if err != nil {
common.SysError("failed to convert setting to map: " + err.Error())
return nil, fmt.Errorf("failed to convert setting to map")
}
return toMap, nil
return setting
}
func (user *User) SetSetting(setting map[string]interface{}) {
func (user *User) SetSetting(setting dto.UserSetting) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
@@ -631,7 +631,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
}
// GetUserSetting gets setting from Redis first, falls back to DB if needed
func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) {
var setting string
defer func() {
// Update Redis cache asynchronously on successful DB read
@@ -653,15 +653,12 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
if err != nil {
return map[string]interface{}{}, err
return settingMap, err
}
toMap, err := common.StrToMap(setting)
if err != nil {
common.SysError("failed to convert setting to map: " + err.Error())
return nil, fmt.Errorf("failed to convert setting to map")
userBase := &UserBase{
Setting: setting,
}
return toMap, nil
return userBase.GetSetting(), nil
}
func IncreaseUserQuota(id int, quota int, db bool) (err error) {

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"one-api/common"
"one-api/constant"
"one-api/dto"
"time"
"github.com/gin-gonic/gin"
@@ -32,25 +33,15 @@ func (user *UserBase) WriteContext(c *gin.Context) {
common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
}
func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
func (user *UserBase) GetSetting() dto.UserSetting {
setting := dto.UserSetting{}
if user.Setting != "" {
err := common.Unmarshal([]byte(user.Setting), &setting)
if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error())
}
}
toMap, err := common.StrToMap(user.Setting)
if err != nil {
common.SysError("failed to convert user setting to map: " + err.Error())
return nil
}
return toMap
}
func (user *UserBase) SetSetting(setting map[string]interface{}) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
return setting
}
// getUserCacheKey returns the key for user cache
@@ -179,11 +170,10 @@ func getUserNameCache(userId int) (string, error) {
return cache.Username, nil
}
func getUserSettingCache(userId int) (map[string]interface{}, error) {
setting := make(map[string]interface{})
func getUserSettingCache(userId int) (dto.UserSetting, error) {
cache, err := GetUserCache(userId)
if err != nil {
return setting, err
return dto.UserSetting{}, err
}
return cache.GetSetting(), nil
}