feat: enhance Authorization header handling with Header Override support
This commit is contained in:
@@ -71,6 +71,12 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
|||||||
return nil, fmt.Errorf("new request failed: %w", err)
|
return nil, fmt.Errorf("new request failed: %w", err)
|
||||||
}
|
}
|
||||||
headers := req.Header
|
headers := req.Header
|
||||||
|
err = a.SetupRequestHeader(c, &headers, info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||||
|
}
|
||||||
|
// 在 SetupRequestHeader 之后应用 Header Override,确保用户设置优先级最高
|
||||||
|
// 这样可以覆盖默认的 Authorization header 设置
|
||||||
headerOverride, err := processHeaderOverride(info)
|
headerOverride, err := processHeaderOverride(info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -78,10 +84,6 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
|||||||
for key, value := range headerOverride {
|
for key, value := range headerOverride {
|
||||||
headers.Set(key, value)
|
headers.Set(key, value)
|
||||||
}
|
}
|
||||||
err = a.SetupRequestHeader(c, &headers, info)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
|
||||||
}
|
|
||||||
resp, err := doRequest(c, req, info)
|
resp, err := doRequest(c, req, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("do request failed: %w", err)
|
return nil, fmt.Errorf("do request failed: %w", err)
|
||||||
@@ -104,6 +106,12 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
|
|||||||
// set form data
|
// set form data
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
headers := req.Header
|
headers := req.Header
|
||||||
|
err = a.SetupRequestHeader(c, &headers, info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||||
|
}
|
||||||
|
// 在 SetupRequestHeader 之后应用 Header Override,确保用户设置优先级最高
|
||||||
|
// 这样可以覆盖默认的 Authorization header 设置
|
||||||
headerOverride, err := processHeaderOverride(info)
|
headerOverride, err := processHeaderOverride(info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -111,10 +119,6 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
|
|||||||
for key, value := range headerOverride {
|
for key, value := range headerOverride {
|
||||||
headers.Set(key, value)
|
headers.Set(key, value)
|
||||||
}
|
}
|
||||||
err = a.SetupRequestHeader(c, &headers, info)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
|
||||||
}
|
|
||||||
resp, err := doRequest(c, req, info)
|
resp, err := doRequest(c, req, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("do request failed: %w", err)
|
return nil, fmt.Errorf("do request failed: %w", err)
|
||||||
@@ -128,6 +132,12 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
|||||||
return nil, fmt.Errorf("get request url failed: %w", err)
|
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||||
}
|
}
|
||||||
targetHeader := http.Header{}
|
targetHeader := http.Header{}
|
||||||
|
err = a.SetupRequestHeader(c, &targetHeader, info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||||
|
}
|
||||||
|
// 在 SetupRequestHeader 之后应用 Header Override,确保用户设置优先级最高
|
||||||
|
// 这样可以覆盖默认的 Authorization header 设置
|
||||||
headerOverride, err := processHeaderOverride(info)
|
headerOverride, err := processHeaderOverride(info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -135,10 +145,6 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
|||||||
for key, value := range headerOverride {
|
for key, value := range headerOverride {
|
||||||
targetHeader.Set(key, value)
|
targetHeader.Set(key, value)
|
||||||
}
|
}
|
||||||
err = a.SetupRequestHeader(c, &targetHeader, info)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
|
||||||
}
|
|
||||||
targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader)
|
targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -184,9 +184,25 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
|
|||||||
header.Set("api-key", info.ApiKey)
|
header.Set("api-key", info.ApiKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
// 自定义渠道类型完全跳过默认 Authorization 设置,由 Header Override 控制
|
||||||
|
if info.ChannelType == constant.ChannelTypeCustom {
|
||||||
|
// 自定义渠道不设置默认 Authorization,完全由 Header Override 控制
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization {
|
if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization {
|
||||||
header.Set("OpenAI-Organization", info.Organization)
|
header.Set("OpenAI-Organization", info.Organization)
|
||||||
}
|
}
|
||||||
|
// 检查 Header Override 是否已设置 Authorization,如果已设置则跳过默认设置
|
||||||
|
// 这样可以避免在 Header Override 应用时被覆盖(虽然 Header Override 会在之后应用,但这里作为额外保护)
|
||||||
|
hasAuthOverride := false
|
||||||
|
if len(info.HeadersOverride) > 0 {
|
||||||
|
for k := range info.HeadersOverride {
|
||||||
|
if strings.EqualFold(k, "Authorization") {
|
||||||
|
hasAuthOverride = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||||
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
|
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
|
||||||
if swp != "" {
|
if swp != "" {
|
||||||
@@ -201,10 +217,14 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
|
|||||||
//req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
|
//req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
|
||||||
} else {
|
} else {
|
||||||
header.Set("openai-beta", "realtime=v1")
|
header.Set("openai-beta", "realtime=v1")
|
||||||
header.Set("Authorization", "Bearer "+info.ApiKey)
|
if !hasAuthOverride {
|
||||||
|
header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
header.Set("Authorization", "Bearer "+info.ApiKey)
|
if !hasAuthOverride {
|
||||||
|
header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if info.ChannelType == constant.ChannelTypeOpenRouter {
|
if info.ChannelType == constant.ChannelTypeOpenRouter {
|
||||||
header.Set("HTTP-Referer", "https://www.newapi.ai")
|
header.Set("HTTP-Referer", "https://www.newapi.ai")
|
||||||
|
|||||||
Reference in New Issue
Block a user