Files
new-api/relay/common/override.go
2025-08-16 17:16:46 +08:00

397 lines
11 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package common
import (
"encoding/json"
"fmt"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"strings"
)
type ConditionOperation struct {
Path string `json:"path"` // JSON路径
Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte
Value interface{} `json:"value"` // 匹配的值
Invert bool `json:"invert"` // 反选功能true表示取反结果
PassMissingKey bool `json:"pass_missing_key"` // 未获取到json key时的行为
}
type ParamOperation struct {
Path string `json:"path"`
Mode string `json:"mode"` // delete, set, move, prepend, append
Value interface{} `json:"value"`
KeepOrigin bool `json:"keep_origin"`
From string `json:"from,omitempty"`
To string `json:"to,omitempty"`
Conditions []ConditionOperation `json:"conditions,omitempty"` // 条件列表
Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
}
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
if len(paramOverride) == 0 {
return jsonData, nil
}
// 尝试断言为操作格式
if operations, ok := tryParseOperations(paramOverride); ok {
// 使用新方法
result, err := applyOperations(string(jsonData), operations)
return []byte(result), err
}
// 直接使用旧方法
return applyOperationsLegacy(jsonData, paramOverride)
}
func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
// 检查是否包含 "operations" 字段
if opsValue, exists := paramOverride["operations"]; exists {
if opsSlice, ok := opsValue.([]interface{}); ok {
var operations []ParamOperation
for _, op := range opsSlice {
if opMap, ok := op.(map[string]interface{}); ok {
operation := ParamOperation{}
// 断言必要字段
if path, ok := opMap["path"].(string); ok {
operation.Path = path
}
if mode, ok := opMap["mode"].(string); ok {
operation.Mode = mode
} else {
return nil, false // mode 是必需的
}
// 可选字段
if value, exists := opMap["value"]; exists {
operation.Value = value
}
if keepOrigin, ok := opMap["keep_origin"].(bool); ok {
operation.KeepOrigin = keepOrigin
}
if from, ok := opMap["from"].(string); ok {
operation.From = from
}
if to, ok := opMap["to"].(string); ok {
operation.To = to
}
if logic, ok := opMap["logic"].(string); ok {
operation.Logic = logic
} else {
operation.Logic = "OR" // 默认为OR
}
// 解析条件
if conditions, exists := opMap["conditions"]; exists {
if condSlice, ok := conditions.([]interface{}); ok {
for _, cond := range condSlice {
if condMap, ok := cond.(map[string]interface{}); ok {
condition := ConditionOperation{}
if path, ok := condMap["path"].(string); ok {
condition.Path = path
}
if mode, ok := condMap["mode"].(string); ok {
condition.Mode = mode
}
if value, ok := condMap["value"]; ok {
condition.Value = value
}
if invert, ok := condMap["invert"].(bool); ok {
condition.Invert = invert
}
if passMissingKey, ok := condMap["pass_missing_key"].(bool); ok {
condition.PassMissingKey = passMissingKey
}
operation.Conditions = append(operation.Conditions, condition)
}
}
}
}
operations = append(operations, operation)
} else {
return nil, false
}
}
return operations, true
}
}
return nil, false
}
func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) {
if len(conditions) == 0 {
return true, nil // 没有条件,直接通过
}
results := make([]bool, len(conditions))
for i, condition := range conditions {
result, err := checkSingleCondition(jsonStr, condition)
if err != nil {
return false, err
}
results[i] = result
}
if strings.ToUpper(logic) == "AND" {
for _, result := range results {
if !result {
return false, nil
}
}
return true, nil
} else {
for _, result := range results {
if result {
return true, nil
}
}
return false, nil
}
}
func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
value := gjson.Get(jsonStr, condition.Path)
if !value.Exists() {
if condition.PassMissingKey {
return true, nil
}
return false, nil
}
// 利用gjson的类型解析
targetBytes, err := json.Marshal(condition.Value)
if err != nil {
return false, fmt.Errorf("failed to marshal condition value: %v", err)
}
targetValue := gjson.ParseBytes(targetBytes)
result, err := compareGjsonValues(value, targetValue, strings.ToLower(condition.Mode))
if err != nil {
return false, fmt.Errorf("comparison failed for path %s: %v", condition.Path, err)
}
if condition.Invert {
result = !result
}
return result, nil
}
// compareGjsonValues 直接比较两个gjson.Result支持所有比较模式
func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) {
switch mode {
case "full":
return compareEqual(jsonValue, targetValue)
case "prefix":
return strings.HasPrefix(jsonValue.String(), targetValue.String()), nil
case "suffix":
return strings.HasSuffix(jsonValue.String(), targetValue.String()), nil
case "contains":
return strings.Contains(jsonValue.String(), targetValue.String()), nil
case "gt":
return compareNumeric(jsonValue, targetValue, "gt")
case "gte":
return compareNumeric(jsonValue, targetValue, "gte")
case "lt":
return compareNumeric(jsonValue, targetValue, "lt")
case "lte":
return compareNumeric(jsonValue, targetValue, "lte")
default:
return false, fmt.Errorf("unsupported comparison mode: %s", mode)
}
}
func compareEqual(jsonValue, targetValue gjson.Result) (bool, error) {
// 对布尔值特殊处理
if (jsonValue.Type == gjson.True || jsonValue.Type == gjson.False) &&
(targetValue.Type == gjson.True || targetValue.Type == gjson.False) {
return jsonValue.Bool() == targetValue.Bool(), nil
}
// 如果类型不同,报错
if jsonValue.Type != targetValue.Type {
return false, fmt.Errorf("compare for different types, got %v and %v", jsonValue.Type, targetValue.Type)
}
switch jsonValue.Type {
case gjson.True, gjson.False:
return jsonValue.Bool() == targetValue.Bool(), nil
case gjson.Number:
return jsonValue.Num == targetValue.Num, nil
case gjson.String:
return jsonValue.String() == targetValue.String(), nil
default:
return jsonValue.String() == targetValue.String(), nil
}
}
func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool, error) {
// 只有数字类型才支持数值比较
if jsonValue.Type != gjson.Number || targetValue.Type != gjson.Number {
return false, fmt.Errorf("numeric comparison requires both values to be numbers, got %v and %v", jsonValue.Type, targetValue.Type)
}
jsonNum := jsonValue.Num
targetNum := targetValue.Num
switch operator {
case "gt":
return jsonNum > targetNum, nil
case "gte":
return jsonNum >= targetNum, nil
case "lt":
return jsonNum < targetNum, nil
case "lte":
return jsonNum <= targetNum, nil
default:
return false, fmt.Errorf("unsupported numeric operator: %s", operator)
}
}
// applyOperationsLegacy 原参数覆盖方法
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
reqMap := make(map[string]interface{})
err := json.Unmarshal(jsonData, &reqMap)
if err != nil {
return nil, err
}
for key, value := range paramOverride {
reqMap[key] = value
}
return json.Marshal(reqMap)
}
func applyOperations(jsonStr string, operations []ParamOperation) (string, error) {
result := jsonStr
for _, op := range operations {
// 检查条件是否满足
ok, err := checkConditions(result, op.Conditions, op.Logic)
if err != nil {
return "", err
}
if !ok {
continue // 条件不满足,跳过当前操作
}
switch op.Mode {
case "delete":
result, err = sjson.Delete(result, op.Path)
case "set":
if op.KeepOrigin && gjson.Get(result, op.Path).Exists() {
continue
}
result, err = sjson.Set(result, op.Path, op.Value)
case "move":
result, err = moveValue(result, op.From, op.To)
case "prepend":
result, err = modifyValue(result, op.Path, op.Value, op.KeepOrigin, true)
case "append":
result, err = modifyValue(result, op.Path, op.Value, op.KeepOrigin, false)
default:
return "", fmt.Errorf("unknown operation: %s", op.Mode)
}
if err != nil {
return "", fmt.Errorf("operation %s failed: %v", op.Mode, err)
}
}
return result, nil
}
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
sourceValue := gjson.Get(jsonStr, fromPath)
if !sourceValue.Exists() {
return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
}
result, err := sjson.Set(jsonStr, toPath, sourceValue.Value())
if err != nil {
return "", err
}
return sjson.Delete(result, fromPath)
}
func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) {
current := gjson.Get(jsonStr, path)
switch {
case current.IsArray():
return modifyArray(jsonStr, path, value, isPrepend)
case current.Type == gjson.String:
return modifyString(jsonStr, path, value, isPrepend)
case current.Type == gjson.JSON:
return mergeObjects(jsonStr, path, value, keepOrigin)
}
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
current := gjson.Get(jsonStr, path)
var newArray []interface{}
// 添加新值
addValue := func() {
if arr, ok := value.([]interface{}); ok {
newArray = append(newArray, arr...)
} else {
newArray = append(newArray, value)
}
}
// 添加原值
addOriginal := func() {
current.ForEach(func(_, val gjson.Result) bool {
newArray = append(newArray, val.Value())
return true
})
}
if isPrepend {
addValue()
addOriginal()
} else {
addOriginal()
addValue()
}
return sjson.Set(jsonStr, path, newArray)
}
func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
current := gjson.Get(jsonStr, path)
valueStr := fmt.Sprintf("%v", value)
var newStr string
if isPrepend {
newStr = valueStr + current.String()
} else {
newStr = current.String() + valueStr
}
return sjson.Set(jsonStr, path, newStr)
}
func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
current := gjson.Get(jsonStr, path)
var currentMap, newMap map[string]interface{}
// 解析当前值
if err := json.Unmarshal([]byte(current.Raw), &currentMap); err != nil {
return "", err
}
// 解析新值
switch v := value.(type) {
case map[string]interface{}:
newMap = v
default:
jsonBytes, _ := json.Marshal(v)
if err := json.Unmarshal(jsonBytes, &newMap); err != nil {
return "", err
}
}
// 合并
result := make(map[string]interface{})
for k, v := range currentMap {
result[k] = v
}
for k, v := range newMap {
if !keepOrigin || result[k] == nil {
result[k] = v
}
}
return sjson.Set(jsonStr, path, result)
}