diff --git a/common/gin.go b/common/gin.go index 62c4c692..d428184a 100644 --- a/common/gin.go +++ b/common/gin.go @@ -76,3 +76,13 @@ func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string] func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time { return c.GetTime(string(key)) } + +func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) { + if value, ok := c.Get(string(key)); ok { + if v, ok := value.(T); ok { + return v, true + } + } + var t T + return t, false +} diff --git a/common/str.go b/common/str.go index 9eb19fe7..76dd801a 100644 --- a/common/str.go +++ b/common/str.go @@ -73,3 +73,11 @@ func StringToByteSlice(s string) []byte { func EncodeBase64(str string) string { return base64.StdEncoding.EncodeToString([]byte(str)) } + +func GetJsonString(data any) string { + if data == nil { + return "" + } + b, _ := json.Marshal(data) + return string(b) +} diff --git a/constant/channel_setting.go b/constant/channel_setting.go deleted file mode 100644 index e06e7eb1..00000000 --- a/constant/channel_setting.go +++ /dev/null @@ -1,7 +0,0 @@ -package constant - -var ( - ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式 - ChanelSettingProxy = "proxy" // Proxy 代理 - ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent -) diff --git a/constant/user_setting.go b/constant/user_setting.go deleted file mode 100644 index 7e79035e..00000000 --- a/constant/user_setting.go +++ /dev/null @@ -1,16 +0,0 @@ -package constant - -var ( - UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型 - UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值 - UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址 - UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥 - UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址 - UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型 - UserSettingRecordIpLog = "record_ip_log" // 是否记录请求和错误日志IP -) - -var ( - NotifyTypeEmail = "email" // Email 邮件 - NotifyTypeWebhook = "webhook" // Webhook -) diff --git a/controller/channel-test.go b/controller/channel-test.go index 0b474c25..009219a7 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -173,8 +173,19 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr consumedTime := float64(milliseconds) / 1000.0 other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio, usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) - model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试", - quota, "模型测试", 0, quota, int(consumedTime), false, info.UsingGroup, other) + model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{ + ChannelId: channel.Id, + PromptTokens: usage.PromptTokens, + CompletionTokens: usage.CompletionTokens, + ModelName: info.OriginModelName, + TokenName: "模型测试", + Quota: quota, + Content: "模型测试", + UseTimeSeconds: int(consumedTime), + IsStream: false, + Group: info.UsingGroup, + Other: other, + }) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) return nil, nil } diff --git a/controller/channel.go b/controller/channel.go index 98ef3c08..4c7d28f2 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -387,6 +387,14 @@ func AddChannel(c *gin.Context) { }) return } + err = channel.ValidateSettings() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "channel setting 格式错误:" + err.Error(), + }) + return + } channel.CreatedTime = common.GetTimestamp() keys := strings.Split(channel.Key, "\n") if channel.Type == constant.ChannelTypeVertexAi { @@ -614,6 +622,14 @@ func UpdateChannel(c *gin.Context) { }) return } + err = channel.ValidateSettings() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "channel setting 格式错误:" + err.Error(), + }) + return + } if channel.Type == constant.ChannelTypeVertexAi { if channel.Other == "" { c.JSON(http.StatusOK, gin.H{ diff --git a/controller/user.go b/controller/user.go index ca161f42..44450836 100644 --- a/controller/user.go +++ b/controller/user.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/dto" "one-api/model" "one-api/setting" "strconv" @@ -961,7 +962,7 @@ func UpdateUserSetting(c *gin.Context) { } // 验证预警类型 - if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook { + if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无效的预警类型", @@ -979,7 +980,7 @@ func UpdateUserSetting(c *gin.Context) { } // 如果是webhook类型,验证webhook地址 - if req.QuotaWarningType == constant.NotifyTypeWebhook { + if req.QuotaWarningType == dto.NotifyTypeWebhook { if req.WebhookUrl == "" { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -998,7 +999,7 @@ func UpdateUserSetting(c *gin.Context) { } // 如果是邮件类型,验证邮箱地址 - if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" { + if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" { // 验证邮箱格式 if !strings.Contains(req.NotificationEmail, "@") { c.JSON(http.StatusOK, gin.H{ @@ -1020,24 +1021,24 @@ func UpdateUserSetting(c *gin.Context) { } // 构建设置 - settings := map[string]interface{}{ - constant.UserSettingNotifyType: req.QuotaWarningType, - constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold, - "accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel, - constant.UserSettingRecordIpLog: req.RecordIpLog, + settings := dto.UserSetting{ + NotifyType: req.QuotaWarningType, + QuotaWarningThreshold: req.QuotaWarningThreshold, + AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel, + RecordIpLog: req.RecordIpLog, } // 如果是webhook类型,添加webhook相关设置 - if req.QuotaWarningType == constant.NotifyTypeWebhook { - settings[constant.UserSettingWebhookUrl] = req.WebhookUrl + if req.QuotaWarningType == dto.NotifyTypeWebhook { + settings.WebhookUrl = req.WebhookUrl if req.WebhookSecret != "" { - settings[constant.UserSettingWebhookSecret] = req.WebhookSecret + settings.WebhookSecret = req.WebhookSecret } } // 如果提供了通知邮箱,添加到设置中 - if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" { - settings[constant.UserSettingNotificationEmail] = req.NotificationEmail + if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" { + settings.NotificationEmail = req.NotificationEmail } // 更新用户设置 diff --git a/dto/channel_settings.go b/dto/channel_settings.go new file mode 100644 index 00000000..871d6716 --- /dev/null +++ b/dto/channel_settings.go @@ -0,0 +1,7 @@ +package dto + +type ChannelSettings struct { + ForceFormat bool `json:"force_format,omitempty"` + ThinkingToContent bool `json:"thinking_to_content,omitempty"` + Proxy string `json:"proxy"` +} diff --git a/dto/user_settings.go b/dto/user_settings.go new file mode 100644 index 00000000..2e1a1541 --- /dev/null +++ b/dto/user_settings.go @@ -0,0 +1,16 @@ +package dto + +type UserSetting struct { + NotifyType string `json:"notify_type,omitempty"` // QuotaWarningType 额度预警类型 + QuotaWarningThreshold float64 `json:"quota_warning_threshold,omitempty"` // QuotaWarningThreshold 额度预警阈值 + WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址 + WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥 + NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址 + AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型 + RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP +} + +var ( + NotifyTypeEmail = "email" // Email 邮件 + NotifyTypeWebhook = "webhook" // Webhook +) diff --git a/main.go b/main.go index 727d5db6..b89350b0 100644 --- a/main.go +++ b/main.go @@ -39,7 +39,6 @@ func main() { return } - common.SetupLogger() common.SysLog("New API " + common.Version + " started") if os.Getenv("GIN_MODE") != "debug" { gin.SetMode(gin.ReleaseMode) @@ -172,6 +171,8 @@ func InitResources() error { // 加载环境变量 common.InitEnv() + common.SetupLogger() + // Initialize model settings ratio_setting.InitRatioSettings() diff --git a/middleware/distributor.go b/middleware/distributor.go index 17916e7a..642b5253 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -247,9 +247,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode } c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) - c.Set("channel_type", channel.Type) + common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type) c.Set("channel_create_time", channel.CreatedTime) - c.Set("channel_setting", channel.GetSetting()) + common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting()) c.Set("param_override", channel.GetParamOverride()) if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { c.Set("channel_organization", *channel.OpenAIOrganization) @@ -258,7 +258,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("model_mapping", channel.GetModelMapping()) c.Set("status_code_mapping", channel.GetStatusCodeMapping()) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - c.Set("base_url", channel.GetBaseURL()) + common.SetContextKey(c, constant.ContextKeyBaseUrl, channel.GetBaseURL()) // TODO: api_version统一 switch channel.Type { case constant.ChannelTypeAzure: diff --git a/model/channel.go b/model/channel.go index 6cbd8adc..a5f307ef 100644 --- a/model/channel.go +++ b/model/channel.go @@ -3,6 +3,7 @@ package model import ( "encoding/json" "one-api/common" + "one-api/dto" "strings" "sync" @@ -514,8 +515,19 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str return tags, nil } -func (channel *Channel) GetSetting() map[string]interface{} { - setting := make(map[string]interface{}) +func (channel *Channel) ValidateSettings() error { + channelParams := &dto.ChannelSettings{} + if channel.Setting != nil && *channel.Setting != "" { + err := json.Unmarshal([]byte(*channel.Setting), channelParams) + if err != nil { + return err + } + } + return nil +} + +func (channel *Channel) GetSetting() dto.ChannelSettings { + setting := dto.ChannelSettings{} if channel.Setting != nil && *channel.Setting != "" { err := json.Unmarshal([]byte(*channel.Setting), &setting) if err != nil { @@ -525,7 +537,7 @@ func (channel *Channel) GetSetting() map[string]interface{} { return setting } -func (channel *Channel) SetSetting(setting map[string]interface{}) { +func (channel *Channel) SetSetting(setting dto.ChannelSettings) { settingBytes, err := json.Marshal(setting) if err != nil { common.SysError("failed to marshal setting: " + err.Error()) diff --git a/model/log.go b/model/log.go index b3fd1ad2..e2d1ee5a 100644 --- a/model/log.go +++ b/model/log.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "one-api/common" - "one-api/constant" "os" "strings" "time" @@ -100,10 +99,8 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, // 判断是否需要记录 IP needRecordIp := false if settingMap, err := GetUserSetting(userId, false); err == nil { - if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok { - if vb, ok := v.(bool); ok && vb { - needRecordIp = true - } + if settingMap.RecordIpLog { + needRecordIp = true } } log := &Log{ @@ -136,22 +133,34 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, } } -func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int, - modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, - isStream bool, group string, other map[string]interface{}) { - common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) +type RecordConsumeLogParams struct { + ChannelId int `json:"channel_id"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + ModelName string `json:"model_name"` + TokenName string `json:"token_name"` + Quota int `json:"quota"` + Content string `json:"content"` + TokenId int `json:"token_id"` + UserQuota int `json:"user_quota"` + UseTimeSeconds int `json:"use_time_seconds"` + IsStream bool `json:"is_stream"` + Group string `json:"group"` + Other map[string]interface{} `json:"other"` +} + +func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) { + common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params))) if !common.LogConsumeEnabled { return } username := c.GetString("username") - otherStr := common.MapToJsonStr(other) + otherStr := common.MapToJsonStr(params.Other) // 判断是否需要记录 IP needRecordIp := false if settingMap, err := GetUserSetting(userId, false); err == nil { - if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok { - if vb, ok := v.(bool); ok && vb { - needRecordIp = true - } + if settingMap.RecordIpLog { + needRecordIp = true } } log := &Log{ @@ -159,17 +168,17 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in Username: username, CreatedAt: common.GetTimestamp(), Type: LogTypeConsume, - Content: content, - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TokenName: tokenName, - ModelName: modelName, - Quota: quota, - ChannelId: channelId, - TokenId: tokenId, - UseTime: useTimeSeconds, - IsStream: isStream, - Group: group, + Content: params.Content, + PromptTokens: params.PromptTokens, + CompletionTokens: params.CompletionTokens, + TokenName: params.TokenName, + ModelName: params.ModelName, + Quota: params.Quota, + ChannelId: params.ChannelId, + TokenId: params.TokenId, + UseTime: params.UseTimeSeconds, + IsStream: params.IsStream, + Group: params.Group, Ip: func() string { if needRecordIp { return c.ClientIP() @@ -184,7 +193,7 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in } if common.DataExportEnabled { gopool.Go(func() { - LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens) + LogQuotaData(userId, username, params.ModelName, params.Quota, common.GetTimestamp(), params.PromptTokens+params.CompletionTokens) }) } } diff --git a/model/user.go b/model/user.go index bd685e54..6bb5a867 100644 --- a/model/user.go +++ b/model/user.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/dto" "strconv" "strings" @@ -68,14 +69,18 @@ func (user *User) SetAccessToken(token string) { user.AccessToken = &token } -func (user *User) GetSetting() map[string]interface{} { - if user.Setting == "" { - return nil +func (user *User) GetSetting() dto.UserSetting { + setting := dto.UserSetting{} + if user.Setting != "" { + err := json.Unmarshal([]byte(user.Setting), &setting) + if err != nil { + common.SysError("failed to unmarshal setting: " + err.Error()) + } } - return common.StrToMap(user.Setting) + return setting } -func (user *User) SetSetting(setting map[string]interface{}) { +func (user *User) SetSetting(setting dto.UserSetting) { settingBytes, err := json.Marshal(setting) if err != nil { common.SysError("failed to marshal setting: " + err.Error()) @@ -626,7 +631,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) { } // GetUserSetting gets setting from Redis first, falls back to DB if needed -func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) { +func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) { var setting string defer func() { // Update Redis cache asynchronously on successful DB read @@ -648,10 +653,12 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err fromDB = true err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error if err != nil { - return map[string]interface{}{}, err + return settingMap, err } - - return common.StrToMap(setting), nil + userBase := &UserBase{ + Setting: setting, + } + return userBase.GetSetting(), nil } func IncreaseUserQuota(id int, quota int, db bool) (err error) { diff --git a/model/user_cache.go b/model/user_cache.go index b4bc2f1e..a62d9773 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -5,6 +5,7 @@ import ( "fmt" "one-api/common" "one-api/constant" + "one-api/dto" "time" "github.com/gin-gonic/gin" @@ -32,20 +33,15 @@ func (user *UserBase) WriteContext(c *gin.Context) { common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting()) } -func (user *UserBase) GetSetting() map[string]interface{} { - if user.Setting == "" { - return nil +func (user *UserBase) GetSetting() dto.UserSetting { + setting := dto.UserSetting{} + if user.Setting != "" { + err := json.Unmarshal([]byte(user.Setting), &setting) + if err != nil { + common.SysError("failed to unmarshal setting: " + err.Error()) + } } - return common.StrToMap(user.Setting) -} - -func (user *UserBase) SetSetting(setting map[string]interface{}) { - settingBytes, err := json.Marshal(setting) - if err != nil { - common.SysError("failed to marshal setting: " + err.Error()) - return - } - user.Setting = string(settingBytes) + return setting } // getUserCacheKey returns the key for user cache @@ -174,11 +170,10 @@ func getUserNameCache(userId int) (string, error) { return cache.Username, nil } -func getUserSettingCache(userId int) (map[string]interface{}, error) { - setting := make(map[string]interface{}) +func getUserSettingCache(userId int) (dto.UserSetting, error) { cache, err := GetUserCache(userId) if err != nil { - return setting, err + return dto.UserSetting{}, err } return cache.GetSetting(), nil } diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index c3da5134..97887266 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -206,8 +206,8 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error { func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) { var client *http.Client var err error - if proxyURL, ok := info.ChannelSetting["proxy"]; ok { - client, err = service.NewProxyHttpClient(proxyURL.(string)) + if info.ChannelSetting.Proxy != "" { + client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy) if err != nil { return nil, fmt.Errorf("new proxy http client failed: %w", err) } diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 6c08261b..618fe16f 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -278,8 +278,8 @@ func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*ht func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) { var client *http.Client var err error // 声明 err 变量 - if proxyURL, ok := info.ChannelSetting["proxy"]; ok { - client, err = service.NewProxyHttpClient(proxyURL.(string)) + if info.ChannelSetting.Proxy != "" { + client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy) if err != nil { return nil, fmt.Errorf("new proxy http client failed: %w", err) } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index c0a3d7f8..367dbc47 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -53,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType // initialize ThinkingContentInfo when thinking_to_content is enabled - if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok && think2Content { + if info.ChannelSetting.ThinkingToContent { info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{ IsFirstThinkingContent: true, SendLastThinkingContent: false, diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 7c283bd0..6aa73274 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -124,12 +124,12 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel var forceFormat bool var thinkToContent bool - if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { - forceFormat = forceFmt + if info.ChannelSetting.ForceFormat { + forceFormat = true } - if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok { - thinkToContent = think2Content + if info.ChannelSetting.ThinkingToContent { + thinkToContent = true } var ( @@ -200,8 +200,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI } forceFormat := false - if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { - forceFormat = forceFmt + if info.ChannelSetting.ForceFormat { + forceFormat = true } if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { diff --git a/relay/channel/vertex/service_account.go b/relay/channel/vertex/service_account.go index 1d41c945..5a97c021 100644 --- a/relay/channel/vertex/service_account.go +++ b/relay/channel/vertex/service_account.go @@ -106,8 +106,8 @@ func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (s var client *http.Client var err error - if proxyURL, ok := info.ChannelSetting["proxy"]; ok { - client, err = service.NewProxyHttpClient(proxyURL.(string)) + if info.ChannelSetting.Proxy != "" { + client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy) if err != nil { return "", fmt.Errorf("new proxy http client failed: %w", err) } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 37161c16..2f5f5d38 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -97,9 +97,9 @@ type RelayInfo struct { IsFirstRequest bool AudioUsage bool ReasoningEffort string - ChannelSetting map[string]interface{} + ChannelSetting dto.ChannelSettings ParamOverride map[string]interface{} - UserSetting map[string]interface{} + UserSetting dto.UserSetting UserEmail string UserQuota int RelayFormat string @@ -213,7 +213,6 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo { func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId) - channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting) paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride) tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId) @@ -227,7 +226,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { info := &RelayInfo{ UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota), - UserSetting: common.GetContextKeyStringMap(c, constant.ContextKeyUserSetting), UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail), isFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), @@ -246,12 +244,12 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), //RecodeModelName: c.GetString("original_model"), - IsModelMapped: false, - ApiType: apiType, - ApiVersion: c.GetString("api_version"), - ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), - Organization: c.GetString("channel_organization"), - ChannelSetting: channelSetting, + IsModelMapped: false, + ApiType: apiType, + ApiVersion: c.GetString("api_version"), + ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + Organization: c.GetString("channel_organization"), + ChannelCreateTime: c.GetInt64("channel_create_time"), ParamOverride: paramOverride, RelayFormat: RelayFormatOpenAI, @@ -277,6 +275,16 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { if streamSupportedChannels[info.ChannelType] { info.SupportStreamOptions = true } + + channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting) + if ok { + info.ChannelSetting = channelSetting + } + userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting) + if ok { + info.UserSetting = userSetting + } + return info } diff --git a/relay/helper/price.go b/relay/helper/price.go index ab614cbd..9995db2f 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -3,7 +3,6 @@ package helper import ( "fmt" "one-api/common" - constant2 "one-api/constant" relaycommon "one-api/relay/common" "one-api/setting/ratio_setting" @@ -83,11 +82,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens modelRatio, success = ratio_setting.GetModelRatio(info.OriginModelName) if !success { acceptUnsetRatio := false - if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok { - b, ok := accept.(bool) - if ok { - acceptUnsetRatio = b - } + if info.UserSetting.AcceptUnsetRatioModel { + acceptUnsetRatio = true } if !acceptUnsetRatio { return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName) diff --git a/relay/relay-mj.go b/relay/relay-mj.go index cc09e4a6..f23f8152 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -34,14 +34,13 @@ func RelayMidjourneyImage(c *gin.Context) { } var httpClient *http.Client if channel, err := model.CacheGetChannel(midjourneyTask.ChannelId); err == nil { - if proxy, ok := channel.GetSetting()["proxy"]; ok { - if proxyURL, ok := proxy.(string); ok && proxyURL != "" { - if httpClient, err = service.NewProxyHttpClient(proxyURL); err != nil { - c.JSON(400, gin.H{ - "error": "proxy_url_invalid", - }) - return - } + proxy := channel.GetSetting().Proxy + if proxy != "" { + if httpClient, err = service.NewProxyHttpClient(proxy); err != nil { + c.JSON(400, gin.H{ + "error": "proxy_url_invalid", + }) + return } } } @@ -175,7 +174,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { startTime := time.Now().UnixNano() / int64(time.Millisecond) tokenId := c.GetInt("token_id") userId := c.GetInt("id") - group := c.GetString("group") + //group := c.GetString("group") channelId := c.GetInt("channel_id") relayInfo := relaycommon.GenRelayInfo(c) var swapFaceRequest dto.SwapFaceRequest @@ -221,8 +220,17 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace) other := service.GenerateMjOtherInfo(priceData) - model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName, - priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other) + model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: channelId, + ModelName: modelName, + TokenName: tokenName, + Quota: priceData.Quota, + Content: logContent, + TokenId: tokenId, + UserQuota: userQuota, + Group: relayInfo.UsingGroup, + Other: other, + }) model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota) model.UpdateChannelUsedQuota(channelId, priceData.Quota) } @@ -363,7 +371,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse { - tokenId := c.GetInt("token_id") + //tokenId := c.GetInt("token_id") //channelType := c.GetInt("channel") userId := c.GetInt("id") group := c.GetString("group") @@ -518,8 +526,17 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result) other := service.GenerateMjOtherInfo(priceData) - model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName, - priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other) + model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: channelId, + ModelName: modelName, + TokenName: tokenName, + Quota: priceData.Quota, + Content: logContent, + TokenId: relayInfo.TokenId, + UserQuota: userQuota, + Group: group, + Other: other, + }) model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota) model.UpdateChannelUsedQuota(channelId, priceData.Quota) } diff --git a/relay/relay-text.go b/relay/relay-text.go index e0c8f047..86b6c530 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -540,6 +540,19 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other["audio_input_token_count"] = audioTokens other["audio_input_price"] = audioInputPrice } - model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other) + model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: relayInfo.ChannelId, + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + ModelName: logModel, + TokenName: tokenName, + Quota: quota, + Content: logContent, + TokenId: relayInfo.TokenId, + UserQuota: userQuota, + UseTimeSeconds: int(useTimeSeconds), + IsStream: relayInfo.IsStream, + Group: relayInfo.UsingGroup, + Other: other, + }) } diff --git a/relay/relay_task.go b/relay/relay_task.go index 702cff4c..ce6b93ce 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -139,8 +139,17 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { if hasUserGroupRatio { other["user_group_ratio"] = userGroupRatio } - model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0, - modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.UsingGroup, other) + model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: relayInfo.ChannelId, + ModelName: modelName, + TokenName: tokenName, + Quota: quota, + Content: logContent, + TokenId: relayInfo.TokenId, + UserQuota: userQuota, + Group: relayInfo.UsingGroup, + Other: other, + }) model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } diff --git a/service/quota.go b/service/quota.go index bc3ef296..7a6177de 100644 --- a/service/quota.go +++ b/service/quota.go @@ -209,8 +209,21 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod } other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) - model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other) + model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: relayInfo.ChannelId, + PromptTokens: usage.InputTokens, + CompletionTokens: usage.OutputTokens, + ModelName: logModel, + TokenName: tokenName, + Quota: quota, + Content: logContent, + TokenId: relayInfo.TokenId, + UserQuota: userQuota, + UseTimeSeconds: int(useTimeSeconds), + IsStream: relayInfo.IsStream, + Group: relayInfo.UsingGroup, + Other: other, + }) } func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, @@ -286,8 +299,22 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) - model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other) + model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: relayInfo.ChannelId, + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + ModelName: modelName, + TokenName: tokenName, + Quota: quota, + Content: logContent, + TokenId: relayInfo.TokenId, + UserQuota: userQuota, + UseTimeSeconds: int(useTimeSeconds), + IsStream: relayInfo.IsStream, + Group: relayInfo.UsingGroup, + Other: other, + }) + } func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int { @@ -384,8 +411,21 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) - model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other) + model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ + ChannelId: relayInfo.ChannelId, + PromptTokens: usage.PromptTokens, + CompletionTokens: usage.CompletionTokens, + ModelName: logModel, + TokenName: tokenName, + Quota: quota, + Content: logContent, + TokenId: relayInfo.TokenId, + UserQuota: userQuota, + UseTimeSeconds: int(useTimeSeconds), + IsStream: relayInfo.IsStream, + Group: relayInfo.UsingGroup, + Other: other, + }) } func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { @@ -447,8 +487,8 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon gopool.Go(func() { userSetting := relayInfo.UserSetting threshold := common.QuotaRemindThreshold - if userCustomThreshold, ok := userSetting[constant.UserSettingQuotaWarningThreshold]; ok { - threshold = int(userCustomThreshold.(float64)) + if userSetting.QuotaWarningThreshold != 0 { + threshold = int(userSetting.QuotaWarningThreshold) } //noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0 diff --git a/service/user_notify.go b/service/user_notify.go index 51f1ff99..96664007 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -3,7 +3,6 @@ package service import ( "fmt" "one-api/common" - "one-api/constant" "one-api/dto" "one-api/model" "strings" @@ -17,10 +16,10 @@ func NotifyRootUser(t string, subject string, content string) { } } -func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error { - notifyType, ok := userSetting[constant.UserSettingNotifyType] - if !ok { - notifyType = constant.NotifyTypeEmail +func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data dto.Notify) error { + notifyType := userSetting.NotifyType + if notifyType == "" { + notifyType = dto.NotifyTypeEmail } // Check notification limit @@ -34,34 +33,23 @@ func NotifyUser(userId int, userEmail string, userSetting map[string]interface{} } switch notifyType { - case constant.NotifyTypeEmail: + case dto.NotifyTypeEmail: // check setting email - if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok { - userEmail = settingEmail.(string) - } + userEmail = userSetting.NotificationEmail if userEmail == "" { common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId)) return nil } return sendEmailNotify(userEmail, data) - case constant.NotifyTypeWebhook: - webhookURL, ok := userSetting[constant.UserSettingWebhookUrl] - if !ok { + case dto.NotifyTypeWebhook: + webhookURLStr := userSetting.WebhookUrl + if webhookURLStr == "" { common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId)) return nil } - webhookURLStr, ok := webhookURL.(string) - if !ok { - common.SysError(fmt.Sprintf("user %d webhook url is not string type", userId)) - return nil - } // 获取 webhook secret - var webhookSecret string - if secret, ok := userSetting[constant.UserSettingWebhookSecret]; ok { - webhookSecret, _ = secret.(string) - } - + webhookSecret := userSetting.WebhookSecret return SendWebhookNotify(webhookURLStr, webhookSecret, data) } return nil