refactor(override): simplify header overrides to a lowercase single map

This commit is contained in:
Seefs
2026-02-25 17:24:18 +08:00
parent 3034fb8899
commit a955d4102d
6 changed files with 260 additions and 387 deletions

View File

@@ -18,11 +18,8 @@ import (
var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`)
const (
paramOverrideContextRequestHeaders = "request_headers"
paramOverrideContextRequestHeadersRaw = "request_headers_raw"
paramOverrideContextHeaderOverride = "header_override"
paramOverrideContextHeaderOverrideNormalized = "header_override_normalized"
paramOverrideContextHeaderOverrideDeleted = "header_override_deleted_normalized"
paramOverrideContextRequestHeaders = "request_headers"
paramOverrideContextHeaderOverride = "header_override"
)
var errSourceHeaderNotFound = errors.New("source header does not exist")
@@ -161,141 +158,118 @@ func getHeaderOverrideMap(info *RelayInfo) map[string]interface{} {
return info.ChannelMeta.HeadersOverride
}
func cloneHeaderOverrideMap(source map[string]interface{}) map[string]interface{} {
func sanitizeHeaderOverrideMap(source map[string]interface{}) map[string]interface{} {
if len(source) == 0 {
return map[string]interface{}{}
}
target := make(map[string]interface{}, len(source))
for key, value := range source {
target[key] = value
normalizedKey := normalizeHeaderContextKey(key)
if normalizedKey == "" {
continue
}
normalizedValue := strings.TrimSpace(fmt.Sprintf("%v", value))
if normalizedValue == "" {
if isHeaderPassthroughRuleKeyForOverride(normalizedKey) {
target[normalizedKey] = ""
}
continue
}
target[normalizedKey] = normalizedValue
}
return target
}
func setHeaderOverrideEntry(target map[string]interface{}, key string, value interface{}) {
key = strings.TrimSpace(key)
func isHeaderPassthroughRuleKeyForOverride(key string) bool {
key = strings.TrimSpace(strings.ToLower(key))
if key == "" {
return
}
for existing := range target {
if strings.EqualFold(strings.TrimSpace(existing), key) {
delete(target, existing)
}
}
target[key] = value
}
func isHeaderDeletedByRuntime(headerName string, deleted map[string]bool) bool {
if len(deleted) == 0 {
return false
}
normalized := normalizeHeaderContextKey(headerName)
if normalized == "" {
return false
if key == "*" {
return true
}
return deleted[normalized]
}
func mergeHeaderOverrideSource(base, runtime map[string]interface{}, deleted map[string]bool) map[string]interface{} {
merged := make(map[string]interface{}, len(base)+len(runtime))
for key, value := range base {
if isHeaderDeletedByRuntime(key, deleted) {
continue
}
setHeaderOverrideEntry(merged, key, value)
}
for key, value := range runtime {
setHeaderOverrideEntry(merged, key, value)
}
return merged
}
func cloneDeletedHeaderKeys(source map[string]bool) map[string]bool {
if len(source) == 0 {
return map[string]bool{}
}
target := make(map[string]bool, len(source))
for key, value := range source {
if !value {
continue
}
normalized := normalizeHeaderContextKey(key)
if normalized == "" {
continue
}
target[normalized] = true
}
return target
return strings.HasPrefix(key, "re:") || strings.HasPrefix(key, "regex:")
}
func GetEffectiveHeaderOverride(info *RelayInfo) map[string]interface{} {
if info == nil {
return map[string]interface{}{}
}
base := getHeaderOverrideMap(info)
if !info.UseRuntimeHeadersOverride {
return cloneHeaderOverrideMap(base)
if info.UseRuntimeHeadersOverride {
return sanitizeHeaderOverrideMap(info.RuntimeHeadersOverride)
}
return mergeHeaderOverrideSource(base, info.RuntimeHeadersOverride, cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized))
return sanitizeHeaderOverrideMap(getHeaderOverrideMap(info))
}
func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
// 检查是否包含 "operations" 字段
if opsValue, exists := paramOverride["operations"]; exists {
if opsSlice, ok := opsValue.([]interface{}); ok {
var operations []ParamOperation
for _, op := range opsSlice {
if opMap, ok := op.(map[string]interface{}); ok {
operation := ParamOperation{}
// 断言必要字段
if path, ok := opMap["path"].(string); ok {
operation.Path = path
}
if mode, ok := opMap["mode"].(string); ok {
operation.Mode = mode
} else {
return nil, false // mode 是必需的
}
// 可选字段
if value, exists := opMap["value"]; exists {
operation.Value = value
}
if keepOrigin, ok := opMap["keep_origin"].(bool); ok {
operation.KeepOrigin = keepOrigin
}
if from, ok := opMap["from"].(string); ok {
operation.From = from
}
if to, ok := opMap["to"].(string); ok {
operation.To = to
}
if logic, ok := opMap["logic"].(string); ok {
operation.Logic = logic
} else {
operation.Logic = "OR" // 默认为OR
}
// 解析条件
if conditions, exists := opMap["conditions"]; exists {
parsedConditions, err := parseConditionOperations(conditions)
if err != nil {
return nil, false
}
operation.Conditions = append(operation.Conditions, parsedConditions...)
}
operations = append(operations, operation)
} else {
return nil, false
}
}
return operations, true
}
opsValue, exists := paramOverride["operations"]
if !exists {
return nil, false
}
return nil, false
var opMaps []map[string]interface{}
switch ops := opsValue.(type) {
case []interface{}:
opMaps = make([]map[string]interface{}, 0, len(ops))
for _, op := range ops {
opMap, ok := op.(map[string]interface{})
if !ok {
return nil, false
}
opMaps = append(opMaps, opMap)
}
case []map[string]interface{}:
opMaps = ops
default:
return nil, false
}
operations := make([]ParamOperation, 0, len(opMaps))
for _, opMap := range opMaps {
operation := ParamOperation{}
// 断言必要字段
if path, ok := opMap["path"].(string); ok {
operation.Path = path
}
if mode, ok := opMap["mode"].(string); ok {
operation.Mode = mode
} else {
return nil, false // mode 是必需的
}
// 可选字段
if value, exists := opMap["value"]; exists {
operation.Value = value
}
if keepOrigin, ok := opMap["keep_origin"].(bool); ok {
operation.KeepOrigin = keepOrigin
}
if from, ok := opMap["from"].(string); ok {
operation.From = from
}
if to, ok := opMap["to"].(string); ok {
operation.To = to
}
if logic, ok := opMap["logic"].(string); ok {
operation.Logic = logic
} else {
operation.Logic = "OR" // 默认为OR
}
// 解析条件
if conditions, exists := opMap["conditions"]; exists {
parsedConditions, err := parseConditionOperations(conditions)
if err != nil {
return nil, false
}
operation.Conditions = append(operation.Conditions, parsedConditions...)
}
operations = append(operations, operation)
}
return operations, true
}
func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
@@ -712,15 +686,10 @@ func marshalContextJSON(context map[string]interface{}) (string, error) {
}
func setHeaderOverrideInContext(context map[string]interface{}, headerName string, value interface{}, keepOrigin bool) error {
headerName = strings.TrimSpace(headerName)
headerName = normalizeHeaderContextKey(headerName)
if headerName == "" {
return fmt.Errorf("header name is required")
}
if keepOrigin {
if _, exists := getHeaderValueFromContext(context, headerName); exists {
return nil
}
}
if value == nil {
return fmt.Errorf("header value is required")
}
@@ -730,21 +699,21 @@ func setHeaderOverrideInContext(context map[string]interface{}, headerName strin
}
rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
rawHeaders[headerName] = headerValue
normalizedHeaderName := normalizeHeaderContextKey(headerName)
normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
normalizedHeaders[normalizedHeaderName] = headerValue
if normalizedHeaderName != "" {
deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted)
delete(deletedHeaders, normalizedHeaderName)
if keepOrigin {
if existing, ok := rawHeaders[headerName]; ok {
existingValue := strings.TrimSpace(fmt.Sprintf("%v", existing))
if existingValue != "" {
return nil
}
}
}
rawHeaders[headerName] = headerValue
return nil
}
func copyHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error {
fromHeader = strings.TrimSpace(fromHeader)
toHeader = strings.TrimSpace(toHeader)
fromHeader = normalizeHeaderContextKey(fromHeader)
toHeader = normalizeHeaderContextKey(toHeader)
if fromHeader == "" || toHeader == "" {
return fmt.Errorf("copy_header from/to is required")
}
@@ -756,8 +725,8 @@ func copyHeaderInContext(context map[string]interface{}, fromHeader, toHeader st
}
func moveHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error {
fromHeader = strings.TrimSpace(fromHeader)
toHeader = strings.TrimSpace(toHeader)
fromHeader = normalizeHeaderContextKey(fromHeader)
toHeader = normalizeHeaderContextKey(toHeader)
if fromHeader == "" || toHeader == "" {
return fmt.Errorf("move_header from/to is required")
}
@@ -771,31 +740,19 @@ func moveHeaderInContext(context map[string]interface{}, fromHeader, toHeader st
}
func deleteHeaderOverrideInContext(context map[string]interface{}, headerName string) error {
headerName = strings.TrimSpace(headerName)
headerName = normalizeHeaderContextKey(headerName)
if headerName == "" {
return fmt.Errorf("header name is required")
}
rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
for key := range rawHeaders {
if strings.EqualFold(strings.TrimSpace(key), headerName) {
delete(rawHeaders, key)
}
}
normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
normalizedHeaderName := normalizeHeaderContextKey(headerName)
delete(normalizedHeaders, normalizedHeaderName)
if normalizedHeaderName != "" {
deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted)
deletedHeaders[normalizedHeaderName] = true
}
delete(rawHeaders, headerName)
return nil
}
func parseHeaderPassThroughNames(value interface{}) ([]string, error) {
normalizeNames := func(values []string) []string {
names := lo.FilterMap(values, func(item string, _ int) (string, bool) {
headerName := strings.TrimSpace(item)
headerName := normalizeHeaderContextKey(item)
if headerName == "" {
return "", false
}
@@ -825,7 +782,20 @@ func parseHeaderPassThroughNames(value interface{}) ([]string, error) {
return names, nil
case []interface{}:
names := lo.FilterMap(raw, func(item interface{}, _ int) (string, bool) {
headerName := strings.TrimSpace(fmt.Sprintf("%v", item))
headerName := normalizeHeaderContextKey(fmt.Sprintf("%v", item))
if headerName == "" {
return "", false
}
return headerName, true
})
names = lo.Uniq(names)
if len(names) == 0 {
return nil, fmt.Errorf("pass_headers value is invalid")
}
return names, nil
case []string:
names := lo.FilterMap(raw, func(item string, _ int) (string, bool) {
headerName := normalizeHeaderContextKey(item)
if headerName == "" {
return "", false
}
@@ -994,76 +964,29 @@ func ensureMapKeyInContext(context map[string]interface{}, key string) map[strin
}
func getHeaderValueFromContext(context map[string]interface{}, headerName string) (string, bool) {
headerName = strings.TrimSpace(headerName)
headerName = normalizeHeaderContextKey(headerName)
if headerName == "" {
return "", false
}
if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextHeaderOverride), headerName); ok {
return value, true
}
if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextRequestHeadersRaw), headerName); ok {
return value, true
}
normalizedName := normalizeHeaderContextKey(headerName)
if normalizedName == "" {
return "", false
}
if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized), normalizedName); ok {
return value, true
}
if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextRequestHeaders), normalizedName); ok {
return value, true
for _, key := range []string{paramOverrideContextHeaderOverride, paramOverrideContextRequestHeaders} {
source := ensureMapKeyInContext(context, key)
raw, ok := source[headerName]
if !ok {
continue
}
value := strings.TrimSpace(fmt.Sprintf("%v", raw))
if value != "" {
return value, true
}
}
return "", false
}
func findHeaderValueInMap(source map[string]interface{}, key string) (string, bool) {
if len(source) == 0 {
return "", false
}
entries := lo.Entries(source)
entry, ok := lo.Find(entries, func(item lo.Entry[string, interface{}]) bool {
return strings.EqualFold(strings.TrimSpace(item.Key), key)
})
if !ok {
return "", false
}
value := strings.TrimSpace(fmt.Sprintf("%v", entry.Value))
if value == "" {
return "", false
}
return value, true
}
func normalizeHeaderContextKey(key string) string {
key = strings.TrimSpace(strings.ToLower(key))
if key == "" {
return ""
}
var b strings.Builder
b.Grow(len(key))
previousUnderscore := false
for _, r := range key {
switch {
case r >= 'a' && r <= 'z':
b.WriteRune(r)
previousUnderscore = false
case r >= '0' && r <= '9':
b.WriteRune(r)
previousUnderscore = false
default:
if !previousUnderscore {
b.WriteByte('_')
previousUnderscore = true
}
}
}
result := strings.Trim(b.String(), "_")
return result
return strings.TrimSpace(strings.ToLower(key))
}
func buildNormalizedHeaders(headers map[string]string) map[string]interface{} {
func buildRequestHeadersContext(headers map[string]string) map[string]interface{} {
if len(headers) == 0 {
return map[string]interface{}{}
}
@@ -1081,54 +1004,6 @@ func buildNormalizedHeaders(headers map[string]string) map[string]interface{} {
})
}
func buildRawHeaders(headers map[string]string) map[string]interface{} {
if len(headers) == 0 {
return map[string]interface{}{}
}
entries := lo.Entries(headers)
rawEntries := lo.FilterMap(entries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) {
key := strings.TrimSpace(item.Key)
value := strings.TrimSpace(item.Value)
if key == "" || value == "" {
return lo.Entry[string, string]{}, false
}
return lo.Entry[string, string]{Key: key, Value: value}, true
})
return lo.SliceToMap(rawEntries, func(item lo.Entry[string, string]) (string, interface{}) {
return item.Key, item.Value
})
}
func buildHeaderOverrideContext(headers map[string]interface{}) (map[string]interface{}, map[string]interface{}) {
if len(headers) == 0 {
return map[string]interface{}{}, map[string]interface{}{}
}
entries := lo.Entries(headers)
rawEntries := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (lo.Entry[string, string], bool) {
key := strings.TrimSpace(item.Key)
value := strings.TrimSpace(fmt.Sprintf("%v", item.Value))
if key == "" || value == "" {
return lo.Entry[string, string]{}, false
}
return lo.Entry[string, string]{Key: key, Value: value}, true
})
raw := lo.SliceToMap(rawEntries, func(item lo.Entry[string, string]) (string, interface{}) {
return item.Key, item.Value
})
normalizedEntries := lo.FilterMap(rawEntries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) {
normalized := normalizeHeaderContextKey(item.Key)
if normalized == "" {
return lo.Entry[string, string]{}, false
}
return lo.Entry[string, string]{Key: normalized, Value: item.Value}, true
})
normalized := lo.SliceToMap(normalizedEntries, func(item lo.Entry[string, string]) (string, interface{}) {
return item.Key, item.Value
})
return raw, normalized
}
func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]interface{}) {
if info == nil || context == nil {
return
@@ -1141,55 +1016,10 @@ func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]in
if !ok {
return
}
entries := lo.Entries(rawMap)
sanitized := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (lo.Entry[string, interface{}], bool) {
key := strings.TrimSpace(item.Key)
if key == "" {
return lo.Entry[string, interface{}]{}, false
}
value := strings.TrimSpace(fmt.Sprintf("%v", item.Value))
if value == "" {
return lo.Entry[string, interface{}]{}, false
}
return lo.Entry[string, interface{}]{Key: key, Value: value}, true
})
info.RuntimeHeadersOverride = lo.SliceToMap(sanitized, func(item lo.Entry[string, interface{}]) (string, interface{}) {
return item.Key, item.Value
})
info.RuntimeHeadersDeletedNormalized = sanitizeRuntimeDeletedHeadersFromContext(context)
info.RuntimeHeadersOverride = sanitizeHeaderOverrideMap(rawMap)
info.UseRuntimeHeadersOverride = true
}
func sanitizeRuntimeDeletedHeadersFromContext(context map[string]interface{}) map[string]bool {
deletedRaw, exists := context[paramOverrideContextHeaderOverrideDeleted]
if !exists {
return nil
}
deletedMap, ok := deletedRaw.(map[string]interface{})
if !ok || len(deletedMap) == 0 {
return nil
}
entries := lo.Entries(deletedMap)
sanitized := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (string, bool) {
if keep, ok := item.Value.(bool); ok && !keep {
return "", false
}
normalized := normalizeHeaderContextKey(item.Key)
if normalized == "" {
return "", false
}
return normalized, true
})
if len(sanitized) == 0 {
return nil
}
keys := lo.Uniq(sanitized)
return lo.SliceToMap(keys, func(item string) (string, bool) {
return item, true
})
}
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
sourceValue := gjson.Get(jsonStr, fromPath)
if !sourceValue.Exists() {
@@ -1635,16 +1465,10 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
}
}
ctx[paramOverrideContextRequestHeaders] = buildNormalizedHeaders(info.RequestHeaders)
ctx[paramOverrideContextRequestHeadersRaw] = buildRawHeaders(info.RequestHeaders)
ctx[paramOverrideContextRequestHeaders] = buildRequestHeadersContext(info.RequestHeaders)
headerOverrideSource := GetEffectiveHeaderOverride(info)
rawHeaderOverride, normalizedHeaderOverride := buildHeaderOverrideContext(headerOverrideSource)
ctx[paramOverrideContextHeaderOverride] = rawHeaderOverride
ctx[paramOverrideContextHeaderOverrideNormalized] = normalizedHeaderOverride
ctx[paramOverrideContextHeaderOverrideDeleted] = lo.SliceToMap(lo.Keys(cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized)), func(item string) (string, interface{}) {
return item, true
})
ctx[paramOverrideContextHeaderOverride] = sanitizeHeaderOverrideMap(headerOverrideSource)
ctx["retry_index"] = info.RetryIndex
ctx["is_retry"] = info.RetryIndex > 0