Files
new-api-oiss/relay/common/override.go

1002 lines
28 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package common
import (
"errors"
"fmt"
"net/http"
"regexp"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/types"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`)
type ConditionOperation struct {
Path string `json:"path"` // JSON路径
Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte
Value interface{} `json:"value"` // 匹配的值
Invert bool `json:"invert"` // 反选功能true表示取反结果
PassMissingKey bool `json:"pass_missing_key"` // 未获取到json key时的行为
}
type ParamOperation struct {
Path string `json:"path"`
Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace, return_error, prune_objects
Value interface{} `json:"value"`
KeepOrigin bool `json:"keep_origin"`
From string `json:"from,omitempty"`
To string `json:"to,omitempty"`
Conditions []ConditionOperation `json:"conditions,omitempty"` // 条件列表
Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
}
type ParamOverrideReturnError struct {
Message string
StatusCode int
Code string
Type string
SkipRetry bool
}
func (e *ParamOverrideReturnError) Error() string {
if e == nil {
return "param override return error"
}
if e.Message == "" {
return "param override return error"
}
return e.Message
}
func AsParamOverrideReturnError(err error) (*ParamOverrideReturnError, bool) {
if err == nil {
return nil, false
}
var target *ParamOverrideReturnError
if errors.As(err, &target) {
return target, true
}
return nil, false
}
func NewAPIErrorFromParamOverride(err *ParamOverrideReturnError) *types.NewAPIError {
if err == nil {
return types.NewError(
errors.New("param override return error is nil"),
types.ErrorCodeChannelParamOverrideInvalid,
types.ErrOptionWithSkipRetry(),
)
}
statusCode := err.StatusCode
if statusCode < http.StatusContinue || statusCode > http.StatusNetworkAuthenticationRequired {
statusCode = http.StatusBadRequest
}
errorCode := err.Code
if strings.TrimSpace(errorCode) == "" {
errorCode = string(types.ErrorCodeInvalidRequest)
}
errorType := err.Type
if strings.TrimSpace(errorType) == "" {
errorType = "invalid_request_error"
}
message := strings.TrimSpace(err.Message)
if message == "" {
message = "request blocked by param override"
}
opts := make([]types.NewAPIErrorOptions, 0, 1)
if err.SkipRetry {
opts = append(opts, types.ErrOptionWithSkipRetry())
}
return types.WithOpenAIError(types.OpenAIError{
Message: message,
Type: errorType,
Code: errorCode,
}, statusCode, opts...)
}
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
if len(paramOverride) == 0 {
return jsonData, nil
}
// 尝试断言为操作格式
if operations, ok := tryParseOperations(paramOverride); ok {
// 使用新方法
result, err := applyOperations(string(jsonData), operations, conditionContext)
return []byte(result), err
}
// 直接使用旧方法
return applyOperationsLegacy(jsonData, paramOverride)
}
func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
// 检查是否包含 "operations" 字段
if opsValue, exists := paramOverride["operations"]; exists {
if opsSlice, ok := opsValue.([]interface{}); ok {
var operations []ParamOperation
for _, op := range opsSlice {
if opMap, ok := op.(map[string]interface{}); ok {
operation := ParamOperation{}
// 断言必要字段
if path, ok := opMap["path"].(string); ok {
operation.Path = path
}
if mode, ok := opMap["mode"].(string); ok {
operation.Mode = mode
} else {
return nil, false // mode 是必需的
}
// 可选字段
if value, exists := opMap["value"]; exists {
operation.Value = value
}
if keepOrigin, ok := opMap["keep_origin"].(bool); ok {
operation.KeepOrigin = keepOrigin
}
if from, ok := opMap["from"].(string); ok {
operation.From = from
}
if to, ok := opMap["to"].(string); ok {
operation.To = to
}
if logic, ok := opMap["logic"].(string); ok {
operation.Logic = logic
} else {
operation.Logic = "OR" // 默认为OR
}
// 解析条件
if conditions, exists := opMap["conditions"]; exists {
if condSlice, ok := conditions.([]interface{}); ok {
for _, cond := range condSlice {
if condMap, ok := cond.(map[string]interface{}); ok {
condition := ConditionOperation{}
if path, ok := condMap["path"].(string); ok {
condition.Path = path
}
if mode, ok := condMap["mode"].(string); ok {
condition.Mode = mode
}
if value, ok := condMap["value"]; ok {
condition.Value = value
}
if invert, ok := condMap["invert"].(bool); ok {
condition.Invert = invert
}
if passMissingKey, ok := condMap["pass_missing_key"].(bool); ok {
condition.PassMissingKey = passMissingKey
}
operation.Conditions = append(operation.Conditions, condition)
}
}
}
}
operations = append(operations, operation)
} else {
return nil, false
}
}
return operations, true
}
}
return nil, false
}
func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
if len(conditions) == 0 {
return true, nil // 没有条件,直接通过
}
results := make([]bool, len(conditions))
for i, condition := range conditions {
result, err := checkSingleCondition(jsonStr, contextJSON, condition)
if err != nil {
return false, err
}
results[i] = result
}
if strings.ToUpper(logic) == "AND" {
for _, result := range results {
if !result {
return false, nil
}
}
return true, nil
} else {
for _, result := range results {
if result {
return true, nil
}
}
return false, nil
}
}
func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) {
// 处理负数索引
path := processNegativeIndex(jsonStr, condition.Path)
value := gjson.Get(jsonStr, path)
if !value.Exists() && contextJSON != "" {
value = gjson.Get(contextJSON, condition.Path)
}
if !value.Exists() {
if condition.PassMissingKey {
return true, nil
}
return false, nil
}
// 利用gjson的类型解析
targetBytes, err := common.Marshal(condition.Value)
if err != nil {
return false, fmt.Errorf("failed to marshal condition value: %v", err)
}
targetValue := gjson.ParseBytes(targetBytes)
result, err := compareGjsonValues(value, targetValue, strings.ToLower(condition.Mode))
if err != nil {
return false, fmt.Errorf("comparison failed for path %s: %v", condition.Path, err)
}
if condition.Invert {
result = !result
}
return result, nil
}
func processNegativeIndex(jsonStr string, path string) string {
matches := negativeIndexRegexp.FindAllStringSubmatch(path, -1)
if len(matches) == 0 {
return path
}
result := path
for _, match := range matches {
negIndex := match[1]
index, _ := strconv.Atoi(negIndex)
arrayPath := strings.Split(path, negIndex)[0]
if strings.HasSuffix(arrayPath, ".") {
arrayPath = arrayPath[:len(arrayPath)-1]
}
array := gjson.Get(jsonStr, arrayPath)
if array.IsArray() {
length := len(array.Array())
actualIndex := length + index
if actualIndex >= 0 && actualIndex < length {
result = strings.Replace(result, match[0], "."+strconv.Itoa(actualIndex), 1)
}
}
}
return result
}
// compareGjsonValues 直接比较两个gjson.Result支持所有比较模式
func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) {
switch mode {
case "full":
return compareEqual(jsonValue, targetValue)
case "prefix":
return strings.HasPrefix(jsonValue.String(), targetValue.String()), nil
case "suffix":
return strings.HasSuffix(jsonValue.String(), targetValue.String()), nil
case "contains":
return strings.Contains(jsonValue.String(), targetValue.String()), nil
case "gt":
return compareNumeric(jsonValue, targetValue, "gt")
case "gte":
return compareNumeric(jsonValue, targetValue, "gte")
case "lt":
return compareNumeric(jsonValue, targetValue, "lt")
case "lte":
return compareNumeric(jsonValue, targetValue, "lte")
default:
return false, fmt.Errorf("unsupported comparison mode: %s", mode)
}
}
func compareEqual(jsonValue, targetValue gjson.Result) (bool, error) {
// 对null值特殊处理两个都是null返回true一个是null另一个不是返回false
if jsonValue.Type == gjson.Null || targetValue.Type == gjson.Null {
return jsonValue.Type == gjson.Null && targetValue.Type == gjson.Null, nil
}
// 对布尔值特殊处理
if (jsonValue.Type == gjson.True || jsonValue.Type == gjson.False) &&
(targetValue.Type == gjson.True || targetValue.Type == gjson.False) {
return jsonValue.Bool() == targetValue.Bool(), nil
}
// 如果类型不同,报错
if jsonValue.Type != targetValue.Type {
return false, fmt.Errorf("compare for different types, got %v and %v", jsonValue.Type, targetValue.Type)
}
switch jsonValue.Type {
case gjson.True, gjson.False:
return jsonValue.Bool() == targetValue.Bool(), nil
case gjson.Number:
return jsonValue.Num == targetValue.Num, nil
case gjson.String:
return jsonValue.String() == targetValue.String(), nil
default:
return jsonValue.String() == targetValue.String(), nil
}
}
func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool, error) {
// 只有数字类型才支持数值比较
if jsonValue.Type != gjson.Number || targetValue.Type != gjson.Number {
return false, fmt.Errorf("numeric comparison requires both values to be numbers, got %v and %v", jsonValue.Type, targetValue.Type)
}
jsonNum := jsonValue.Num
targetNum := targetValue.Num
switch operator {
case "gt":
return jsonNum > targetNum, nil
case "gte":
return jsonNum >= targetNum, nil
case "lt":
return jsonNum < targetNum, nil
case "lte":
return jsonNum <= targetNum, nil
default:
return false, fmt.Errorf("unsupported numeric operator: %s", operator)
}
}
// applyOperationsLegacy 原参数覆盖方法
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
reqMap := make(map[string]interface{})
err := common.Unmarshal(jsonData, &reqMap)
if err != nil {
return nil, err
}
for key, value := range paramOverride {
reqMap[key] = value
}
return common.Marshal(reqMap)
}
func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
var contextJSON string
if conditionContext != nil && len(conditionContext) > 0 {
ctxBytes, err := common.Marshal(conditionContext)
if err != nil {
return "", fmt.Errorf("failed to marshal condition context: %v", err)
}
contextJSON = string(ctxBytes)
}
result := jsonStr
for _, op := range operations {
// 检查条件是否满足
ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic)
if err != nil {
return "", err
}
if !ok {
continue // 条件不满足,跳过当前操作
}
// 处理路径中的负数索引
opPath := processNegativeIndex(result, op.Path)
switch op.Mode {
case "delete":
result, err = sjson.Delete(result, opPath)
case "set":
if op.KeepOrigin && gjson.Get(result, opPath).Exists() {
continue
}
result, err = sjson.Set(result, opPath, op.Value)
case "move":
opFrom := processNegativeIndex(result, op.From)
opTo := processNegativeIndex(result, op.To)
result, err = moveValue(result, opFrom, opTo)
case "copy":
if op.From == "" || op.To == "" {
return "", fmt.Errorf("copy from/to is required")
}
opFrom := processNegativeIndex(result, op.From)
opTo := processNegativeIndex(result, op.To)
result, err = copyValue(result, opFrom, opTo)
case "prepend":
result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, true)
case "append":
result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, false)
case "trim_prefix":
result, err = trimStringValue(result, opPath, op.Value, true)
case "trim_suffix":
result, err = trimStringValue(result, opPath, op.Value, false)
case "ensure_prefix":
result, err = ensureStringAffix(result, opPath, op.Value, true)
case "ensure_suffix":
result, err = ensureStringAffix(result, opPath, op.Value, false)
case "trim_space":
result, err = transformStringValue(result, opPath, strings.TrimSpace)
case "to_lower":
result, err = transformStringValue(result, opPath, strings.ToLower)
case "to_upper":
result, err = transformStringValue(result, opPath, strings.ToUpper)
case "replace":
result, err = replaceStringValue(result, opPath, op.From, op.To)
case "regex_replace":
result, err = regexReplaceStringValue(result, opPath, op.From, op.To)
case "return_error":
returnErr, parseErr := parseParamOverrideReturnError(op.Value)
if parseErr != nil {
return "", parseErr
}
return "", returnErr
case "prune_objects":
result, err = pruneObjects(result, opPath, contextJSON, op.Value)
default:
return "", fmt.Errorf("unknown operation: %s", op.Mode)
}
if err != nil {
return "", fmt.Errorf("operation %s failed: %w", op.Mode, err)
}
}
return result, nil
}
func parseParamOverrideReturnError(value interface{}) (*ParamOverrideReturnError, error) {
result := &ParamOverrideReturnError{
StatusCode: http.StatusBadRequest,
Code: string(types.ErrorCodeInvalidRequest),
Type: "invalid_request_error",
SkipRetry: true,
}
switch raw := value.(type) {
case nil:
return nil, fmt.Errorf("return_error value is required")
case string:
result.Message = strings.TrimSpace(raw)
case map[string]interface{}:
if message, ok := raw["message"].(string); ok {
result.Message = strings.TrimSpace(message)
}
if result.Message == "" {
if message, ok := raw["msg"].(string); ok {
result.Message = strings.TrimSpace(message)
}
}
if code, exists := raw["code"]; exists {
codeStr := strings.TrimSpace(fmt.Sprintf("%v", code))
if codeStr != "" {
result.Code = codeStr
}
}
if errType, ok := raw["type"].(string); ok {
errType = strings.TrimSpace(errType)
if errType != "" {
result.Type = errType
}
}
if skipRetry, ok := raw["skip_retry"].(bool); ok {
result.SkipRetry = skipRetry
}
if statusCodeRaw, exists := raw["status_code"]; exists {
statusCode, ok := parseOverrideInt(statusCodeRaw)
if !ok {
return nil, fmt.Errorf("return_error status_code must be an integer")
}
result.StatusCode = statusCode
} else if statusRaw, exists := raw["status"]; exists {
statusCode, ok := parseOverrideInt(statusRaw)
if !ok {
return nil, fmt.Errorf("return_error status must be an integer")
}
result.StatusCode = statusCode
}
default:
return nil, fmt.Errorf("return_error value must be string or object")
}
if result.Message == "" {
return nil, fmt.Errorf("return_error message is required")
}
if result.StatusCode < http.StatusContinue || result.StatusCode > http.StatusNetworkAuthenticationRequired {
return nil, fmt.Errorf("return_error status code out of range: %d", result.StatusCode)
}
return result, nil
}
func parseOverrideInt(v interface{}) (int, bool) {
switch value := v.(type) {
case int:
return value, true
case float64:
if value != float64(int(value)) {
return 0, false
}
return int(value), true
default:
return 0, false
}
}
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
sourceValue := gjson.Get(jsonStr, fromPath)
if !sourceValue.Exists() {
return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
}
result, err := sjson.Set(jsonStr, toPath, sourceValue.Value())
if err != nil {
return "", err
}
return sjson.Delete(result, fromPath)
}
func copyValue(jsonStr, fromPath, toPath string) (string, error) {
sourceValue := gjson.Get(jsonStr, fromPath)
if !sourceValue.Exists() {
return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
}
return sjson.Set(jsonStr, toPath, sourceValue.Value())
}
func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) {
current := gjson.Get(jsonStr, path)
switch {
case current.IsArray():
return modifyArray(jsonStr, path, value, isPrepend)
case current.Type == gjson.String:
return modifyString(jsonStr, path, value, isPrepend)
case current.Type == gjson.JSON:
return mergeObjects(jsonStr, path, value, keepOrigin)
}
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
current := gjson.Get(jsonStr, path)
var newArray []interface{}
// 添加新值
addValue := func() {
if arr, ok := value.([]interface{}); ok {
newArray = append(newArray, arr...)
} else {
newArray = append(newArray, value)
}
}
// 添加原值
addOriginal := func() {
current.ForEach(func(_, val gjson.Result) bool {
newArray = append(newArray, val.Value())
return true
})
}
if isPrepend {
addValue()
addOriginal()
} else {
addOriginal()
addValue()
}
return sjson.Set(jsonStr, path, newArray)
}
func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
current := gjson.Get(jsonStr, path)
valueStr := fmt.Sprintf("%v", value)
var newStr string
if isPrepend {
newStr = valueStr + current.String()
} else {
newStr = current.String() + valueStr
}
return sjson.Set(jsonStr, path, newStr)
}
func trimStringValue(jsonStr, path string, value interface{}, isPrefix bool) (string, error) {
current := gjson.Get(jsonStr, path)
if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
if value == nil {
return jsonStr, fmt.Errorf("trim value is required")
}
valueStr := fmt.Sprintf("%v", value)
var newStr string
if isPrefix {
newStr = strings.TrimPrefix(current.String(), valueStr)
} else {
newStr = strings.TrimSuffix(current.String(), valueStr)
}
return sjson.Set(jsonStr, path, newStr)
}
func ensureStringAffix(jsonStr, path string, value interface{}, isPrefix bool) (string, error) {
current := gjson.Get(jsonStr, path)
if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
if value == nil {
return jsonStr, fmt.Errorf("ensure value is required")
}
valueStr := fmt.Sprintf("%v", value)
if valueStr == "" {
return jsonStr, fmt.Errorf("ensure value is required")
}
currentStr := current.String()
if isPrefix {
if strings.HasPrefix(currentStr, valueStr) {
return jsonStr, nil
}
return sjson.Set(jsonStr, path, valueStr+currentStr)
}
if strings.HasSuffix(currentStr, valueStr) {
return jsonStr, nil
}
return sjson.Set(jsonStr, path, currentStr+valueStr)
}
func transformStringValue(jsonStr, path string, transform func(string) string) (string, error) {
current := gjson.Get(jsonStr, path)
if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
return sjson.Set(jsonStr, path, transform(current.String()))
}
func replaceStringValue(jsonStr, path, from, to string) (string, error) {
current := gjson.Get(jsonStr, path)
if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
if from == "" {
return jsonStr, fmt.Errorf("replace from is required")
}
return sjson.Set(jsonStr, path, strings.ReplaceAll(current.String(), from, to))
}
func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string, error) {
current := gjson.Get(jsonStr, path)
if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
if pattern == "" {
return jsonStr, fmt.Errorf("regex pattern is required")
}
re, err := regexp.Compile(pattern)
if err != nil {
return jsonStr, err
}
return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement))
}
type pruneObjectsOptions struct {
conditions []ConditionOperation
logic string
recursive bool
}
func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string, error) {
options, err := parsePruneObjectsOptions(value)
if err != nil {
return "", err
}
if path == "" {
var root interface{}
if err := common.Unmarshal([]byte(jsonStr), &root); err != nil {
return "", err
}
cleaned, _, err := pruneObjectsNode(root, options, contextJSON, true)
if err != nil {
return "", err
}
cleanedBytes, err := common.Marshal(cleaned)
if err != nil {
return "", err
}
return string(cleanedBytes), nil
}
target := gjson.Get(jsonStr, path)
if !target.Exists() {
return jsonStr, nil
}
var targetNode interface{}
if target.Type == gjson.JSON {
if err := common.Unmarshal([]byte(target.Raw), &targetNode); err != nil {
return "", err
}
} else {
targetNode = target.Value()
}
cleaned, _, err := pruneObjectsNode(targetNode, options, contextJSON, true)
if err != nil {
return "", err
}
cleanedBytes, err := common.Marshal(cleaned)
if err != nil {
return "", err
}
return sjson.SetRaw(jsonStr, path, string(cleanedBytes))
}
func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) {
opts := pruneObjectsOptions{
logic: "AND",
recursive: true,
}
switch raw := value.(type) {
case nil:
return opts, fmt.Errorf("prune_objects value is required")
case string:
v := strings.TrimSpace(raw)
if v == "" {
return opts, fmt.Errorf("prune_objects value is required")
}
opts.conditions = []ConditionOperation{
{
Path: "type",
Mode: "full",
Value: v,
},
}
case map[string]interface{}:
if logic, ok := raw["logic"].(string); ok && strings.TrimSpace(logic) != "" {
opts.logic = logic
}
if recursive, ok := raw["recursive"].(bool); ok {
opts.recursive = recursive
}
if condRaw, exists := raw["conditions"]; exists {
conditions, err := parseConditionOperations(condRaw)
if err != nil {
return opts, err
}
opts.conditions = append(opts.conditions, conditions...)
}
if whereRaw, exists := raw["where"]; exists {
whereMap, ok := whereRaw.(map[string]interface{})
if !ok {
return opts, fmt.Errorf("prune_objects where must be object")
}
for key, val := range whereMap {
key = strings.TrimSpace(key)
if key == "" {
continue
}
opts.conditions = append(opts.conditions, ConditionOperation{
Path: key,
Mode: "full",
Value: val,
})
}
}
if matchType, exists := raw["type"]; exists {
opts.conditions = append(opts.conditions, ConditionOperation{
Path: "type",
Mode: "full",
Value: matchType,
})
}
default:
return opts, fmt.Errorf("prune_objects value must be string or object")
}
if len(opts.conditions) == 0 {
return opts, fmt.Errorf("prune_objects conditions are required")
}
return opts, nil
}
func parseConditionOperations(raw interface{}) ([]ConditionOperation, error) {
items, ok := raw.([]interface{})
if !ok {
return nil, fmt.Errorf("conditions must be an array")
}
result := make([]ConditionOperation, 0, len(items))
for _, item := range items {
itemMap, ok := item.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("condition must be object")
}
path, _ := itemMap["path"].(string)
mode, _ := itemMap["mode"].(string)
if strings.TrimSpace(path) == "" || strings.TrimSpace(mode) == "" {
return nil, fmt.Errorf("condition path/mode is required")
}
condition := ConditionOperation{
Path: path,
Mode: mode,
}
if value, exists := itemMap["value"]; exists {
condition.Value = value
}
if invert, ok := itemMap["invert"].(bool); ok {
condition.Invert = invert
}
if passMissingKey, ok := itemMap["pass_missing_key"].(bool); ok {
condition.PassMissingKey = passMissingKey
}
result = append(result, condition)
}
return result, nil
}
func pruneObjectsNode(node interface{}, options pruneObjectsOptions, contextJSON string, isRoot bool) (interface{}, bool, error) {
switch value := node.(type) {
case []interface{}:
result := make([]interface{}, 0, len(value))
for _, item := range value {
next, drop, err := pruneObjectsNode(item, options, contextJSON, false)
if err != nil {
return nil, false, err
}
if drop {
continue
}
result = append(result, next)
}
return result, false, nil
case map[string]interface{}:
shouldDrop, err := shouldPruneObject(value, options, contextJSON)
if err != nil {
return nil, false, err
}
if shouldDrop && !isRoot {
return nil, true, nil
}
if !options.recursive {
return value, false, nil
}
for key, child := range value {
next, drop, err := pruneObjectsNode(child, options, contextJSON, false)
if err != nil {
return nil, false, err
}
if drop {
delete(value, key)
continue
}
value[key] = next
}
return value, false, nil
default:
return node, false, nil
}
}
func shouldPruneObject(node map[string]interface{}, options pruneObjectsOptions, contextJSON string) (bool, error) {
nodeBytes, err := common.Marshal(node)
if err != nil {
return false, err
}
return checkConditions(string(nodeBytes), contextJSON, options.conditions, options.logic)
}
func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
current := gjson.Get(jsonStr, path)
var currentMap, newMap map[string]interface{}
// 解析当前值
if err := common.Unmarshal([]byte(current.Raw), &currentMap); err != nil {
return "", err
}
// 解析新值
switch v := value.(type) {
case map[string]interface{}:
newMap = v
default:
jsonBytes, _ := common.Marshal(v)
if err := common.Unmarshal(jsonBytes, &newMap); err != nil {
return "", err
}
}
// 合并
result := make(map[string]interface{})
for k, v := range currentMap {
result[k] = v
}
for k, v := range newMap {
if !keepOrigin || result[k] == nil {
result[k] = v
}
}
return sjson.Set(jsonStr, path, result)
}
// BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。
// 目前内置以下字段:
// - upstream_model/model始终为通道映射后的上游模型名。
// - original_model请求最初指定的模型名。
// - request_path请求路径
// - is_channel_test是否为渠道测试请求同 is_test
func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
if info == nil {
return nil
}
ctx := make(map[string]interface{})
if info.ChannelMeta != nil && info.ChannelMeta.UpstreamModelName != "" {
ctx["model"] = info.ChannelMeta.UpstreamModelName
ctx["upstream_model"] = info.ChannelMeta.UpstreamModelName
}
if info.OriginModelName != "" {
ctx["original_model"] = info.OriginModelName
if _, exists := ctx["model"]; !exists {
ctx["model"] = info.OriginModelName
}
}
if info.RequestURLPath != "" {
requestPath := info.RequestURLPath
if requestPath != "" {
ctx["request_path"] = requestPath
}
}
ctx["retry_index"] = info.RetryIndex
ctx["is_retry"] = info.RetryIndex > 0
ctx["retry"] = map[string]interface{}{
"index": info.RetryIndex,
"is_retry": info.RetryIndex > 0,
}
if info.LastError != nil {
code := string(info.LastError.GetErrorCode())
errorType := string(info.LastError.GetErrorType())
lastError := map[string]interface{}{
"status_code": info.LastError.StatusCode,
"message": info.LastError.Error(),
"code": code,
"error_code": code,
"type": errorType,
"error_type": errorType,
"skip_retry": types.IsSkipRetryError(info.LastError),
}
ctx["last_error"] = lastError
ctx["last_error_status_code"] = info.LastError.StatusCode
ctx["last_error_message"] = info.LastError.Error()
ctx["last_error_code"] = code
ctx["last_error_type"] = errorType
}
ctx["is_channel_test"] = info.IsChannelTest
return ctx
}