diff --git a/common/gin.go b/common/gin.go index 76b59d68..f876a92b 100644 --- a/common/gin.go +++ b/common/gin.go @@ -76,3 +76,13 @@ func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string] func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time { return c.GetTime(string(key)) } + +func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) { + if value, ok := c.Get(string(key)); ok { + if v, ok := value.(T); ok { + return v, true + } + } + var t T + return t, false +} diff --git a/common/str.go b/common/str.go index 5906f923..88b58c72 100644 --- a/common/str.go +++ b/common/str.go @@ -1,6 +1,7 @@ package common import ( + "encoding/base64" "encoding/json" "math/rand" "strconv" @@ -82,3 +83,15 @@ func StringToByteSlice(s string) []byte { tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]} return *(*[]byte)(unsafe.Pointer(&tmp2)) } + +func EncodeBase64(str string) string { + return base64.StdEncoding.EncodeToString([]byte(str)) +} + +func GetJsonString(data any) string { + if data == nil { + return "" + } + b, _ := json.Marshal(data) + return string(b) +} diff --git a/constant/channel_setting.go b/constant/channel_setting.go deleted file mode 100644 index e06e7eb1..00000000 --- a/constant/channel_setting.go +++ /dev/null @@ -1,7 +0,0 @@ -package constant - -var ( - ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式 - ChanelSettingProxy = "proxy" // Proxy 代理 - ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent -) diff --git a/constant/user_setting.go b/constant/user_setting.go deleted file mode 100644 index 7e79035e..00000000 --- a/constant/user_setting.go +++ /dev/null @@ -1,16 +0,0 @@ -package constant - -var ( - UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型 - UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值 - UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址 - UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥 - UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址 - UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型 - UserSettingRecordIpLog = "record_ip_log" // 是否记录请求和错误日志IP -) - -var ( - NotifyTypeEmail = "email" // Email 邮件 - NotifyTypeWebhook = "webhook" // Webhook -) diff --git a/controller/channel-test.go b/controller/channel-test.go index 8fd3d2d8..89c1a133 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -174,8 +174,19 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr consumedTime := float64(milliseconds) / 1000.0 other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio, usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) - model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试", - quota, "模型测试", 0, quota, int(consumedTime), false, info.UsingGroup, other) + model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{ + ChannelId: channel.Id, + PromptTokens: usage.PromptTokens, + CompletionTokens: usage.CompletionTokens, + ModelName: info.OriginModelName, + TokenName: "模型测试", + Quota: quota, + Content: "模型测试", + UseTimeSeconds: int(consumedTime), + IsStream: false, + Group: info.UsingGroup, + Other: other, + }) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) return nil, nil } @@ -342,6 +353,10 @@ func TestAllChannels(c *gin.Context) { } func AutomaticallyTestChannels(frequency int) { + if frequency <= 0 { + common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test") + return + } for { time.Sleep(time.Duration(frequency) * time.Minute) common.SysLog("testing all channels") diff --git a/controller/channel.go b/controller/channel.go index 0f374fe4..c9f20fa5 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -228,7 +228,7 @@ func FetchUpstreamModels(c *gin.Context) { } func FixChannelsAbilities(c *gin.Context) { - count, err := model.FixAbility() + success, fails, err := model.FixAbility() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -239,7 +239,10 @@ func FixChannelsAbilities(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": count, + "data": gin.H{ + "success": success, + "fails": fails, + }, }) } @@ -425,6 +428,16 @@ func AddChannel(c *gin.Context) { }) return } + + err = addChannelRequest.Channel.ValidateSettings() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "channel setting 格式错误:" + err.Error(), + }) + return + } + if addChannelRequest.Channel == nil || addChannelRequest.Channel.Key == "" { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -713,6 +726,14 @@ func UpdateChannel(c *gin.Context) { }) return } + err = channel.ValidateSettings() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "channel setting 格式错误:" + err.Error(), + }) + return + } if channel.Type == constant.ChannelTypeVertexAi { if channel.Other == "" { c.JSON(http.StatusOK, gin.H{ diff --git a/controller/user.go b/controller/user.go index ca161f42..44450836 100644 --- a/controller/user.go +++ b/controller/user.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/dto" "one-api/model" "one-api/setting" "strconv" @@ -961,7 +962,7 @@ func UpdateUserSetting(c *gin.Context) { } // 验证预警类型 - if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook { + if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无效的预警类型", @@ -979,7 +980,7 @@ func UpdateUserSetting(c *gin.Context) { } // 如果是webhook类型,验证webhook地址 - if req.QuotaWarningType == constant.NotifyTypeWebhook { + if req.QuotaWarningType == dto.NotifyTypeWebhook { if req.WebhookUrl == "" { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -998,7 +999,7 @@ func UpdateUserSetting(c *gin.Context) { } // 如果是邮件类型,验证邮箱地址 - if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" { + if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" { // 验证邮箱格式 if !strings.Contains(req.NotificationEmail, "@") { c.JSON(http.StatusOK, gin.H{ @@ -1020,24 +1021,24 @@ func UpdateUserSetting(c *gin.Context) { } // 构建设置 - settings := map[string]interface{}{ - constant.UserSettingNotifyType: req.QuotaWarningType, - constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold, - "accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel, - constant.UserSettingRecordIpLog: req.RecordIpLog, + settings := dto.UserSetting{ + NotifyType: req.QuotaWarningType, + QuotaWarningThreshold: req.QuotaWarningThreshold, + AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel, + RecordIpLog: req.RecordIpLog, } // 如果是webhook类型,添加webhook相关设置 - if req.QuotaWarningType == constant.NotifyTypeWebhook { - settings[constant.UserSettingWebhookUrl] = req.WebhookUrl + if req.QuotaWarningType == dto.NotifyTypeWebhook { + settings.WebhookUrl = req.WebhookUrl if req.WebhookSecret != "" { - settings[constant.UserSettingWebhookSecret] = req.WebhookSecret + settings.WebhookSecret = req.WebhookSecret } } // 如果提供了通知邮箱,添加到设置中 - if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" { - settings[constant.UserSettingNotificationEmail] = req.NotificationEmail + if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" { + settings.NotificationEmail = req.NotificationEmail } // 更新用户设置 diff --git a/dto/channel_settings.go b/dto/channel_settings.go new file mode 100644 index 00000000..871d6716 --- /dev/null +++ b/dto/channel_settings.go @@ -0,0 +1,7 @@ +package dto + +type ChannelSettings struct { + ForceFormat bool `json:"force_format,omitempty"` + ThinkingToContent bool `json:"thinking_to_content,omitempty"` + Proxy string `json:"proxy"` +} diff --git a/dto/user_settings.go b/dto/user_settings.go new file mode 100644 index 00000000..2e1a1541 --- /dev/null +++ b/dto/user_settings.go @@ -0,0 +1,16 @@ +package dto + +type UserSetting struct { + NotifyType string `json:"notify_type,omitempty"` // QuotaWarningType 额度预警类型 + QuotaWarningThreshold float64 `json:"quota_warning_threshold,omitempty"` // QuotaWarningThreshold 额度预警阈值 + WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址 + WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥 + NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址 + AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型 + RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP +} + +var ( + NotifyTypeEmail = "email" // Email 邮件 + NotifyTypeWebhook = "webhook" // Webhook +) diff --git a/main.go b/main.go index 727d5db6..996b65b3 100644 --- a/main.go +++ b/main.go @@ -39,7 +39,6 @@ func main() { return } - common.SetupLogger() common.SysLog("New API " + common.Version + " started") if os.Getenv("GIN_MODE") != "debug" { gin.SetMode(gin.ReleaseMode) @@ -69,9 +68,9 @@ func main() { if r := recover(); r != nil { common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) // Retry once - _, fixErr := model.FixAbility() + _, _, fixErr := model.FixAbility() if fixErr != nil { - common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error())) + common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error())) } } }() @@ -169,6 +168,8 @@ func InitResources() error { common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.") } + common.SetupLogger() + // 加载环境变量 common.InitEnv() diff --git a/model/ability.go b/model/ability.go index fb5301fe..ed124676 100644 --- a/model/ability.go +++ b/model/ability.go @@ -5,6 +5,7 @@ import ( "fmt" "one-api/common" "strings" + "sync" "github.com/samber/lo" "gorm.io/gorm" @@ -272,74 +273,45 @@ func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uin return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error } -func FixAbility() (int, error) { - var channelIds []int - count := 0 - // Find all channel ids from channel table - err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error +var fixLock = sync.Mutex{} + +func FixAbility() (int, int, error) { + lock := fixLock.TryLock() + if !lock { + return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试") + } + defer fixLock.Unlock() + var channels []*Channel + // Find all channels + err := DB.Model(&Channel{}).Find(&channels).Error if err != nil { - common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error())) - return 0, err + return 0, 0, err } - - // Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders - if len(channelIds) > 0 { - // Process deletion in chunks to avoid "too many placeholders" error - for _, chunk := range lo.Chunk(channelIds, 100) { - err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error - if err != nil { - common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error())) - return 0, err - } - } - } else { - // If no channels exist, delete all abilities - err = DB.Delete(&Ability{}).Error + if len(channels) == 0 { + return 0, 0, nil + } + successCount := 0 + failCount := 0 + for _, chunk := range lo.Chunk(channels, 50) { + ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id }) + // Delete all abilities of this channel + err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error if err != nil { - common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error())) - return 0, err + common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) + failCount += len(chunk) + continue } - common.SysLog("Delete all abilities successfully") - return 0, nil - } - - common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds)) - count += len(channelIds) - - // Use channelIds to find channel not in abilities table - var abilityChannelIds []int - err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error - if err != nil { - common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error())) - return count, err - } - - var channels []Channel - if len(abilityChannelIds) == 0 { - err = DB.Find(&channels).Error - } else { - // Process query in chunks to avoid "too many placeholders" error - err = nil - for _, chunk := range lo.Chunk(abilityChannelIds, 100) { - var channelsChunk []Channel - err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error + // Then add new abilities + for _, channel := range chunk { + err = channel.AddAbilities() if err != nil { - common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error())) - return count, err + common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) + failCount++ + } else { + successCount++ } - channels = append(channels, channelsChunk...) - } - } - - for _, channel := range channels { - err := channel.UpdateAbilities(nil) - if err != nil { - common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error())) - } else { - common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id)) - count++ } } InitChannelCache() - return count, nil + return successCount, failCount, nil } diff --git a/model/channel.go b/model/channel.go index 787cb33c..9d2ad853 100644 --- a/model/channel.go +++ b/model/channel.go @@ -7,6 +7,7 @@ import ( "math/rand" "one-api/common" "one-api/constant" + "one-api/dto" "strings" "sync" @@ -610,8 +611,19 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str return tags, nil } -func (channel *Channel) GetSetting() map[string]interface{} { - setting := make(map[string]interface{}) +func (channel *Channel) ValidateSettings() error { + channelParams := &dto.ChannelSettings{} + if channel.Setting != nil && *channel.Setting != "" { + err := json.Unmarshal([]byte(*channel.Setting), channelParams) + if err != nil { + return err + } + } + return nil +} + +func (channel *Channel) GetSetting() dto.ChannelSettings { + setting := dto.ChannelSettings{} if channel.Setting != nil && *channel.Setting != "" { err := json.Unmarshal([]byte(*channel.Setting), &setting) if err != nil { @@ -621,7 +633,7 @@ func (channel *Channel) GetSetting() map[string]interface{} { return setting } -func (channel *Channel) SetSetting(setting map[string]interface{}) { +func (channel *Channel) SetSetting(setting dto.ChannelSettings) { settingBytes, err := json.Marshal(setting) if err != nil { common.SysError("failed to marshal setting: " + err.Error()) diff --git a/model/log.go b/model/log.go index 1550aa91..45923075 100644 --- a/model/log.go +++ b/model/log.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "one-api/common" - "one-api/constant" "os" "strings" "time" @@ -100,10 +99,8 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, // 判断是否需要记录 IP needRecordIp := false if settingMap, err := GetUserSetting(userId, false); err == nil { - if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok { - if vb, ok := v.(bool); ok && vb { - needRecordIp = true - } + if settingMap.RecordIpLog { + needRecordIp = true } } log := &Log{ @@ -136,22 +133,34 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, } } -func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int, - modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, - isStream bool, group string, other map[string]interface{}) { - common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) +type RecordConsumeLogParams struct { + ChannelId int `json:"channel_id"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + ModelName string `json:"model_name"` + TokenName string `json:"token_name"` + Quota int `json:"quota"` + Content string `json:"content"` + TokenId int `json:"token_id"` + UserQuota int `json:"user_quota"` + UseTimeSeconds int `json:"use_time_seconds"` + IsStream bool `json:"is_stream"` + Group string `json:"group"` + Other map[string]interface{} `json:"other"` +} + +func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) { + common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params))) if !common.LogConsumeEnabled { return } username := c.GetString("username") - otherStr := common.MapToJsonStr(other) + otherStr := common.MapToJsonStr(params.Other) // 判断是否需要记录 IP needRecordIp := false if settingMap, err := GetUserSetting(userId, false); err == nil { - if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok { - if vb, ok := v.(bool); ok && vb { - needRecordIp = true - } + if settingMap.RecordIpLog { + needRecordIp = true } } log := &Log{ @@ -159,17 +168,17 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in Username: username, CreatedAt: common.GetTimestamp(), Type: LogTypeConsume, - Content: content, - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TokenName: tokenName, - ModelName: modelName, - Quota: quota, - ChannelId: channelId, - TokenId: tokenId, - UseTime: useTimeSeconds, - IsStream: isStream, - Group: group, + Content: params.Content, + PromptTokens: params.PromptTokens, + CompletionTokens: params.CompletionTokens, + TokenName: params.TokenName, + ModelName: params.ModelName, + Quota: params.Quota, + ChannelId: params.ChannelId, + TokenId: params.TokenId, + UseTime: params.UseTimeSeconds, + IsStream: params.IsStream, + Group: params.Group, Ip: func() string { if needRecordIp { return c.ClientIP() @@ -184,7 +193,7 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in } if common.DataExportEnabled { gopool.Go(func() { - LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens) + LogQuotaData(userId, username, params.ModelName, params.Quota, common.GetTimestamp(), params.PromptTokens+params.CompletionTokens) }) } } diff --git a/model/user.go b/model/user.go index 634d6754..6bb5a867 100644 --- a/model/user.go +++ b/model/user.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/dto" "strconv" "strings" @@ -68,19 +69,18 @@ func (user *User) SetAccessToken(token string) { user.AccessToken = &token } -func (user *User) GetSetting() (map[string]interface{}, error) { - if user.Setting == "" { - return map[string]interface{}{}, nil +func (user *User) GetSetting() dto.UserSetting { + setting := dto.UserSetting{} + if user.Setting != "" { + err := json.Unmarshal([]byte(user.Setting), &setting) + if err != nil { + common.SysError("failed to unmarshal setting: " + err.Error()) + } } - toMap, err := common.StrToMap(user.Setting) - if err != nil { - common.SysError("failed to convert setting to map: " + err.Error()) - return nil, fmt.Errorf("failed to convert setting to map") - } - return toMap, nil + return setting } -func (user *User) SetSetting(setting map[string]interface{}) { +func (user *User) SetSetting(setting dto.UserSetting) { settingBytes, err := json.Marshal(setting) if err != nil { common.SysError("failed to marshal setting: " + err.Error()) @@ -631,7 +631,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) { } // GetUserSetting gets setting from Redis first, falls back to DB if needed -func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) { +func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) { var setting string defer func() { // Update Redis cache asynchronously on successful DB read @@ -653,15 +653,12 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err fromDB = true err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error if err != nil { - return map[string]interface{}{}, err + return settingMap, err } - - toMap, err := common.StrToMap(setting) - if err != nil { - common.SysError("failed to convert setting to map: " + err.Error()) - return nil, fmt.Errorf("failed to convert setting to map") + userBase := &UserBase{ + Setting: setting, } - return toMap, nil + return userBase.GetSetting(), nil } func IncreaseUserQuota(id int, quota int, db bool) (err error) { diff --git a/model/user_cache.go b/model/user_cache.go index 44eaa842..87fa973a 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -5,6 +5,7 @@ import ( "fmt" "one-api/common" "one-api/constant" + "one-api/dto" "time" "github.com/gin-gonic/gin" @@ -32,25 +33,15 @@ func (user *UserBase) WriteContext(c *gin.Context) { common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting()) } -func (user *UserBase) GetSetting() map[string]interface{} { - if user.Setting == "" { - return nil +func (user *UserBase) GetSetting() dto.UserSetting { + setting := dto.UserSetting{} + if user.Setting != "" { + err := common.Unmarshal([]byte(user.Setting), &setting) + if err != nil { + common.SysError("failed to unmarshal setting: " + err.Error()) + } } - toMap, err := common.StrToMap(user.Setting) - if err != nil { - common.SysError("failed to convert user setting to map: " + err.Error()) - return nil - } - return toMap -} - -func (user *UserBase) SetSetting(setting map[string]interface{}) { - settingBytes, err := json.Marshal(setting) - if err != nil { - common.SysError("failed to marshal setting: " + err.Error()) - return - } - user.Setting = string(settingBytes) + return setting } // getUserCacheKey returns the key for user cache @@ -179,11 +170,10 @@ func getUserNameCache(userId int) (string, error) { return cache.Username, nil } -func getUserSettingCache(userId int) (map[string]interface{}, error) { - setting := make(map[string]interface{}) +func getUserSettingCache(userId int) (dto.UserSetting, error) { cache, err := GetUserCache(userId) if err != nil { - return setting, err + return dto.UserSetting{}, err } return cache.GetSetting(), nil } diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index c3da5134..97887266 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -206,8 +206,8 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error { func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) { var client *http.Client var err error - if proxyURL, ok := info.ChannelSetting["proxy"]; ok { - client, err = service.NewProxyHttpClient(proxyURL.(string)) + if info.ChannelSetting.Proxy != "" { + client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy) if err != nil { return nil, fmt.Errorf("new proxy http client failed: %w", err) } diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index 94c53711..375fd531 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -43,7 +43,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - req.Set("Authorization", "Bearer "+info.ApiKey) + keyParts := strings.Split(info.ApiKey, "|") + if len(keyParts) == 0 || keyParts[0] == "" { + return errors.New("invalid API key: authorization token is required") + } + if len(keyParts) > 1 { + if keyParts[1] != "" { + req.Set("appid", keyParts[1]) + } + } + req.Set("Authorization", "Bearer "+keyParts[0]) return nil } diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 18604b16..42f8503e 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -279,8 +279,8 @@ func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*ht func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) { var client *http.Client var err error // 声明 err 变量 - if proxyURL, ok := info.ChannelSetting["proxy"]; ok { - client, err = service.NewProxyHttpClient(proxyURL.(string)) + if info.ChannelSetting.Proxy != "" { + client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy) if err != nil { return nil, fmt.Errorf("new proxy http client failed: %w", err) } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 02178cb8..217790a7 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -54,7 +54,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType // initialize ThinkingContentInfo when thinking_to_content is enabled - if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok && think2Content { + if info.ChannelSetting.ThinkingToContent { info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{ IsFirstThinkingContent: true, SendLastThinkingContent: false, @@ -146,7 +146,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info * header.Set("Authorization", "Bearer "+info.ApiKey) } if info.ChannelType == constant.ChannelTypeOpenRouter { - header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api") + header.Set("HTTP-Referer", "https://www.newapi.ai") header.Set("X-Title", "New API") } return nil diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 5e2a7da9..bfe8bcd3 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -126,12 +126,12 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re var forceFormat bool var thinkToContent bool - if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { - forceFormat = forceFmt + if info.ChannelSetting.ForceFormat { + forceFormat = true } - if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok { - thinkToContent = think2Content + if info.ChannelSetting.ThinkingToContent { + thinkToContent = true } var ( @@ -199,8 +199,8 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo } forceFormat := false - if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { - forceFormat = forceFmt + if info.ChannelSetting.ForceFormat { + forceFormat = true } if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { diff --git a/relay/channel/vertex/service_account.go b/relay/channel/vertex/service_account.go index 1d41c945..5a97c021 100644 --- a/relay/channel/vertex/service_account.go +++ b/relay/channel/vertex/service_account.go @@ -106,8 +106,8 @@ func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (s var client *http.Client var err error - if proxyURL, ok := info.ChannelSetting["proxy"]; ok { - client, err = service.NewProxyHttpClient(proxyURL.(string)) + if info.ChannelSetting.Proxy != "" { + client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy) if err != nil { return "", fmt.Errorf("new proxy http client failed: %w", err) } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index ce170df4..beada0ee 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -97,9 +97,9 @@ type RelayInfo struct { IsFirstRequest bool AudioUsage bool ReasoningEffort string - ChannelSetting map[string]interface{} + ChannelSetting dto.ChannelSettings ParamOverride map[string]interface{} - UserSetting map[string]interface{} + UserSetting dto.UserSetting UserEmail string UserQuota int RelayFormat string @@ -213,7 +213,6 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo { func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId) - channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting) paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId) @@ -227,7 +226,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { info := &RelayInfo{ UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota), - UserSetting: common.GetContextKeyStringMap(c, constant.ContextKeyUserSetting), UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail), isFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), @@ -246,12 +244,12 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), //RecodeModelName: c.GetString("original_model"), - IsModelMapped: false, - ApiType: apiType, - ApiVersion: c.GetString("api_version"), - ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), - Organization: c.GetString("channel_organization"), - ChannelSetting: channelSetting, + IsModelMapped: false, + ApiType: apiType, + ApiVersion: c.GetString("api_version"), + ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + Organization: c.GetString("channel_organization"), + ChannelCreateTime: c.GetInt64("channel_create_time"), ParamOverride: paramOverride, RelayFormat: RelayFormatOpenAI, @@ -277,6 +275,16 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { if streamSupportedChannels[info.ChannelType] { info.SupportStreamOptions = true } + + channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting) + if ok { + info.ChannelSetting = channelSetting + } + userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting) + if ok { + info.UserSetting = userSetting + } + return info } diff --git a/relay/helper/price.go b/relay/helper/price.go index ab614cbd..9995db2f 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -3,7 +3,6 @@ package helper import ( "fmt" "one-api/common" - constant2 "one-api/constant" relaycommon "one-api/relay/common" "one-api/setting/ratio_setting" @@ -83,11 +82,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens modelRatio, success = ratio_setting.GetModelRatio(info.OriginModelName) if !success { acceptUnsetRatio := false - if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok { - b, ok := accept.(bool) - if ok { - acceptUnsetRatio = b - } + if info.UserSetting.AcceptUnsetRatioModel { + acceptUnsetRatio = true } if !acceptUnsetRatio { return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName) diff --git a/relay/relay-mj.go b/relay/relay-mj.go index cc09e4a6..f23f8152 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -34,14 +34,13 @@ func RelayMidjourneyImage(c *gin.Context) { } var httpClient *http.Client if channel, err := model.CacheGetChannel(midjourneyTask.ChannelId); err == nil { - if proxy, ok := channel.GetSetting()["proxy"]; ok { - if proxyURL, ok := proxy.(string); ok && proxyURL != "" { - if httpClient, err = service.NewProxyHttpClient(proxyURL); err != nil { - c.JSON(400, gin.H{ - "error": "proxy_url_invalid", - }) - return - } + proxy := channel.GetSetting().Proxy + if proxy != "" { + if httpClient, err = service.NewProxyHttpClient(proxy); err != nil { + c.JSON(400, gin.H{ + "error": "proxy_url_invalid", + }) + return } } } @@ -175,7 +174,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { startTime := time.Now().UnixNano() / int64(time.Millisecond) tokenId := c.GetInt("token_id") userId := c.GetInt("id") - group := c.GetString("group") + //group := c.GetString("group") channelId := c.GetInt("channel_id") relayInfo := relaycommon.GenRelayInfo(c) var swapFaceRequest dto.SwapFaceRequest @@ -221,8 +220,17 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace) other := service.GenerateMjOtherInfo(priceData) - model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName, - priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other) + model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: channelId, + ModelName: modelName, + TokenName: tokenName, + Quota: priceData.Quota, + Content: logContent, + TokenId: tokenId, + UserQuota: userQuota, + Group: relayInfo.UsingGroup, + Other: other, + }) model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota) model.UpdateChannelUsedQuota(channelId, priceData.Quota) } @@ -363,7 +371,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse { - tokenId := c.GetInt("token_id") + //tokenId := c.GetInt("token_id") //channelType := c.GetInt("channel") userId := c.GetInt("id") group := c.GetString("group") @@ -518,8 +526,17 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result) other := service.GenerateMjOtherInfo(priceData) - model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName, - priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other) + model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: channelId, + ModelName: modelName, + TokenName: tokenName, + Quota: priceData.Quota, + Content: logContent, + TokenId: relayInfo.TokenId, + UserQuota: userQuota, + Group: group, + Other: other, + }) model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota) model.UpdateChannelUsedQuota(channelId, priceData.Quota) } diff --git a/relay/relay-text.go b/relay/relay-text.go index a3917f1c..46120529 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -537,6 +537,19 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other["audio_input_token_count"] = audioTokens other["audio_input_price"] = audioInputPrice } - model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other) + model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: relayInfo.ChannelId, + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + ModelName: logModel, + TokenName: tokenName, + Quota: quota, + Content: logContent, + TokenId: relayInfo.TokenId, + UserQuota: userQuota, + UseTimeSeconds: int(useTimeSeconds), + IsStream: relayInfo.IsStream, + Group: relayInfo.UsingGroup, + Other: other, + }) } diff --git a/relay/relay_task.go b/relay/relay_task.go index 702cff4c..ce6b93ce 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -139,8 +139,17 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { if hasUserGroupRatio { other["user_group_ratio"] = userGroupRatio } - model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0, - modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.UsingGroup, other) + model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: relayInfo.ChannelId, + ModelName: modelName, + TokenName: tokenName, + Quota: quota, + Content: logContent, + TokenId: relayInfo.TokenId, + UserQuota: userQuota, + Group: relayInfo.UsingGroup, + Other: other, + }) model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } diff --git a/service/quota.go b/service/quota.go index bc3ef296..7a6177de 100644 --- a/service/quota.go +++ b/service/quota.go @@ -209,8 +209,21 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod } other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) - model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other) + model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: relayInfo.ChannelId, + PromptTokens: usage.InputTokens, + CompletionTokens: usage.OutputTokens, + ModelName: logModel, + TokenName: tokenName, + Quota: quota, + Content: logContent, + TokenId: relayInfo.TokenId, + UserQuota: userQuota, + UseTimeSeconds: int(useTimeSeconds), + IsStream: relayInfo.IsStream, + Group: relayInfo.UsingGroup, + Other: other, + }) } func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, @@ -286,8 +299,22 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) - model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other) + model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: relayInfo.ChannelId, + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + ModelName: modelName, + TokenName: tokenName, + Quota: quota, + Content: logContent, + TokenId: relayInfo.TokenId, + UserQuota: userQuota, + UseTimeSeconds: int(useTimeSeconds), + IsStream: relayInfo.IsStream, + Group: relayInfo.UsingGroup, + Other: other, + }) + } func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int { @@ -384,8 +411,21 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) - model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other) + model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: relayInfo.ChannelId, + PromptTokens: usage.PromptTokens, + CompletionTokens: usage.CompletionTokens, + ModelName: logModel, + TokenName: tokenName, + Quota: quota, + Content: logContent, + TokenId: relayInfo.TokenId, + UserQuota: userQuota, + UseTimeSeconds: int(useTimeSeconds), + IsStream: relayInfo.IsStream, + Group: relayInfo.UsingGroup, + Other: other, + }) } func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { @@ -447,8 +487,8 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon gopool.Go(func() { userSetting := relayInfo.UserSetting threshold := common.QuotaRemindThreshold - if userCustomThreshold, ok := userSetting[constant.UserSettingQuotaWarningThreshold]; ok { - threshold = int(userCustomThreshold.(float64)) + if userSetting.QuotaWarningThreshold != 0 { + threshold = int(userSetting.QuotaWarningThreshold) } //noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0 diff --git a/service/user_notify.go b/service/user_notify.go index 51f1ff99..96664007 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -3,7 +3,6 @@ package service import ( "fmt" "one-api/common" - "one-api/constant" "one-api/dto" "one-api/model" "strings" @@ -17,10 +16,10 @@ func NotifyRootUser(t string, subject string, content string) { } } -func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error { - notifyType, ok := userSetting[constant.UserSettingNotifyType] - if !ok { - notifyType = constant.NotifyTypeEmail +func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data dto.Notify) error { + notifyType := userSetting.NotifyType + if notifyType == "" { + notifyType = dto.NotifyTypeEmail } // Check notification limit @@ -34,34 +33,23 @@ func NotifyUser(userId int, userEmail string, userSetting map[string]interface{} } switch notifyType { - case constant.NotifyTypeEmail: + case dto.NotifyTypeEmail: // check setting email - if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok { - userEmail = settingEmail.(string) - } + userEmail = userSetting.NotificationEmail if userEmail == "" { common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId)) return nil } return sendEmailNotify(userEmail, data) - case constant.NotifyTypeWebhook: - webhookURL, ok := userSetting[constant.UserSettingWebhookUrl] - if !ok { + case dto.NotifyTypeWebhook: + webhookURLStr := userSetting.WebhookUrl + if webhookURLStr == "" { common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId)) return nil } - webhookURLStr, ok := webhookURL.(string) - if !ok { - common.SysError(fmt.Sprintf("user %d webhook url is not string type", userId)) - return nil - } // 获取 webhook secret - var webhookSecret string - if secret, ok := userSetting[constant.UserSettingWebhookSecret]; ok { - webhookSecret, _ = secret.(string) - } - + webhookSecret := userSetting.WebhookSecret return SendWebhookNotify(webhookURLStr, webhookSecret, data) } return nil diff --git a/setting/chat.go b/setting/chat.go index ef308000..53cb655a 100644 --- a/setting/chat.go +++ b/setting/chat.go @@ -6,8 +6,11 @@ import ( ) var Chats = []map[string]string{ + //{ + // "ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}", + //}, { - "ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}", + "Cherry Studio": "cherrystudio://providers/api-keys?v=1&data={cherryConfig}", }, { "Lobe Chat 官方示例": "https://chat-preview.lobehub.com/?settings={\"keyVaults\":{\"openai\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\"}}}", diff --git a/web/src/components/table/ChannelsTable.js b/web/src/components/table/ChannelsTable.js index 0e84437d..810993c4 100644 --- a/web/src/components/table/ChannelsTable.js +++ b/web/src/components/table/ChannelsTable.js @@ -1461,9 +1461,9 @@ const ChannelsTable = () => { const fixChannelsAbilities = async () => { const res = await API.post(`/api/channel/fix`); - const { success, message, data } = res.data; + const { success, message, data } = res.data; if (success) { - showSuccess(t('已修复 ${data} 个通道!').replace('${data}', data)); + showSuccess(t('已修复 ${success} 个通道,失败 ${fails} 个通道。').replace('${success}', data.success).replace('${fails}', data.fails)); await refresh(); } else { showError(message); diff --git a/web/src/components/table/TokensTable.js b/web/src/components/table/TokensTable.js index f91f7b82..a6d669a6 100644 --- a/web/src/components/table/TokensTable.js +++ b/web/src/components/table/TokensTable.js @@ -432,9 +432,22 @@ const TokensTable = () => { if (serverAddress === '') { serverAddress = window.location.origin; } - let encodedServerAddress = encodeURIComponent(serverAddress); - url = url.replaceAll('{address}', encodedServerAddress); - url = url.replaceAll('{key}', 'sk-' + record.key); + if (url.includes('{cherryConfig}') === true) { + let cherryConfig = { + id: 'new-api', + baseUrl: serverAddress, + apiKey: 'sk-' + record.key, + } + // 替换 {cherryConfig} 为base64编码的JSON字符串 + let encodedConfig = encodeURIComponent( + btoa(JSON.stringify(cherryConfig)) + ); + url = url.replaceAll('{cherryConfig}', encodedConfig); + } else { + let encodedServerAddress = encodeURIComponent(serverAddress); + url = url.replaceAll('{address}', encodedServerAddress); + url = url.replaceAll('{key}', 'sk-' + record.key); + } window.open(url, '_blank'); }; diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index de55fb9d..e73f9a2c 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -261,7 +261,7 @@ const EditChannel = (props) => { if (isEdit) { // 如果是编辑模式,使用已有的channel id获取模型列表 const res = await API.get('/api/channel/fetch_models/' + channelId); - if (res.data && res.data?.success) { + if (res.data && res.data.success) { models.push(...res.data.data); } else { err = true;