diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 79a0f706..548e720d 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -14,6 +14,7 @@ import ( "one-api/service" "one-api/setting/operation_setting" "one-api/types" + "strings" "sync" "time" @@ -36,6 +37,26 @@ func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Hea } } +// processHeaderOverride 处理请求头覆盖,支持变量替换 +// 支持的变量:{api_key} +func processHeaderOverride(info *common.RelayInfo) (map[string]string, error) { + headerOverride := make(map[string]string) + for k, v := range info.HeadersOverride { + str, ok := v.(string) + if !ok { + return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid) + } + + // 替换支持的变量 + if strings.Contains(str, "{api_key}") { + str = strings.ReplaceAll(str, "{api_key}", info.ApiKey) + } + + headerOverride[k] = str + } + return headerOverride, nil +} + func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { fullRequestURL, err := a.GetRequestURL(info) if err != nil { @@ -49,13 +70,9 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody return nil, fmt.Errorf("new request failed: %w", err) } headers := req.Header - headerOverride := make(map[string]string) - for k, v := range info.HeadersOverride { - if str, ok := v.(string); ok { - headerOverride[k] = str - } else { - return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid) - } + headerOverride, err := processHeaderOverride(info) + if err != nil { + return nil, err } for key, value := range headerOverride { headers.Set(key, value) @@ -86,13 +103,9 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod // set form data req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) headers := req.Header - headerOverride := make(map[string]string) - for k, v := range info.HeadersOverride { - if str, ok := v.(string); ok { - headerOverride[k] = str - } else { - return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid) - } + headerOverride, err := processHeaderOverride(info) + if err != nil { + return nil, err } for key, value := range headerOverride { headers.Set(key, value) @@ -114,6 +127,13 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody return nil, fmt.Errorf("get request url failed: %w", err) } targetHeader := http.Header{} + headerOverride, err := processHeaderOverride(info) + if err != nil { + return nil, err + } + for key, value := range headerOverride { + targetHeader.Set(key, value) + } err = a.SetupRequestHeader(c, &targetHeader, info) if err != nil { return nil, fmt.Errorf("setup request header failed: %w", err) diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index e9a21c20..dfbd75a4 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -2452,32 +2452,44 @@ const EditChannelModal = (props) => { t('此项可选,用于覆盖请求头参数') + '\n' + t('格式示例:') + - '\n{\n "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36 Edg/139.0.0.0"\n}' + '\n{\n "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36 Edg/139.0.0.0",\n "Authorization": "Bearer {api_key}"\n}' } autosize onChange={(value) => handleInputChange('header_override', value) } extraText={ -