diff --git a/middleware/distributor.go b/middleware/distributor.go index 49fcf59b..fc9f5512 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -212,6 +212,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("channel_name", channel.Name) c.Set("channel_type", channel.Type) c.Set("channel_setting", channel.GetSetting()) + c.Set("param_override", channel.GetParamOverride()) if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { c.Set("channel_organization", *channel.OpenAIOrganization) } diff --git a/model/channel.go b/model/channel.go index d51a345e..91f5384c 100644 --- a/model/channel.go +++ b/model/channel.go @@ -36,6 +36,7 @@ type Channel struct { OtherInfo string `json:"other_info"` Tag *string `json:"tag" gorm:"index"` Setting *string `json:"setting" gorm:"type:text"` + ParamOverride *string `json:"param_override" gorm:"type:text"` } func (channel *Channel) GetModels() []string { @@ -511,6 +512,17 @@ func (channel *Channel) SetSetting(setting map[string]interface{}) { channel.Setting = common.GetPointer[string](string(settingBytes)) } +func (channel *Channel) GetParamOverride() map[string]interface{} { + paramOverride := make(map[string]interface{}) + if channel.ParamOverride != nil && *channel.ParamOverride != "" { + err := json.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride) + if err != nil { + common.SysError("failed to unmarshal param override: " + err.Error()) + } + } + return paramOverride +} + func GetChannelsByIds(ids []int) ([]*Channel, error) { var channels []*Channel err := DB.Where("id in (?)", ids).Find(&channels).Error diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 059c8284..0a7678ea 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -76,6 +76,7 @@ type RelayInfo struct { AudioUsage bool ReasoningEffort string ChannelSetting map[string]interface{} + ParamOverride map[string]interface{} UserSetting map[string]interface{} UserEmail string UserQuota int @@ -131,6 +132,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") channelSetting := c.GetStringMap("channel_setting") + paramOverride := c.GetStringMap("param_override") tokenId := c.GetInt("token_id") tokenKey := c.GetString("token_key") @@ -168,6 +170,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), Organization: c.GetString("channel_organization"), ChannelSetting: channelSetting, + ParamOverride: paramOverride, RelayFormat: RelayFormatOpenAI, ThinkingContentInfo: ThinkingContentInfo{ IsFirstThinkingContent: true, diff --git a/relay/relay-text.go b/relay/relay-text.go index c871b80b..7b2b7fc0 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -168,6 +168,23 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { if err != nil { return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) } + + // apply param override + if len(relayInfo.ParamOverride) > 0 { + reqMap := make(map[string]interface{}) + err = json.Unmarshal(jsonData, &reqMap) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError) + } + for key, value := range relayInfo.ParamOverride { + reqMap[key] = value + } + jsonData, err = json.Marshal(reqMap) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError) + } + } + if common.DebugEnabled { println("requestBody: ", string(jsonData)) } diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index 6e3ddf32..cf820b81 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -1275,6 +1275,7 @@ "代理站地址": "Base URL", "对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写": "For official channels, the new-api has a built-in address. Unless it is a third-party proxy site or a special Azure access address, there is no need to fill it in", "渠道额外设置": "Channel extra settings", + "参数覆盖": "Parameters override", "模型请求速率限制": "Model request rate limit", "启用用户模型请求速率限制(可能会影响高并发性能)": "Enable user model request rate limit (may affect high concurrency performance)", "限制周期": "Limit period", diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index f3024d9b..a127c09a 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -983,6 +983,23 @@ const EditChannel = (props) => { > + <> +