From abcb35379389ba15d25d3d13512faea6aca227e7 Mon Sep 17 00:00:00 2001 From: Nekohy Date: Sun, 24 Aug 2025 01:02:23 +0800 Subject: [PATCH 1/2] feats:add custom headers override --- constant/context_key.go | 1 + middleware/distributor.go | 1 + model/channel.go | 12 ++++++++ relay/channel/api_request.go | 30 +++++++++++++++++-- relay/common/relay_info.go | 3 ++ types/error.go | 13 ++++---- .../channels/modals/EditChannelModal.jsx | 25 ++++++++++++++++ 7 files changed, 76 insertions(+), 9 deletions(-) diff --git a/constant/context_key.go b/constant/context_key.go index 3945243c..f7640272 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -27,6 +27,7 @@ const ( ContextKeyChannelSetting ContextKey = "channel_setting" ContextKeyChannelOtherSetting ContextKey = "channel_other_setting" ContextKeyChannelParamOverride ContextKey = "param_override" + ContextKeyChannelHeaderOverride ContextKey = "header_override" ContextKeyChannelOrganization ContextKey = "channel_organization" ContextKeyChannelAutoBan ContextKey = "auto_ban" ContextKeyChannelModelMapping ContextKey = "model_mapping" diff --git a/middleware/distributor.go b/middleware/distributor.go index 28b66a3a..a80ed3c6 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -248,6 +248,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting()) common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings()) common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride()) + common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, channel.GetHeaderOverride()) if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" { common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization) } diff --git a/model/channel.go b/model/channel.go index a9a23481..77113edf 100644 --- a/model/channel.go +++ b/model/channel.go @@ -46,6 +46,7 @@ type Channel struct { Tag *string `json:"tag" gorm:"index"` Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置 ParamOverride *string `json:"param_override" gorm:"type:text"` + HeaderOverride *string `json:"header_override" gorm:"type:text"` // add after v0.8.5 ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"` @@ -875,6 +876,17 @@ func (channel *Channel) GetParamOverride() map[string]interface{} { return paramOverride } +func (channel *Channel) GetHeaderOverride() map[string]interface{} { + headerOverride := make(map[string]interface{}) + if channel.HeaderOverride != nil && *channel.HeaderOverride != "" { + err := common.Unmarshal([]byte(*channel.HeaderOverride), &headerOverride) + if err != nil { + common.SysLog(fmt.Sprintf("failed to unmarshal param override: channel_id=%d, error=%v", channel.Id, err)) + } + } + return headerOverride +} + func GetChannelsByIds(ids []int) ([]*Channel, error) { var channels []*Channel err := DB.Where("id in (?)", ids).Find(&channels).Error diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index fd745cf7..518d25ce 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -13,6 +13,7 @@ import ( "one-api/relay/helper" "one-api/service" "one-api/setting/operation_setting" + "one-api/types" "sync" "time" @@ -47,7 +48,19 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody if err != nil { return nil, fmt.Errorf("new request failed: %w", err) } - err = a.SetupRequestHeader(c, &req.Header, info) + headers := req.Header + headerOverride := make(map[string]string) + for k, v := range info.HeadersOverride { + if str, ok := v.(string); ok { + headerOverride[k] = str + } else { + return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid) + } + } + for key, value := range headerOverride { + headers.Set(key, value) + } + err = a.SetupRequestHeader(c, &headers, info) if err != nil { return nil, fmt.Errorf("setup request header failed: %w", err) } @@ -72,8 +85,19 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod } // set form data req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - - err = a.SetupRequestHeader(c, &req.Header, info) + headers := req.Header + headerOverride := make(map[string]string) + for k, v := range info.HeadersOverride { + if str, ok := v.(string); ok { + headerOverride[k] = str + } else { + return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid) + } + } + for key, value := range headerOverride { + headers.Set(key, value) + } + err = a.SetupRequestHeader(c, &headers, info) if err != nil { return nil, fmt.Errorf("setup request header failed: %w", err) } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 032a577d..caf8b452 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -63,6 +63,7 @@ type ChannelMeta struct { Organization string ChannelCreateTime int64 ParamOverride map[string]interface{} + HeadersOverride map[string]interface{} ChannelSetting dto.ChannelSettings ChannelOtherSettings dto.ChannelOtherSettings UpstreamModelName string @@ -120,6 +121,7 @@ type RelayInfo struct { func (info *RelayInfo) InitChannelMeta(c *gin.Context) { channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) + headerOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelHeaderOverride) apiType, _ := common.ChannelType2APIType(channelType) channelMeta := &ChannelMeta{ ChannelType: channelType, @@ -133,6 +135,7 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) { Organization: c.GetString("channel_organization"), ChannelCreateTime: c.GetInt64("channel_create_time"), ParamOverride: paramOverride, + HeadersOverride: headerOverride, UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), IsModelMapped: false, SupportStreamOptions: false, diff --git a/types/error.go b/types/error.go index 07486c27..6b0dd84c 100644 --- a/types/error.go +++ b/types/error.go @@ -48,12 +48,13 @@ const ( ErrorCodeGenRelayInfoFailed ErrorCode = "gen_relay_info_failed" // channel error - ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key" - ErrorCodeChannelParamOverrideInvalid ErrorCode = "channel:param_override_invalid" - ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error" - ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error" - ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key" - ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded" + ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key" + ErrorCodeChannelParamOverrideInvalid ErrorCode = "channel:param_override_invalid" + ErrorCodeChannelHeaderOverrideInvalid ErrorCode = "channel:header_override_invalid" + ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error" + ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error" + ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key" + ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded" // client request error ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed" diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index 23fe2c58..b4e5a403 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -1699,6 +1699,31 @@ const EditChannelModal = (props) => { showClear /> + handleInputChange('header_override', value)} + extraText={ +
+ handleInputChange('header_override', JSON.stringify({ + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36 Edg/139.0.0.0" + }, null, 2))} + > + {t('格式模板')} + +
+ } + showClear + /> + Date: Sun, 24 Aug 2025 01:32:19 +0800 Subject: [PATCH 2/2] fix: log name --- model/channel.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/channel.go b/model/channel.go index 77113edf..0d21b076 100644 --- a/model/channel.go +++ b/model/channel.go @@ -881,7 +881,7 @@ func (channel *Channel) GetHeaderOverride() map[string]interface{} { if channel.HeaderOverride != nil && *channel.HeaderOverride != "" { err := common.Unmarshal([]byte(*channel.HeaderOverride), &headerOverride) if err != nil { - common.SysLog(fmt.Sprintf("failed to unmarshal param override: channel_id=%d, error=%v", channel.Id, err)) + common.SysLog(fmt.Sprintf("failed to unmarshal header override: channel_id=%d, error=%v", channel.Id, err)) } } return headerOverride