Merge pull request #1644 from nekohy/feats-custom-request-headers

feats:add custom headers override
This commit is contained in:
Calcium-Ion
2025-08-24 10:14:32 +08:00
committed by GitHub
7 changed files with 76 additions and 9 deletions

View File

@@ -27,6 +27,7 @@ const (
ContextKeyChannelSetting ContextKey = "channel_setting" ContextKeyChannelSetting ContextKey = "channel_setting"
ContextKeyChannelOtherSetting ContextKey = "channel_other_setting" ContextKeyChannelOtherSetting ContextKey = "channel_other_setting"
ContextKeyChannelParamOverride ContextKey = "param_override" ContextKeyChannelParamOverride ContextKey = "param_override"
ContextKeyChannelHeaderOverride ContextKey = "header_override"
ContextKeyChannelOrganization ContextKey = "channel_organization" ContextKeyChannelOrganization ContextKey = "channel_organization"
ContextKeyChannelAutoBan ContextKey = "auto_ban" ContextKeyChannelAutoBan ContextKey = "auto_ban"
ContextKeyChannelModelMapping ContextKey = "model_mapping" ContextKeyChannelModelMapping ContextKey = "model_mapping"

View File

@@ -248,6 +248,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting()) common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings()) common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings())
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride()) common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, channel.GetHeaderOverride())
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" { if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization) common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
} }

View File

@@ -46,6 +46,7 @@ type Channel struct {
Tag *string `json:"tag" gorm:"index"` Tag *string `json:"tag" gorm:"index"`
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置 Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
ParamOverride *string `json:"param_override" gorm:"type:text"` ParamOverride *string `json:"param_override" gorm:"type:text"`
HeaderOverride *string `json:"header_override" gorm:"type:text"`
// add after v0.8.5 // add after v0.8.5
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"` ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
@@ -875,6 +876,17 @@ func (channel *Channel) GetParamOverride() map[string]interface{} {
return paramOverride return paramOverride
} }
func (channel *Channel) GetHeaderOverride() map[string]interface{} {
headerOverride := make(map[string]interface{})
if channel.HeaderOverride != nil && *channel.HeaderOverride != "" {
err := common.Unmarshal([]byte(*channel.HeaderOverride), &headerOverride)
if err != nil {
common.SysLog(fmt.Sprintf("failed to unmarshal header override: channel_id=%d, error=%v", channel.Id, err))
}
}
return headerOverride
}
func GetChannelsByIds(ids []int) ([]*Channel, error) { func GetChannelsByIds(ids []int) ([]*Channel, error) {
var channels []*Channel var channels []*Channel
err := DB.Where("id in (?)", ids).Find(&channels).Error err := DB.Where("id in (?)", ids).Find(&channels).Error

View File

@@ -13,6 +13,7 @@ import (
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting/operation_setting" "one-api/setting/operation_setting"
"one-api/types"
"sync" "sync"
"time" "time"
@@ -47,7 +48,19 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
if err != nil { if err != nil {
return nil, fmt.Errorf("new request failed: %w", err) return nil, fmt.Errorf("new request failed: %w", err)
} }
err = a.SetupRequestHeader(c, &req.Header, info) headers := req.Header
headerOverride := make(map[string]string)
for k, v := range info.HeadersOverride {
if str, ok := v.(string); ok {
headerOverride[k] = str
} else {
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
}
}
for key, value := range headerOverride {
headers.Set(key, value)
}
err = a.SetupRequestHeader(c, &headers, info)
if err != nil { if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err) return nil, fmt.Errorf("setup request header failed: %w", err)
} }
@@ -72,8 +85,19 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
} }
// set form data // set form data
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
headers := req.Header
err = a.SetupRequestHeader(c, &req.Header, info) headerOverride := make(map[string]string)
for k, v := range info.HeadersOverride {
if str, ok := v.(string); ok {
headerOverride[k] = str
} else {
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
}
}
for key, value := range headerOverride {
headers.Set(key, value)
}
err = a.SetupRequestHeader(c, &headers, info)
if err != nil { if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err) return nil, fmt.Errorf("setup request header failed: %w", err)
} }

View File

@@ -63,6 +63,7 @@ type ChannelMeta struct {
Organization string Organization string
ChannelCreateTime int64 ChannelCreateTime int64
ParamOverride map[string]interface{} ParamOverride map[string]interface{}
HeadersOverride map[string]interface{}
ChannelSetting dto.ChannelSettings ChannelSetting dto.ChannelSettings
ChannelOtherSettings dto.ChannelOtherSettings ChannelOtherSettings dto.ChannelOtherSettings
UpstreamModelName string UpstreamModelName string
@@ -120,6 +121,7 @@ type RelayInfo struct {
func (info *RelayInfo) InitChannelMeta(c *gin.Context) { func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
headerOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelHeaderOverride)
apiType, _ := common.ChannelType2APIType(channelType) apiType, _ := common.ChannelType2APIType(channelType)
channelMeta := &ChannelMeta{ channelMeta := &ChannelMeta{
ChannelType: channelType, ChannelType: channelType,
@@ -133,6 +135,7 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
Organization: c.GetString("channel_organization"), Organization: c.GetString("channel_organization"),
ChannelCreateTime: c.GetInt64("channel_create_time"), ChannelCreateTime: c.GetInt64("channel_create_time"),
ParamOverride: paramOverride, ParamOverride: paramOverride,
HeadersOverride: headerOverride,
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
IsModelMapped: false, IsModelMapped: false,
SupportStreamOptions: false, SupportStreamOptions: false,

View File

@@ -48,12 +48,13 @@ const (
ErrorCodeGenRelayInfoFailed ErrorCode = "gen_relay_info_failed" ErrorCodeGenRelayInfoFailed ErrorCode = "gen_relay_info_failed"
// channel error // channel error
ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key" ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key"
ErrorCodeChannelParamOverrideInvalid ErrorCode = "channel:param_override_invalid" ErrorCodeChannelParamOverrideInvalid ErrorCode = "channel:param_override_invalid"
ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error" ErrorCodeChannelHeaderOverrideInvalid ErrorCode = "channel:header_override_invalid"
ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error" ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error"
ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key" ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error"
ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded" ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key"
ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded"
// client request error // client request error
ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed" ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed"

View File

@@ -1699,6 +1699,31 @@ const EditChannelModal = (props) => {
showClear showClear
/> />
<Form.TextArea
field='header_override'
label={t('请求头覆盖')}
placeholder={
t('此项可选,用于覆盖请求头参数') +
'\n' + t('格式示例:') +
'\n{\n "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36 Edg/139.0.0.0"\n}'
}
autosize
onChange={(value) => handleInputChange('header_override', value)}
extraText={
<div className="flex gap-2 flex-wrap">
<Text
className="!text-semi-color-primary cursor-pointer"
onClick={() => handleInputChange('header_override', JSON.stringify({
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36 Edg/139.0.0.0"
}, null, 2))}
>
{t('格式模板')}
</Text>
</div>
}
showClear
/>
<JSONEditor <JSONEditor
key={`status_code_mapping-${isEdit ? channelId : 'new'}`} key={`status_code_mapping-${isEdit ? channelId : 'new'}`}