diff --git a/controller/channel-test.go b/controller/channel-test.go index ab12132b..3947c8d5 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -366,7 +366,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string, newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed), } } - jsonData, err := json.Marshal(convertedRequest) + jsonData, err := common.Marshal(convertedRequest) if err != nil { return testResult{ context: c, @@ -385,8 +385,15 @@ func testChannel(channel *model.Channel, testModel string, endpointType string, //} if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { + if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok { + return testResult{ + context: c, + localErr: fixedErr, + newAPIError: relaycommon.NewAPIErrorFromParamOverride(fixedErr), + } + } return testResult{ context: c, localErr: err, diff --git a/controller/relay.go b/controller/relay.go index edea1586..1788b25b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -182,8 +182,11 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { ModelName: relayInfo.OriginModelName, Retry: common.GetPointer(0), } + relayInfo.RetryIndex = 0 + relayInfo.LastError = nil for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() { + relayInfo.RetryIndex = retryParam.GetRetry() channel, channelErr := getChannel(c, relayInfo, retryParam) if channelErr != nil { logger.LogError(c, channelErr.Error()) @@ -216,10 +219,12 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { } if newAPIError == nil { + relayInfo.LastError = nil return } newAPIError = service.NormalizeViolationFeeError(newAPIError) + relayInfo.LastError = newAPIError processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) diff --git a/middleware/distributor.go b/middleware/distributor.go index 9e66cb8f..db57998c 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -348,8 +348,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime) common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting()) common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings()) - common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride()) - common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, channel.GetHeaderOverride()) + paramOverride := channel.GetParamOverride() + headerOverride := channel.GetHeaderOverride() + if mergedParam, applied := service.ApplyChannelAffinityOverrideTemplate(c, paramOverride); applied { + paramOverride = mergedParam + } + common.SetContextKey(c, constant.ContextKeyChannelParamOverride, paramOverride) + common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, headerOverride) if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" { common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization) } diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 09ca855d..79eac3ad 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -169,12 +169,17 @@ func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey str // Passthrough rules are applied first, then normal overrides are applied, so explicit overrides win. func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) { headerOverride := make(map[string]string) + if info == nil { + return headerOverride, nil + } + + headerOverrideSource := common.GetEffectiveHeaderOverride(info) passAll := false var passthroughRegex []*regexp.Regexp if !info.IsChannelTest { - for k := range info.HeadersOverride { - key := strings.TrimSpace(k) + for k := range headerOverrideSource { + key := strings.TrimSpace(strings.ToLower(k)) if key == "" { continue } @@ -183,12 +188,11 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s continue } - lower := strings.ToLower(key) var pattern string switch { - case strings.HasPrefix(lower, headerPassthroughRegexPrefix): + case strings.HasPrefix(key, headerPassthroughRegexPrefix): pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):]) - case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2): + case strings.HasPrefix(key, headerPassthroughRegexPrefixV2): pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):]) default: continue @@ -229,15 +233,15 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s if value == "" { continue } - headerOverride[name] = value + headerOverride[strings.ToLower(strings.TrimSpace(name))] = value } } - for k, v := range info.HeadersOverride { + for k, v := range headerOverrideSource { if isHeaderPassthroughRuleKey(k) { continue } - key := strings.TrimSpace(k) + key := strings.TrimSpace(strings.ToLower(k)) if key == "" { continue } diff --git a/relay/channel/api_request_test.go b/relay/channel/api_request_test.go index 6c7834ef..f697f855 100644 --- a/relay/channel/api_request_test.go +++ b/relay/channel/api_request_test.go @@ -53,7 +53,7 @@ func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testin headers, err := processHeaderOverride(info, ctx) require.NoError(t, err) - _, ok := headers["X-Upstream-Trace"] + _, ok := headers["x-upstream-trace"] require.False(t, ok) } @@ -77,7 +77,38 @@ func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) headers, err := processHeaderOverride(info, ctx) require.NoError(t, err) - require.Equal(t, "trace-123", headers["X-Upstream-Trace"]) + require.Equal(t, "trace-123", headers["x-upstream-trace"]) +} + +func TestProcessHeaderOverride_RuntimeOverrideIsFinalHeaderMap(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + info := &relaycommon.RelayInfo{ + IsChannelTest: false, + UseRuntimeHeadersOverride: true, + RuntimeHeadersOverride: map[string]any{ + "x-static": "runtime-value", + "x-runtime": "runtime-only", + }, + ChannelMeta: &relaycommon.ChannelMeta{ + HeadersOverride: map[string]any{ + "X-Static": "legacy-value", + "X-Legacy": "legacy-only", + }, + }, + } + + headers, err := processHeaderOverride(info, ctx) + require.NoError(t, err) + require.Equal(t, "runtime-value", headers["x-static"]) + require.Equal(t, "runtime-only", headers["x-runtime"]) + _, exists := headers["x-legacy"] + require.False(t, exists) } func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) { @@ -101,8 +132,62 @@ func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) { headers, err := processHeaderOverride(info, ctx) require.NoError(t, err) - require.Equal(t, "trace-123", headers["X-Trace-Id"]) + require.Equal(t, "trace-123", headers["x-trace-id"]) - _, hasAcceptEncoding := headers["Accept-Encoding"] + _, hasAcceptEncoding := headers["accept-encoding"] require.False(t, hasAcceptEncoding) } + +func TestProcessHeaderOverride_PassHeadersTemplateSetsRuntimeHeaders(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + ctx.Request.Header.Set("Originator", "Codex CLI") + ctx.Request.Header.Set("Session_id", "sess-123") + + info := &relaycommon.RelayInfo{ + IsChannelTest: false, + RequestHeaders: map[string]string{ + "Originator": "Codex CLI", + "Session_id": "sess-123", + }, + ChannelMeta: &relaycommon.ChannelMeta{ + ParamOverride: map[string]any{ + "operations": []any{ + map[string]any{ + "mode": "pass_headers", + "value": []any{"Originator", "Session_id", "X-Codex-Beta-Features"}, + }, + }, + }, + HeadersOverride: map[string]any{ + "X-Static": "legacy-value", + }, + }, + } + + _, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-4.1"}`), info) + require.NoError(t, err) + require.True(t, info.UseRuntimeHeadersOverride) + require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["originator"]) + require.Equal(t, "sess-123", info.RuntimeHeadersOverride["session_id"]) + _, exists := info.RuntimeHeadersOverride["x-codex-beta-features"] + require.False(t, exists) + require.Equal(t, "legacy-value", info.RuntimeHeadersOverride["x-static"]) + + headers, err := processHeaderOverride(info, ctx) + require.NoError(t, err) + require.Equal(t, "Codex CLI", headers["originator"]) + require.Equal(t, "sess-123", headers["session_id"]) + _, exists = headers["x-codex-beta-features"] + require.False(t, exists) + + upstreamReq := httptest.NewRequest(http.MethodPost, "https://example.com/v1/responses", nil) + applyHeaderOverrideToRequest(upstreamReq, headers) + require.Equal(t, "Codex CLI", upstreamReq.Header.Get("Originator")) + require.Equal(t, "sess-123", upstreamReq.Header.Get("Session_id")) + require.Empty(t, upstreamReq.Header.Get("X-Codex-Beta-Features")) +} diff --git a/relay/chat_completions_via_responses.go b/relay/chat_completions_via_responses.go index 6412b7d2..8f69b937 100644 --- a/relay/chat_completions_via_responses.go +++ b/relay/chat_completions_via_responses.go @@ -70,7 +70,6 @@ func applySystemPromptIfNeeded(c *gin.Context, info *relaycommon.RelayInfo, requ } func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, adaptor channel.Adaptor, request *dto.GeneralOpenAIRequest) (*dto.Usage, *types.NewAPIError) { - overrideCtx := relaycommon.BuildParamOverrideContext(info) chatJSON, err := common.Marshal(request) if err != nil { return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) @@ -82,9 +81,9 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad } if len(info.ParamOverride) > 0 { - chatJSON, err = relaycommon.ApplyParamOverride(chatJSON, info.ParamOverride, overrideCtx) + chatJSON, err = relaycommon.ApplyParamOverrideWithRelayInfo(chatJSON, info) if err != nil { - return nil, types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return nil, newAPIErrorFromParamOverride(err) } } diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 9b08781c..1722cd9b 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -153,9 +153,9 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/relay/common/override.go b/relay/common/override.go index 1a0c2478..59e15176 100644 --- a/relay/common/override.go +++ b/relay/common/override.go @@ -1,18 +1,29 @@ package common import ( + "errors" "fmt" + "net/http" "regexp" "strconv" "strings" "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/types" + "github.com/samber/lo" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`) +const ( + paramOverrideContextRequestHeaders = "request_headers" + paramOverrideContextHeaderOverride = "header_override" +) + +var errSourceHeaderNotFound = errors.New("source header does not exist") + type ConditionOperation struct { Path string `json:"path"` // JSON路径 Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte @@ -23,7 +34,7 @@ type ConditionOperation struct { type ParamOperation struct { Path string `json:"path"` - Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace + Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace, return_error, prune_objects, set_header, delete_header, copy_header, move_header, pass_headers, sync_fields Value interface{} `json:"value"` KeepOrigin bool `json:"keep_origin"` From string `json:"from,omitempty"` @@ -32,6 +43,76 @@ type ParamOperation struct { Logic string `json:"logic,omitempty"` // AND, OR (默认OR) } +type ParamOverrideReturnError struct { + Message string + StatusCode int + Code string + Type string + SkipRetry bool +} + +func (e *ParamOverrideReturnError) Error() string { + if e == nil { + return "param override return error" + } + if e.Message == "" { + return "param override return error" + } + return e.Message +} + +func AsParamOverrideReturnError(err error) (*ParamOverrideReturnError, bool) { + if err == nil { + return nil, false + } + var target *ParamOverrideReturnError + if errors.As(err, &target) { + return target, true + } + return nil, false +} + +func NewAPIErrorFromParamOverride(err *ParamOverrideReturnError) *types.NewAPIError { + if err == nil { + return types.NewError( + errors.New("param override return error is nil"), + types.ErrorCodeChannelParamOverrideInvalid, + types.ErrOptionWithSkipRetry(), + ) + } + + statusCode := err.StatusCode + if statusCode < http.StatusContinue || statusCode > http.StatusNetworkAuthenticationRequired { + statusCode = http.StatusBadRequest + } + + errorCode := err.Code + if strings.TrimSpace(errorCode) == "" { + errorCode = string(types.ErrorCodeInvalidRequest) + } + + errorType := err.Type + if strings.TrimSpace(errorType) == "" { + errorType = "invalid_request_error" + } + + message := strings.TrimSpace(err.Message) + if message == "" { + message = "request blocked by param override" + } + + opts := make([]types.NewAPIErrorOptions, 0, 1) + if err.SkipRetry { + opts = append(opts, types.ErrOptionWithSkipRetry()) + } + + return types.WithOpenAIError(types.OpenAIError{ + Message: message, + Type: errorType, + Code: errorCode, + }, statusCode, opts...) +} + func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) { if len(paramOverride) == 0 { return jsonData, nil @@ -48,81 +129,147 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c return applyOperationsLegacy(jsonData, paramOverride) } -func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) { - // 检查是否包含 "operations" 字段 - if opsValue, exists := paramOverride["operations"]; exists { - if opsSlice, ok := opsValue.([]interface{}); ok { - var operations []ParamOperation - for _, op := range opsSlice { - if opMap, ok := op.(map[string]interface{}); ok { - operation := ParamOperation{} - - // 断言必要字段 - if path, ok := opMap["path"].(string); ok { - operation.Path = path - } - if mode, ok := opMap["mode"].(string); ok { - operation.Mode = mode - } else { - return nil, false // mode 是必需的 - } - - // 可选字段 - if value, exists := opMap["value"]; exists { - operation.Value = value - } - if keepOrigin, ok := opMap["keep_origin"].(bool); ok { - operation.KeepOrigin = keepOrigin - } - if from, ok := opMap["from"].(string); ok { - operation.From = from - } - if to, ok := opMap["to"].(string); ok { - operation.To = to - } - if logic, ok := opMap["logic"].(string); ok { - operation.Logic = logic - } else { - operation.Logic = "OR" // 默认为OR - } - - // 解析条件 - if conditions, exists := opMap["conditions"]; exists { - if condSlice, ok := conditions.([]interface{}); ok { - for _, cond := range condSlice { - if condMap, ok := cond.(map[string]interface{}); ok { - condition := ConditionOperation{} - if path, ok := condMap["path"].(string); ok { - condition.Path = path - } - if mode, ok := condMap["mode"].(string); ok { - condition.Mode = mode - } - if value, ok := condMap["value"]; ok { - condition.Value = value - } - if invert, ok := condMap["invert"].(bool); ok { - condition.Invert = invert - } - if passMissingKey, ok := condMap["pass_missing_key"].(bool); ok { - condition.PassMissingKey = passMissingKey - } - operation.Conditions = append(operation.Conditions, condition) - } - } - } - } - - operations = append(operations, operation) - } else { - return nil, false - } - } - return operations, true - } +func ApplyParamOverrideWithRelayInfo(jsonData []byte, info *RelayInfo) ([]byte, error) { + paramOverride := getParamOverrideMap(info) + if len(paramOverride) == 0 { + return jsonData, nil } - return nil, false + overrideCtx := BuildParamOverrideContext(info) + result, err := ApplyParamOverride(jsonData, paramOverride, overrideCtx) + if err != nil { + return nil, err + } + syncRuntimeHeaderOverrideFromContext(info, overrideCtx) + return result, nil +} + +func getParamOverrideMap(info *RelayInfo) map[string]interface{} { + if info == nil || info.ChannelMeta == nil { + return nil + } + return info.ChannelMeta.ParamOverride +} + +func getHeaderOverrideMap(info *RelayInfo) map[string]interface{} { + if info == nil || info.ChannelMeta == nil { + return nil + } + return info.ChannelMeta.HeadersOverride +} + +func sanitizeHeaderOverrideMap(source map[string]interface{}) map[string]interface{} { + if len(source) == 0 { + return map[string]interface{}{} + } + target := make(map[string]interface{}, len(source)) + for key, value := range source { + normalizedKey := normalizeHeaderContextKey(key) + if normalizedKey == "" { + continue + } + normalizedValue := strings.TrimSpace(fmt.Sprintf("%v", value)) + if normalizedValue == "" { + if isHeaderPassthroughRuleKeyForOverride(normalizedKey) { + target[normalizedKey] = "" + } + continue + } + target[normalizedKey] = normalizedValue + } + return target +} + +func isHeaderPassthroughRuleKeyForOverride(key string) bool { + key = strings.TrimSpace(strings.ToLower(key)) + if key == "" { + return false + } + if key == "*" { + return true + } + return strings.HasPrefix(key, "re:") || strings.HasPrefix(key, "regex:") +} + +func GetEffectiveHeaderOverride(info *RelayInfo) map[string]interface{} { + if info == nil { + return map[string]interface{}{} + } + if info.UseRuntimeHeadersOverride { + return sanitizeHeaderOverrideMap(info.RuntimeHeadersOverride) + } + return sanitizeHeaderOverrideMap(getHeaderOverrideMap(info)) +} + +func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) { + // 检查是否包含 "operations" 字段 + opsValue, exists := paramOverride["operations"] + if !exists { + return nil, false + } + + var opMaps []map[string]interface{} + switch ops := opsValue.(type) { + case []interface{}: + opMaps = make([]map[string]interface{}, 0, len(ops)) + for _, op := range ops { + opMap, ok := op.(map[string]interface{}) + if !ok { + return nil, false + } + opMaps = append(opMaps, opMap) + } + case []map[string]interface{}: + opMaps = ops + default: + return nil, false + } + + operations := make([]ParamOperation, 0, len(opMaps)) + for _, opMap := range opMaps { + operation := ParamOperation{} + + // 断言必要字段 + if path, ok := opMap["path"].(string); ok { + operation.Path = path + } + if mode, ok := opMap["mode"].(string); ok { + operation.Mode = mode + } else { + return nil, false // mode 是必需的 + } + + // 可选字段 + if value, exists := opMap["value"]; exists { + operation.Value = value + } + if keepOrigin, ok := opMap["keep_origin"].(bool); ok { + operation.KeepOrigin = keepOrigin + } + if from, ok := opMap["from"].(string); ok { + operation.From = from + } + if to, ok := opMap["to"].(string); ok { + operation.To = to + } + if logic, ok := opMap["logic"].(string); ok { + operation.Logic = logic + } else { + operation.Logic = "OR" // 默认为OR + } + + // 解析条件 + if conditions, exists := opMap["conditions"]; exists { + parsedConditions, err := parseConditionOperations(conditions) + if err != nil { + return nil, false + } + operation.Conditions = append(operation.Conditions, parsedConditions...) + } + + operations = append(operations, operation) + } + return operations, true } func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) { @@ -139,20 +286,9 @@ func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperatio } if strings.ToUpper(logic) == "AND" { - for _, result := range results { - if !result { - return false, nil - } - } - return true, nil - } else { - for _, result := range results { - if result { - return true, nil - } - } - return false, nil + return lo.EveryBy(results, func(item bool) bool { return item }), nil } + return lo.SomeBy(results, func(item bool) bool { return item }), nil } func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) { @@ -309,13 +445,10 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{} } func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) { - var contextJSON string - if conditionContext != nil && len(conditionContext) > 0 { - ctxBytes, err := common.Marshal(conditionContext) - if err != nil { - return "", fmt.Errorf("failed to marshal condition context: %v", err) - } - contextJSON = string(ctxBytes) + context := ensureContextMap(conditionContext) + contextJSON, err := marshalContextJSON(context) + if err != nil { + return "", fmt.Errorf("failed to marshal condition context: %v", err) } result := jsonStr @@ -372,16 +505,631 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte result, err = replaceStringValue(result, opPath, op.From, op.To) case "regex_replace": result, err = regexReplaceStringValue(result, opPath, op.From, op.To) + case "return_error": + returnErr, parseErr := parseParamOverrideReturnError(op.Value) + if parseErr != nil { + return "", parseErr + } + return "", returnErr + case "prune_objects": + result, err = pruneObjects(result, opPath, contextJSON, op.Value) + case "set_header": + err = setHeaderOverrideInContext(context, op.Path, op.Value, op.KeepOrigin) + if err == nil { + contextJSON, err = marshalContextJSON(context) + } + case "delete_header": + err = deleteHeaderOverrideInContext(context, op.Path) + if err == nil { + contextJSON, err = marshalContextJSON(context) + } + case "copy_header": + sourceHeader := strings.TrimSpace(op.From) + targetHeader := strings.TrimSpace(op.To) + if sourceHeader == "" { + sourceHeader = strings.TrimSpace(op.Path) + } + if targetHeader == "" { + targetHeader = strings.TrimSpace(op.Path) + } + err = copyHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin) + if errors.Is(err, errSourceHeaderNotFound) { + err = nil + } + if err == nil { + contextJSON, err = marshalContextJSON(context) + } + case "move_header": + sourceHeader := strings.TrimSpace(op.From) + targetHeader := strings.TrimSpace(op.To) + if sourceHeader == "" { + sourceHeader = strings.TrimSpace(op.Path) + } + if targetHeader == "" { + targetHeader = strings.TrimSpace(op.Path) + } + err = moveHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin) + if errors.Is(err, errSourceHeaderNotFound) { + err = nil + } + if err == nil { + contextJSON, err = marshalContextJSON(context) + } + case "pass_headers": + headerNames, parseErr := parseHeaderPassThroughNames(op.Value) + if parseErr != nil { + return "", parseErr + } + for _, headerName := range headerNames { + if err = copyHeaderInContext(context, headerName, headerName, op.KeepOrigin); err != nil { + if errors.Is(err, errSourceHeaderNotFound) { + err = nil + continue + } + break + } + } + if err == nil { + contextJSON, err = marshalContextJSON(context) + } + case "sync_fields": + result, err = syncFieldsBetweenTargets(result, context, op.From, op.To) + if err == nil { + contextJSON, err = marshalContextJSON(context) + } default: return "", fmt.Errorf("unknown operation: %s", op.Mode) } if err != nil { - return "", fmt.Errorf("operation %s failed: %v", op.Mode, err) + return "", fmt.Errorf("operation %s failed: %w", op.Mode, err) } } return result, nil } +func parseParamOverrideReturnError(value interface{}) (*ParamOverrideReturnError, error) { + result := &ParamOverrideReturnError{ + StatusCode: http.StatusBadRequest, + Code: string(types.ErrorCodeInvalidRequest), + Type: "invalid_request_error", + SkipRetry: true, + } + + switch raw := value.(type) { + case nil: + return nil, fmt.Errorf("return_error value is required") + case string: + result.Message = strings.TrimSpace(raw) + case map[string]interface{}: + if message, ok := raw["message"].(string); ok { + result.Message = strings.TrimSpace(message) + } + if result.Message == "" { + if message, ok := raw["msg"].(string); ok { + result.Message = strings.TrimSpace(message) + } + } + + if code, exists := raw["code"]; exists { + codeStr := strings.TrimSpace(fmt.Sprintf("%v", code)) + if codeStr != "" { + result.Code = codeStr + } + } + if errType, ok := raw["type"].(string); ok { + errType = strings.TrimSpace(errType) + if errType != "" { + result.Type = errType + } + } + if skipRetry, ok := raw["skip_retry"].(bool); ok { + result.SkipRetry = skipRetry + } + + if statusCodeRaw, exists := raw["status_code"]; exists { + statusCode, ok := parseOverrideInt(statusCodeRaw) + if !ok { + return nil, fmt.Errorf("return_error status_code must be an integer") + } + result.StatusCode = statusCode + } else if statusRaw, exists := raw["status"]; exists { + statusCode, ok := parseOverrideInt(statusRaw) + if !ok { + return nil, fmt.Errorf("return_error status must be an integer") + } + result.StatusCode = statusCode + } + default: + return nil, fmt.Errorf("return_error value must be string or object") + } + + if result.Message == "" { + return nil, fmt.Errorf("return_error message is required") + } + if result.StatusCode < http.StatusContinue || result.StatusCode > http.StatusNetworkAuthenticationRequired { + return nil, fmt.Errorf("return_error status code out of range: %d", result.StatusCode) + } + + return result, nil +} + +func parseOverrideInt(v interface{}) (int, bool) { + switch value := v.(type) { + case int: + return value, true + case float64: + if value != float64(int(value)) { + return 0, false + } + return int(value), true + default: + return 0, false + } +} + +func ensureContextMap(conditionContext map[string]interface{}) map[string]interface{} { + if conditionContext != nil { + return conditionContext + } + return make(map[string]interface{}) +} + +func marshalContextJSON(context map[string]interface{}) (string, error) { + if context == nil || len(context) == 0 { + return "", nil + } + ctxBytes, err := common.Marshal(context) + if err != nil { + return "", err + } + return string(ctxBytes), nil +} + +func setHeaderOverrideInContext(context map[string]interface{}, headerName string, value interface{}, keepOrigin bool) error { + headerName = normalizeHeaderContextKey(headerName) + if headerName == "" { + return fmt.Errorf("header name is required") + } + + rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride) + if keepOrigin { + if existing, ok := rawHeaders[headerName]; ok { + existingValue := strings.TrimSpace(fmt.Sprintf("%v", existing)) + if existingValue != "" { + return nil + } + } + } + + headerValue, hasValue, err := resolveHeaderOverrideValue(context, headerName, value) + if err != nil { + return err + } + if !hasValue { + delete(rawHeaders, headerName) + return nil + } + + rawHeaders[headerName] = headerValue + return nil +} + +func resolveHeaderOverrideValue(context map[string]interface{}, headerName string, value interface{}) (string, bool, error) { + if value == nil { + return "", false, fmt.Errorf("header value is required") + } + + if mapping, ok := value.(map[string]interface{}); ok { + return resolveHeaderOverrideValueByMapping(context, headerName, mapping) + } + if mapping, ok := value.(map[string]string); ok { + converted := make(map[string]interface{}, len(mapping)) + for key, item := range mapping { + converted[key] = item + } + return resolveHeaderOverrideValueByMapping(context, headerName, converted) + } + + headerValue := strings.TrimSpace(fmt.Sprintf("%v", value)) + if headerValue == "" { + return "", false, nil + } + return headerValue, true, nil +} + +func resolveHeaderOverrideValueByMapping(context map[string]interface{}, headerName string, mapping map[string]interface{}) (string, bool, error) { + if len(mapping) == 0 { + return "", false, fmt.Errorf("header value mapping cannot be empty") + } + + sourceValue, exists := getHeaderValueFromContext(context, headerName) + if !exists { + return "", false, nil + } + sourceTokens := splitHeaderListValue(sourceValue) + if len(sourceTokens) == 0 { + return "", false, nil + } + + wildcardValue, hasWildcard := mapping["*"] + resultTokens := make([]string, 0, len(sourceTokens)) + for _, token := range sourceTokens { + replacementRaw, hasReplacement := mapping[token] + if !hasReplacement && hasWildcard { + replacementRaw = wildcardValue + hasReplacement = true + } + if !hasReplacement { + resultTokens = append(resultTokens, token) + continue + } + replacementTokens, err := parseHeaderReplacementTokens(replacementRaw) + if err != nil { + return "", false, err + } + resultTokens = append(resultTokens, replacementTokens...) + } + + resultTokens = lo.Uniq(resultTokens) + if len(resultTokens) == 0 { + return "", false, nil + } + return strings.Join(resultTokens, ","), true, nil +} + +func parseHeaderReplacementTokens(value interface{}) ([]string, error) { + switch raw := value.(type) { + case nil: + return nil, nil + case string: + return splitHeaderListValue(raw), nil + case []string: + tokens := make([]string, 0, len(raw)) + for _, item := range raw { + tokens = append(tokens, splitHeaderListValue(item)...) + } + return lo.Uniq(tokens), nil + case []interface{}: + tokens := make([]string, 0, len(raw)) + for _, item := range raw { + itemTokens, err := parseHeaderReplacementTokens(item) + if err != nil { + return nil, err + } + tokens = append(tokens, itemTokens...) + } + return lo.Uniq(tokens), nil + case map[string]interface{}, map[string]string: + return nil, fmt.Errorf("header replacement value must be string, array or null") + default: + token := strings.TrimSpace(fmt.Sprintf("%v", raw)) + if token == "" { + return nil, nil + } + return []string{token}, nil + } +} + +func splitHeaderListValue(raw string) []string { + items := strings.Split(raw, ",") + return lo.FilterMap(items, func(item string, _ int) (string, bool) { + token := strings.TrimSpace(item) + if token == "" { + return "", false + } + return token, true + }) +} + +func copyHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error { + fromHeader = normalizeHeaderContextKey(fromHeader) + toHeader = normalizeHeaderContextKey(toHeader) + if fromHeader == "" || toHeader == "" { + return fmt.Errorf("copy_header from/to is required") + } + value, exists := getHeaderValueFromContext(context, fromHeader) + if !exists { + return fmt.Errorf("%w: %s", errSourceHeaderNotFound, fromHeader) + } + return setHeaderOverrideInContext(context, toHeader, value, keepOrigin) +} + +func moveHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error { + fromHeader = normalizeHeaderContextKey(fromHeader) + toHeader = normalizeHeaderContextKey(toHeader) + if fromHeader == "" || toHeader == "" { + return fmt.Errorf("move_header from/to is required") + } + if err := copyHeaderInContext(context, fromHeader, toHeader, keepOrigin); err != nil { + return err + } + if strings.EqualFold(fromHeader, toHeader) { + return nil + } + return deleteHeaderOverrideInContext(context, fromHeader) +} + +func deleteHeaderOverrideInContext(context map[string]interface{}, headerName string) error { + headerName = normalizeHeaderContextKey(headerName) + if headerName == "" { + return fmt.Errorf("header name is required") + } + rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride) + delete(rawHeaders, headerName) + return nil +} + +func parseHeaderPassThroughNames(value interface{}) ([]string, error) { + normalizeNames := func(values []string) []string { + names := lo.FilterMap(values, func(item string, _ int) (string, bool) { + headerName := normalizeHeaderContextKey(item) + if headerName == "" { + return "", false + } + return headerName, true + }) + return lo.Uniq(names) + } + + switch raw := value.(type) { + case nil: + return nil, fmt.Errorf("pass_headers value is required") + case string: + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return nil, fmt.Errorf("pass_headers value is required") + } + if strings.HasPrefix(trimmed, "[") || strings.HasPrefix(trimmed, "{") { + var parsed interface{} + if err := common.UnmarshalJsonStr(trimmed, &parsed); err == nil { + return parseHeaderPassThroughNames(parsed) + } + } + names := normalizeNames(strings.Split(trimmed, ",")) + if len(names) == 0 { + return nil, fmt.Errorf("pass_headers value is invalid") + } + return names, nil + case []interface{}: + names := lo.FilterMap(raw, func(item interface{}, _ int) (string, bool) { + headerName := normalizeHeaderContextKey(fmt.Sprintf("%v", item)) + if headerName == "" { + return "", false + } + return headerName, true + }) + names = lo.Uniq(names) + if len(names) == 0 { + return nil, fmt.Errorf("pass_headers value is invalid") + } + return names, nil + case []string: + names := lo.FilterMap(raw, func(item string, _ int) (string, bool) { + headerName := normalizeHeaderContextKey(item) + if headerName == "" { + return "", false + } + return headerName, true + }) + names = lo.Uniq(names) + if len(names) == 0 { + return nil, fmt.Errorf("pass_headers value is invalid") + } + return names, nil + case map[string]interface{}: + candidates := make([]string, 0, 8) + if headersRaw, ok := raw["headers"]; ok { + names, err := parseHeaderPassThroughNames(headersRaw) + if err == nil { + candidates = append(candidates, names...) + } + } + if namesRaw, ok := raw["names"]; ok { + names, err := parseHeaderPassThroughNames(namesRaw) + if err == nil { + candidates = append(candidates, names...) + } + } + if headerRaw, ok := raw["header"]; ok { + names, err := parseHeaderPassThroughNames(headerRaw) + if err == nil { + candidates = append(candidates, names...) + } + } + names := normalizeNames(candidates) + if len(names) == 0 { + return nil, fmt.Errorf("pass_headers value is invalid") + } + return names, nil + default: + return nil, fmt.Errorf("pass_headers value must be string, array or object") + } +} + +type syncTarget struct { + kind string + key string +} + +func parseSyncTarget(spec string) (syncTarget, error) { + raw := strings.TrimSpace(spec) + if raw == "" { + return syncTarget{}, fmt.Errorf("sync_fields target is required") + } + + idx := strings.Index(raw, ":") + if idx < 0 { + // Backward compatibility: treat bare value as JSON path. + return syncTarget{ + kind: "json", + key: raw, + }, nil + } + + kind := strings.ToLower(strings.TrimSpace(raw[:idx])) + key := strings.TrimSpace(raw[idx+1:]) + if key == "" { + return syncTarget{}, fmt.Errorf("sync_fields target key is required: %s", raw) + } + + switch kind { + case "json", "body": + return syncTarget{ + kind: "json", + key: key, + }, nil + case "header": + return syncTarget{ + kind: "header", + key: key, + }, nil + default: + return syncTarget{}, fmt.Errorf("sync_fields target prefix is invalid: %s", raw) + } +} + +func readSyncTargetValue(jsonStr string, context map[string]interface{}, target syncTarget) (interface{}, bool, error) { + switch target.kind { + case "json": + path := processNegativeIndex(jsonStr, target.key) + value := gjson.Get(jsonStr, path) + if !value.Exists() || value.Type == gjson.Null { + return nil, false, nil + } + if value.Type == gjson.String && strings.TrimSpace(value.String()) == "" { + return nil, false, nil + } + return value.Value(), true, nil + case "header": + value, ok := getHeaderValueFromContext(context, target.key) + if !ok || strings.TrimSpace(value) == "" { + return nil, false, nil + } + return value, true, nil + default: + return nil, false, fmt.Errorf("unsupported sync_fields target kind: %s", target.kind) + } +} + +func writeSyncTargetValue(jsonStr string, context map[string]interface{}, target syncTarget, value interface{}) (string, error) { + switch target.kind { + case "json": + path := processNegativeIndex(jsonStr, target.key) + nextJSON, err := sjson.Set(jsonStr, path, value) + if err != nil { + return "", err + } + return nextJSON, nil + case "header": + if err := setHeaderOverrideInContext(context, target.key, value, false); err != nil { + return "", err + } + return jsonStr, nil + default: + return "", fmt.Errorf("unsupported sync_fields target kind: %s", target.kind) + } +} + +func syncFieldsBetweenTargets(jsonStr string, context map[string]interface{}, fromSpec string, toSpec string) (string, error) { + fromTarget, err := parseSyncTarget(fromSpec) + if err != nil { + return "", err + } + toTarget, err := parseSyncTarget(toSpec) + if err != nil { + return "", err + } + + fromValue, fromExists, err := readSyncTargetValue(jsonStr, context, fromTarget) + if err != nil { + return "", err + } + toValue, toExists, err := readSyncTargetValue(jsonStr, context, toTarget) + if err != nil { + return "", err + } + + // If one side exists and the other side is missing, sync the missing side. + if fromExists && !toExists { + return writeSyncTargetValue(jsonStr, context, toTarget, fromValue) + } + if toExists && !fromExists { + return writeSyncTargetValue(jsonStr, context, fromTarget, toValue) + } + return jsonStr, nil +} + +func ensureMapKeyInContext(context map[string]interface{}, key string) map[string]interface{} { + if context == nil { + return map[string]interface{}{} + } + if existing, ok := context[key]; ok { + if mapVal, ok := existing.(map[string]interface{}); ok { + return mapVal + } + } + result := make(map[string]interface{}) + context[key] = result + return result +} + +func getHeaderValueFromContext(context map[string]interface{}, headerName string) (string, bool) { + headerName = normalizeHeaderContextKey(headerName) + if headerName == "" { + return "", false + } + for _, key := range []string{paramOverrideContextHeaderOverride, paramOverrideContextRequestHeaders} { + source := ensureMapKeyInContext(context, key) + raw, ok := source[headerName] + if !ok { + continue + } + value := strings.TrimSpace(fmt.Sprintf("%v", raw)) + if value != "" { + return value, true + } + } + return "", false +} + +func normalizeHeaderContextKey(key string) string { + return strings.TrimSpace(strings.ToLower(key)) +} + +func buildRequestHeadersContext(headers map[string]string) map[string]interface{} { + if len(headers) == 0 { + return map[string]interface{}{} + } + entries := lo.Entries(headers) + normalizedEntries := lo.FilterMap(entries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) { + normalized := normalizeHeaderContextKey(item.Key) + value := strings.TrimSpace(item.Value) + if normalized == "" || value == "" { + return lo.Entry[string, string]{}, false + } + return lo.Entry[string, string]{Key: normalized, Value: value}, true + }) + return lo.SliceToMap(normalizedEntries, func(item lo.Entry[string, string]) (string, interface{}) { + return item.Key, item.Value + }) +} + +func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]interface{}) { + if info == nil || context == nil { + return + } + raw, exists := context[paramOverrideContextHeaderOverride] + if !exists { + return + } + rawMap, ok := raw.(map[string]interface{}) + if !ok { + return + } + info.RuntimeHeadersOverride = sanitizeHeaderOverrideMap(rawMap) + info.UseRuntimeHeadersOverride = true +} + func moveValue(jsonStr, fromPath, toPath string) (string, error) { sourceValue := gjson.Get(jsonStr, fromPath) if !sourceValue.Exists() { @@ -537,6 +1285,235 @@ func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement)) } +type pruneObjectsOptions struct { + conditions []ConditionOperation + logic string + recursive bool +} + +func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string, error) { + options, err := parsePruneObjectsOptions(value) + if err != nil { + return "", err + } + + if path == "" { + var root interface{} + if err := common.Unmarshal([]byte(jsonStr), &root); err != nil { + return "", err + } + cleaned, _, err := pruneObjectsNode(root, options, contextJSON, true) + if err != nil { + return "", err + } + cleanedBytes, err := common.Marshal(cleaned) + if err != nil { + return "", err + } + return string(cleanedBytes), nil + } + + target := gjson.Get(jsonStr, path) + if !target.Exists() { + return jsonStr, nil + } + + var targetNode interface{} + if target.Type == gjson.JSON { + if err := common.Unmarshal([]byte(target.Raw), &targetNode); err != nil { + return "", err + } + } else { + targetNode = target.Value() + } + + cleaned, _, err := pruneObjectsNode(targetNode, options, contextJSON, true) + if err != nil { + return "", err + } + cleanedBytes, err := common.Marshal(cleaned) + if err != nil { + return "", err + } + return sjson.SetRaw(jsonStr, path, string(cleanedBytes)) +} + +func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) { + opts := pruneObjectsOptions{ + logic: "AND", + recursive: true, + } + + switch raw := value.(type) { + case nil: + return opts, fmt.Errorf("prune_objects value is required") + case string: + v := strings.TrimSpace(raw) + if v == "" { + return opts, fmt.Errorf("prune_objects value is required") + } + opts.conditions = []ConditionOperation{ + { + Path: "type", + Mode: "full", + Value: v, + }, + } + case map[string]interface{}: + if logic, ok := raw["logic"].(string); ok && strings.TrimSpace(logic) != "" { + opts.logic = logic + } + if recursive, ok := raw["recursive"].(bool); ok { + opts.recursive = recursive + } + + if condRaw, exists := raw["conditions"]; exists { + conditions, err := parseConditionOperations(condRaw) + if err != nil { + return opts, err + } + opts.conditions = append(opts.conditions, conditions...) + } + + if whereRaw, exists := raw["where"]; exists { + whereMap, ok := whereRaw.(map[string]interface{}) + if !ok { + return opts, fmt.Errorf("prune_objects where must be object") + } + for key, val := range whereMap { + key = strings.TrimSpace(key) + if key == "" { + continue + } + opts.conditions = append(opts.conditions, ConditionOperation{ + Path: key, + Mode: "full", + Value: val, + }) + } + } + + if matchType, exists := raw["type"]; exists { + opts.conditions = append(opts.conditions, ConditionOperation{ + Path: "type", + Mode: "full", + Value: matchType, + }) + } + default: + return opts, fmt.Errorf("prune_objects value must be string or object") + } + + if len(opts.conditions) == 0 { + return opts, fmt.Errorf("prune_objects conditions are required") + } + return opts, nil +} + +func parseConditionOperations(raw interface{}) ([]ConditionOperation, error) { + switch typed := raw.(type) { + case map[string]interface{}: + entries := lo.Entries(typed) + conditions := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (ConditionOperation, bool) { + path := strings.TrimSpace(item.Key) + if path == "" { + return ConditionOperation{}, false + } + return ConditionOperation{ + Path: path, + Mode: "full", + Value: item.Value, + }, true + }) + if len(conditions) == 0 { + return nil, fmt.Errorf("conditions object must contain at least one key") + } + return conditions, nil + case []interface{}: + items := typed + result := make([]ConditionOperation, 0, len(items)) + for _, item := range items { + itemMap, ok := item.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("condition must be object") + } + path, _ := itemMap["path"].(string) + mode, _ := itemMap["mode"].(string) + if strings.TrimSpace(path) == "" || strings.TrimSpace(mode) == "" { + return nil, fmt.Errorf("condition path/mode is required") + } + condition := ConditionOperation{ + Path: path, + Mode: mode, + } + if value, exists := itemMap["value"]; exists { + condition.Value = value + } + if invert, ok := itemMap["invert"].(bool); ok { + condition.Invert = invert + } + if passMissingKey, ok := itemMap["pass_missing_key"].(bool); ok { + condition.PassMissingKey = passMissingKey + } + result = append(result, condition) + } + return result, nil + default: + return nil, fmt.Errorf("conditions must be an array or object") + } +} + +func pruneObjectsNode(node interface{}, options pruneObjectsOptions, contextJSON string, isRoot bool) (interface{}, bool, error) { + switch value := node.(type) { + case []interface{}: + result := make([]interface{}, 0, len(value)) + for _, item := range value { + next, drop, err := pruneObjectsNode(item, options, contextJSON, false) + if err != nil { + return nil, false, err + } + if drop { + continue + } + result = append(result, next) + } + return result, false, nil + case map[string]interface{}: + shouldDrop, err := shouldPruneObject(value, options, contextJSON) + if err != nil { + return nil, false, err + } + if shouldDrop && !isRoot { + return nil, true, nil + } + if !options.recursive { + return value, false, nil + } + for key, child := range value { + next, drop, err := pruneObjectsNode(child, options, contextJSON, false) + if err != nil { + return nil, false, err + } + if drop { + delete(value, key) + continue + } + value[key] = next + } + return value, false, nil + default: + return node, false, nil + } +} + +func shouldPruneObject(node map[string]interface{}, options pruneObjectsOptions, contextJSON string) (bool, error) { + nodeBytes, err := common.Marshal(node) + if err != nil { + return false, err + } + return checkConditions(string(nodeBytes), contextJSON, options.conditions, options.logic) +} + func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) { current := gjson.Get(jsonStr, path) var currentMap, newMap map[string]interface{} @@ -598,6 +1575,37 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} { } } + ctx[paramOverrideContextRequestHeaders] = buildRequestHeadersContext(info.RequestHeaders) + + headerOverrideSource := GetEffectiveHeaderOverride(info) + ctx[paramOverrideContextHeaderOverride] = sanitizeHeaderOverrideMap(headerOverrideSource) + + ctx["retry_index"] = info.RetryIndex + ctx["is_retry"] = info.RetryIndex > 0 + ctx["retry"] = map[string]interface{}{ + "index": info.RetryIndex, + "is_retry": info.RetryIndex > 0, + } + + if info.LastError != nil { + code := string(info.LastError.GetErrorCode()) + errorType := string(info.LastError.GetErrorType()) + lastError := map[string]interface{}{ + "status_code": info.LastError.StatusCode, + "message": info.LastError.Error(), + "code": code, + "error_code": code, + "type": errorType, + "error_type": errorType, + "skip_retry": types.IsSkipRetryError(info.LastError), + } + ctx["last_error"] = lastError + ctx["last_error_status_code"] = info.LastError.StatusCode + ctx["last_error_message"] = info.LastError.Error() + ctx["last_error_code"] = code + ctx["last_error_type"] = errorType + } + ctx["is_channel_test"] = info.IsChannelTest return ctx } diff --git a/relay/common/override_test.go b/relay/common/override_test.go index c83cddff..5f49d95a 100644 --- a/relay/common/override_test.go +++ b/relay/common/override_test.go @@ -5,6 +5,8 @@ import ( "reflect" "testing" + "github.com/QuantumNous/new-api/types" + "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/setting/model_setting" ) @@ -775,6 +777,754 @@ func TestApplyParamOverrideToUpper(t *testing.T) { assertJSONEqual(t, `{"model":"GPT-4"}`, string(out)) } +func TestApplyParamOverrideReturnError(t *testing.T) { + input := []byte(`{"model":"gemini-2.5-pro"}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "return_error", + "value": map[string]interface{}{ + "message": "forced bad request by param override", + "status_code": 422, + "code": "forced_bad_request", + "type": "invalid_request_error", + "skip_retry": true, + }, + "conditions": []interface{}{ + map[string]interface{}{ + "path": "retry.is_retry", + "mode": "full", + "value": true, + }, + }, + }, + }, + } + ctx := map[string]interface{}{ + "retry": map[string]interface{}{ + "index": 1, + "is_retry": true, + }, + } + + _, err := ApplyParamOverride(input, override, ctx) + if err == nil { + t.Fatalf("expected error, got nil") + } + returnErr, ok := AsParamOverrideReturnError(err) + if !ok { + t.Fatalf("expected ParamOverrideReturnError, got %T: %v", err, err) + } + if returnErr.StatusCode != 422 { + t.Fatalf("expected status 422, got %d", returnErr.StatusCode) + } + if returnErr.Code != "forced_bad_request" { + t.Fatalf("expected code forced_bad_request, got %s", returnErr.Code) + } + if !returnErr.SkipRetry { + t.Fatalf("expected skip_retry true") + } +} + +func TestApplyParamOverridePruneObjectsByTypeString(t *testing.T) { + input := []byte(`{ + "messages":[ + {"role":"assistant","content":[ + {"type":"output_text","text":"a"}, + {"type":"redacted_thinking","text":"secret"}, + {"type":"tool_call","name":"tool_a"} + ]}, + {"role":"assistant","content":[ + {"type":"output_text","text":"b"}, + {"type":"wrapper","parts":[ + {"type":"redacted_thinking","text":"secret2"}, + {"type":"output_text","text":"c"} + ]} + ]} + ] + }`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "prune_objects", + "value": "redacted_thinking", + }, + }, + } + + out, err := ApplyParamOverride(input, override, nil) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{ + "messages":[ + {"role":"assistant","content":[ + {"type":"output_text","text":"a"}, + {"type":"tool_call","name":"tool_a"} + ]}, + {"role":"assistant","content":[ + {"type":"output_text","text":"b"}, + {"type":"wrapper","parts":[ + {"type":"output_text","text":"c"} + ]} + ]} + ] + }`, string(out)) +} + +func TestApplyParamOverridePruneObjectsWhereAndPath(t *testing.T) { + input := []byte(`{ + "a":{"items":[{"type":"redacted_thinking","id":1},{"type":"output_text","id":2}]}, + "b":{"items":[{"type":"redacted_thinking","id":3},{"type":"output_text","id":4}]} + }`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "path": "a", + "mode": "prune_objects", + "value": map[string]interface{}{ + "where": map[string]interface{}{ + "type": "redacted_thinking", + }, + }, + }, + }, + } + + out, err := ApplyParamOverride(input, override, nil) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{ + "a":{"items":[{"type":"output_text","id":2}]}, + "b":{"items":[{"type":"redacted_thinking","id":3},{"type":"output_text","id":4}]} + }`, string(out)) +} + +func TestApplyParamOverrideNormalizeThinkingSignatureUnsupported(t *testing.T) { + input := []byte(`{"items":[{"type":"redacted_thinking"}]}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "normalize_thinking_signature", + }, + }, + } + + _, err := ApplyParamOverride(input, override, nil) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestApplyParamOverrideConditionFromRetryAndLastErrorContext(t *testing.T) { + info := &RelayInfo{ + RetryIndex: 1, + LastError: types.WithOpenAIError(types.OpenAIError{ + Message: "invalid thinking signature", + Type: "invalid_request_error", + Code: "bad_thought_signature", + }, 400), + } + ctx := BuildParamOverrideContext(info) + + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "path": "temperature", + "mode": "set", + "value": 0.1, + "logic": "AND", + "conditions": []interface{}{ + map[string]interface{}{ + "path": "is_retry", + "mode": "full", + "value": true, + }, + map[string]interface{}{ + "path": "last_error.code", + "mode": "contains", + "value": "thought_signature", + }, + }, + }, + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.1}`, string(out)) +} + +func TestApplyParamOverrideConditionFromRequestHeaders(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "path": "temperature", + "mode": "set", + "value": 0.1, + "conditions": []interface{}{ + map[string]interface{}{ + "path": "request_headers.authorization", + "mode": "contains", + "value": "Bearer ", + }, + }, + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "authorization": "Bearer token-123", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.1}`, string(out)) +} + +func TestApplyParamOverrideSetHeaderAndUseInLaterCondition(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "set_header", + "path": "X-Debug-Mode", + "value": "enabled", + }, + map[string]interface{}{ + "path": "temperature", + "mode": "set", + "value": 0.1, + "conditions": []interface{}{ + map[string]interface{}{ + "path": "header_override.x-debug-mode", + "mode": "full", + "value": "enabled", + }, + }, + }, + }, + } + + out, err := ApplyParamOverride(input, override, nil) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.1}`, string(out)) +} + +func TestApplyParamOverrideCopyHeaderFromRequestHeaders(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "copy_header", + "from": "Authorization", + "to": "X-Upstream-Auth", + }, + map[string]interface{}{ + "path": "temperature", + "mode": "set", + "value": 0.1, + "conditions": []interface{}{ + map[string]interface{}{ + "path": "header_override.x-upstream-auth", + "mode": "contains", + "value": "Bearer ", + }, + }, + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "authorization": "Bearer token-123", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.1}`, string(out)) +} + +func TestApplyParamOverridePassHeadersSkipsMissingHeaders(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "pass_headers", + "value": []interface{}{"X-Codex-Beta-Features", "Session_id"}, + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "session_id": "sess-123", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.7}`, string(out)) + + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + t.Fatalf("expected header_override context map") + } + if headers["session_id"] != "sess-123" { + t.Fatalf("expected session_id to be passed, got: %v", headers["session_id"]) + } + if _, exists := headers["x-codex-beta-features"]; exists { + t.Fatalf("expected missing header to be skipped") + } +} + +func TestApplyParamOverrideCopyHeaderSkipsMissingSource(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "copy_header", + "from": "X-Missing-Header", + "to": "X-Upstream-Auth", + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "authorization": "Bearer token-123", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.7}`, string(out)) + + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + return + } + if _, exists := headers["x-upstream-auth"]; exists { + t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing") + } +} + +func TestApplyParamOverrideMoveHeaderSkipsMissingSource(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "move_header", + "from": "X-Missing-Header", + "to": "X-Upstream-Auth", + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "authorization": "Bearer token-123", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.7}`, string(out)) + + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + return + } + if _, exists := headers["x-upstream-auth"]; exists { + t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing") + } +} + +func TestApplyParamOverrideSyncFieldsHeaderToJSON(t *testing.T) { + input := []byte(`{"model":"gpt-4"}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "sync_fields", + "from": "header:session_id", + "to": "json:prompt_cache_key", + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "session_id": "sess-123", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"model":"gpt-4","prompt_cache_key":"sess-123"}`, string(out)) +} + +func TestApplyParamOverrideSyncFieldsJSONToHeader(t *testing.T) { + input := []byte(`{"model":"gpt-4","prompt_cache_key":"cache-abc"}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "sync_fields", + "from": "header:session_id", + "to": "json:prompt_cache_key", + }, + }, + } + ctx := map[string]interface{}{} + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"model":"gpt-4","prompt_cache_key":"cache-abc"}`, string(out)) + + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + t.Fatalf("expected header_override context map") + } + if headers["session_id"] != "cache-abc" { + t.Fatalf("expected session_id to be synced from prompt_cache_key, got: %v", headers["session_id"]) + } +} + +func TestApplyParamOverrideSyncFieldsNoChangeWhenBothExist(t *testing.T) { + input := []byte(`{"model":"gpt-4","prompt_cache_key":"cache-body"}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "sync_fields", + "from": "header:session_id", + "to": "json:prompt_cache_key", + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "session_id": "cache-header", + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"model":"gpt-4","prompt_cache_key":"cache-body"}`, string(out)) + + headers, _ := ctx["header_override"].(map[string]interface{}) + if headers != nil { + if _, exists := headers["session_id"]; exists { + t.Fatalf("expected no override when both sides already have value") + } + } +} + +func TestApplyParamOverrideSyncFieldsInvalidTarget(t *testing.T) { + input := []byte(`{"model":"gpt-4"}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "sync_fields", + "from": "foo:session_id", + "to": "json:prompt_cache_key", + }, + }, + } + + _, err := ApplyParamOverride(input, override, nil) + if err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestApplyParamOverrideSetHeaderKeepOrigin(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "set_header", + "path": "X-Feature-Flag", + "value": "new-value", + "keep_origin": true, + }, + }, + } + ctx := map[string]interface{}{ + "header_override": map[string]interface{}{ + "x-feature-flag": "legacy-value", + }, + } + + _, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + t.Fatalf("expected header_override context map") + } + if headers["x-feature-flag"] != "legacy-value" { + t.Fatalf("expected keep_origin to preserve old value, got: %v", headers["x-feature-flag"]) + } +} + +func TestApplyParamOverrideSetHeaderMapRewritesCommaSeparatedHeader(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "set_header", + "path": "anthropic-beta", + "value": map[string]interface{}{ + "advanced-tool-use-2025-11-20": nil, + "computer-use-2025-01-24": "computer-use-2025-01-24", + }, + }, + }, + } + ctx := map[string]interface{}{ + "request_headers": map[string]interface{}{ + "anthropic-beta": "advanced-tool-use-2025-11-20, computer-use-2025-01-24", + }, + } + + _, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + t.Fatalf("expected header_override context map") + } + if headers["anthropic-beta"] != "computer-use-2025-01-24" { + t.Fatalf("expected anthropic-beta to keep only mapped value, got: %v", headers["anthropic-beta"]) + } +} + +func TestApplyParamOverrideSetHeaderMapDeleteWholeHeaderWhenAllTokensCleared(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "set_header", + "path": "anthropic-beta", + "value": map[string]interface{}{ + "advanced-tool-use-2025-11-20": nil, + "computer-use-2025-01-24": nil, + }, + }, + }, + } + ctx := map[string]interface{}{ + "header_override": map[string]interface{}{ + "anthropic-beta": "advanced-tool-use-2025-11-20,computer-use-2025-01-24", + }, + } + + _, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + + headers, ok := ctx["header_override"].(map[string]interface{}) + if !ok { + t.Fatalf("expected header_override context map") + } + if _, exists := headers["anthropic-beta"]; exists { + t.Fatalf("expected anthropic-beta to be deleted when all mapped values are null") + } +} + +func TestApplyParamOverrideConditionsObjectShorthand(t *testing.T) { + input := []byte(`{"temperature":0.7}`) + override := map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "path": "temperature", + "mode": "set", + "value": 0.1, + "logic": "AND", + "conditions": map[string]interface{}{ + "is_retry": true, + "last_error.status_code": 400.0, + }, + }, + }, + } + ctx := map[string]interface{}{ + "is_retry": true, + "last_error": map[string]interface{}{ + "status_code": 400.0, + }, + } + + out, err := ApplyParamOverride(input, override, ctx) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.1}`, string(out)) +} + +func TestApplyParamOverrideWithRelayInfoSyncRuntimeHeaders(t *testing.T) { + info := &RelayInfo{ + ChannelMeta: &ChannelMeta{ + ParamOverride: map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "set_header", + "path": "X-Injected-By-Param-Override", + "value": "enabled", + }, + map[string]interface{}{ + "mode": "delete_header", + "path": "X-Delete-Me", + }, + }, + }, + HeadersOverride: map[string]interface{}{ + "X-Delete-Me": "legacy", + "X-Keep-Me": "keep", + }, + }, + } + + input := []byte(`{"temperature":0.7}`) + out, err := ApplyParamOverrideWithRelayInfo(input, info) + if err != nil { + t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err) + } + assertJSONEqual(t, `{"temperature":0.7}`, string(out)) + + if !info.UseRuntimeHeadersOverride { + t.Fatalf("expected runtime header override to be enabled") + } + if info.RuntimeHeadersOverride["x-keep-me"] != "keep" { + t.Fatalf("expected x-keep-me header to be preserved, got: %v", info.RuntimeHeadersOverride["x-keep-me"]) + } + if info.RuntimeHeadersOverride["x-injected-by-param-override"] != "enabled" { + t.Fatalf("expected x-injected-by-param-override header to be set, got: %v", info.RuntimeHeadersOverride["x-injected-by-param-override"]) + } + if _, exists := info.RuntimeHeadersOverride["x-delete-me"]; exists { + t.Fatalf("expected x-delete-me header to be deleted") + } +} + +func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) { + info := &RelayInfo{ + ChannelMeta: &ChannelMeta{ + ParamOverride: map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "move_header", + "from": "X-Legacy-Trace", + "to": "X-Trace", + }, + map[string]interface{}{ + "mode": "copy_header", + "from": "X-Trace", + "to": "X-Trace-Backup", + }, + }, + }, + HeadersOverride: map[string]interface{}{ + "X-Legacy-Trace": "trace-123", + }, + }, + } + + input := []byte(`{"temperature":0.7}`) + _, err := ApplyParamOverrideWithRelayInfo(input, info) + if err != nil { + t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err) + } + if _, exists := info.RuntimeHeadersOverride["x-legacy-trace"]; exists { + t.Fatalf("expected source header to be removed after move") + } + if info.RuntimeHeadersOverride["x-trace"] != "trace-123" { + t.Fatalf("expected x-trace to be set, got: %v", info.RuntimeHeadersOverride["x-trace"]) + } + if info.RuntimeHeadersOverride["x-trace-backup"] != "trace-123" { + t.Fatalf("expected x-trace-backup to be copied, got: %v", info.RuntimeHeadersOverride["x-trace-backup"]) + } +} + +func TestApplyParamOverrideWithRelayInfoSetHeaderMapRewritesAnthropicBeta(t *testing.T) { + info := &RelayInfo{ + ChannelMeta: &ChannelMeta{ + ParamOverride: map[string]interface{}{ + "operations": []interface{}{ + map[string]interface{}{ + "mode": "set_header", + "path": "anthropic-beta", + "value": map[string]interface{}{ + "advanced-tool-use-2025-11-20": nil, + "computer-use-2025-01-24": "computer-use-2025-01-24", + }, + }, + }, + }, + HeadersOverride: map[string]interface{}{ + "anthropic-beta": "advanced-tool-use-2025-11-20, computer-use-2025-01-24", + }, + }, + } + + _, err := ApplyParamOverrideWithRelayInfo([]byte(`{"temperature":0.7}`), info) + if err != nil { + t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err) + } + + if !info.UseRuntimeHeadersOverride { + t.Fatalf("expected runtime header override to be enabled") + } + if info.RuntimeHeadersOverride["anthropic-beta"] != "computer-use-2025-01-24" { + t.Fatalf("expected anthropic-beta to be rewritten, got: %v", info.RuntimeHeadersOverride["anthropic-beta"]) + } +} + +func TestGetEffectiveHeaderOverrideUsesRuntimeOverrideAsFinalResult(t *testing.T) { + info := &RelayInfo{ + UseRuntimeHeadersOverride: true, + RuntimeHeadersOverride: map[string]interface{}{ + "x-runtime": "runtime-only", + }, + ChannelMeta: &ChannelMeta{ + HeadersOverride: map[string]interface{}{ + "X-Static": "static-value", + "X-Deleted": "should-not-exist", + }, + }, + } + + effective := GetEffectiveHeaderOverride(info) + if effective["x-runtime"] != "runtime-only" { + t.Fatalf("expected x-runtime from runtime override, got: %v", effective["x-runtime"]) + } + if _, exists := effective["x-static"]; exists { + t.Fatalf("expected runtime override to be final and not merge channel headers") + } +} + func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) { input := `{ "service_tier":"flex", diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 6d286d61..8b0789c0 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -101,6 +101,7 @@ type RelayInfo struct { RelayMode int OriginModelName string RequestURLPath string + RequestHeaders map[string]string ShouldIncludeUsage bool DisablePing bool // 是否禁止向下游发送自定义 Ping ClientWs *websocket.Conn @@ -144,6 +145,10 @@ type RelayInfo struct { SubscriptionAmountUsedAfterPreConsume int64 IsClaudeBetaQuery bool // /v1/messages?beta=true IsChannelTest bool // channel test request + RetryIndex int + LastError *types.NewAPIError + RuntimeHeadersOverride map[string]interface{} + UseRuntimeHeadersOverride bool PriceData types.PriceData @@ -461,6 +466,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { isFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), RequestURLPath: c.Request.URL.String(), + RequestHeaders: cloneRequestHeaders(c), IsStream: isStream, StartTime: startTime, @@ -493,6 +499,27 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { return info } +func cloneRequestHeaders(c *gin.Context) map[string]string { + if c == nil || c.Request == nil { + return nil + } + if len(c.Request.Header) == 0 { + return nil + } + headers := make(map[string]string, len(c.Request.Header)) + for key := range c.Request.Header { + value := strings.TrimSpace(c.Request.Header.Get(key)) + if value == "" { + continue + } + headers[key] = value + } + if len(headers) == 0 { + return nil + } + return headers +} + func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) { var info *RelayInfo var err error diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index a133bab8..9a25237c 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -172,9 +172,9 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index 1a41756b..d8ca4223 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "encoding/json" "fmt" "net/http" @@ -46,15 +45,15 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } relaycommon.AppendRequestConversionFromRequest(info, convertedRequest) - jsonData, err := json.Marshal(convertedRequest) + jsonData, err := common.Marshal(convertedRequest) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index a1b8e592..39bd44e6 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -157,9 +157,9 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } @@ -257,14 +257,9 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI // apply param override if len(info.ParamOverride) > 0 { - reqMap := make(map[string]interface{}) - _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range info.ParamOverride { - reqMap[key] = value - } - jsonData, err = common.Marshal(reqMap) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } logger.LogDebug(c, "Gemini embedding request body: "+string(jsonData)) diff --git a/relay/image_handler.go b/relay/image_handler.go index e8329426..fc8ef500 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -70,9 +70,9 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/relay/param_override_error.go b/relay/param_override_error.go new file mode 100644 index 00000000..c2338298 --- /dev/null +++ b/relay/param_override_error.go @@ -0,0 +1,13 @@ +package relay + +import ( + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/types" +) + +func newAPIErrorFromParamOverride(err error) *types.NewAPIError { + if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok { + return relaycommon.NewAPIErrorFromParamOverride(fixedErr) + } + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) +} diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 8fe2930e..40d686f7 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -61,9 +61,9 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index b3169e72..18f1b711 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -96,9 +96,9 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * // apply param override if len(info.ParamOverride) > 0 { - jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info)) + jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + return newAPIErrorFromParamOverride(err) } } diff --git a/service/channel_affinity.go b/service/channel_affinity.go index 524c6574..3e90b9c2 100644 --- a/service/channel_affinity.go +++ b/service/channel_affinity.go @@ -45,6 +45,7 @@ type channelAffinityMeta struct { TTLSeconds int RuleName string SkipRetry bool + ParamTemplate map[string]interface{} KeySourceType string KeySourceKey string KeySourcePath string @@ -415,6 +416,84 @@ func buildChannelAffinityKeyHint(s string) string { return s[:4] + "..." + s[len(s)-4:] } +func cloneStringAnyMap(src map[string]interface{}) map[string]interface{} { + if len(src) == 0 { + return map[string]interface{}{} + } + dst := make(map[string]interface{}, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func mergeChannelOverride(base map[string]interface{}, tpl map[string]interface{}) map[string]interface{} { + if len(base) == 0 && len(tpl) == 0 { + return map[string]interface{}{} + } + if len(tpl) == 0 { + return base + } + out := cloneStringAnyMap(base) + for k, v := range tpl { + out[k] = v + } + return out +} + +func appendChannelAffinityTemplateAdminInfo(c *gin.Context, meta channelAffinityMeta) { + if c == nil { + return + } + if len(meta.ParamTemplate) == 0 { + return + } + + templateInfo := map[string]interface{}{ + "applied": true, + "rule_name": meta.RuleName, + "param_override_keys": len(meta.ParamTemplate), + } + if anyInfo, ok := c.Get(ginKeyChannelAffinityLogInfo); ok { + if info, ok := anyInfo.(map[string]interface{}); ok { + info["override_template"] = templateInfo + c.Set(ginKeyChannelAffinityLogInfo, info) + return + } + } + c.Set(ginKeyChannelAffinityLogInfo, map[string]interface{}{ + "reason": meta.RuleName, + "rule_name": meta.RuleName, + "using_group": meta.UsingGroup, + "model": meta.ModelName, + "request_path": meta.RequestPath, + "key_source": meta.KeySourceType, + "key_key": meta.KeySourceKey, + "key_path": meta.KeySourcePath, + "key_hint": meta.KeyHint, + "key_fp": meta.KeyFingerprint, + "override_template": templateInfo, + }) +} + +// ApplyChannelAffinityOverrideTemplate merges per-rule channel override templates onto the selected channel override config. +func ApplyChannelAffinityOverrideTemplate(c *gin.Context, paramOverride map[string]interface{}) (map[string]interface{}, bool) { + if c == nil { + return paramOverride, false + } + meta, ok := getChannelAffinityMeta(c) + if !ok { + return paramOverride, false + } + if len(meta.ParamTemplate) == 0 { + return paramOverride, false + } + + mergedParam := mergeChannelOverride(paramOverride, meta.ParamTemplate) + appendChannelAffinityTemplateAdminInfo(c, meta) + return mergedParam, true +} + func GetPreferredChannelByAffinity(c *gin.Context, modelName string, usingGroup string) (int, bool) { setting := operation_setting.GetChannelAffinitySetting() if setting == nil || !setting.Enabled { @@ -466,6 +545,7 @@ func GetPreferredChannelByAffinity(c *gin.Context, modelName string, usingGroup TTLSeconds: ttlSeconds, RuleName: rule.Name, SkipRetry: rule.SkipRetryOnFailure, + ParamTemplate: cloneStringAnyMap(rule.ParamOverrideTemplate), KeySourceType: strings.TrimSpace(usedSource.Type), KeySourceKey: strings.TrimSpace(usedSource.Key), KeySourcePath: strings.TrimSpace(usedSource.Path), diff --git a/service/channel_affinity_template_test.go b/service/channel_affinity_template_test.go new file mode 100644 index 00000000..acf30154 --- /dev/null +++ b/service/channel_affinity_template_test.go @@ -0,0 +1,145 @@ +package service + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/setting/operation_setting" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func buildChannelAffinityTemplateContextForTest(meta channelAffinityMeta) *gin.Context { + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + setChannelAffinityContext(ctx, meta) + return ctx +} + +func TestApplyChannelAffinityOverrideTemplate_NoTemplate(t *testing.T) { + ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{ + RuleName: "rule-no-template", + }) + base := map[string]interface{}{ + "temperature": 0.7, + } + + merged, applied := ApplyChannelAffinityOverrideTemplate(ctx, base) + require.False(t, applied) + require.Equal(t, base, merged) +} + +func TestApplyChannelAffinityOverrideTemplate_MergeTemplate(t *testing.T) { + ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{ + RuleName: "rule-with-template", + ParamTemplate: map[string]interface{}{ + "temperature": 0.2, + "top_p": 0.95, + }, + UsingGroup: "default", + ModelName: "gpt-4.1", + RequestPath: "/v1/responses", + KeySourceType: "gjson", + KeySourcePath: "prompt_cache_key", + KeyHint: "abcd...wxyz", + KeyFingerprint: "abcd1234", + }) + base := map[string]interface{}{ + "temperature": 0.7, + "max_tokens": 2000, + } + + merged, applied := ApplyChannelAffinityOverrideTemplate(ctx, base) + require.True(t, applied) + require.Equal(t, 0.2, merged["temperature"]) + require.Equal(t, 0.95, merged["top_p"]) + require.Equal(t, 2000, merged["max_tokens"]) + require.Equal(t, 0.7, base["temperature"]) + + anyInfo, ok := ctx.Get(ginKeyChannelAffinityLogInfo) + require.True(t, ok) + info, ok := anyInfo.(map[string]interface{}) + require.True(t, ok) + overrideInfoAny, ok := info["override_template"] + require.True(t, ok) + overrideInfo, ok := overrideInfoAny.(map[string]interface{}) + require.True(t, ok) + require.Equal(t, true, overrideInfo["applied"]) + require.Equal(t, "rule-with-template", overrideInfo["rule_name"]) + require.EqualValues(t, 2, overrideInfo["param_override_keys"]) +} + +func TestChannelAffinityHitCodexTemplatePassHeadersEffective(t *testing.T) { + gin.SetMode(gin.TestMode) + + setting := operation_setting.GetChannelAffinitySetting() + require.NotNil(t, setting) + + var codexRule *operation_setting.ChannelAffinityRule + for i := range setting.Rules { + rule := &setting.Rules[i] + if strings.EqualFold(strings.TrimSpace(rule.Name), "codex cli trace") { + codexRule = rule + break + } + } + require.NotNil(t, codexRule) + + affinityValue := fmt.Sprintf("pc-hit-%d", time.Now().UnixNano()) + cacheKeySuffix := buildChannelAffinityCacheKeySuffix(*codexRule, "default", affinityValue) + + cache := getChannelAffinityCache() + require.NoError(t, cache.SetWithTTL(cacheKeySuffix, 9527, time.Minute)) + t.Cleanup(func() { + _, _ = cache.DeleteMany([]string{cacheKeySuffix}) + }) + + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(fmt.Sprintf(`{"prompt_cache_key":"%s"}`, affinityValue))) + ctx.Request.Header.Set("Content-Type", "application/json") + + channelID, found := GetPreferredChannelByAffinity(ctx, "gpt-5", "default") + require.True(t, found) + require.Equal(t, 9527, channelID) + + baseOverride := map[string]interface{}{ + "temperature": 0.2, + } + mergedOverride, applied := ApplyChannelAffinityOverrideTemplate(ctx, baseOverride) + require.True(t, applied) + require.Equal(t, 0.2, mergedOverride["temperature"]) + + info := &relaycommon.RelayInfo{ + RequestHeaders: map[string]string{ + "Originator": "Codex CLI", + "Session_id": "sess-123", + "User-Agent": "codex-cli-test", + }, + ChannelMeta: &relaycommon.ChannelMeta{ + ParamOverride: mergedOverride, + HeadersOverride: map[string]interface{}{ + "X-Static": "legacy-static", + }, + }, + } + + _, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-5"}`), info) + require.NoError(t, err) + require.True(t, info.UseRuntimeHeadersOverride) + + require.Equal(t, "legacy-static", info.RuntimeHeadersOverride["x-static"]) + require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["originator"]) + require.Equal(t, "sess-123", info.RuntimeHeadersOverride["session_id"]) + require.Equal(t, "codex-cli-test", info.RuntimeHeadersOverride["user-agent"]) + + _, exists := info.RuntimeHeadersOverride["x-codex-beta-features"] + require.False(t, exists) + _, exists = info.RuntimeHeadersOverride["x-codex-turn-metadata"] + require.False(t, exists) +} diff --git a/setting/operation_setting/channel_affinity_setting.go b/setting/operation_setting/channel_affinity_setting.go index 22f19824..7727315a 100644 --- a/setting/operation_setting/channel_affinity_setting.go +++ b/setting/operation_setting/channel_affinity_setting.go @@ -18,6 +18,8 @@ type ChannelAffinityRule struct { ValueRegex string `json:"value_regex"` TTLSeconds int `json:"ttl_seconds"` + ParamOverrideTemplate map[string]interface{} `json:"param_override_template,omitempty"` + SkipRetryOnFailure bool `json:"skip_retry_on_failure,omitempty"` IncludeUsingGroup bool `json:"include_using_group"` @@ -32,6 +34,44 @@ type ChannelAffinitySetting struct { Rules []ChannelAffinityRule `json:"rules"` } +var codexCliPassThroughHeaders = []string{ + "Originator", + "Session_id", + "User-Agent", + "X-Codex-Beta-Features", + "X-Codex-Turn-Metadata", +} + +var claudeCliPassThroughHeaders = []string{ + "X-Stainless-Arch", + "X-Stainless-Lang", + "X-Stainless-Os", + "X-Stainless-Package-Version", + "X-Stainless-Retry-Count", + "X-Stainless-Runtime", + "X-Stainless-Runtime-Version", + "X-Stainless-Timeout", + "User-Agent", + "X-App", + "Anthropic-Beta", + "Anthropic-Dangerous-Direct-Browser-Access", + "Anthropic-Version", +} + +func buildPassHeaderTemplate(headers []string) map[string]interface{} { + clonedHeaders := make([]string, 0, len(headers)) + clonedHeaders = append(clonedHeaders, headers...) + return map[string]interface{}{ + "operations": []map[string]interface{}{ + { + "mode": "pass_headers", + "value": clonedHeaders, + "keep_origin": true, + }, + }, + } +} + var channelAffinitySetting = ChannelAffinitySetting{ Enabled: true, SwitchOnSuccess: true, @@ -39,32 +79,34 @@ var channelAffinitySetting = ChannelAffinitySetting{ DefaultTTLSeconds: 3600, Rules: []ChannelAffinityRule{ { - Name: "codex trace", + Name: "codex cli trace", ModelRegex: []string{"^gpt-.*$"}, PathRegex: []string{"/v1/responses"}, KeySources: []ChannelAffinityKeySource{ {Type: "gjson", Path: "prompt_cache_key"}, }, - ValueRegex: "", - TTLSeconds: 0, - SkipRetryOnFailure: false, - IncludeUsingGroup: true, - IncludeRuleName: true, - UserAgentInclude: nil, + ValueRegex: "", + TTLSeconds: 0, + ParamOverrideTemplate: buildPassHeaderTemplate(codexCliPassThroughHeaders), + SkipRetryOnFailure: false, + IncludeUsingGroup: true, + IncludeRuleName: true, + UserAgentInclude: nil, }, { - Name: "claude code trace", + Name: "claude cli trace", ModelRegex: []string{"^claude-.*$"}, PathRegex: []string{"/v1/messages"}, KeySources: []ChannelAffinityKeySource{ {Type: "gjson", Path: "metadata.user_id"}, }, - ValueRegex: "", - TTLSeconds: 0, - SkipRetryOnFailure: false, - IncludeUsingGroup: true, - IncludeRuleName: true, - UserAgentInclude: nil, + ValueRegex: "", + TTLSeconds: 0, + ParamOverrideTemplate: buildPassHeaderTemplate(claudeCliPassThroughHeaders), + SkipRetryOnFailure: false, + IncludeUsingGroup: true, + IncludeRuleName: true, + UserAgentInclude: nil, }, }, } diff --git a/web/package.json b/web/package.json index 97c7c821..4d8c7e7f 100644 --- a/web/package.json +++ b/web/package.json @@ -10,7 +10,7 @@ "@visactor/react-vchart": "~1.8.8", "@visactor/vchart": "~1.8.8", "@visactor/vchart-semi-theme": "~1.8.8", - "axios": "1.13.5", + "axios": "1.12.0", "clsx": "^2.1.1", "dayjs": "^1.11.11", "history": "^5.3.0", diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index 3d3afcc3..3a91207d 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -59,6 +59,7 @@ import ModelSelectModal from './ModelSelectModal'; import SingleModelSelectModal from './SingleModelSelectModal'; import OllamaModelModal from './OllamaModelModal'; import CodexOAuthModal from './CodexOAuthModal'; +import ParamOverrideEditorModal from './ParamOverrideEditorModal'; import JSONEditor from '../../../common/ui/JSONEditor'; import SecureVerificationModal from '../../../common/modals/SecureVerificationModal'; import StatusCodeRiskGuardModal from './StatusCodeRiskGuardModal'; @@ -75,6 +76,7 @@ import { IconServer, IconSetting, IconCode, + IconCopy, IconGlobe, IconBolt, IconSearch, @@ -99,6 +101,28 @@ const REGION_EXAMPLE = { 'claude-3-5-sonnet-20240620': 'europe-west1', }; +const PARAM_OVERRIDE_LEGACY_TEMPLATE = { + temperature: 0, +}; + +const PARAM_OVERRIDE_OPERATIONS_TEMPLATE = { + operations: [ + { + path: 'temperature', + mode: 'set', + value: 0.7, + conditions: [ + { + path: 'model', + mode: 'prefix', + value: 'openai/', + }, + ], + logic: 'AND', + }, + ], +}; + // 支持并且已适配通过接口获取模型列表的渠道类型 const MODEL_FETCHABLE_TYPES = new Set([ 1, 4, 14, 34, 17, 26, 27, 24, 47, 25, 20, 23, 31, 40, 42, 48, 43, @@ -148,6 +172,7 @@ const EditChannelModal = (props) => { base_url: '', other: '', model_mapping: '', + param_override: '', status_code_mapping: '', models: [], auto_ban: 1, @@ -251,11 +276,69 @@ const EditChannelModal = (props) => { name: keyword, }); }, [modelSearchMatchedCount, modelSearchValue, t]); + const paramOverrideMeta = useMemo(() => { + const raw = + typeof inputs.param_override === 'string' + ? inputs.param_override.trim() + : ''; + if (!raw) { + return { + tagLabel: t('不更改'), + tagColor: 'grey', + preview: t( + '此项可选,用于覆盖请求参数。不支持覆盖 stream 参数', + ), + }; + } + if (!verifyJSON(raw)) { + return { + tagLabel: t('JSON格式错误'), + tagColor: 'red', + preview: raw, + }; + } + try { + const parsed = JSON.parse(raw); + const pretty = JSON.stringify(parsed, null, 2); + if ( + parsed && + typeof parsed === 'object' && + !Array.isArray(parsed) && + Array.isArray(parsed.operations) + ) { + return { + tagLabel: `${t('新格式模板')} (${parsed.operations.length})`, + tagColor: 'cyan', + preview: pretty, + }; + } + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + return { + tagLabel: `${t('旧格式模板')} (${Object.keys(parsed).length})`, + tagColor: 'blue', + preview: pretty, + }; + } + return { + tagLabel: t('自定义 JSON'), + tagColor: 'orange', + preview: pretty, + }; + } catch (error) { + return { + tagLabel: t('JSON格式错误'), + tagColor: 'red', + preview: raw, + }; + } + }, [inputs.param_override, t]); const [isIonetChannel, setIsIonetChannel] = useState(false); const [ionetMetadata, setIonetMetadata] = useState(null); const [codexOAuthModalVisible, setCodexOAuthModalVisible] = useState(false); const [codexCredentialRefreshing, setCodexCredentialRefreshing] = useState(false); + const [paramOverrideEditorVisible, setParamOverrideEditorVisible] = + useState(false); // 密钥显示状态 const [keyDisplayState, setKeyDisplayState] = useState({ @@ -582,6 +665,100 @@ const EditChannelModal = (props) => { } }; + const copyParamOverrideJson = async () => { + const raw = + typeof inputs.param_override === 'string' + ? inputs.param_override.trim() + : ''; + if (!raw) { + showInfo(t('暂无可复制 JSON')); + return; + } + + let content = raw; + if (verifyJSON(raw)) { + try { + content = JSON.stringify(JSON.parse(raw), null, 2); + } catch (error) { + content = raw; + } + } + + const ok = await copy(content); + if (ok) { + showSuccess(t('参数覆盖 JSON 已复制')); + } else { + showError(t('复制失败')); + } + }; + + const parseParamOverrideInput = () => { + const raw = + typeof inputs.param_override === 'string' + ? inputs.param_override.trim() + : ''; + if (!raw) return null; + if (!verifyJSON(raw)) { + throw new Error(t('当前参数覆盖不是合法的 JSON')); + } + return JSON.parse(raw); + }; + + const applyParamOverrideTemplate = ( + templateType = 'operations', + applyMode = 'fill', + ) => { + try { + const parsedCurrent = parseParamOverrideInput(); + if (templateType === 'legacy') { + if (applyMode === 'fill') { + handleInputChange( + 'param_override', + JSON.stringify(PARAM_OVERRIDE_LEGACY_TEMPLATE, null, 2), + ); + return; + } + const currentLegacy = + parsedCurrent && + typeof parsedCurrent === 'object' && + !Array.isArray(parsedCurrent) && + !Array.isArray(parsedCurrent.operations) + ? parsedCurrent + : {}; + const merged = { + ...PARAM_OVERRIDE_LEGACY_TEMPLATE, + ...currentLegacy, + }; + handleInputChange('param_override', JSON.stringify(merged, null, 2)); + return; + } + + if (applyMode === 'fill') { + handleInputChange( + 'param_override', + JSON.stringify(PARAM_OVERRIDE_OPERATIONS_TEMPLATE, null, 2), + ); + return; + } + const currentOperations = + parsedCurrent && + typeof parsedCurrent === 'object' && + !Array.isArray(parsedCurrent) && + Array.isArray(parsedCurrent.operations) + ? parsedCurrent.operations + : []; + const merged = { + operations: [ + ...currentOperations, + ...PARAM_OVERRIDE_OPERATIONS_TEMPLATE.operations, + ], + }; + handleInputChange('param_override', JSON.stringify(merged, null, 2)); + } catch (error) { + showError(error.message || t('模板应用失败')); + } + }; + const loadChannel = async () => { setLoading(true); let res = await API.get(`/api/channel/${channelId}`); @@ -1242,6 +1419,7 @@ const EditChannelModal = (props) => { const submit = async () => { const formValues = formApiRef.current ? formApiRef.current.getValues() : {}; let localInputs = { ...formValues }; + localInputs.param_override = inputs.param_override; if (localInputs.type === 57) { if (batch) { @@ -3150,78 +3328,73 @@ const EditChannelModal = (props) => { initValue={autoBan} /> - - handleInputChange('param_override', value) - } - extraText={ -
- +
+ {t('参数覆盖')} + + + + +
+ + {t('此项可选,用于覆盖请求参数。不支持覆盖 stream 参数')} + +
+
+ + {paramOverrideMeta.tagLabel} + + + + +
- } - showClear - /> +
+                          {paramOverrideMeta.preview}
+                        
+
+
{ /> + setParamOverrideEditorVisible(false)} + onSave={(nextValue) => { + handleInputChange('param_override', nextValue); + setParamOverrideEditorVisible(false); + }} + /> + . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React, { useCallback, useEffect, useMemo, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + Button, + Card, + Col, + Collapse, + Input, + Modal, + Row, + Select, + Space, + Switch, + Tag, + TextArea, + Typography, +} from '@douyinfe/semi-ui'; +import { IconDelete, IconPlus } from '@douyinfe/semi-icons'; +import { copy, showError, showSuccess, verifyJSON } from '../../../../helpers'; +import { + CLAUDE_CLI_HEADER_PASSTHROUGH_TEMPLATE, + CODEX_CLI_HEADER_PASSTHROUGH_TEMPLATE, +} from '../../../../constants/channel-affinity-template.constants'; + +const { Text } = Typography; + +const OPERATION_MODE_OPTIONS = [ + { label: '设置字段', value: 'set' }, + { label: '删除字段', value: 'delete' }, + { label: '追加到末尾', value: 'append' }, + { label: '追加到开头', value: 'prepend' }, + { label: '复制字段', value: 'copy' }, + { label: '移动字段', value: 'move' }, + { label: '字符串替换', value: 'replace' }, + { label: '正则替换', value: 'regex_replace' }, + { label: '裁剪前缀', value: 'trim_prefix' }, + { label: '裁剪后缀', value: 'trim_suffix' }, + { label: '确保前缀', value: 'ensure_prefix' }, + { label: '确保后缀', value: 'ensure_suffix' }, + { label: '去掉空白', value: 'trim_space' }, + { label: '转小写', value: 'to_lower' }, + { label: '转大写', value: 'to_upper' }, + { label: '返回自定义错误', value: 'return_error' }, + { label: '清理对象项', value: 'prune_objects' }, + { label: '请求头透传', value: 'pass_headers' }, + { label: '字段同步', value: 'sync_fields' }, + { label: '设置请求头', value: 'set_header' }, + { label: '删除请求头', value: 'delete_header' }, + { label: '复制请求头', value: 'copy_header' }, + { label: '移动请求头', value: 'move_header' }, +]; + +const OPERATION_MODE_VALUES = new Set( + OPERATION_MODE_OPTIONS.map((item) => item.value), +); + +const CONDITION_MODE_OPTIONS = [ + { label: '完全匹配', value: 'full' }, + { label: '前缀匹配', value: 'prefix' }, + { label: '后缀匹配', value: 'suffix' }, + { label: '包含', value: 'contains' }, + { label: '大于', value: 'gt' }, + { label: '大于等于', value: 'gte' }, + { label: '小于', value: 'lt' }, + { label: '小于等于', value: 'lte' }, +]; + +const CONDITION_MODE_VALUES = new Set( + CONDITION_MODE_OPTIONS.map((item) => item.value), +); + +const MODE_META = { + delete: { path: true }, + set: { path: true, value: true, keepOrigin: true }, + append: { path: true, value: true, keepOrigin: true }, + prepend: { path: true, value: true, keepOrigin: true }, + copy: { from: true, to: true }, + move: { from: true, to: true }, + replace: { path: true, from: true, to: false }, + regex_replace: { path: true, from: true, to: false }, + trim_prefix: { path: true, value: true }, + trim_suffix: { path: true, value: true }, + ensure_prefix: { path: true, value: true }, + ensure_suffix: { path: true, value: true }, + trim_space: { path: true }, + to_lower: { path: true }, + to_upper: { path: true }, + return_error: { value: true }, + prune_objects: { pathOptional: true, value: true }, + pass_headers: { value: true, keepOrigin: true }, + sync_fields: { from: true, to: true }, + set_header: { path: true, value: true, keepOrigin: true }, + delete_header: { path: true }, + copy_header: { from: true, to: true, keepOrigin: true, pathAlias: true }, + move_header: { from: true, to: true, keepOrigin: true, pathAlias: true }, +}; + +const VALUE_REQUIRED_MODES = new Set([ + 'trim_prefix', + 'trim_suffix', + 'ensure_prefix', + 'ensure_suffix', + 'set_header', + 'return_error', + 'prune_objects', + 'pass_headers', +]); + +const FROM_REQUIRED_MODES = new Set([ + 'copy', + 'move', + 'replace', + 'regex_replace', + 'copy_header', + 'move_header', + 'sync_fields', +]); + +const TO_REQUIRED_MODES = new Set([ + 'copy', + 'move', + 'copy_header', + 'move_header', + 'sync_fields', +]); + +const MODE_DESCRIPTIONS = { + set: '把值写入目标字段', + delete: '删除目标字段', + append: '把值追加到数组 / 字符串 / 对象末尾', + prepend: '把值追加到数组 / 字符串 / 对象开头', + copy: '把来源字段复制到目标字段', + move: '把来源字段移动到目标字段', + replace: '在目标字段里做字符串替换', + regex_replace: '在目标字段里做正则替换', + trim_prefix: '去掉字符串前缀', + trim_suffix: '去掉字符串后缀', + ensure_prefix: '确保字符串有指定前缀', + ensure_suffix: '确保字符串有指定后缀', + trim_space: '去掉字符串头尾空白', + to_lower: '把字符串转成小写', + to_upper: '把字符串转成大写', + return_error: '立即返回自定义错误', + prune_objects: '按条件清理对象中的子项', + pass_headers: '把指定请求头透传到上游请求', + sync_fields: '在一个字段有值、另一个缺失时自动补齐', + set_header: '设置运行期请求头(支持整值覆盖,或用 JSON 映射按逗号 token 替换/删除)', + delete_header: '删除运行期请求头', + copy_header: '复制请求头', + move_header: '移动请求头', +}; + +const getModePathLabel = (mode) => { + if (mode === 'set_header' || mode === 'delete_header') { + return '请求头名称'; + } + if (mode === 'prune_objects') { + return '目标路径(可选)'; + } + return '目标字段路径'; +}; + +const getModePathPlaceholder = (mode) => { + if (mode === 'set_header') return 'Authorization'; + if (mode === 'delete_header') return 'X-Debug-Mode'; + if (mode === 'prune_objects') return 'messages'; + return 'temperature'; +}; + +const getModeFromLabel = (mode) => { + if (mode === 'replace') return '匹配文本'; + if (mode === 'regex_replace') return '正则表达式'; + if (mode === 'copy_header' || mode === 'move_header') return '来源请求头'; + return '来源字段'; +}; + +const getModeFromPlaceholder = (mode) => { + if (mode === 'replace') return 'openai/'; + if (mode === 'regex_replace') return '^gpt-'; + if (mode === 'copy_header' || mode === 'move_header') return 'Authorization'; + return 'model'; +}; + +const getModeToLabel = (mode) => { + if (mode === 'replace' || mode === 'regex_replace') return '替换为'; + if (mode === 'copy_header' || mode === 'move_header') return '目标请求头'; + return '目标字段'; +}; + +const getModeToPlaceholder = (mode) => { + if (mode === 'replace') return '(可留空)'; + if (mode === 'regex_replace') return 'openai/gpt-'; + if (mode === 'copy_header' || mode === 'move_header') return 'X-Upstream-Auth'; + return 'original_model'; +}; + +const getModeValueLabel = (mode) => { + if (mode === 'set_header') return '请求头值(支持字符串或 JSON 映射)'; + if (mode === 'pass_headers') return '透传请求头(支持逗号分隔或 JSON 数组)'; + if ( + mode === 'trim_prefix' || + mode === 'trim_suffix' || + mode === 'ensure_prefix' || + mode === 'ensure_suffix' + ) { + return '前后缀文本'; + } + if (mode === 'prune_objects') { + return '清理规则(字符串或 JSON 对象)'; + } + return '值(支持 JSON 或普通文本)'; +}; + +const getModeValuePlaceholder = (mode) => { + if (mode === 'set_header') { + return [ + 'String example:', + 'Bearer sk-xxx', + '', + 'JSON map example:', + '{"advanced-tool-use-2025-11-20": null, "computer-use-2025-01-24": "computer-use-2025-01-24"}', + '', + 'JSON map wildcard:', + '{"*": null, "computer-use-2025-11-24": "computer-use-2025-11-24"}', + ].join('\n'); + } + if (mode === 'pass_headers') return 'Authorization, X-Request-Id'; + if ( + mode === 'trim_prefix' || + mode === 'trim_suffix' || + mode === 'ensure_prefix' || + mode === 'ensure_suffix' + ) { + return 'openai/'; + } + if (mode === 'prune_objects') { + return '{"type":"redacted_thinking"}'; + } + return '0.7'; +}; + +const getModeValueHelp = (mode) => { + if (mode !== 'set_header') return ''; + return '字符串:整条请求头直接覆盖。JSON 映射:按逗号分隔 token 逐项处理,null 表示删除,string/array 表示替换,* 表示兜底规则。'; +}; + +const SYNC_TARGET_TYPE_OPTIONS = [ + { label: '请求体字段', value: 'json' }, + { label: '请求头字段', value: 'header' }, +]; + +const LEGACY_TEMPLATE = { + temperature: 0, + max_tokens: 1000, +}; + +const OPERATION_TEMPLATE = { + operations: [ + { + path: 'temperature', + mode: 'set', + value: 0.7, + conditions: [ + { + path: 'model', + mode: 'prefix', + value: 'openai/', + }, + ], + logic: 'AND', + }, + ], +}; + +const HEADER_PASSTHROUGH_TEMPLATE = { + operations: [ + { + mode: 'pass_headers', + value: ['Authorization'], + keep_origin: true, + }, + ], +}; + +const GEMINI_IMAGE_4K_TEMPLATE = { + operations: [ + { + mode: 'set', + path: 'generationConfig.imageConfig.imageSize', + value: '4K', + conditions: [ + { + path: 'original_model', + mode: 'contains', + value: 'gemini-3-pro-image-preview', + }, + ], + logic: 'AND', + }, + ], +}; + +const AWS_BEDROCK_ANTHROPIC_BETA_OVERRIDE_TEMPLATE = { + operations: [ + { + mode: 'set_header', + path: 'anthropic-beta', + value: { + 'advanced-tool-use-2025-11-20': 'tool-search-tool-2025-10-19', + bash_20241022: null, + bash_20250124: null, + 'code-execution-2025-08-25': null, + 'compact-2026-01-12': 'compact-2026-01-12', + 'computer-use-2025-01-24': 'computer-use-2025-01-24', + 'computer-use-2025-11-24': 'computer-use-2025-11-24', + 'context-1m-2025-08-07': 'context-1m-2025-08-07', + 'context-management-2025-06-27': 'context-management-2025-06-27', + 'effort-2025-11-24': null, + 'fast-mode-2026-02-01': null, + 'files-api-2025-04-14': null, + 'fine-grained-tool-streaming-2025-05-14': null, + 'interleaved-thinking-2025-05-14': 'interleaved-thinking-2025-05-14', + 'mcp-client-2025-11-20': null, + 'mcp-client-2025-04-04': null, + 'mcp-servers-2025-12-04': null, + 'output-128k-2025-02-19': null, + 'structured-output-2024-03-01': null, + 'prompt-caching-scope-2026-01-05': null, + 'skills-2025-10-02': null, + 'structured-outputs-2025-11-13': null, + text_editor_20241022: null, + text_editor_20250124: null, + 'token-efficient-tools-2025-02-19': null, + 'tool-search-tool-2025-10-19': 'tool-search-tool-2025-10-19', + 'web-fetch-2025-09-10': null, + 'web-search-2025-03-05': null, + }, + }, + ], +}; + +const TEMPLATE_GROUP_OPTIONS = [ + { label: '基础模板', value: 'basic' }, + { label: '场景模板', value: 'scenario' }, +]; + +const TEMPLATE_PRESET_CONFIG = { + operations_default: { + group: 'basic', + label: '新格式模板(规则集)', + kind: 'operations', + payload: OPERATION_TEMPLATE, + }, + legacy_default: { + group: 'basic', + label: '旧格式模板(JSON 对象)', + kind: 'legacy', + payload: LEGACY_TEMPLATE, + }, + pass_headers_auth: { + group: 'scenario', + label: '请求头透传(Authorization)', + kind: 'operations', + payload: HEADER_PASSTHROUGH_TEMPLATE, + }, + gemini_image_4k: { + group: 'scenario', + label: 'Gemini 图片 4K', + kind: 'operations', + payload: GEMINI_IMAGE_4K_TEMPLATE, + }, + claude_cli_headers_passthrough: { + group: 'scenario', + label: 'Claude CLI 请求头透传', + kind: 'operations', + payload: CLAUDE_CLI_HEADER_PASSTHROUGH_TEMPLATE, + }, + codex_cli_headers_passthrough: { + group: 'scenario', + label: 'Codex CLI 请求头透传', + kind: 'operations', + payload: CODEX_CLI_HEADER_PASSTHROUGH_TEMPLATE, + }, + aws_bedrock_anthropic_beta_override: { + group: 'scenario', + label: 'AWS Bedrock anthropic-beta覆盖', + kind: 'operations', + payload: AWS_BEDROCK_ANTHROPIC_BETA_OVERRIDE_TEMPLATE, + }, +}; + +const FIELD_GUIDE_TARGET_OPTIONS = [ + { label: '填入目标路径', value: 'path' }, + { label: '填入来源字段', value: 'from' }, + { label: '填入目标字段', value: 'to' }, +]; + +const BUILTIN_FIELD_SECTIONS = [ + { + title: '常用请求字段', + fields: [ + { + key: 'model', + label: '模型名称', + tip: '支持多级模型名,例如 openai/gpt-4o-mini', + }, + { key: 'temperature', label: '采样温度', tip: '控制输出随机性' }, + { key: 'max_tokens', label: '最大输出 Token', tip: '控制输出长度上限' }, + { key: 'messages.-1.content', label: '最后一条消息内容', tip: '常用于重写用户输入' }, + ], + }, + { + title: '上下文字段', + fields: [ + { key: 'retry.is_retry', label: '是否重试', tip: 'true 表示重试请求' }, + { key: 'last_error.code', label: '上次错误码', tip: '配合重试策略使用' }, + { + key: 'metadata.conversation_id', + label: '会话 ID', + tip: '可用于路由或缓存命中', + }, + ], + }, + { + title: '请求头映射字段', + fields: [ + { + key: 'header_override_normalized.authorization', + label: '标准化 Authorization', + tip: '统一小写后可稳定匹配', + }, + { + key: 'header_override_normalized.x_debug_mode', + label: '标准化 X-Debug-Mode', + tip: '适合灰度 / 调试开关判断', + }, + ], + }, +]; + +const OPERATION_MODE_LABEL_MAP = OPERATION_MODE_OPTIONS.reduce((acc, item) => { + acc[item.value] = item.label; + return acc; +}, {}); + +let localIdSeed = 0; +const nextLocalId = () => `param_override_${Date.now()}_${localIdSeed++}`; + +const toValueText = (value) => { + if (value === undefined) return ''; + if (typeof value === 'string') return value; + try { + return JSON.stringify(value); + } catch (error) { + return String(value); + } +}; + +const parseLooseValue = (valueText) => { + const raw = String(valueText ?? ''); + if (raw.trim() === '') return ''; + try { + return JSON.parse(raw); + } catch (error) { + return raw; + } +}; + +const parsePassHeaderNames = (rawValue) => { + if (Array.isArray(rawValue)) { + return rawValue + .map((item) => String(item ?? '').trim()) + .filter(Boolean); + } + if (rawValue && typeof rawValue === 'object') { + if (Array.isArray(rawValue.headers)) { + return rawValue.headers + .map((item) => String(item ?? '').trim()) + .filter(Boolean); + } + if (rawValue.header !== undefined) { + const single = String(rawValue.header ?? '').trim(); + return single ? [single] : []; + } + return []; + } + if (typeof rawValue === 'string') { + return rawValue + .split(',') + .map((item) => item.trim()) + .filter(Boolean); + } + return []; +}; + +const parseReturnErrorDraft = (valueText) => { + const defaults = { + message: '', + statusCode: 400, + code: '', + type: '', + skipRetry: true, + simpleMode: true, + }; + + const raw = String(valueText ?? '').trim(); + if (!raw) { + return defaults; + } + + try { + const parsed = JSON.parse(raw); + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + const statusRaw = + parsed.status_code !== undefined ? parsed.status_code : parsed.status; + const statusValue = Number(statusRaw); + return { + ...defaults, + message: String(parsed.message || parsed.msg || '').trim(), + statusCode: + Number.isInteger(statusValue) && + statusValue >= 100 && + statusValue <= 599 + ? statusValue + : 400, + code: String(parsed.code || '').trim(), + type: String(parsed.type || '').trim(), + skipRetry: parsed.skip_retry !== false, + simpleMode: false, + }; + } + } catch (error) { + // treat as plain text message + } + + return { + ...defaults, + message: raw, + simpleMode: true, + }; +}; + +const buildReturnErrorValueText = (draft = {}) => { + const message = String(draft.message || '').trim(); + if (draft.simpleMode) { + return message; + } + + const statusCode = Number(draft.statusCode); + const payload = { + message, + status_code: + Number.isInteger(statusCode) && statusCode >= 100 && statusCode <= 599 + ? statusCode + : 400, + }; + const code = String(draft.code || '').trim(); + const type = String(draft.type || '').trim(); + if (code) payload.code = code; + if (type) payload.type = type; + if (draft.skipRetry === false) { + payload.skip_retry = false; + } + return JSON.stringify(payload); +}; + +const normalizePruneRule = (rule = {}) => ({ + id: nextLocalId(), + path: typeof rule.path === 'string' ? rule.path : '', + mode: CONDITION_MODE_VALUES.has(rule.mode) ? rule.mode : 'full', + value_text: toValueText(rule.value), + invert: rule.invert === true, + pass_missing_key: rule.pass_missing_key === true, +}); + +const parsePruneObjectsDraft = (valueText) => { + const defaults = { + simpleMode: true, + typeText: '', + logic: 'AND', + recursive: true, + rules: [], + }; + + const raw = String(valueText ?? '').trim(); + if (!raw) { + return defaults; + } + + try { + const parsed = JSON.parse(raw); + if (typeof parsed === 'string') { + return { + ...defaults, + simpleMode: true, + typeText: parsed.trim(), + }; + } + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + const rules = []; + if (parsed.where && typeof parsed.where === 'object' && !Array.isArray(parsed.where)) { + Object.entries(parsed.where).forEach(([path, value]) => { + rules.push( + normalizePruneRule({ + path, + mode: 'full', + value, + }), + ); + }); + } + if (Array.isArray(parsed.conditions)) { + parsed.conditions.forEach((item) => { + if (item && typeof item === 'object') { + rules.push(normalizePruneRule(item)); + } + }); + } else if ( + parsed.conditions && + typeof parsed.conditions === 'object' && + !Array.isArray(parsed.conditions) + ) { + Object.entries(parsed.conditions).forEach(([path, value]) => { + rules.push( + normalizePruneRule({ + path, + mode: 'full', + value, + }), + ); + }); + } + + const typeText = + parsed.type === undefined ? '' : String(parsed.type).trim(); + const logic = + String(parsed.logic || 'AND').toUpperCase() === 'OR' ? 'OR' : 'AND'; + const recursive = parsed.recursive !== false; + const hasAdvancedFields = + parsed.logic !== undefined || + parsed.recursive !== undefined || + parsed.where !== undefined || + parsed.conditions !== undefined; + + return { + ...defaults, + simpleMode: !hasAdvancedFields, + typeText, + logic, + recursive, + rules, + }; + } + return { + ...defaults, + simpleMode: true, + typeText: String(parsed ?? '').trim(), + }; + } catch (error) { + return { + ...defaults, + simpleMode: true, + typeText: raw, + }; + } +}; + +const buildPruneObjectsValueText = (draft = {}) => { + const typeText = String(draft.typeText || '').trim(); + if (draft.simpleMode) { + return typeText; + } + + const payload = {}; + if (typeText) { + payload.type = typeText; + } + if (String(draft.logic || 'AND').toUpperCase() === 'OR') { + payload.logic = 'OR'; + } + if (draft.recursive === false) { + payload.recursive = false; + } + + const conditions = (draft.rules || []) + .filter((rule) => String(rule.path || '').trim()) + .map((rule) => { + const conditionPayload = { + path: String(rule.path || '').trim(), + mode: CONDITION_MODE_VALUES.has(rule.mode) ? rule.mode : 'full', + }; + const valueRaw = String(rule.value_text || '').trim(); + if (valueRaw !== '') { + conditionPayload.value = parseLooseValue(valueRaw); + } + if (rule.invert) { + conditionPayload.invert = true; + } + if (rule.pass_missing_key) { + conditionPayload.pass_missing_key = true; + } + return conditionPayload; + }); + + if (conditions.length > 0) { + payload.conditions = conditions; + } + + if (!payload.type && !payload.conditions) { + return JSON.stringify({ logic: 'AND' }); + } + return JSON.stringify(payload); +}; + +const parseSyncTargetSpec = (spec) => { + const raw = String(spec ?? '').trim(); + if (!raw) return { type: 'json', key: '' }; + const idx = raw.indexOf(':'); + if (idx < 0) return { type: 'json', key: raw }; + const prefix = raw.slice(0, idx).trim().toLowerCase(); + const key = raw.slice(idx + 1).trim(); + if (prefix === 'header') { + return { type: 'header', key }; + } + return { type: 'json', key }; +}; + +const buildSyncTargetSpec = (type, key) => { + const normalizedType = type === 'header' ? 'header' : 'json'; + const normalizedKey = String(key ?? '').trim(); + if (!normalizedKey) return ''; + return `${normalizedType}:${normalizedKey}`; +}; + +const normalizeCondition = (condition = {}) => ({ + id: nextLocalId(), + path: typeof condition.path === 'string' ? condition.path : '', + mode: CONDITION_MODE_VALUES.has(condition.mode) ? condition.mode : 'full', + value_text: toValueText(condition.value), + invert: condition.invert === true, + pass_missing_key: condition.pass_missing_key === true, +}); + +const createDefaultCondition = () => normalizeCondition({}); + +const normalizeOperation = (operation = {}) => ({ + id: nextLocalId(), + path: typeof operation.path === 'string' ? operation.path : '', + mode: OPERATION_MODE_VALUES.has(operation.mode) ? operation.mode : 'set', + value_text: toValueText(operation.value), + keep_origin: operation.keep_origin === true, + from: typeof operation.from === 'string' ? operation.from : '', + to: typeof operation.to === 'string' ? operation.to : '', + logic: String(operation.logic || 'OR').toUpperCase() === 'AND' ? 'AND' : 'OR', + conditions: Array.isArray(operation.conditions) + ? operation.conditions.map(normalizeCondition) + : [], +}); + +const createDefaultOperation = () => normalizeOperation({ mode: 'set' }); + +const getOperationSummary = (operation = {}, index = 0) => { + const mode = operation.mode || 'set'; + const modeLabel = OPERATION_MODE_LABEL_MAP[mode] || mode; + if (mode === 'sync_fields') { + const from = String(operation.from || '').trim(); + const to = String(operation.to || '').trim(); + return `${index + 1}. ${modeLabel} · ${from || to || '-'}`; + } + const path = String(operation.path || '').trim(); + const from = String(operation.from || '').trim(); + const to = String(operation.to || '').trim(); + return `${index + 1}. ${modeLabel} · ${path || from || to || '-'}`; +}; + +const getOperationModeTagColor = (mode = 'set') => { + if (mode.includes('header')) return 'cyan'; + if (mode.includes('replace') || mode.includes('trim')) return 'violet'; + if (mode.includes('copy') || mode.includes('move')) return 'blue'; + if (mode.includes('error') || mode.includes('prune')) return 'red'; + if (mode.includes('sync')) return 'green'; + return 'grey'; +}; + +const parseInitialState = (rawValue) => { + const text = typeof rawValue === 'string' ? rawValue : ''; + const trimmed = text.trim(); + if (!trimmed) { + return { + editMode: 'visual', + visualMode: 'operations', + legacyValue: '', + operations: [createDefaultOperation()], + jsonText: '', + jsonError: '', + }; + } + + if (!verifyJSON(trimmed)) { + return { + editMode: 'json', + visualMode: 'operations', + legacyValue: '', + operations: [createDefaultOperation()], + jsonText: text, + jsonError: 'JSON 格式不正确', + }; + } + + const parsed = JSON.parse(trimmed); + const pretty = JSON.stringify(parsed, null, 2); + + if ( + parsed && + typeof parsed === 'object' && + !Array.isArray(parsed) && + Array.isArray(parsed.operations) + ) { + return { + editMode: 'visual', + visualMode: 'operations', + legacyValue: '', + operations: + parsed.operations.length > 0 + ? parsed.operations.map(normalizeOperation) + : [createDefaultOperation()], + jsonText: pretty, + jsonError: '', + }; + } + + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + return { + editMode: 'visual', + visualMode: 'legacy', + legacyValue: pretty, + operations: [createDefaultOperation()], + jsonText: pretty, + jsonError: '', + }; + } + + return { + editMode: 'json', + visualMode: 'operations', + legacyValue: '', + operations: [createDefaultOperation()], + jsonText: pretty, + jsonError: '', + }; +}; + +const isOperationBlank = (operation) => { + const hasCondition = (operation.conditions || []).some( + (condition) => + condition.path.trim() || + String(condition.value_text ?? '').trim() || + condition.mode !== 'full' || + condition.invert || + condition.pass_missing_key, + ); + return ( + operation.mode === 'set' && + !operation.path.trim() && + !operation.from.trim() && + !operation.to.trim() && + String(operation.value_text ?? '').trim() === '' && + !operation.keep_origin && + !hasCondition + ); +}; + +const buildConditionPayload = (condition) => { + const path = condition.path.trim(); + if (!path) return null; + const payload = { + path, + mode: condition.mode || 'full', + value: parseLooseValue(condition.value_text), + }; + if (condition.invert) payload.invert = true; + if (condition.pass_missing_key) payload.pass_missing_key = true; + return payload; +}; + +const validateOperations = (operations, t) => { + for (let i = 0; i < operations.length; i++) { + const op = operations[i]; + const mode = op.mode || 'set'; + const meta = MODE_META[mode] || MODE_META.set; + const line = i + 1; + const pathValue = op.path.trim(); + const fromValue = op.from.trim(); + const toValue = op.to.trim(); + + if (meta.path && !pathValue) { + return t('第 {{line}} 条操作缺少目标路径', { line }); + } + if (FROM_REQUIRED_MODES.has(mode) && !fromValue) { + if (!(meta.pathAlias && pathValue)) { + return t('第 {{line}} 条操作缺少来源字段', { line }); + } + } + if (TO_REQUIRED_MODES.has(mode) && !toValue) { + if (!(meta.pathAlias && pathValue)) { + return t('第 {{line}} 条操作缺少目标字段', { line }); + } + } + if (meta.from && !fromValue) { + return t('第 {{line}} 条操作缺少来源字段', { line }); + } + if (meta.to && !toValue) { + return t('第 {{line}} 条操作缺少目标字段', { line }); + } + if ( + VALUE_REQUIRED_MODES.has(mode) && + String(op.value_text ?? '').trim() === '' + ) { + return t('第 {{line}} 条操作缺少值', { line }); + } + if (mode === 'return_error') { + const raw = String(op.value_text ?? '').trim(); + if (!raw) { + return t('第 {{line}} 条操作缺少值', { line }); + } + try { + const parsed = JSON.parse(raw); + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + if (!String(parsed.message || '').trim()) { + return t('第 {{line}} 条 return_error 需要 message 字段', { line }); + } + } + } catch (error) { + // plain string value is allowed + } + } + + if (mode === 'prune_objects') { + const raw = String(op.value_text ?? '').trim(); + if (!raw) { + return t('第 {{line}} 条 prune_objects 缺少条件', { line }); + } + try { + const parsed = JSON.parse(raw); + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + const hasType = + parsed.type !== undefined && + String(parsed.type).trim() !== ''; + const hasWhere = + parsed.where && + typeof parsed.where === 'object' && + !Array.isArray(parsed.where) && + Object.keys(parsed.where).length > 0; + const hasConditionsArray = + Array.isArray(parsed.conditions) && parsed.conditions.length > 0; + const hasConditionsObject = + parsed.conditions && + typeof parsed.conditions === 'object' && + !Array.isArray(parsed.conditions) && + Object.keys(parsed.conditions).length > 0; + if (!hasType && !hasWhere && !hasConditionsArray && !hasConditionsObject) { + return t('第 {{line}} 条 prune_objects 需要至少一个匹配条件', { + line, + }); + } + } + } catch (error) { + // non-JSON string is treated as type string + } + } + + if (mode === 'pass_headers') { + const raw = String(op.value_text ?? '').trim(); + if (!raw) { + return t('第 {{line}} 条请求头透传缺少请求头名称', { line }); + } + const parsed = parseLooseValue(raw); + const headers = parsePassHeaderNames(parsed); + if (headers.length === 0) { + return t('第 {{line}} 条请求头透传格式无效', { line }); + } + } + } + return ''; +}; + +const ParamOverrideEditorModal = ({ visible, value, onSave, onCancel }) => { + const { t } = useTranslation(); + + const [editMode, setEditMode] = useState('visual'); + const [visualMode, setVisualMode] = useState('operations'); + const [legacyValue, setLegacyValue] = useState(''); + const [operations, setOperations] = useState([createDefaultOperation()]); + const [jsonText, setJsonText] = useState(''); + const [jsonError, setJsonError] = useState(''); + const [operationSearch, setOperationSearch] = useState(''); + const [selectedOperationId, setSelectedOperationId] = useState(''); + const [expandedConditionMap, setExpandedConditionMap] = useState({}); + const [templateGroupKey, setTemplateGroupKey] = useState('basic'); + const [templatePresetKey, setTemplatePresetKey] = useState('operations_default'); + const [fieldGuideVisible, setFieldGuideVisible] = useState(false); + const [fieldGuideTarget, setFieldGuideTarget] = useState('path'); + const [fieldGuideKeyword, setFieldGuideKeyword] = useState(''); + + useEffect(() => { + if (!visible) return; + const nextState = parseInitialState(value); + setEditMode(nextState.editMode); + setVisualMode(nextState.visualMode); + setLegacyValue(nextState.legacyValue); + setOperations(nextState.operations); + setJsonText(nextState.jsonText); + setJsonError(nextState.jsonError); + setOperationSearch(''); + setSelectedOperationId(nextState.operations[0]?.id || ''); + setExpandedConditionMap({}); + if (nextState.visualMode === 'legacy') { + setTemplateGroupKey('basic'); + setTemplatePresetKey('legacy_default'); + } else { + setTemplateGroupKey('basic'); + setTemplatePresetKey('operations_default'); + } + setFieldGuideVisible(false); + setFieldGuideTarget('path'); + setFieldGuideKeyword(''); + }, [visible, value]); + + useEffect(() => { + if (operations.length === 0) { + setSelectedOperationId(''); + return; + } + if (!operations.some((item) => item.id === selectedOperationId)) { + setSelectedOperationId(operations[0].id); + } + }, [operations, selectedOperationId]); + + const templatePresetOptions = useMemo( + () => + Object.entries(TEMPLATE_PRESET_CONFIG) + .filter(([, config]) => config.group === templateGroupKey) + .map(([value, config]) => ({ + value, + label: config.label, + })), + [templateGroupKey], + ); + + useEffect(() => { + if (templatePresetOptions.length === 0) return; + const exists = templatePresetOptions.some( + (item) => item.value === templatePresetKey, + ); + if (!exists) { + setTemplatePresetKey(templatePresetOptions[0].value); + } + }, [templatePresetKey, templatePresetOptions]); + + const operationCount = useMemo( + () => operations.filter((item) => !isOperationBlank(item)).length, + [operations], + ); + + const filteredOperations = useMemo(() => { + const keyword = operationSearch.trim().toLowerCase(); + if (!keyword) return operations; + return operations.filter((operation) => { + const searchableText = [ + operation.mode, + operation.path, + operation.from, + operation.to, + operation.value_text, + ] + .filter(Boolean) + .join(' ') + .toLowerCase(); + return searchableText.includes(keyword); + }); + }, [operationSearch, operations]); + + const selectedOperation = useMemo( + () => operations.find((operation) => operation.id === selectedOperationId), + [operations, selectedOperationId], + ); + + const selectedOperationIndex = useMemo( + () => + operations.findIndex((operation) => operation.id === selectedOperationId), + [operations, selectedOperationId], + ); + + const returnErrorDraft = useMemo(() => { + if (!selectedOperation || (selectedOperation.mode || '') !== 'return_error') { + return null; + } + return parseReturnErrorDraft(selectedOperation.value_text); + }, [selectedOperation]); + + const pruneObjectsDraft = useMemo(() => { + if (!selectedOperation || (selectedOperation.mode || '') !== 'prune_objects') { + return null; + } + return parsePruneObjectsDraft(selectedOperation.value_text); + }, [selectedOperation]); + + const topOperationModes = useMemo(() => { + const counts = operations.reduce((acc, operation) => { + const mode = operation.mode || 'set'; + acc[mode] = (acc[mode] || 0) + 1; + return acc; + }, {}); + return Object.entries(counts) + .sort((a, b) => b[1] - a[1]) + .slice(0, 4); + }, [operations]); + + const buildOperationsJson = useCallback( + (sourceOperations, options = {}) => { + const { validate = true } = options; + const filteredOps = sourceOperations.filter((item) => !isOperationBlank(item)); + if (filteredOps.length === 0) return ''; + + if (validate) { + const message = validateOperations(filteredOps, t); + if (message) { + throw new Error(message); + } + } + + const payloadOps = filteredOps.map((operation) => { + const mode = operation.mode || 'set'; + const meta = MODE_META[mode] || MODE_META.set; + const pathValue = operation.path.trim(); + const fromValue = operation.from.trim(); + const toValue = operation.to.trim(); + const payload = { mode }; + if (meta.path) { + payload.path = pathValue; + } + if (meta.pathOptional && pathValue) { + payload.path = pathValue; + } + if (meta.value) { + payload.value = parseLooseValue(operation.value_text); + } + if (meta.keepOrigin && operation.keep_origin) { + payload.keep_origin = true; + } + if (meta.from) { + payload.from = fromValue; + } + if (!meta.to && operation.to.trim()) { + payload.to = toValue; + } + if (meta.to) { + payload.to = toValue; + } + if (meta.pathAlias) { + if (!payload.from && pathValue) { + payload.from = pathValue; + } + if (!payload.to && pathValue) { + payload.to = pathValue; + } + } + + const conditions = (operation.conditions || []) + .map(buildConditionPayload) + .filter(Boolean); + + if (conditions.length > 0) { + payload.conditions = conditions; + payload.logic = operation.logic === 'AND' ? 'AND' : 'OR'; + } + + return payload; + }); + + return JSON.stringify({ operations: payloadOps }, null, 2); + }, + [t], + ); + + const buildVisualJson = useCallback(() => { + if (visualMode === 'legacy') { + const trimmed = legacyValue.trim(); + if (!trimmed) return ''; + if (!verifyJSON(trimmed)) { + throw new Error(t('参数覆盖必须是合法的 JSON 格式!')); + } + const parsed = JSON.parse(trimmed); + if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) { + throw new Error(t('旧格式必须是 JSON 对象')); + } + return JSON.stringify(parsed, null, 2); + } + return buildOperationsJson(operations, { validate: true }); + }, [buildOperationsJson, legacyValue, operations, t, visualMode]); + + const switchToJsonMode = () => { + if (editMode === 'json') return; + try { + setJsonText(buildVisualJson()); + setJsonError(''); + } catch (error) { + showError(error.message); + if (visualMode === 'legacy') { + setJsonText(legacyValue); + } else { + setJsonText(buildOperationsJson(operations, { validate: false })); + } + setJsonError(error.message || t('参数配置有误')); + } + setEditMode('json'); + }; + + const switchToVisualMode = () => { + if (editMode === 'visual') return; + const trimmed = jsonText.trim(); + if (!trimmed) { + const fallback = createDefaultOperation(); + setVisualMode('operations'); + setOperations([fallback]); + setSelectedOperationId(fallback.id); + setLegacyValue(''); + setJsonError(''); + setEditMode('visual'); + return; + } + if (!verifyJSON(trimmed)) { + showError(t('参数覆盖必须是合法的 JSON 格式!')); + return; + } + const parsed = JSON.parse(trimmed); + if ( + parsed && + typeof parsed === 'object' && + !Array.isArray(parsed) && + Array.isArray(parsed.operations) + ) { + const nextOperations = + parsed.operations.length > 0 + ? parsed.operations.map(normalizeOperation) + : [createDefaultOperation()]; + setVisualMode('operations'); + setOperations(nextOperations); + setSelectedOperationId(nextOperations[0]?.id || ''); + setLegacyValue(''); + setJsonError(''); + setEditMode('visual'); + setTemplateGroupKey('basic'); + setTemplatePresetKey('operations_default'); + return; + } + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + const fallback = createDefaultOperation(); + setVisualMode('legacy'); + setLegacyValue(JSON.stringify(parsed, null, 2)); + setOperations([fallback]); + setSelectedOperationId(fallback.id); + setJsonError(''); + setEditMode('visual'); + setTemplateGroupKey('basic'); + setTemplatePresetKey('legacy_default'); + return; + } + showError(t('参数覆盖必须是合法的 JSON 对象')); + }; + + const fillLegacyTemplate = (legacyPayload) => { + const text = JSON.stringify(legacyPayload, null, 2); + const fallback = createDefaultOperation(); + setVisualMode('legacy'); + setLegacyValue(text); + setOperations([fallback]); + setSelectedOperationId(fallback.id); + setExpandedConditionMap({}); + setJsonText(text); + setJsonError(''); + setEditMode('visual'); + }; + + const fillOperationsTemplate = (operationsPayload) => { + const nextOperations = (operationsPayload || []).map(normalizeOperation); + const finalOperations = + nextOperations.length > 0 ? nextOperations : [createDefaultOperation()]; + setVisualMode('operations'); + setOperations(finalOperations); + setSelectedOperationId(finalOperations[0]?.id || ''); + setExpandedConditionMap({}); + setJsonText(JSON.stringify({ operations: operationsPayload || [] }, null, 2)); + setJsonError(''); + setEditMode('visual'); + }; + + const appendLegacyTemplate = (legacyPayload) => { + let parsedCurrent = {}; + if (visualMode === 'legacy') { + const trimmed = legacyValue.trim(); + if (trimmed) { + if (!verifyJSON(trimmed)) { + showError(t('当前旧格式 JSON 不合法,无法追加模板')); + return; + } + const parsed = JSON.parse(trimmed); + if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) { + showError(t('当前旧格式不是 JSON 对象,无法追加模板')); + return; + } + parsedCurrent = parsed; + } + } + + const merged = { + ...(legacyPayload || {}), + ...parsedCurrent, + }; + const text = JSON.stringify(merged, null, 2); + const fallback = createDefaultOperation(); + setVisualMode('legacy'); + setLegacyValue(text); + setOperations([fallback]); + setSelectedOperationId(fallback.id); + setExpandedConditionMap({}); + setJsonText(text); + setJsonError(''); + setEditMode('visual'); + }; + + const appendOperationsTemplate = (operationsPayload) => { + const appended = (operationsPayload || []).map(normalizeOperation); + const existing = + visualMode === 'operations' + ? operations.filter((item) => !isOperationBlank(item)) + : []; + const nextOperations = [...existing, ...appended]; + setVisualMode('operations'); + setOperations(nextOperations.length > 0 ? nextOperations : appended); + setSelectedOperationId(nextOperations[0]?.id || appended[0]?.id || ''); + setExpandedConditionMap({}); + setLegacyValue(''); + setJsonError(''); + setEditMode('visual'); + setJsonText(''); + }; + + const clearValue = () => { + const fallback = createDefaultOperation(); + setVisualMode('operations'); + setLegacyValue(''); + setOperations([fallback]); + setSelectedOperationId(fallback.id); + setExpandedConditionMap({}); + setJsonText(''); + setJsonError(''); + setTemplateGroupKey('basic'); + setTemplatePresetKey('operations_default'); + }; + + const getSelectedTemplatePreset = () => + TEMPLATE_PRESET_CONFIG[templatePresetKey] || + TEMPLATE_PRESET_CONFIG.operations_default; + + const fillTemplateFromLibrary = () => { + const preset = getSelectedTemplatePreset(); + if (preset.kind === 'legacy') { + fillLegacyTemplate(preset.payload || {}); + return; + } + fillOperationsTemplate(preset.payload?.operations || []); + }; + + const appendTemplateFromLibrary = () => { + const preset = getSelectedTemplatePreset(); + if (preset.kind === 'legacy') { + appendLegacyTemplate(preset.payload || {}); + return; + } + appendOperationsTemplate(preset.payload?.operations || []); + }; + + const resetEditorState = () => { + clearValue(); + setEditMode('visual'); + }; + + const applyBuiltinField = (fieldKey, target = 'path') => { + if (!selectedOperation) { + showError(t('请先选择一条规则')); + return; + } + const mode = selectedOperation.mode || 'set'; + const meta = MODE_META[mode] || MODE_META.set; + if (target === 'path' && (meta.path || meta.pathOptional || meta.pathAlias)) { + updateOperation(selectedOperation.id, { path: fieldKey }); + return; + } + if (target === 'from' && (meta.from || meta.pathAlias || mode === 'sync_fields')) { + updateOperation(selectedOperation.id, { + from: mode === 'sync_fields' ? buildSyncTargetSpec('json', fieldKey) : fieldKey, + }); + return; + } + if (target === 'to' && (meta.to || mode === 'sync_fields')) { + updateOperation(selectedOperation.id, { + to: mode === 'sync_fields' ? buildSyncTargetSpec('json', fieldKey) : fieldKey, + }); + return; + } + showError(t('当前规则不支持写入到该位置')); + }; + + const openFieldGuide = (target = 'path') => { + setFieldGuideTarget(target); + setFieldGuideVisible(true); + }; + + const copyBuiltinField = async (fieldKey) => { + const ok = await copy(fieldKey); + if (ok) { + showSuccess(t('已复制字段:{{name}}', { name: fieldKey })); + } else { + showError(t('复制失败')); + } + }; + + const filteredFieldGuideSections = useMemo(() => { + const keyword = fieldGuideKeyword.trim().toLowerCase(); + if (!keyword) { + return BUILTIN_FIELD_SECTIONS; + } + return BUILTIN_FIELD_SECTIONS.map((section) => ({ + ...section, + fields: section.fields.filter((field) => + [field.key, field.label, field.tip] + .filter(Boolean) + .join(' ') + .toLowerCase() + .includes(keyword), + ), + })).filter((section) => section.fields.length > 0); + }, [fieldGuideKeyword]); + + const fieldGuideActionLabel = useMemo(() => { + if (fieldGuideTarget === 'from') return t('填入来源'); + if (fieldGuideTarget === 'to') return t('填入目标'); + return t('填入路径'); + }, [fieldGuideTarget, t]); + + const fieldGuideFieldCount = useMemo( + () => + filteredFieldGuideSections.reduce( + (total, section) => total + section.fields.length, + 0, + ), + [filteredFieldGuideSections], + ); + + const updateOperation = (operationId, patch) => { + setOperations((prev) => + prev.map((item) => + item.id === operationId ? { ...item, ...patch } : item, + ), + ); + }; + + const formatSelectedOperationValueAsJson = useCallback(() => { + if (!selectedOperation) return; + const raw = String(selectedOperation.value_text || '').trim(); + if (!raw) return; + if (!verifyJSON(raw)) { + showError(t('当前值不是合法 JSON,无法格式化')); + return; + } + try { + updateOperation(selectedOperation.id, { + value_text: JSON.stringify(JSON.parse(raw), null, 2), + }); + showSuccess(t('JSON 已格式化')); + } catch (error) { + showError(t('当前值不是合法 JSON,无法格式化')); + } + }, [selectedOperation, t, updateOperation]); + + const updateReturnErrorDraft = (operationId, draftPatch = {}) => { + const current = operations.find((item) => item.id === operationId); + if (!current) return; + const draft = parseReturnErrorDraft(current.value_text); + const nextDraft = { ...draft, ...draftPatch }; + updateOperation(operationId, { + value_text: buildReturnErrorValueText(nextDraft), + }); + }; + + const updatePruneObjectsDraft = (operationId, updater) => { + const current = operations.find((item) => item.id === operationId); + if (!current) return; + const draft = parsePruneObjectsDraft(current.value_text); + const nextDraft = + typeof updater === 'function' + ? updater(draft) + : { ...draft, ...(updater || {}) }; + updateOperation(operationId, { + value_text: buildPruneObjectsValueText(nextDraft), + }); + }; + + const addPruneRule = (operationId) => { + updatePruneObjectsDraft(operationId, (draft) => ({ + ...draft, + simpleMode: false, + rules: [...(draft.rules || []), normalizePruneRule({})], + })); + }; + + const updatePruneRule = (operationId, ruleId, patch) => { + updatePruneObjectsDraft(operationId, (draft) => ({ + ...draft, + rules: (draft.rules || []).map((rule) => + rule.id === ruleId ? { ...rule, ...patch } : rule, + ), + })); + }; + + const removePruneRule = (operationId, ruleId) => { + updatePruneObjectsDraft(operationId, (draft) => ({ + ...draft, + rules: (draft.rules || []).filter((rule) => rule.id !== ruleId), + })); + }; + + const addOperation = () => { + const created = createDefaultOperation(); + setOperations((prev) => [...prev, created]); + setSelectedOperationId(created.id); + }; + + const duplicateOperation = (operationId) => { + let insertedId = ''; + setOperations((prev) => { + const index = prev.findIndex((item) => item.id === operationId); + if (index < 0) return prev; + const source = prev[index]; + const cloned = normalizeOperation({ + path: source.path, + mode: source.mode, + value: parseLooseValue(source.value_text), + keep_origin: source.keep_origin, + from: source.from, + to: source.to, + logic: source.logic, + conditions: (source.conditions || []).map((condition) => ({ + path: condition.path, + mode: condition.mode, + value: parseLooseValue(condition.value_text), + invert: condition.invert, + pass_missing_key: condition.pass_missing_key, + })), + }); + insertedId = cloned.id; + const next = [...prev]; + next.splice(index + 1, 0, cloned); + return next; + }); + if (insertedId) { + setSelectedOperationId(insertedId); + } + }; + + const removeOperation = (operationId) => { + setOperations((prev) => { + if (prev.length <= 1) return [createDefaultOperation()]; + return prev.filter((item) => item.id !== operationId); + }); + setExpandedConditionMap((prev) => { + if (!Object.prototype.hasOwnProperty.call(prev, operationId)) { + return prev; + } + const next = { ...prev }; + delete next[operationId]; + return next; + }); + }; + + const addCondition = (operationId) => { + const createdCondition = createDefaultCondition(); + setOperations((prev) => + prev.map((operation) => + operation.id === operationId + ? { + ...operation, + conditions: [...(operation.conditions || []), createdCondition], + } + : operation, + ), + ); + setExpandedConditionMap((prev) => ({ + ...prev, + [operationId]: [...(prev[operationId] || []), createdCondition.id], + })); + }; + + const updateCondition = (operationId, conditionId, patch) => { + setOperations((prev) => + prev.map((operation) => { + if (operation.id !== operationId) return operation; + return { + ...operation, + conditions: (operation.conditions || []).map((condition) => + condition.id === conditionId + ? { ...condition, ...patch } + : condition, + ), + }; + }), + ); + }; + + const removeCondition = (operationId, conditionId) => { + setOperations((prev) => + prev.map((operation) => { + if (operation.id !== operationId) return operation; + return { + ...operation, + conditions: (operation.conditions || []).filter( + (condition) => condition.id !== conditionId, + ), + }; + }), + ); + setExpandedConditionMap((prev) => ({ + ...prev, + [operationId]: (prev[operationId] || []).filter( + (id) => id !== conditionId, + ), + })); + }; + + const selectedConditionKeys = useMemo( + () => expandedConditionMap[selectedOperationId] || [], + [expandedConditionMap, selectedOperationId], + ); + + const handleConditionCollapseChange = useCallback( + (operationId, activeKeys) => { + const keys = ( + Array.isArray(activeKeys) ? activeKeys : [activeKeys] + ).filter(Boolean); + setExpandedConditionMap((prev) => ({ + ...prev, + [operationId]: keys, + })); + }, + [], + ); + + const expandAllSelectedConditions = useCallback(() => { + if (!selectedOperationId || !selectedOperation) return; + setExpandedConditionMap((prev) => ({ + ...prev, + [selectedOperationId]: (selectedOperation.conditions || []).map( + (condition) => condition.id, + ), + })); + }, [selectedOperation, selectedOperationId]); + + const collapseAllSelectedConditions = useCallback(() => { + if (!selectedOperationId) return; + setExpandedConditionMap((prev) => ({ + ...prev, + [selectedOperationId]: [], + })); + }, [selectedOperationId]); + + const handleJsonChange = (nextValue) => { + setJsonText(nextValue); + const trimmed = String(nextValue || '').trim(); + if (!trimmed) { + setJsonError(''); + return; + } + if (!verifyJSON(trimmed)) { + setJsonError(t('JSON格式错误')); + return; + } + setJsonError(''); + }; + + const formatJson = () => { + const trimmed = jsonText.trim(); + if (!trimmed) return; + if (!verifyJSON(trimmed)) { + showError(t('参数覆盖必须是合法的 JSON 格式!')); + return; + } + setJsonText(JSON.stringify(JSON.parse(trimmed), null, 2)); + setJsonError(''); + }; + + const visualValidationError = useMemo(() => { + if (editMode !== 'visual') { + return ''; + } + try { + buildVisualJson(); + return ''; + } catch (error) { + return error?.message || t('参数配置有误'); + } + }, [buildVisualJson, editMode, t]); + + const handleSave = () => { + try { + let result = ''; + if (editMode === 'json') { + const trimmed = jsonText.trim(); + if (!trimmed) { + result = ''; + } else { + if (!verifyJSON(trimmed)) { + throw new Error(t('参数覆盖必须是合法的 JSON 格式!')); + } + result = JSON.stringify(JSON.parse(trimmed), null, 2); + } + } else { + result = buildVisualJson(); + } + onSave?.(result); + } catch (error) { + showError(error.message); + } + }; + + return ( + <> + + + +
+ + {t('编辑方式')} + + + {t('模板')} + + setTemplatePresetKey(nextValue || 'operations_default') + } + style={{ width: 260 }} + /> + + + + + openFieldGuide('path')} + > + {t('字段速查')} + +
+
+ + {editMode === 'visual' ? ( +
+ {visualMode === 'legacy' ? ( + + {t('旧格式(JSON 对象)')} +