fix: ignore header passthrough during channel tests
This commit is contained in:
@@ -171,6 +171,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
|
|||||||
|
|
||||||
passAll := false
|
passAll := false
|
||||||
var passthroughRegex []*regexp.Regexp
|
var passthroughRegex []*regexp.Regexp
|
||||||
|
if !info.IsChannelTest {
|
||||||
for k := range info.HeadersOverride {
|
for k := range info.HeadersOverride {
|
||||||
key := strings.TrimSpace(k)
|
key := strings.TrimSpace(k)
|
||||||
if key == "" {
|
if key == "" {
|
||||||
@@ -201,6 +202,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
|
|||||||
}
|
}
|
||||||
passthroughRegex = append(passthroughRegex, compiled)
|
passthroughRegex = append(passthroughRegex, compiled)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if passAll || len(passthroughRegex) > 0 {
|
if passAll || len(passthroughRegex) > 0 {
|
||||||
if c == nil || c.Request == nil {
|
if c == nil || c.Request == nil {
|
||||||
@@ -243,6 +245,9 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid)
|
return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||||
}
|
}
|
||||||
|
if info.IsChannelTest && strings.HasPrefix(strings.TrimSpace(str), clientHeaderPlaceholderPrefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
value, include, err := applyHeaderOverridePlaceholders(str, c, info.ApiKey)
|
value, include, err := applyHeaderOverridePlaceholders(str, c, info.ApiKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
81
relay/channel/api_request_test.go
Normal file
81
relay/channel/api_request_test.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package channel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(recorder)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
ctx.Request.Header.Set("X-Trace-Id", "trace-123")
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
IsChannelTest: true,
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
HeadersOverride: map[string]any{
|
||||||
|
"*": "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
headers, err := processHeaderOverride(info, ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(recorder)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
ctx.Request.Header.Set("X-Trace-Id", "trace-123")
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
IsChannelTest: true,
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
HeadersOverride: map[string]any{
|
||||||
|
"X-Upstream-Trace": "{client_header:X-Trace-Id}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
headers, err := processHeaderOverride(info, ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, ok := headers["X-Upstream-Trace"]
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(recorder)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
ctx.Request.Header.Set("X-Trace-Id", "trace-123")
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
IsChannelTest: false,
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
HeadersOverride: map[string]any{
|
||||||
|
"X-Upstream-Trace": "{client_header:X-Trace-Id}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
headers, err := processHeaderOverride(info, ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user