diff --git a/controller/channel-test.go b/controller/channel-test.go index ab12132b..3947c8d5 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -366,7 +366,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string, newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed), } } - jsonData, err := json.Marshal(convertedRequest) + jsonData, err := common.Marshal(convertedRequest) if err != nil { return testResult{ context: c, @@ -385,8 +385,15 @@ func testChannel(channel *model.Channel, testModel string, endpointType string, //} if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { + if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok { + return testResult{ + context: c, + localErr: fixedErr, + newAPIError: relaycommon.NewAPIErrorFromParamOverride(fixedErr), + } + } return testResult{ context: c, localErr: err, diff --git a/controller/relay.go b/controller/relay.go index edea1586..1788b25b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -182,8 +182,11 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { ModelName: relayInfo.OriginModelName, Retry: common.GetPointer(0), } + relayInfo.RetryIndex = 0 + relayInfo.LastError = nil for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() { + relayInfo.RetryIndex = retryParam.GetRetry() channel, channelErr := getChannel(c, relayInfo, retryParam) if channelErr != nil { logger.LogError(c, channelErr.Error()) @@ -216,10 +219,12 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { } if newAPIError == nil { + relayInfo.LastError = nil return } newAPIError = service.NormalizeViolationFeeError(newAPIError) + relayInfo.LastError = newAPIError processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) diff --git a/middleware/distributor.go b/middleware/distributor.go index 9e66cb8f..db57998c 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -348,8 +348,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime) common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting()) common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings()) - common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride()) - common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, channel.GetHeaderOverride()) + paramOverride := channel.GetParamOverride() + headerOverride := channel.GetHeaderOverride() + if mergedParam, applied := service.ApplyChannelAffinityOverrideTemplate(c, paramOverride); applied { + paramOverride = mergedParam + } + common.SetContextKey(c, constant.ContextKeyChannelParamOverride, paramOverride) + common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, headerOverride) if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" { common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization) } diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 09ca855d..79eac3ad 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -169,12 +169,17 @@ func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey str // Passthrough rules are applied first, then normal overrides are applied, so explicit overrides win. func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) { headerOverride := make(map[string]string) + if info == nil { + return headerOverride, nil + } + + headerOverrideSource := common.GetEffectiveHeaderOverride(info) passAll := false var passthroughRegex []*regexp.Regexp if !info.IsChannelTest { - for k := range info.HeadersOverride { - key := strings.TrimSpace(k) + for k := range headerOverrideSource { + key := strings.TrimSpace(strings.ToLower(k)) if key == "" { continue } @@ -183,12 +188,11 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s continue } - lower := strings.ToLower(key) var pattern string switch { - case strings.HasPrefix(lower, headerPassthroughRegexPrefix): + case strings.HasPrefix(key, headerPassthroughRegexPrefix): pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):]) - case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2): + case strings.HasPrefix(key, headerPassthroughRegexPrefixV2): pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):]) default: continue @@ -229,15 +233,15 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s if value == "" { continue } - headerOverride[name] = value + headerOverride[strings.ToLower(strings.TrimSpace(name))] = value } } - for k, v := range info.HeadersOverride { + for k, v := range headerOverrideSource { if isHeaderPassthroughRuleKey(k) { continue } - key := strings.TrimSpace(k) + key := strings.TrimSpace(strings.ToLower(k)) if key == "" { continue } diff --git a/relay/channel/api_request_test.go b/relay/channel/api_request_test.go index 6c7834ef..f697f855 100644 --- a/relay/channel/api_request_test.go +++ b/relay/channel/api_request_test.go @@ -53,7 +53,7 @@ func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testin headers, err := processHeaderOverride(info, ctx) require.NoError(t, err) - _, ok := headers["X-Upstream-Trace"] + _, ok := headers["x-upstream-trace"] require.False(t, ok) } @@ -77,7 +77,38 @@ func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) headers, err := processHeaderOverride(info, ctx) require.NoError(t, err) - require.Equal(t, "trace-123", headers["X-Upstream-Trace"]) + require.Equal(t, "trace-123", headers["x-upstream-trace"]) +} + +func TestProcessHeaderOverride_RuntimeOverrideIsFinalHeaderMap(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + info := &relaycommon.RelayInfo{ + IsChannelTest: false, + UseRuntimeHeadersOverride: true, + RuntimeHeadersOverride: map[string]any{ + "x-static": "runtime-value", + "x-runtime": "runtime-only", + }, + ChannelMeta: &relaycommon.ChannelMeta{ + HeadersOverride: map[string]any{ + "X-Static": "legacy-value", + "X-Legacy": "legacy-only", + }, + }, + } + + headers, err := processHeaderOverride(info, ctx) + require.NoError(t, err) + require.Equal(t, "runtime-value", headers["x-static"]) + require.Equal(t, "runtime-only", headers["x-runtime"]) + _, exists := headers["x-legacy"] + require.False(t, exists) } func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) { @@ -101,8 +132,62 @@ func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) { headers, err := processHeaderOverride(info, ctx) require.NoError(t, err) - require.Equal(t, "trace-123", headers["X-Trace-Id"]) + require.Equal(t, "trace-123", headers["x-trace-id"]) - _, hasAcceptEncoding := headers["Accept-Encoding"] + _, hasAcceptEncoding := headers["accept-encoding"] require.False(t, hasAcceptEncoding) } + +func TestProcessHeaderOverride_PassHeadersTemplateSetsRuntimeHeaders(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + ctx.Request.Header.Set("Originator", "Codex CLI") + ctx.Request.Header.Set("Session_id", "sess-123") + + info := &relaycommon.RelayInfo{ + IsChannelTest: false, + RequestHeaders: map[string]string{ + "Originator": "Codex CLI", + "Session_id": "sess-123", + }, + ChannelMeta: &relaycommon.ChannelMeta{ + ParamOverride: map[string]any{ + "operations": []any{ + map[string]any{ + "mode": "pass_headers", + "value": []any{"Originator", "Session_id", "X-Codex-Beta-Features"}, + }, + }, + }, + HeadersOverride: map[string]any{ + "X-Static": "legacy-value", + }, + }, + } + + _, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-4.1"}`), info) + require.NoError(t, err) + require.True(t, info.UseRuntimeHeadersOverride) + require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["originator"]) + require.Equal(t, "sess-123", info.RuntimeHeadersOverride["session_id"]) + _, exists := info.RuntimeHeadersOverride["x-codex-beta-features"] + require.False(t, exists) + require.Equal(t, "legacy-value", info.RuntimeHeadersOverride["x-static"]) + + headers, err := processHeaderOverride(info, ctx) + require.NoError(t, err) + require.Equal(t, "Codex CLI", headers["originator"]) + require.Equal(t, "sess-123", headers["session_id"]) + _, exists = headers["x-codex-beta-features"] + require.False(t, exists) + + upstreamReq := httptest.NewRequest(http.MethodPost, "https://example.com/v1/responses", nil) + applyHeaderOverrideToRequest(upstreamReq, headers) + require.Equal(t, "Codex CLI", upstreamReq.Header.Get("Originator")) + require.Equal(t, "sess-123", upstreamReq.Header.Get("Session_id")) + require.Empty(t, upstreamReq.Header.Get("X-Codex-Beta-Features")) +} diff --git a/relay/chat_completions_via_responses.go b/relay/chat_completions_via_responses.go index 6412b7d2..8f69b937 100644 --- a/relay/chat_completions_via_responses.go +++ b/relay/chat_completions_via_responses.go @@ -70,7 +70,6 @@ func applySystemPromptIfNeeded(c *gin.Context, info *relaycommon.RelayInfo, requ } func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, adaptor channel.Adaptor, request *dto.GeneralOpenAIRequest) (*dto.Usage, *types.NewAPIError) { - overrideCtx := relaycommon.BuildParamOverrideContext(info) chatJSON, err := common.Marshal(request) if err != nil { return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) @@ -82,9 +81,9 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad } if len(info.ParamOverride) > 0 { - chatJSON, err = relaycommon.ApplyParamOverride(chatJSON, info.ParamOverride, overrideCtx) + chatJSON, err = relaycommon.ApplyParamOverrideWithRelayInfo(chatJSON, info) if err != nil { - return nil, types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return nil, newAPIErrorFromParamOverride(err) } } diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 9b08781c..1722cd9b 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -153,9 +153,9 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/relay/common/override.go b/relay/common/override.go index 1a0c2478..59e15176 100644 --- a/relay/common/override.go +++ b/relay/common/override.go @@ -1,18 +1,29 @@ package common import ( + "errors" "fmt" + "net/http" "regexp" "strconv" "strings" "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/types" + "github.com/samber/lo" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`) +const ( + paramOverrideContextRequestHeaders = "request_headers" + paramOverrideContextHeaderOverride = "header_override" +) + +var errSourceHeaderNotFound = errors.New("source header does not exist") + type ConditionOperation struct { Path string `json:"path"` // JSON路径 Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte @@ -23,7 +34,7 @@ type ConditionOperation struct { type ParamOperation struct { Path string `json:"path"` - Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace + Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace, return_error, prune_objects, set_header, delete_header, copy_header, move_header, pass_headers, sync_fields Value interface{} `json:"value"` KeepOrigin bool `json:"keep_origin"` From string `json:"from,omitempty"` @@ -32,6 +43,76 @@ type ParamOperation struct { Logic string `json:"logic,omitempty"` // AND, OR (默认OR) } +type ParamOverrideReturnError struct { + Message string + StatusCode int + Code string + Type string + SkipRetry bool +} + +func (e *ParamOverrideReturnError) Error() string { + if e == nil { + return "param override return error" + } + if e.Message == "" { + return "param override return error" + } + return e.Message +} + +func AsParamOverrideReturnError(err error) (*ParamOverrideReturnError, bool) { + if err == nil { + return nil, false + } + var target *ParamOverrideReturnError + if errors.As(err, &target) { + return target, true + } + return nil, false +} + +func NewAPIErrorFromParamOverride(err *ParamOverrideReturnError) *types.NewAPIError { + if err == nil { + return types.NewError( + errors.New("param override return error is nil"), + types.ErrorCodeChannelParamOverrideInvalid, + types.ErrOptionWithSkipRetry(), + ) + } + + statusCode := err.StatusCode + if statusCode < http.StatusContinue || statusCode > http.StatusNetworkAuthenticationRequired { + statusCode = http.StatusBadRequest + } + + errorCode := err.Code + if strings.TrimSpace(errorCode) == "" { + errorCode = string(types.ErrorCodeInvalidRequest) + } + + errorType := err.Type + if strings.TrimSpace(errorType) == "" { + errorType = "invalid_request_error" + } + + message := strings.TrimSpace(err.Message) + if message == "" { + message = "request blocked by param override" + } + + opts := make([]types.NewAPIErrorOptions, 0, 1) + if err.SkipRetry { + opts = append(opts, types.ErrOptionWithSkipRetry()) + } + + return types.WithOpenAIError(types.OpenAIError{ + Message: message, + Type: errorType, + Code: errorCode, + }, statusCode, opts...) +} + func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) { if len(paramOverride) == 0 { return jsonData, nil @@ -48,81 +129,147 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c return applyOperationsLegacy(jsonData, paramOverride) } -func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) { - // 检查是否包含 "operations" 字段 - if opsValue, exists := paramOverride["operations"]; exists { - if opsSlice, ok := opsValue.([]interface{}); ok { - var operations []ParamOperation - for _, op := range opsSlice { - if opMap, ok := op.(map[string]interface{}); ok { - operation := ParamOperation{} - - // 断言必要字段 - if path, ok := opMap["path"].(string); ok { - operation.Path = path - } - if mode, ok := opMap["mode"].(string); ok { - operation.Mode = mode - } else { - return nil, false // mode 是必需的 - } - - // 可选字段 - if value, exists := opMap["value"]; exists { - operation.Value = value - } - if keepOrigin, ok := opMap["keep_origin"].(bool); ok { - operation.KeepOrigin = keepOrigin - } - if from, ok := opMap["from"].(string); ok { - operation.From = from - } - if to, ok := opMap["to"].(string); ok { - operation.To = to - } - if logic, ok := opMap["logic"].(string); ok { - operation.Logic = logic - } else { - operation.Logic = "OR" // 默认为OR - } - - // 解析条件 - if conditions, exists := opMap["conditions"]; exists { - if condSlice, ok := conditions.([]interface{}); ok { - for _, cond := range condSlice { - if condMap, ok := cond.(map[string]interface{}); ok { - condition := ConditionOperation{} - if path, ok := condMap["path"].(string); ok { - condition.Path = path - } - if mode, ok := condMap["mode"].(string); ok { - condition.Mode = mode - } - if value, ok := condMap["value"]; ok { - condition.Value = value - } - if invert, ok := condMap["invert"].(bool); ok { - condition.Invert = invert - } - if passMissingKey, ok := condMap["pass_missing_key"].(bool); ok { - condition.PassMissingKey = passMissingKey - } - operation.Conditions = append(operation.Conditions, condition) - } - } - } - } - - operations = append(operations, operation) - } else { - return nil, false - } - } - return operations, true - } +func ApplyParamOverrideWithRelayInfo(jsonData []byte, info *RelayInfo) ([]byte, error) { + paramOverride := getParamOverrideMap(info) + if len(paramOverride) == 0 { + return jsonData, nil } - return nil, false + overrideCtx := BuildParamOverrideContext(info) + result, err := ApplyParamOverride(jsonData, paramOverride, overrideCtx) + if err != nil { + return nil, err + } + syncRuntimeHeaderOverrideFromContext(info, overrideCtx) + return result, nil +} + +func getParamOverrideMap(info *RelayInfo) map[string]interface{} { + if info == nil || info.ChannelMeta == nil { + return nil + } + return info.ChannelMeta.ParamOverride +} + +func getHeaderOverrideMap(info *RelayInfo) map[string]interface{} { + if info == nil || info.ChannelMeta == nil { + return nil + } + return info.ChannelMeta.HeadersOverride +} + +func sanitizeHeaderOverrideMap(source map[string]interface{}) map[string]interface{} { + if len(source) == 0 { + return map[string]interface{}{} + } + target := make(map[string]interface{}, len(source)) + for key, value := range source { + normalizedKey := normalizeHeaderContextKey(key) + if normalizedKey == "" { + continue + } + normalizedValue := strings.TrimSpace(fmt.Sprintf("%v", value)) + if normalizedValue == "" { + if isHeaderPassthroughRuleKeyForOverride(normalizedKey) { + target[normalizedKey] = "" + } + continue + } + target[normalizedKey] = normalizedValue + } + return target +} + +func isHeaderPassthroughRuleKeyForOverride(key string) bool { + key = strings.TrimSpace(strings.ToLower(key)) + if key == "" { + return false + } + if key == "*" { + return true + } + return strings.HasPrefix(key, "re:") || strings.HasPrefix(key, "regex:") +} + +func GetEffectiveHeaderOverride(info *RelayInfo) map[string]interface{} { + if info == nil { + return map[string]interface{}{} + } + if info.UseRuntimeHeadersOverride { + return sanitizeHeaderOverrideMap(info.RuntimeHeadersOverride) + } + return sanitizeHeaderOverrideMap(getHeaderOverrideMap(info)) +} + +func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) { + // 检查是否包含 "operations" 字段 + opsValue, exists := paramOverride["operations"] + if !exists { + return nil, false + } + + var opMaps []map[string]interface{} + switch ops := opsValue.(type) { + case []interface{}: + opMaps = make([]map[string]interface{}, 0, len(ops)) + for _, op := range ops { + opMap, ok := op.(map[string]interface{}) + if !ok { + return nil, false + } + opMaps = append(opMaps, opMap) + } + case []map[string]interface{}: + opMaps = ops + default: + return nil, false + } + + operations := make([]ParamOperation, 0, len(opMaps)) + for _, opMap := range opMaps { + operation := ParamOperation{} + + // 断言必要字段 + if path, ok := opMap["path"].(string); ok { + operation.Path = path + } + if mode, ok := opMap["mode"].(string); ok { + operation.Mode = mode + } else { + return nil, false // mode 是必需的 + } + + // 可选字段 + if value, exists := opMap["value"]; exists { + operation.Value = value + } + if keepOrigin, ok := opMap["keep_origin"].(bool); ok { + operation.KeepOrigin = keepOrigin + } + if from, ok := opMap["from"].(string); ok { + operation.From = from + } + if to, ok := opMap["to"].(string); ok { + operation.To = to + } + if logic, ok := opMap["logic"].(string); ok { + operation.Logic = logic + } else { + operation.Logic = "OR" // 默认为OR + } + + // 解析条件 + if conditions, exists := opMap["conditions"]; exists { + parsedConditions, err := parseConditionOperations(conditions) + if err != nil { + return nil, false + } + operation.Conditions = append(operation.Conditions, parsedConditions...) + } + + operations = append(operations, operation) + } + return operations, true } func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) { @@ -139,20 +286,9 @@ func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperatio } if strings.ToUpper(logic) == "AND" { - for _, result := range results { - if !result { - return false, nil - } - } - return true, nil - } else { - for _, result := range results { - if result { - return true, nil - } - } - return false, nil + return lo.EveryBy(results, func(item bool) bool { return item }), nil } + return lo.SomeBy(results, func(item bool) bool { return item }), nil } func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) { @@ -309,13 +445,10 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{} } func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) { - var contextJSON string - if conditionContext != nil && len(conditionContext) > 0 { - ctxBytes, err := common.Marshal(conditionContext) - if err != nil { - return "", fmt.Errorf("failed to marshal condition context: %v", err) - } - contextJSON = string(ctxBytes) + context := ensureContextMap(conditionContext) + contextJSON, err := marshalContextJSON(context) + if err != nil { + return "", fmt.Errorf("failed to marshal condition context: %v", err) } result := jsonStr @@ -372,16 +505,631 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte result, err = replaceStringValue(result, opPath, op.From, op.To) case "regex_replace": result, err = regexReplaceStringValue(result, opPath, op.From, op.To) + case "return_error": + returnErr, parseErr := parseParamOverrideReturnError(op.Value) + if parseErr != nil { + return "", parseErr + } + return "", returnErr + case "prune_objects": + result, err = pruneObjects(result, opPath, contextJSON, op.Value) + case "set_header": + err = setHeaderOverrideInContext(context, op.Path, op.Value, op.KeepOrigin) + if err == nil { + contextJSON, err = marshalContextJSON(context) + } + case "delete_header": + err = deleteHeaderOverrideInContext(context, op.Path) + if err == nil { + contextJSON, err = marshalContextJSON(context) + } + case "copy_header": + sourceHeader := strings.TrimSpace(op.From) + targetHeader := strings.TrimSpace(op.To) + if sourceHeader == "" { + sourceHeader = strings.TrimSpace(op.Path) + } + if targetHeader == "" { + targetHeader = strings.TrimSpace(op.Path) + } + err = copyHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin) + if errors.Is(err, errSourceHeaderNotFound) { + err = nil + } + if err == nil { + contextJSON, err = marshalContextJSON(context) + } + case "move_header": + sourceHeader := strings.TrimSpace(op.From) + targetHeader := strings.TrimSpace(op.To) + if sourceHeader == "" { + sourceHeader = strings.TrimSpace(op.Path) + } + if targetHeader == "" { + targetHeader = strings.TrimSpace(op.Path) + } + err = moveHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin) + if errors.Is(err, errSourceHeaderNotFound) { + err = nil + } + if err == nil { + contextJSON, err = marshalContextJSON(context) + } + case "pass_headers": + headerNames, parseErr := parseHeaderPassThroughNames(op.Value) + if parseErr != nil { + return "", parseErr + } + for _, headerName := range headerNames { + if err = copyHeaderInContext(context, headerName, headerName, op.KeepOrigin); err != nil { + if errors.Is(err, errSourceHeaderNotFound) { + err = nil + continue + } + break + } + } + if err == nil { + contextJSON, err = marshalContextJSON(context) + } + case "sync_fields": + result, err = syncFieldsBetweenTargets(result, context, op.From, op.To) + if err == nil { + contextJSON, err = marshalContextJSON(context) + } default: return "", fmt.Errorf("unknown operation: %s", op.Mode) } if err != nil { - return "", fmt.Errorf("operation %s failed: %v", op.Mode, err) + return "", fmt.Errorf("operation %s failed: %w", op.Mode, err) } } return result, nil } +func parseParamOverrideReturnError(value interface{}) (*ParamOverrideReturnError, error) { + result := &ParamOverrideReturnError{ + StatusCode: http.StatusBadRequest, + Code: string(types.ErrorCodeInvalidRequest), + Type: "invalid_request_error", + SkipRetry: true, + } + + switch raw := value.(type) { + case nil: + return nil, fmt.Errorf("return_error value is required") + case string: + result.Message = strings.TrimSpace(raw) + case map[string]interface{}: + if message, ok := raw["message"].(string); ok { + result.Message = strings.TrimSpace(message) + } + if result.Message == "" { + if message, ok := raw["msg"].(string); ok { + result.Message = strings.TrimSpace(message) + } + } + + if code, exists := raw["code"]; exists { + codeStr := strings.TrimSpace(fmt.Sprintf("%v", code)) + if codeStr != "" { + result.Code = codeStr + } + } + if errType, ok := raw["type"].(string); ok { + errType = strings.TrimSpace(errType) + if errType != "" { + result.Type = errType + } + } + if skipRetry, ok := raw["skip_retry"].(bool); ok { + result.SkipRetry = skipRetry + } + + if statusCodeRaw, exists := raw["status_code"]; exists { + statusCode, ok := parseOverrideInt(statusCodeRaw) + if !ok { + return nil, fmt.Errorf("return_error status_code must be an integer") + } + result.StatusCode = statusCode + } else if statusRaw, exists := raw["status"]; exists { + statusCode, ok := parseOverrideInt(statusRaw) + if !ok { + return nil, fmt.Errorf("return_error status must be an integer") + } + result.StatusCode = statusCode + } + default: + return nil, fmt.Errorf("return_error value must be string or object") + } + + if result.Message == "" { + return nil, fmt.Errorf("return_error message is required") + } + if result.StatusCode < http.StatusContinue || result.StatusCode > http.StatusNetworkAuthenticationRequired { + return nil, fmt.Errorf("return_error status code out of range: %d", result.StatusCode) + } + + return result, nil +} + +func parseOverrideInt(v interface{}) (int, bool) { + switch value := v.(type) { + case int: + return value, true + case float64: + if value != float64(int(value)) { + return 0, false + } + return int(value), true + default: + return 0, false + } +} + +func ensureContextMap(conditionContext map[string]interface{}) map[string]interface{} { + if conditionContext != nil { + return conditionContext + } + return make(map[string]interface{}) +} + +func marshalContextJSON(context map[string]interface{}) (string, error) { + if context == nil || len(context) == 0 { + return "", nil + } + ctxBytes, err := common.Marshal(context) + if err != nil { + return "", err + } + return string(ctxBytes), nil +} + +func setHeaderOverrideInContext(context map[string]interface{}, headerName string, value interface{}, keepOrigin bool) error { + headerName = normalizeHeaderContextKey(headerName) + if headerName == "" { + return fmt.Errorf("header name is required") + } + + rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride) + if keepOrigin { + if existing, ok := rawHeaders[headerName]; ok { + existingValue := strings.TrimSpace(fmt.Sprintf("%v", existing)) + if existingValue != "" { + return nil + } + } + } + + headerValue, hasValue, err := resolveHeaderOverrideValue(context, headerName, value) + if err != nil { + return err + } + if !hasValue { + delete(rawHeaders, headerName) + return nil + } + + rawHeaders[headerName] = headerValue + return nil +} + +func resolveHeaderOverrideValue(context map[string]interface{}, headerName string, value interface{}) (string, bool, error) { + if value == nil { + return "", false, fmt.Errorf("header value is required") + } + + if mapping, ok := value.(map[string]interface{}); ok { + return resolveHeaderOverrideValueByMapping(context, headerName, mapping) + } + if mapping, ok := value.(map[string]string); ok { + converted := make(map[string]interface{}, len(mapping)) + for key, item := range mapping { + converted[key] = item + } + return resolveHeaderOverrideValueByMapping(context, headerName, converted) + } + + headerValue := strings.TrimSpace(fmt.Sprintf("%v", value)) + if headerValue == "" { + return "", false, nil + } + return headerValue, true, nil +} + +func resolveHeaderOverrideValueByMapping(context map[string]interface{}, headerName string, mapping map[string]interface{}) (string, bool, error) { + if len(mapping) == 0 { + return "", false, fmt.Errorf("header value mapping cannot be empty") + } + + sourceValue, exists := getHeaderValueFromContext(context, headerName) + if !exists { + return "", false, nil + } + sourceTokens := splitHeaderListValue(sourceValue) + if len(sourceTokens) == 0 { + return "", false, nil + } + + wildcardValue, hasWildcard := mapping["*"] + resultTokens := make([]string, 0, len(sourceTokens)) + for _, token := range sourceTokens { + replacementRaw, hasReplacement := mapping[token] + if !hasReplacement && hasWildcard { + replacementRaw = wildcardValue + hasReplacement = true + } + if !hasReplacement { + resultTokens = append(resultTokens, token) + continue + } + replacementTokens, err := parseHeaderReplacementTokens(replacementRaw) + if err != nil { + return "", false, err + } + resultTokens = append(resultTokens, replacementTokens...) + } + + resultTokens = lo.Uniq(resultTokens) + if len(resultTokens) == 0 { + return "", false, nil + } + return strings.Join(resultTokens, ","), true, nil +} + +func parseHeaderReplacementTokens(value interface{}) ([]string, error) { + switch raw := value.(type) { + case nil: + return nil, nil + case string: + return splitHeaderListValue(raw), nil + case []string: + tokens := make([]string, 0, len(raw)) + for _, item := range raw { + tokens = append(tokens, splitHeaderListValue(item)...) + } + return lo.Uniq(tokens), nil + case []interface{}: + tokens := make([]string, 0, len(raw)) + for _, item := range raw { + itemTokens, err := parseHeaderReplacementTokens(item) + if err != nil { + return nil, err + } + tokens = append(tokens, itemTokens...) + } + return lo.Uniq(tokens), nil + case map[string]interface{}, map[string]string: + return nil, fmt.Errorf("header replacement value must be string, array or null") + default: + token := strings.TrimSpace(fmt.Sprintf("%v", raw)) + if token == "" { + return nil, nil + } + return []string{token}, nil + } +} + +func splitHeaderListValue(raw string) []string { + items := strings.Split(raw, ",") + return lo.FilterMap(items, func(item string, _ int) (string, bool) { + token := strings.TrimSpace(item) + if token == "" { + return "", false + } + return token, true + }) +} + +func copyHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error { + fromHeader = normalizeHeaderContextKey(fromHeader) + toHeader = normalizeHeaderContextKey(toHeader) + if fromHeader == "" || toHeader == "" { + return fmt.Errorf("copy_header from/to is required") + } + value, exists := getHeaderValueFromContext(context, fromHeader) + if !exists { + return fmt.Errorf("%w: %s", errSourceHeaderNotFound, fromHeader) + } + return setHeaderOverrideInContext(context, toHeader, value, keepOrigin) +} + +func moveHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error { + fromHeader = normalizeHeaderContextKey(fromHeader) + toHeader = normalizeHeaderContextKey(toHeader) + if fromHeader == "" || toHeader == "" { + return fmt.Errorf("move_header from/to is required") + } + if err := copyHeaderInContext(context, fromHeader, toHeader, keepOrigin); err != nil { + return err + } + if strings.EqualFold(fromHeader, toHeader) { + return nil + } + return deleteHeaderOverrideInContext(context, fromHeader) +} + +func deleteHeaderOverrideInContext(context map[string]interface{}, headerName string) error { + headerName = normalizeHeaderContextKey(headerName) + if headerName == "" { + return fmt.Errorf("header name is required") + } + rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride) + delete(rawHeaders, headerName) + return nil +} + +func parseHeaderPassThroughNames(value interface{}) ([]string, error) { + normalizeNames := func(values []string) []string { + names := lo.FilterMap(values, func(item string, _ int) (string, bool) { + headerName := normalizeHeaderContextKey(item) + if headerName == "" { + return "", false + } + return headerName, true + }) + return lo.Uniq(names) + } + + switch raw := value.(type) { + case nil: + return nil, fmt.Errorf("pass_headers value is required") + case string: + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return nil, fmt.Errorf("pass_headers value is required") + } + if strings.HasPrefix(trimmed, "[") || strings.HasPrefix(trimmed, "{") { + var parsed interface{} + if err := common.UnmarshalJsonStr(trimmed, &parsed); err == nil { + return parseHeaderPassThroughNames(parsed) + } + } + names := normalizeNames(strings.Split(trimmed, ",")) + if len(names) == 0 { + return nil, fmt.Errorf("pass_headers value is invalid") + } + return names, nil + case []interface{}: + names := lo.FilterMap(raw, func(item interface{}, _ int) (string, bool) { + headerName := normalizeHeaderContextKey(fmt.Sprintf("%v", item)) + if headerName == "" { + return "", false + } + return headerName, true + }) + names = lo.Uniq(names) + if len(names) == 0 { + return nil, fmt.Errorf("pass_headers value is invalid") + } + return names, nil + case []string: + names := lo.FilterMap(raw, func(item string, _ int) (string, bool) { + headerName := normalizeHeaderContextKey(item) + if headerName == "" { + return "", false + } + return headerName, true + }) + names = lo.Uniq(names) + if len(names) == 0 { + return nil, fmt.Errorf("pass_headers value is invalid") + } + return names, nil + case map[string]interface{}: + candidates := make([]string, 0, 8) + if headersRaw, ok := raw["headers"]; ok { + names, err := parseHeaderPassThroughNames(headersRaw) + if err == nil { + candidates = append(candidates, names...) + } + } + if namesRaw, ok := raw["names"]; ok { + names, err := parseHeaderPassThroughNames(namesRaw) + if err == nil { + candidates = append(candidates, names...) + } + } + if headerRaw, ok := raw["header"]; ok { + names, err := parseHeaderPassThroughNames(headerRaw) + if err == nil { + candidates = append(candidates, names...) + } + } + names := normalizeNames(candidates) + if len(names) == 0 { + return nil, fmt.Errorf("pass_headers value is invalid") + } + return names, nil + default: + return nil, fmt.Errorf("pass_headers value must be string, array or object") + } +} + +type syncTarget struct { + kind string + key string +} + +func parseSyncTarget(spec string) (syncTarget, error) { + raw := strings.TrimSpace(spec) + if raw == "" { + return syncTarget{}, fmt.Errorf("sync_fields target is required") + } + + idx := strings.Index(raw, ":") + if idx < 0 { + // Backward compatibility: treat bare value as JSON path. + return syncTarget{ + kind: "json", + key: raw, + }, nil + } + + kind := strings.ToLower(strings.TrimSpace(raw[:idx])) + key := strings.TrimSpace(raw[idx+1:]) + if key == "" { + return syncTarget{}, fmt.Errorf("sync_fields target key is required: %s", raw) + } + + switch kind { + case "json", "body": + return syncTarget{ + kind: "json", + key: key, + }, nil + case "header": + return syncTarget{ + kind: "header", + key: key, + }, nil + default: + return syncTarget{}, fmt.Errorf("sync_fields target prefix is invalid: %s", raw) + } +} + +func readSyncTargetValue(jsonStr string, context map[string]interface{}, target syncTarget) (interface{}, bool, error) { + switch target.kind { + case "json": + path := processNegativeIndex(jsonStr, target.key) + value := gjson.Get(jsonStr, path) + if !value.Exists() || value.Type == gjson.Null { + return nil, false, nil + } + if value.Type == gjson.String && strings.TrimSpace(value.String()) == "" { + return nil, false, nil + } + return value.Value(), true, nil + case "header": + value, ok := getHeaderValueFromContext(context, target.key) + if !ok || strings.TrimSpace(value) == "" { + return nil, false, nil + } + return value, true, nil + default: + return nil, false, fmt.Errorf("unsupported sync_fields target kind: %s", target.kind) + } +} + +func writeSyncTargetValue(jsonStr string, context map[string]interface{}, target syncTarget, value interface{}) (string, error) { + switch target.kind { + case "json": + path := processNegativeIndex(jsonStr, target.key) + nextJSON, err := sjson.Set(jsonStr, path, value) + if err != nil { + return "", err + } + return nextJSON, nil + case "header": + if err := setHeaderOverrideInContext(context, target.key, value, false); err != nil { + return "", err + } + return jsonStr, nil + default: + return "", fmt.Errorf("unsupported sync_fields target kind: %s", target.kind) + } +} + +func syncFieldsBetweenTargets(jsonStr string, context map[string]interface{}, fromSpec string, toSpec string) (string, error) { + fromTarget, err := parseSyncTarget(fromSpec) + if err != nil { + return "", err + } + toTarget, err := parseSyncTarget(toSpec) + if err != nil { + return "", err + } + + fromValue, fromExists, err := readSyncTargetValue(jsonStr, context, fromTarget) + if err != nil { + return "", err + } + toValue, toExists, err := readSyncTargetValue(jsonStr, context, toTarget) + if err != nil { + return "", err + } + + // If one side exists and the other side is missing, sync the missing side. + if fromExists && !toExists { + return writeSyncTargetValue(jsonStr, context, toTarget, fromValue) + } + if toExists && !fromExists { + return writeSyncTargetValue(jsonStr, context, fromTarget, toValue) + } + return jsonStr, nil +} + +func ensureMapKeyInContext(context map[string]interface{}, key string) map[string]interface{} { + if context == nil { + return map[string]interface{}{} + } + if existing, ok := context[key]; ok { + if mapVal, ok := existing.(map[string]interface{}); ok { + return mapVal + } + } + result := make(map[string]interface{}) + context[key] = result + return result +} + +func getHeaderValueFromContext(context map[string]interface{}, headerName string) (string, bool) { + headerName = normalizeHeaderContextKey(headerName) + if headerName == "" { + return "", false + } + for _, key := range []string{paramOverrideContextHeaderOverride, paramOverrideContextRequestHeaders} { + source := ensureMapKeyInContext(context, key) + raw, ok := source[headerName] + if !ok { + continue + } + value := strings.TrimSpace(fmt.Sprintf("%v", raw)) + if value != "" { + return value, true + } + } + return "", false +} + +func normalizeHeaderContextKey(key string) string { + return strings.TrimSpace(strings.ToLower(key)) +} + +func buildRequestHeadersContext(headers map[string]string) map[string]interface{} { + if len(headers) == 0 { + return map[string]interface{}{} + } + entries := lo.Entries(headers) + normalizedEntries := lo.FilterMap(entries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) { + normalized := normalizeHeaderContextKey(item.Key) + value := strings.TrimSpace(item.Value) + if normalized == "" || value == "" { + return lo.Entry[string, string]{}, false + } + return lo.Entry[string, string]{Key: normalized, Value: value}, true + }) + return lo.SliceToMap(normalizedEntries, func(item lo.Entry[string, string]) (string, interface{}) { + return item.Key, item.Value + }) +} + +func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]interface{}) { + if info == nil || context == nil { + return + } + raw, exists := context[paramOverrideContextHeaderOverride] + if !exists { + return + } + rawMap, ok := raw.(map[string]interface{}) + if !ok { + return + } + info.RuntimeHeadersOverride = sanitizeHeaderOverrideMap(rawMap) + info.UseRuntimeHeadersOverride = true +} + func moveValue(jsonStr, fromPath, toPath string) (string, error) { sourceValue := gjson.Get(jsonStr, fromPath) if !sourceValue.Exists() { @@ -537,6 +1285,235 @@ func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement)) } +type pruneObjectsOptions struct { + conditions []ConditionOperation + logic string + recursive bool +} + +func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string, error) { + options, err := parsePruneObjectsOptions(value) + if err != nil { + return "", err + } + + if path == "" { + var root interface{} + if err := common.Unmarshal([]byte(jsonStr), &root); err != nil { + return "", err + } + cleaned, _, err := pruneObjectsNode(root, options, contextJSON, true) + if err != nil { + return "", err + } + cleanedBytes, err := common.Marshal(cleaned) + if err != nil { + return "", err + } + return string(cleanedBytes), nil + } + + target := gjson.Get(jsonStr, path) + if !target.Exists() { + return jsonStr, nil + } + + var targetNode interface{} + if target.Type == gjson.JSON { + if err := common.Unmarshal([]byte(target.Raw), &targetNode); err != nil { + return "", err + } + } else { + targetNode = target.Value() + } + + cleaned, _, err := pruneObjectsNode(targetNode, options, contextJSON, true) + if err != nil { + return "", err + } + cleanedBytes, err := common.Marshal(cleaned) + if err != nil { + return "", err + } + return sjson.SetRaw(jsonStr, path, string(cleanedBytes)) +} + +func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) { + opts := pruneObjectsOptions{ + logic: "AND", + recursive: true, + } + + switch raw := value.(type) { + case nil: + return opts, fmt.Errorf("prune_objects value is required") + case string: + v := strings.TrimSpace(raw) + if v == "" { + return opts, fmt.Errorf("prune_objects value is required") + } + opts.conditions = []ConditionOperation{ + { + Path: "type", + Mode: "full", + Value: v, + }, + } + case map[string]interface{}: + if logic, ok := raw["logic"].(string); ok && strings.TrimSpace(logic) != "" { + opts.logic = logic + } + if recursive, ok := raw["recursive"].(bool); ok { + opts.recursive = recursive + } + + if condRaw, exists := raw["conditions"]; exists { + conditions, err := parseConditionOperations(condRaw) + if err != nil { + return opts, err + } + opts.conditions = append(opts.conditions, conditions...) + } + + if whereRaw, exists := raw["where"]; exists { + whereMap, ok := whereRaw.(map[string]interface{}) + if !ok { + return opts, fmt.Errorf("prune_objects where must be object") + } + for key, val := range whereMap { + key = strings.TrimSpace(key) + if key == "" { + continue + } + opts.conditions = append(opts.conditions, ConditionOperation{ + Path: key, + Mode: "full", + Value: val, + }) + } + } + + if matchType, exists := raw["type"]; exists { + opts.conditions = append(opts.conditions, ConditionOperation{ + Path: "type", + Mode: "full", + Value: matchType, + }) + } + default: + return opts, fmt.Errorf("prune_objects value must be string or object") + } + + if len(opts.conditions) == 0 { + return opts, fmt.Errorf("prune_objects conditions are required") + } + return opts, nil +} + +func parseConditionOperations(raw interface{}) ([]ConditionOperation, error) { + switch typed := raw.(type) { + case map[string]interface{}: + entries := lo.Entries(typed) + conditions := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (ConditionOperation, bool) { + path := strings.TrimSpace(item.Key) + if path == "" { + return ConditionOperation{}, false + } + return ConditionOperation{ + Path: path, + Mode: "full", + Value: item.Value, + }, true + }) + if len(conditions) == 0 { + return nil, fmt.Errorf("conditions object must contain at least one key") + } + return conditions, nil + case []interface{}: + items := typed + result := make([]ConditionOperation, 0, len(items)) + for _, item := range items { + itemMap, ok := item.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("condition must be object") + } + path, _ := itemMap["path"].(string) + mode, _ := itemMap["mode"].(string) + if strings.TrimSpace(path) == "" || strings.TrimSpace(mode) == "" { + return nil, fmt.Errorf("condition path/mode is required") + } + condition := ConditionOperation{ + Path: path, + Mode: mode, + } + if value, exists := itemMap["value"]; exists { + condition.Value = value + } + if invert, ok := itemMap["invert"].(bool); ok { + condition.Invert = invert + } + if passMissingKey, ok := itemMap["pass_missing_key"].(bool); ok { + condition.PassMissingKey = passMissingKey + } + result = append(result, condition) + } + return result, nil + default: + return nil, fmt.Errorf("conditions must be an array or object") + } +} + +func pruneObjectsNode(node interface{}, options pruneObjectsOptions, contextJSON string, isRoot bool) (interface{}, bool, error) { + switch value := node.(type) { + case []interface{}: + result := make([]interface{}, 0, len(value)) + for _, item := range value { + next, drop, err := pruneObjectsNode(item, options, contextJSON, false) + if err != nil { + return nil, false, err + } + if drop { + continue + } + result = append(result, next) + } + return result, false, nil + case map[string]interface{}: + shouldDrop, err := shouldPruneObject(value, options, contextJSON) + if err != nil { + return nil, false, err + } + if shouldDrop && !isRoot { + return nil, true, nil + } + if !options.recursive { + return value, false, nil + } + for key, child := range value { + next, drop, err := pruneObjectsNode(child, options, contextJSON, false) + if err != nil { + return nil, false, err + } + if drop { + delete(value, key) + continue + } + value[key] = next + } + return value, false, nil + default: + return node, false, nil + } +} + +func shouldPruneObject(node map[string]interface{}, options pruneObjectsOptions, contextJSON string) (bool, error) { + nodeBytes, err := common.Marshal(node) + if err != nil { + return false, err + } + return checkConditions(string(nodeBytes), contextJSON, options.conditions, options.logic) +} + func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) { current := gjson.Get(jsonStr, path) var currentMap, newMap map[string]interface{} @@ -598,6 +1575,37 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} { } } + ctx[paramOverrideContextRequestHeaders] = buildRequestHeadersContext(info.RequestHeaders) + + headerOverrideSource := GetEffectiveHeaderOverride(info) + ctx[paramOverrideContextHeaderOverride] = sanitizeHeaderOverrideMap(headerOverrideSource) + + ctx["retry_index"] = info.RetryIndex + ctx["is_retry"] = info.RetryIndex > 0 + ctx["retry"] = map[string]interface{}{ + "index": info.RetryIndex, + "is_retry": info.RetryIndex > 0, + } + + if info.LastError != nil { + code := string(info.LastError.GetErrorCode()) + errorType := string(info.LastError.GetErrorType()) + lastError := map[string]interface{}{ + "status_code": info.LastError.StatusCode, + "message": info.LastError.Error(), + "code": code, + "error_code": code, + "type": errorType, + "error_type": errorType, + "skip_retry": types.IsSkipRetryError(info.LastError), + } + ctx["last_error"] = lastError + ctx["last_error_status_code"] = info.LastError.StatusCode + ctx["last_error_message"] = info.LastError.Error() + ctx["last_error_code"] = code + ctx["last_error_type"] = errorType + } + ctx["is_channel_test"] = info.IsChannelTest return ctx } diff --git a/relay/common/override_test.go b/relay/common/override_test.go index c83cddff..5f49d95a 100644 --- a/relay/common/override_test.go +++ b/relay/common/override_test.go @@ -5,6 +5,8 @@ import ( "reflect" "testing" + "github.com/QuantumNous/new-api/types" + "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/setting/model_setting" ) @@ -775,6 +777,754 @@ func TestApplyParamOverrideToUpper(t *testing.T) { assertJSONEqual(t, `{"model":"GPT-4"}`, string(out)) } +func TestApplyParamOverrideReturnError(t *testing.T) { + input := []byte(`{"model":"gemini-2.5-pro"}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "return_error", + "value": map[string]interface{}{ + "message": "forced bad request by param override", + "status_code": 422, + "code": "forced_bad_request", + "type": "invalid_request_error", + "skip_retry": true, + }, + "conditions": []interface{}{ + map[string]interface{}{ + "path": "retry.is_retry", + "mode": "full", + "value": true, + }, + }, + }, + }, + } + ctx := map[string]interface{}{ + "retry": map[string]interface{}{ + "index": 1, + "is_retry": true, + }, + } + + _, err := ApplyParamOverride(input, override, ctx) + if err == nil { + t.Fatalf("expected error, got nil") + } + returnErr, ok := AsParamOverrideReturnError(err) + if !ok { + t.Fatalf("expected ParamOverrideReturnError, got %T: %v", err, err) + } + if returnErr.StatusCode != 422 { + t.Fatalf("expected status 422, got %d", returnErr.StatusCode) + } + if returnErr.Code != "forced_bad_request" { + t.Fatalf("expected code forced_bad_request, got %s", returnErr.Code) + } + if !returnErr.SkipRetry { + t.Fatalf("expected skip_retry true") + } +} + +func TestApplyParamOverridePruneObjectsByTypeString(t *testing.T) { + input := []byte(`{ + "messages":[ + {"role":"assistant","content":[ + {"type":"output_text","text":"a"}, + {"type":"redacted_thinking","text":"secret"}, + {"type":"tool_call","name":"tool_a"} + ]}, + {"role":"assistant","content":[ + {"type":"output_text","text":"b"}, + {"type":"wrapper","parts":[ + {"type":"redacted_thinking","text":"secret2"}, + {"type":"output_text","text":"c"} + ]} + ]} + ] + }`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "prune_objects", + "value": "redacted_thinking", + }, + }, + } + + out, err := ApplyParamOverride(input, override, nil) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{ + "messages":[ + {"role":"assistant","content":[ + {"type":"output_text","text":"a"}, + {"type":"tool_call","name":"tool_a"} + ]}, + {"role":"assistant","content":[ + {"type":"output_text","text":"b"}, + {"type":"wrapper","parts":[ + {"type":"output_text","text":"c"} + ]} + ]} + ] + }`, string(out)) +} + +func TestApplyParamOverridePruneObjectsWhereAndPath(t *testing.T) { + input := []byte(`{ + "a":{"items":[{"type":"redacted_thinking","id":1},{"type":"output_text","id":2}]}, + "b":{"items":[{"type":"redacted_thinking","id":3},{"type":"output_text","id":4}]} + }`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "path": "a", + "mode": "prune_objects", + "value": map[string]interface{}{ + "where": map[string]interface{}{ + "type": "redacted_thinking", + }, + }, + }, + }, + } + + out, err := ApplyParamOverride(input, override, nil) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{ + "a":{"items":[{"type":"output_text","id":2}]}, + "b":{"items":[{"type":"redacted_thinking","id":3},{"type":"output_text","id":4}]} + }`, string(out)) +} + +func TestApplyParamOverrideNormalizeThinkingSignatureUnsupported(t *testing.T) { + input := []byte(`{"items":[{"type":"redacted_thinking"}]}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "normalize_thinking_signature", + }, + }, + } + + _, err := ApplyParamOverride(input, override, nil) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestApplyParamOverrideConditionFromRetryAndLastErrorContext(t *testing.T) { + info := &RelayInfo{ + RetryIndex: 1, + LastError: types.WithOpenAIError(types.OpenAIError{ + Message: "invalid thinking signature", + Type: "invalid_request_error", + Code: "bad_thought_signature", + }, 400), + } + ctx := BuildParamOverrideContext(info) + + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "path": "temperature", + "mode": "set", + "value": 0.1, + "logic": "AND", + "conditions": []interface{}{ + map[string]interface{}{ + "path": "is_retry", + "mode": "full", + "value": true, + }, + map[string]interface{}{ + "path": "last_error.code", + "mode": "contains", + "value": "thought_signature", + }, + }, + }, + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.1}`, string(out)) +} + +func TestApplyParamOverrideConditionFromRequestHeaders(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "path": "temperature", + "mode": "set", + "value": 0.1, + "conditions": []interface{}{ + map[string]interface{}{ + "path": "request_headers.authorization", + "mode": "contains", + "value": "Bearer ", + }, + }, + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "authorization": "Bearer token-123", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.1}`, string(out)) +} + +func TestApplyParamOverrideSetHeaderAndUseInLaterCondition(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "set_header", + "path": "X-Debug-Mode", + "value": "enabled", + }, + map[string]interface{}{ + "path": "temperature", + "mode": "set", + "value": 0.1, + "conditions": []interface{}{ + map[string]interface{}{ + "path": "header_override.x-debug-mode", + "mode": "full", + "value": "enabled", + }, + }, + }, + }, + } + + out, err := ApplyParamOverride(input, override, nil) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.1}`, string(out)) +} + +func TestApplyParamOverrideCopyHeaderFromRequestHeaders(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "copy_header", + "from": "Authorization", + "to": "X-Upstream-Auth", + }, + map[string]interface{}{ + "path": "temperature", + "mode": "set", + "value": 0.1, + "conditions": []interface{}{ + map[string]interface{}{ + "path": "header_override.x-upstream-auth", + "mode": "contains", + "value": "Bearer ", + }, + }, + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "authorization": "Bearer token-123", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.1}`, string(out)) +} + +func TestApplyParamOverridePassHeadersSkipsMissingHeaders(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "pass_headers", + "value": []interface{}{"X-Codex-Beta-Features", "Session_id"}, + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "session_id": "sess-123", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.7}`, string(out)) + + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + t.Fatalf("expected header_override context map") + } + if headers["session_id"] != "sess-123" { + t.Fatalf("expected session_id to be passed, got: %v", headers["session_id"]) + } + if _, exists := headers["x-codex-beta-features"]; exists { + t.Fatalf("expected missing header to be skipped") + } +} + +func TestApplyParamOverrideCopyHeaderSkipsMissingSource(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "copy_header", + "from": "X-Missing-Header", + "to": "X-Upstream-Auth", + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "authorization": "Bearer token-123", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.7}`, string(out)) + + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + return + } + if _, exists := headers["x-upstream-auth"]; exists { + t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing") + } +} + +func TestApplyParamOverrideMoveHeaderSkipsMissingSource(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "move_header", + "from": "X-Missing-Header", + "to": "X-Upstream-Auth", + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "authorization": "Bearer token-123", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.7}`, string(out)) + + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + return + } + if _, exists := headers["x-upstream-auth"]; exists { + t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing") + } +} + +func TestApplyParamOverrideSyncFieldsHeaderToJSON(t *testing.T) { + input := []byte(`{"model":"gpt-4"}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "sync_fields", + "from": "header:session_id", + "to": "json:prompt_cache_key", + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "session_id": "sess-123", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"model":"gpt-4","prompt_cache_key":"sess-123"}`, string(out)) +} + +func TestApplyParamOverrideSyncFieldsJSONToHeader(t *testing.T) { + input := []byte(`{"model":"gpt-4","prompt_cache_key":"cache-abc"}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "sync_fields", + "from": "header:session_id", + "to": "json:prompt_cache_key", + }, + }, + } + ctx := map[string]interface{}{} + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"model":"gpt-4","prompt_cache_key":"cache-abc"}`, string(out)) + + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + t.Fatalf("expected header_override context map") + } + if headers["session_id"] != "cache-abc" { + t.Fatalf("expected session_id to be synced from prompt_cache_key, got: %v", headers["session_id"]) + } +} + +func TestApplyParamOverrideSyncFieldsNoChangeWhenBothExist(t *testing.T) { + input := []byte(`{"model":"gpt-4","prompt_cache_key":"cache-body"}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "sync_fields", + "from": "header:session_id", + "to": "json:prompt_cache_key", + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "session_id": "cache-header", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"model":"gpt-4","prompt_cache_key":"cache-body"}`, string(out)) + + headers, _ := ctx["header_override"].(map[string]interface{}) + if headers != nil { + if _, exists := headers["session_id"]; exists { + t.Fatalf("expected no override when both sides already have value") + } + } +} + +func TestApplyParamOverrideSyncFieldsInvalidTarget(t *testing.T) { + input := []byte(`{"model":"gpt-4"}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "sync_fields", + "from": "foo:session_id", + "to": "json:prompt_cache_key", + }, + }, + } + + _, err := ApplyParamOverride(input, override, nil) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestApplyParamOverrideSetHeaderKeepOrigin(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "set_header", + "path": "X-Feature-Flag", + "value": "new-value", + "keep_origin": true, + }, + }, + } + ctx := map[string]interface{}{ + "header_override": map[string]interface{}{ + "x-feature-flag": "legacy-value", + }, + } + + _, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + t.Fatalf("expected header_override context map") + } + if headers["x-feature-flag"] != "legacy-value" { + t.Fatalf("expected keep_origin to preserve old value, got: %v", headers["x-feature-flag"]) + } +} + +func TestApplyParamOverrideSetHeaderMapRewritesCommaSeparatedHeader(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "set_header", + "path": "anthropic-beta", + "value": map[string]interface{}{ + "advanced-tool-use-2025-11-20": nil, + "computer-use-2025-01-24": "computer-use-2025-01-24", + }, + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "anthropic-beta": "advanced-tool-use-2025-11-20, computer-use-2025-01-24", + }, + } + + _, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + t.Fatalf("expected header_override context map") + } + if headers["anthropic-beta"] != "computer-use-2025-01-24" { + t.Fatalf("expected anthropic-beta to keep only mapped value, got: %v", headers["anthropic-beta"]) + } +} + +func TestApplyParamOverrideSetHeaderMapDeleteWholeHeaderWhenAllTokensCleared(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "set_header", + "path": "anthropic-beta", + "value": map[string]interface{}{ + "advanced-tool-use-2025-11-20": nil, + "computer-use-2025-01-24": nil, + }, + }, + }, + } + ctx := map[string]interface{}{ + "header_override": map[string]interface{}{ + "anthropic-beta": "advanced-tool-use-2025-11-20,computer-use-2025-01-24", + }, + } + + _, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + t.Fatalf("expected header_override context map") + } + if _, exists := headers["anthropic-beta"]; exists { + t.Fatalf("expected anthropic-beta to be deleted when all mapped values are null") + } +} + +func TestApplyParamOverrideConditionsObjectShorthand(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "path": "temperature", + "mode": "set", + "value": 0.1, + "logic": "AND", + "conditions": map[string]interface{}{ + "is_retry": true, + "last_error.status_code": 400.0, + }, + }, + }, + } + ctx := map[string]interface{}{ + "is_retry": true, + "last_error": map[string]interface{}{ + "status_code": 400.0, + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.1}`, string(out)) +} + +func TestApplyParamOverrideWithRelayInfoSyncRuntimeHeaders(t *testing.T) { + info := &RelayInfo{ + ChannelMeta: &ChannelMeta{ + ParamOverride: map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "set_header", + "path": "X-Injected-By-Param-Override", + "value": "enabled", + }, + map[string]interface{}{ + "mode": "delete_header", + "path": "X-Delete-Me", + }, + }, + }, + HeadersOverride: map[string]interface{}{ + "X-Delete-Me": "legacy", + "X-Keep-Me": "keep", + }, + }, + } + + input := []byte(`{"temperature":0.7}`) + out, err := ApplyParamOverrideWithRelayInfo(input, info) + if err != nil { + t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.7}`, string(out)) + + if !info.UseRuntimeHeadersOverride { + t.Fatalf("expected runtime header override to be enabled") + } + if info.RuntimeHeadersOverride["x-keep-me"] != "keep" { + t.Fatalf("expected x-keep-me header to be preserved, got: %v", info.RuntimeHeadersOverride["x-keep-me"]) + } + if info.RuntimeHeadersOverride["x-injected-by-param-override"] != "enabled" { + t.Fatalf("expected x-injected-by-param-override header to be set, got: %v", info.RuntimeHeadersOverride["x-injected-by-param-override"]) + } + if _, exists := info.RuntimeHeadersOverride["x-delete-me"]; exists { + t.Fatalf("expected x-delete-me header to be deleted") + } +} + +func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) { + info := &RelayInfo{ + ChannelMeta: &ChannelMeta{ + ParamOverride: map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "move_header", + "from": "X-Legacy-Trace", + "to": "X-Trace", + }, + map[string]interface{}{ + "mode": "copy_header", + "from": "X-Trace", + "to": "X-Trace-Backup", + }, + }, + }, + HeadersOverride: map[string]interface{}{ + "X-Legacy-Trace": "trace-123", + }, + }, + } + + input := []byte(`{"temperature":0.7}`) + _, err := ApplyParamOverrideWithRelayInfo(input, info) + if err != nil { + t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err) + } + if _, exists := info.RuntimeHeadersOverride["x-legacy-trace"]; exists { + t.Fatalf("expected source header to be removed after move") + } + if info.RuntimeHeadersOverride["x-trace"] != "trace-123" { + t.Fatalf("expected x-trace to be set, got: %v", info.RuntimeHeadersOverride["x-trace"]) + } + if info.RuntimeHeadersOverride["x-trace-backup"] != "trace-123" { + t.Fatalf("expected x-trace-backup to be copied, got: %v", info.RuntimeHeadersOverride["x-trace-backup"]) + } +} + +func TestApplyParamOverrideWithRelayInfoSetHeaderMapRewritesAnthropicBeta(t *testing.T) { + info := &RelayInfo{ + ChannelMeta: &ChannelMeta{ + ParamOverride: map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "set_header", + "path": "anthropic-beta", + "value": map[string]interface{}{ + "advanced-tool-use-2025-11-20": nil, + "computer-use-2025-01-24": "computer-use-2025-01-24", + }, + }, + }, + }, + HeadersOverride: map[string]interface{}{ + "anthropic-beta": "advanced-tool-use-2025-11-20, computer-use-2025-01-24", + }, + }, + } + + _, err := ApplyParamOverrideWithRelayInfo([]byte(`{"temperature":0.7}`), info) + if err != nil { + t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err) + } + + if !info.UseRuntimeHeadersOverride { + t.Fatalf("expected runtime header override to be enabled") + } + if info.RuntimeHeadersOverride["anthropic-beta"] != "computer-use-2025-01-24" { + t.Fatalf("expected anthropic-beta to be rewritten, got: %v", info.RuntimeHeadersOverride["anthropic-beta"]) + } +} + +func TestGetEffectiveHeaderOverrideUsesRuntimeOverrideAsFinalResult(t *testing.T) { + info := &RelayInfo{ + UseRuntimeHeadersOverride: true, + RuntimeHeadersOverride: map[string]interface{}{ + "x-runtime": "runtime-only", + }, + ChannelMeta: &ChannelMeta{ + HeadersOverride: map[string]interface{}{ + "X-Static": "static-value", + "X-Deleted": "should-not-exist", + }, + }, + } + + effective := GetEffectiveHeaderOverride(info) + if effective["x-runtime"] != "runtime-only" { + t.Fatalf("expected x-runtime from runtime override, got: %v", effective["x-runtime"]) + } + if _, exists := effective["x-static"]; exists { + t.Fatalf("expected runtime override to be final and not merge channel headers") + } +} + func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) { input := `{ "service_tier":"flex", diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 6d286d61..8b0789c0 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -101,6 +101,7 @@ type RelayInfo struct { RelayMode int OriginModelName string RequestURLPath string + RequestHeaders map[string]string ShouldIncludeUsage bool DisablePing bool // 是否禁止向下游发送自定义 Ping ClientWs *websocket.Conn @@ -144,6 +145,10 @@ type RelayInfo struct { SubscriptionAmountUsedAfterPreConsume int64 IsClaudeBetaQuery bool // /v1/messages?beta=true IsChannelTest bool // channel test request + RetryIndex int + LastError *types.NewAPIError + RuntimeHeadersOverride map[string]interface{} + UseRuntimeHeadersOverride bool PriceData types.PriceData @@ -461,6 +466,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { isFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), RequestURLPath: c.Request.URL.String(), + RequestHeaders: cloneRequestHeaders(c), IsStream: isStream, StartTime: startTime, @@ -493,6 +499,27 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { return info } +func cloneRequestHeaders(c *gin.Context) map[string]string { + if c == nil || c.Request == nil { + return nil + } + if len(c.Request.Header) == 0 { + return nil + } + headers := make(map[string]string, len(c.Request.Header)) + for key := range c.Request.Header { + value := strings.TrimSpace(c.Request.Header.Get(key)) + if value == "" { + continue + } + headers[key] = value + } + if len(headers) == 0 { + return nil + } + return headers +} + func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) { var info *RelayInfo var err error diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index a133bab8..9a25237c 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -172,9 +172,9 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index 1a41756b..d8ca4223 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "encoding/json" "fmt" "net/http" @@ -46,15 +45,15 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } relaycommon.AppendRequestConversionFromRequest(info, convertedRequest) - jsonData, err := json.Marshal(convertedRequest) + jsonData, err := common.Marshal(convertedRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index a1b8e592..39bd44e6 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -157,9 +157,9 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } @@ -257,14 +257,9 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI // apply param override if len(info.ParamOverride) > 0 { - reqMap := make(map[string]interface{}) - _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range info.ParamOverride { - reqMap[key] = value - } - jsonData, err = common.Marshal(reqMap) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } logger.LogDebug(c, "Gemini embedding request body: "+string(jsonData)) diff --git a/relay/image_handler.go b/relay/image_handler.go index e8329426..fc8ef500 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -70,9 +70,9 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/relay/param_override_error.go b/relay/param_override_error.go new file mode 100644 index 00000000..c2338298 --- /dev/null +++ b/relay/param_override_error.go @@ -0,0 +1,13 @@ +package relay + +import ( + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/types" +) + +func newAPIErrorFromParamOverride(err error) *types.NewAPIError { + if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok { + return relaycommon.NewAPIErrorFromParamOverride(fixedErr) + } + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) +} diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 8fe2930e..40d686f7 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -61,9 +61,9 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index b3169e72..18f1b711 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -96,9 +96,9 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/service/channel_affinity.go b/service/channel_affinity.go index 524c6574..3e90b9c2 100644 --- a/service/channel_affinity.go +++ b/service/channel_affinity.go @@ -45,6 +45,7 @@ type channelAffinityMeta struct { TTLSeconds int RuleName string SkipRetry bool + ParamTemplate map[string]interface{} KeySourceType string KeySourceKey string KeySourcePath string @@ -415,6 +416,84 @@ func buildChannelAffinityKeyHint(s string) string { return s[:4] + "..." + s[len(s)-4:] } +func cloneStringAnyMap(src map[string]interface{}) map[string]interface{} { + if len(src) == 0 { + return map[string]interface{}{} + } + dst := make(map[string]interface{}, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func mergeChannelOverride(base map[string]interface{}, tpl map[string]interface{}) map[string]interface{} { + if len(base) == 0 && len(tpl) == 0 { + return map[string]interface{}{} + } + if len(tpl) == 0 { + return base + } + out := cloneStringAnyMap(base) + for k, v := range tpl { + out[k] = v + } + return out +} + +func appendChannelAffinityTemplateAdminInfo(c *gin.Context, meta channelAffinityMeta) { + if c == nil { + return + } + if len(meta.ParamTemplate) == 0 { + return + } + + templateInfo := map[string]interface{}{ + "applied": true, + "rule_name": meta.RuleName, + "param_override_keys": len(meta.ParamTemplate), + } + if anyInfo, ok := c.Get(ginKeyChannelAffinityLogInfo); ok { + if info, ok := anyInfo.(map[string]interface{}); ok { + info["override_template"] = templateInfo + c.Set(ginKeyChannelAffinityLogInfo, info) + return + } + } + c.Set(ginKeyChannelAffinityLogInfo, map[string]interface{}{ + "reason": meta.RuleName, + "rule_name": meta.RuleName, + "using_group": meta.UsingGroup, + "model": meta.ModelName, + "request_path": meta.RequestPath, + "key_source": meta.KeySourceType, + "key_key": meta.KeySourceKey, + "key_path": meta.KeySourcePath, + "key_hint": meta.KeyHint, + "key_fp": meta.KeyFingerprint, + "override_template": templateInfo, + }) +} + +// ApplyChannelAffinityOverrideTemplate merges per-rule channel override templates onto the selected channel override config. +func ApplyChannelAffinityOverrideTemplate(c *gin.Context, paramOverride map[string]interface{}) (map[string]interface{}, bool) { + if c == nil { + return paramOverride, false + } + meta, ok := getChannelAffinityMeta(c) + if !ok { + return paramOverride, false + } + if len(meta.ParamTemplate) == 0 { + return paramOverride, false + } + + mergedParam := mergeChannelOverride(paramOverride, meta.ParamTemplate) + appendChannelAffinityTemplateAdminInfo(c, meta) + return mergedParam, true +} + func GetPreferredChannelByAffinity(c *gin.Context, modelName string, usingGroup string) (int, bool) { setting := operation_setting.GetChannelAffinitySetting() if setting == nil || !setting.Enabled { @@ -466,6 +545,7 @@ func GetPreferredChannelByAffinity(c *gin.Context, modelName string, usingGroup TTLSeconds: ttlSeconds, RuleName: rule.Name, SkipRetry: rule.SkipRetryOnFailure, + ParamTemplate: cloneStringAnyMap(rule.ParamOverrideTemplate), KeySourceType: strings.TrimSpace(usedSource.Type), KeySourceKey: strings.TrimSpace(usedSource.Key), KeySourcePath: strings.TrimSpace(usedSource.Path), diff --git a/service/channel_affinity_template_test.go b/service/channel_affinity_template_test.go new file mode 100644 index 00000000..acf30154 --- /dev/null +++ b/service/channel_affinity_template_test.go @@ -0,0 +1,145 @@ +package service + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/setting/operation_setting" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func buildChannelAffinityTemplateContextForTest(meta channelAffinityMeta) *gin.Context { + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + setChannelAffinityContext(ctx, meta) + return ctx +} + +func TestApplyChannelAffinityOverrideTemplate_NoTemplate(t *testing.T) { + ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{ + RuleName: "rule-no-template", + }) + base := map[string]interface{}{ + "temperature": 0.7, + } + + merged, applied := ApplyChannelAffinityOverrideTemplate(ctx, base) + require.False(t, applied) + require.Equal(t, base, merged) +} + +func TestApplyChannelAffinityOverrideTemplate_MergeTemplate(t *testing.T) { + ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{ + RuleName: "rule-with-template", + ParamTemplate: map[string]interface{}{ + "temperature": 0.2, + "top_p": 0.95, + }, + UsingGroup: "default", + ModelName: "gpt-4.1", + RequestPath: "/v1/responses", + KeySourceType: "gjson", + KeySourcePath: "prompt_cache_key", + KeyHint: "abcd...wxyz", + KeyFingerprint: "abcd1234", + }) + base := map[string]interface{}{ + "temperature": 0.7, + "max_tokens": 2000, + } + + merged, applied := ApplyChannelAffinityOverrideTemplate(ctx, base) + require.True(t, applied) + require.Equal(t, 0.2, merged["temperature"]) + require.Equal(t, 0.95, merged["top_p"]) + require.Equal(t, 2000, merged["max_tokens"]) + require.Equal(t, 0.7, base["temperature"]) + + anyInfo, ok := ctx.Get(ginKeyChannelAffinityLogInfo) + require.True(t, ok) + info, ok := anyInfo.(map[string]interface{}) + require.True(t, ok) + overrideInfoAny, ok := info["override_template"] + require.True(t, ok) + overrideInfo, ok := overrideInfoAny.(map[string]interface{}) + require.True(t, ok) + require.Equal(t, true, overrideInfo["applied"]) + require.Equal(t, "rule-with-template", overrideInfo["rule_name"]) + require.EqualValues(t, 2, overrideInfo["param_override_keys"]) +} + +func TestChannelAffinityHitCodexTemplatePassHeadersEffective(t *testing.T) { + gin.SetMode(gin.TestMode) + + setting := operation_setting.GetChannelAffinitySetting() + require.NotNil(t, setting) + + var codexRule *operation_setting.ChannelAffinityRule + for i := range setting.Rules { + rule := &setting.Rules[i] + if strings.EqualFold(strings.TrimSpace(rule.Name), "codex cli trace") { + codexRule = rule + break + } + } + require.NotNil(t, codexRule) + + affinityValue := fmt.Sprintf("pc-hit-%d", time.Now().UnixNano()) + cacheKeySuffix := buildChannelAffinityCacheKeySuffix(*codexRule, "default", affinityValue) + + cache := getChannelAffinityCache() + require.NoError(t, cache.SetWithTTL(cacheKeySuffix, 9527, time.Minute)) + t.Cleanup(func() { + _, _ = cache.DeleteMany([]string{cacheKeySuffix}) + }) + + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(fmt.Sprintf(`{"prompt_cache_key":"%s"}`, affinityValue))) + ctx.Request.Header.Set("Content-Type", "application/json") + + channelID, found := GetPreferredChannelByAffinity(ctx, "gpt-5", "default") + require.True(t, found) + require.Equal(t, 9527, channelID) + + baseOverride := map[string]interface{}{ + "temperature": 0.2, + } + mergedOverride, applied := ApplyChannelAffinityOverrideTemplate(ctx, baseOverride) + require.True(t, applied) + require.Equal(t, 0.2, mergedOverride["temperature"]) + + info := &relaycommon.RelayInfo{ + RequestHeaders: map[string]string{ + "Originator": "Codex CLI", + "Session_id": "sess-123", + "User-Agent": "codex-cli-test", + }, + ChannelMeta: &relaycommon.ChannelMeta{ + ParamOverride: mergedOverride, + HeadersOverride: map[string]interface{}{ + "X-Static": "legacy-static", + }, + }, + } + + _, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-5"}`), info) + require.NoError(t, err) + require.True(t, info.UseRuntimeHeadersOverride) + + require.Equal(t, "legacy-static", info.RuntimeHeadersOverride["x-static"]) + require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["originator"]) + require.Equal(t, "sess-123", info.RuntimeHeadersOverride["session_id"]) + require.Equal(t, "codex-cli-test", info.RuntimeHeadersOverride["user-agent"]) + + _, exists := info.RuntimeHeadersOverride["x-codex-beta-features"] + require.False(t, exists) + _, exists = info.RuntimeHeadersOverride["x-codex-turn-metadata"] + require.False(t, exists) +} diff --git a/setting/operation_setting/channel_affinity_setting.go b/setting/operation_setting/channel_affinity_setting.go index 22f19824..7727315a 100644 --- a/setting/operation_setting/channel_affinity_setting.go +++ b/setting/operation_setting/channel_affinity_setting.go @@ -18,6 +18,8 @@ type ChannelAffinityRule struct { ValueRegex string `json:"value_regex"` TTLSeconds int `json:"ttl_seconds"` + ParamOverrideTemplate map[string]interface{} `json:"param_override_template,omitempty"` + SkipRetryOnFailure bool `json:"skip_retry_on_failure,omitempty"` IncludeUsingGroup bool `json:"include_using_group"` @@ -32,6 +34,44 @@ type ChannelAffinitySetting struct { Rules []ChannelAffinityRule `json:"rules"` } +var codexCliPassThroughHeaders = []string{ + "Originator", + "Session_id", + "User-Agent", + "X-Codex-Beta-Features", + "X-Codex-Turn-Metadata", +} + +var claudeCliPassThroughHeaders = []string{ + "X-Stainless-Arch", + "X-Stainless-Lang", + "X-Stainless-Os", + "X-Stainless-Package-Version", + "X-Stainless-Retry-Count", + "X-Stainless-Runtime", + "X-Stainless-Runtime-Version", + "X-Stainless-Timeout", + "User-Agent", + "X-App", + "Anthropic-Beta", + "Anthropic-Dangerous-Direct-Browser-Access", + "Anthropic-Version", +} + +func buildPassHeaderTemplate(headers []string) map[string]interface{} { + clonedHeaders := make([]string, 0, len(headers)) + clonedHeaders = append(clonedHeaders, headers...) + return map[string]interface{}{ + "operations": []map[string]interface{}{ + { + "mode": "pass_headers", + "value": clonedHeaders, + "keep_origin": true, + }, + }, + } +} + var channelAffinitySetting = ChannelAffinitySetting{ Enabled: true, SwitchOnSuccess: true, @@ -39,32 +79,34 @@ var channelAffinitySetting = ChannelAffinitySetting{ DefaultTTLSeconds: 3600, Rules: []ChannelAffinityRule{ { - Name: "codex trace", + Name: "codex cli trace", ModelRegex: []string{"^gpt-.*$"}, PathRegex: []string{"/v1/responses"}, KeySources: []ChannelAffinityKeySource{ {Type: "gjson", Path: "prompt_cache_key"}, }, - ValueRegex: "", - TTLSeconds: 0, - SkipRetryOnFailure: false, - IncludeUsingGroup: true, - IncludeRuleName: true, - UserAgentInclude: nil, + ValueRegex: "", + TTLSeconds: 0, + ParamOverrideTemplate: buildPassHeaderTemplate(codexCliPassThroughHeaders), + SkipRetryOnFailure: false, + IncludeUsingGroup: true, + IncludeRuleName: true, + UserAgentInclude: nil, }, { - Name: "claude code trace", + Name: "claude cli trace", ModelRegex: []string{"^claude-.*$"}, PathRegex: []string{"/v1/messages"}, KeySources: []ChannelAffinityKeySource{ {Type: "gjson", Path: "metadata.user_id"}, }, - ValueRegex: "", - TTLSeconds: 0, - SkipRetryOnFailure: false, - IncludeUsingGroup: true, - IncludeRuleName: true, - UserAgentInclude: nil, + ValueRegex: "", + TTLSeconds: 0, + ParamOverrideTemplate: buildPassHeaderTemplate(claudeCliPassThroughHeaders), + SkipRetryOnFailure: false, + IncludeUsingGroup: true, + IncludeRuleName: true, + UserAgentInclude: nil, }, }, } diff --git a/web/package.json b/web/package.json index 97c7c821..4d8c7e7f 100644 --- a/web/package.json +++ b/web/package.json @@ -10,7 +10,7 @@ "@visactor/react-vchart": "~1.8.8", "@visactor/vchart": "~1.8.8", "@visactor/vchart-semi-theme": "~1.8.8", - "axios": "1.13.5", + "axios": "1.12.0", "clsx": "^2.1.1", "dayjs": "^1.11.11", "history": "^5.3.0", diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index 3d3afcc3..3a91207d 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -59,6 +59,7 @@ import ModelSelectModal from './ModelSelectModal'; import SingleModelSelectModal from './SingleModelSelectModal'; import OllamaModelModal from './OllamaModelModal'; import CodexOAuthModal from './CodexOAuthModal'; +import ParamOverrideEditorModal from './ParamOverrideEditorModal'; import JSONEditor from '../../../common/ui/JSONEditor'; import SecureVerificationModal from '../../../common/modals/SecureVerificationModal'; import StatusCodeRiskGuardModal from './StatusCodeRiskGuardModal'; @@ -75,6 +76,7 @@ import { IconServer, IconSetting, IconCode, + IconCopy, IconGlobe, IconBolt, IconSearch, @@ -99,6 +101,28 @@ const REGION_EXAMPLE = { 'claude-3-5-sonnet-20240620': 'europe-west1', }; +const PARAM_OVERRIDE_LEGACY_TEMPLATE = { + temperature: 0, +}; + +const PARAM_OVERRIDE_OPERATIONS_TEMPLATE = { + operations: [ + { + path: 'temperature', + mode: 'set', + value: 0.7, + conditions: [ + { + path: 'model', + mode: 'prefix', + value: 'openai/', + }, + ], + logic: 'AND', + }, + ], +}; + // 支持并且已适配通过接口获取模型列表的渠道类型 const MODEL_FETCHABLE_TYPES = new Set([ 1, 4, 14, 34, 17, 26, 27, 24, 47, 25, 20, 23, 31, 40, 42, 48, 43, @@ -148,6 +172,7 @@ const EditChannelModal = (props) => { base_url: '', other: '', model_mapping: '', + param_override: '', status_code_mapping: '', models: [], auto_ban: 1, @@ -251,11 +276,69 @@ const EditChannelModal = (props) => { name: keyword, }); }, [modelSearchMatchedCount, modelSearchValue, t]); + const paramOverrideMeta = useMemo(() => { + const raw = + typeof inputs.param_override === 'string' + ? inputs.param_override.trim() + : ''; + if (!raw) { + return { + tagLabel: t('不更改'), + tagColor: 'grey', + preview: t( + '此项可选,用于覆盖请求参数。不支持覆盖 stream 参数', + ), + }; + } + if (!verifyJSON(raw)) { + return { + tagLabel: t('JSON格式错误'), + tagColor: 'red', + preview: raw, + }; + } + try { + const parsed = JSON.parse(raw); + const pretty = JSON.stringify(parsed, null, 2); + if ( + parsed && + typeof parsed === 'object' && + !Array.isArray(parsed) && + Array.isArray(parsed.operations) + ) { + return { + tagLabel: `${t('新格式模板')} (${parsed.operations.length})`, + tagColor: 'cyan', + preview: pretty, + }; + } + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + return { + tagLabel: `${t('旧格式模板')} (${Object.keys(parsed).length})`, + tagColor: 'blue', + preview: pretty, + }; + } + return { + tagLabel: t('自定义 JSON'), + tagColor: 'orange', + preview: pretty, + }; + } catch (error) { + return { + tagLabel: t('JSON格式错误'), + tagColor: 'red', + preview: raw, + }; + } + }, [inputs.param_override, t]); const [isIonetChannel, setIsIonetChannel] = useState(false); const [ionetMetadata, setIonetMetadata] = useState(null); const [codexOAuthModalVisible, setCodexOAuthModalVisible] = useState(false); const [codexCredentialRefreshing, setCodexCredentialRefreshing] = useState(false); + const [paramOverrideEditorVisible, setParamOverrideEditorVisible] = + useState(false); // 密钥显示状态 const [keyDisplayState, setKeyDisplayState] = useState({ @@ -582,6 +665,100 @@ const EditChannelModal = (props) => { } }; + const copyParamOverrideJson = async () => { + const raw = + typeof inputs.param_override === 'string' + ? inputs.param_override.trim() + : ''; + if (!raw) { + showInfo(t('暂无可复制 JSON')); + return; + } + + let content = raw; + if (verifyJSON(raw)) { + try { + content = JSON.stringify(JSON.parse(raw), null, 2); + } catch (error) { + content = raw; + } + } + + const ok = await copy(content); + if (ok) { + showSuccess(t('参数覆盖 JSON 已复制')); + } else { + showError(t('复制失败')); + } + }; + + const parseParamOverrideInput = () => { + const raw = + typeof inputs.param_override === 'string' + ? inputs.param_override.trim() + : ''; + if (!raw) return null; + if (!verifyJSON(raw)) { + throw new Error(t('当前参数覆盖不是合法的 JSON')); + } + return JSON.parse(raw); + }; + + const applyParamOverrideTemplate = ( + templateType = 'operations', + applyMode = 'fill', + ) => { + try { + const parsedCurrent = parseParamOverrideInput(); + if (templateType === 'legacy') { + if (applyMode === 'fill') { + handleInputChange( + 'param_override', + JSON.stringify(PARAM_OVERRIDE_LEGACY_TEMPLATE, null, 2), + ); + return; + } + const currentLegacy = + parsedCurrent && + typeof parsedCurrent === 'object' && + !Array.isArray(parsedCurrent) && + !Array.isArray(parsedCurrent.operations) + ? parsedCurrent + : {}; + const merged = { + ...PARAM_OVERRIDE_LEGACY_TEMPLATE, + ...currentLegacy, + }; + handleInputChange('param_override', JSON.stringify(merged, null, 2)); + return; + } + + if (applyMode === 'fill') { + handleInputChange( + 'param_override', + JSON.stringify(PARAM_OVERRIDE_OPERATIONS_TEMPLATE, null, 2), + ); + return; + } + const currentOperations = + parsedCurrent && + typeof parsedCurrent === 'object' && + !Array.isArray(parsedCurrent) && + Array.isArray(parsedCurrent.operations) + ? parsedCurrent.operations + : []; + const merged = { + operations: [ + ...currentOperations, + ...PARAM_OVERRIDE_OPERATIONS_TEMPLATE.operations, + ], + }; + handleInputChange('param_override', JSON.stringify(merged, null, 2)); + } catch (error) { + showError(error.message || t('模板应用失败')); + } + }; + const loadChannel = async () => { setLoading(true); let res = await API.get(`/api/channel/${channelId}`); @@ -1242,6 +1419,7 @@ const EditChannelModal = (props) => { const submit = async () => { const formValues = formApiRef.current ? formApiRef.current.getValues() : {}; let localInputs = { ...formValues }; + localInputs.param_override = inputs.param_override; if (localInputs.type === 57) { if (batch) { @@ -3150,78 +3328,73 @@ const EditChannelModal = (props) => { initValue={autoBan} /> -
+ {paramOverrideMeta.preview}
+
+
+ {stringifyPretty(raw)}
+
+
+ {paramTemplatePreviewMeta.preview}
+
+