fix: ignore header passthrough during channel tests
This commit is contained in:
@@ -171,35 +171,37 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
|
|||||||
|
|
||||||
passAll := false
|
passAll := false
|
||||||
var passthroughRegex []*regexp.Regexp
|
var passthroughRegex []*regexp.Regexp
|
||||||
for k := range info.HeadersOverride {
|
if !info.IsChannelTest {
|
||||||
key := strings.TrimSpace(k)
|
for k := range info.HeadersOverride {
|
||||||
if key == "" {
|
key := strings.TrimSpace(k)
|
||||||
continue
|
if key == "" {
|
||||||
}
|
continue
|
||||||
if key == headerPassthroughAllKey {
|
}
|
||||||
passAll = true
|
if key == headerPassthroughAllKey {
|
||||||
continue
|
passAll = true
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
lower := strings.ToLower(key)
|
lower := strings.ToLower(key)
|
||||||
var pattern string
|
var pattern string
|
||||||
switch {
|
switch {
|
||||||
case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
|
case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
|
||||||
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
|
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
|
||||||
case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
|
case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
|
||||||
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
|
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
|
||||||
default:
|
default:
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if pattern == "" {
|
if pattern == "" {
|
||||||
return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid)
|
return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||||
|
}
|
||||||
|
compiled, err := getHeaderPassthroughRegex(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||||
|
}
|
||||||
|
passthroughRegex = append(passthroughRegex, compiled)
|
||||||
}
|
}
|
||||||
compiled, err := getHeaderPassthroughRegex(pattern)
|
|
||||||
if err != nil {
|
|
||||||
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
|
|
||||||
}
|
|
||||||
passthroughRegex = append(passthroughRegex, compiled)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if passAll || len(passthroughRegex) > 0 {
|
if passAll || len(passthroughRegex) > 0 {
|
||||||
@@ -243,6 +245,9 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid)
|
return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||||
}
|
}
|
||||||
|
if info.IsChannelTest && strings.HasPrefix(strings.TrimSpace(str), clientHeaderPlaceholderPrefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
value, include, err := applyHeaderOverridePlaceholders(str, c, info.ApiKey)
|
value, include, err := applyHeaderOverridePlaceholders(str, c, info.ApiKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
81
relay/channel/api_request_test.go
Normal file
81
relay/channel/api_request_test.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package channel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(recorder)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
ctx.Request.Header.Set("X-Trace-Id", "trace-123")
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
IsChannelTest: true,
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
HeadersOverride: map[string]any{
|
||||||
|
"*": "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
headers, err := processHeaderOverride(info, ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(recorder)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
ctx.Request.Header.Set("X-Trace-Id", "trace-123")
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
IsChannelTest: true,
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
HeadersOverride: map[string]any{
|
||||||
|
"X-Upstream-Trace": "{client_header:X-Trace-Id}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
headers, err := processHeaderOverride(info, ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, ok := headers["X-Upstream-Trace"]
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(recorder)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
ctx.Request.Header.Set("X-Trace-Id", "trace-123")
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
IsChannelTest: false,
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
HeadersOverride: map[string]any{
|
||||||
|
"X-Upstream-Trace": "{client_header:X-Trace-Id}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
headers, err := processHeaderOverride(info, ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user