Merge pull request #1644 from nekohy/feats-custom-request-headers
feats:add custom headers override
This commit is contained in:
@@ -27,6 +27,7 @@ const (
|
|||||||
ContextKeyChannelSetting ContextKey = "channel_setting"
|
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||||
ContextKeyChannelOtherSetting ContextKey = "channel_other_setting"
|
ContextKeyChannelOtherSetting ContextKey = "channel_other_setting"
|
||||||
ContextKeyChannelParamOverride ContextKey = "param_override"
|
ContextKeyChannelParamOverride ContextKey = "param_override"
|
||||||
|
ContextKeyChannelHeaderOverride ContextKey = "header_override"
|
||||||
ContextKeyChannelOrganization ContextKey = "channel_organization"
|
ContextKeyChannelOrganization ContextKey = "channel_organization"
|
||||||
ContextKeyChannelAutoBan ContextKey = "auto_ban"
|
ContextKeyChannelAutoBan ContextKey = "auto_ban"
|
||||||
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
||||||
|
|||||||
@@ -248,6 +248,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
|
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings())
|
common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings())
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
|
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, channel.GetHeaderOverride())
|
||||||
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
|
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
|
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ type Channel struct {
|
|||||||
Tag *string `json:"tag" gorm:"index"`
|
Tag *string `json:"tag" gorm:"index"`
|
||||||
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
||||||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||||||
|
HeaderOverride *string `json:"header_override" gorm:"type:text"`
|
||||||
// add after v0.8.5
|
// add after v0.8.5
|
||||||
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||||||
|
|
||||||
@@ -875,6 +876,17 @@ func (channel *Channel) GetParamOverride() map[string]interface{} {
|
|||||||
return paramOverride
|
return paramOverride
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetHeaderOverride() map[string]interface{} {
|
||||||
|
headerOverride := make(map[string]interface{})
|
||||||
|
if channel.HeaderOverride != nil && *channel.HeaderOverride != "" {
|
||||||
|
err := common.Unmarshal([]byte(*channel.HeaderOverride), &headerOverride)
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog(fmt.Sprintf("failed to unmarshal header override: channel_id=%d, error=%v", channel.Id, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return headerOverride
|
||||||
|
}
|
||||||
|
|
||||||
func GetChannelsByIds(ids []int) ([]*Channel, error) {
|
func GetChannelsByIds(ids []int) ([]*Channel, error) {
|
||||||
var channels []*Channel
|
var channels []*Channel
|
||||||
err := DB.Where("id in (?)", ids).Find(&channels).Error
|
err := DB.Where("id in (?)", ids).Find(&channels).Error
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting/operation_setting"
|
"one-api/setting/operation_setting"
|
||||||
|
"one-api/types"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -47,7 +48,19 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("new request failed: %w", err)
|
return nil, fmt.Errorf("new request failed: %w", err)
|
||||||
}
|
}
|
||||||
err = a.SetupRequestHeader(c, &req.Header, info)
|
headers := req.Header
|
||||||
|
headerOverride := make(map[string]string)
|
||||||
|
for k, v := range info.HeadersOverride {
|
||||||
|
if str, ok := v.(string); ok {
|
||||||
|
headerOverride[k] = str
|
||||||
|
} else {
|
||||||
|
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for key, value := range headerOverride {
|
||||||
|
headers.Set(key, value)
|
||||||
|
}
|
||||||
|
err = a.SetupRequestHeader(c, &headers, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -72,8 +85,19 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
|
|||||||
}
|
}
|
||||||
// set form data
|
// set form data
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
|
headers := req.Header
|
||||||
err = a.SetupRequestHeader(c, &req.Header, info)
|
headerOverride := make(map[string]string)
|
||||||
|
for k, v := range info.HeadersOverride {
|
||||||
|
if str, ok := v.(string); ok {
|
||||||
|
headerOverride[k] = str
|
||||||
|
} else {
|
||||||
|
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for key, value := range headerOverride {
|
||||||
|
headers.Set(key, value)
|
||||||
|
}
|
||||||
|
err = a.SetupRequestHeader(c, &headers, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ type ChannelMeta struct {
|
|||||||
Organization string
|
Organization string
|
||||||
ChannelCreateTime int64
|
ChannelCreateTime int64
|
||||||
ParamOverride map[string]interface{}
|
ParamOverride map[string]interface{}
|
||||||
|
HeadersOverride map[string]interface{}
|
||||||
ChannelSetting dto.ChannelSettings
|
ChannelSetting dto.ChannelSettings
|
||||||
ChannelOtherSettings dto.ChannelOtherSettings
|
ChannelOtherSettings dto.ChannelOtherSettings
|
||||||
UpstreamModelName string
|
UpstreamModelName string
|
||||||
@@ -120,6 +121,7 @@ type RelayInfo struct {
|
|||||||
func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
|
func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
|
||||||
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
||||||
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
|
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
|
||||||
|
headerOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelHeaderOverride)
|
||||||
apiType, _ := common.ChannelType2APIType(channelType)
|
apiType, _ := common.ChannelType2APIType(channelType)
|
||||||
channelMeta := &ChannelMeta{
|
channelMeta := &ChannelMeta{
|
||||||
ChannelType: channelType,
|
ChannelType: channelType,
|
||||||
@@ -133,6 +135,7 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
|
|||||||
Organization: c.GetString("channel_organization"),
|
Organization: c.GetString("channel_organization"),
|
||||||
ChannelCreateTime: c.GetInt64("channel_create_time"),
|
ChannelCreateTime: c.GetInt64("channel_create_time"),
|
||||||
ParamOverride: paramOverride,
|
ParamOverride: paramOverride,
|
||||||
|
HeadersOverride: headerOverride,
|
||||||
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||||
IsModelMapped: false,
|
IsModelMapped: false,
|
||||||
SupportStreamOptions: false,
|
SupportStreamOptions: false,
|
||||||
|
|||||||
@@ -48,12 +48,13 @@ const (
|
|||||||
ErrorCodeGenRelayInfoFailed ErrorCode = "gen_relay_info_failed"
|
ErrorCodeGenRelayInfoFailed ErrorCode = "gen_relay_info_failed"
|
||||||
|
|
||||||
// channel error
|
// channel error
|
||||||
ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key"
|
ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key"
|
||||||
ErrorCodeChannelParamOverrideInvalid ErrorCode = "channel:param_override_invalid"
|
ErrorCodeChannelParamOverrideInvalid ErrorCode = "channel:param_override_invalid"
|
||||||
ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error"
|
ErrorCodeChannelHeaderOverrideInvalid ErrorCode = "channel:header_override_invalid"
|
||||||
ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error"
|
ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error"
|
||||||
ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key"
|
ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error"
|
||||||
ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded"
|
ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key"
|
||||||
|
ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded"
|
||||||
|
|
||||||
// client request error
|
// client request error
|
||||||
ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed"
|
ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed"
|
||||||
|
|||||||
@@ -1699,6 +1699,31 @@ const EditChannelModal = (props) => {
|
|||||||
showClear
|
showClear
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<Form.TextArea
|
||||||
|
field='header_override'
|
||||||
|
label={t('请求头覆盖')}
|
||||||
|
placeholder={
|
||||||
|
t('此项可选,用于覆盖请求头参数') +
|
||||||
|
'\n' + t('格式示例:') +
|
||||||
|
'\n{\n "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36 Edg/139.0.0.0"\n}'
|
||||||
|
}
|
||||||
|
autosize
|
||||||
|
onChange={(value) => handleInputChange('header_override', value)}
|
||||||
|
extraText={
|
||||||
|
<div className="flex gap-2 flex-wrap">
|
||||||
|
<Text
|
||||||
|
className="!text-semi-color-primary cursor-pointer"
|
||||||
|
onClick={() => handleInputChange('header_override', JSON.stringify({
|
||||||
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36 Edg/139.0.0.0"
|
||||||
|
}, null, 2))}
|
||||||
|
>
|
||||||
|
{t('格式模板')}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
showClear
|
||||||
|
/>
|
||||||
|
|
||||||
|
|
||||||
<JSONEditor
|
<JSONEditor
|
||||||
key={`status_code_mapping-${isEdit ? channelId : 'new'}`}
|
key={`status_code_mapping-${isEdit ? channelId : 'new'}`}
|
||||||
|
|||||||
Reference in New Issue
Block a user