fix: merge runtime and channel header overrides, skip missing source headers

This commit is contained in:
Seefs
2026-02-25 16:12:34 +08:00
parent 58fcd9cbca
commit 3034fb8899
5 changed files with 234 additions and 13 deletions

View File

@@ -22,6 +22,7 @@ const (
paramOverrideContextRequestHeadersRaw = "request_headers_raw"
paramOverrideContextHeaderOverride = "header_override"
paramOverrideContextHeaderOverrideNormalized = "header_override_normalized"
paramOverrideContextHeaderOverrideDeleted = "header_override_deleted_normalized"
)
var errSourceHeaderNotFound = errors.New("source header does not exist")
@@ -160,6 +161,84 @@ func getHeaderOverrideMap(info *RelayInfo) map[string]interface{} {
return info.ChannelMeta.HeadersOverride
}
func cloneHeaderOverrideMap(source map[string]interface{}) map[string]interface{} {
if len(source) == 0 {
return map[string]interface{}{}
}
target := make(map[string]interface{}, len(source))
for key, value := range source {
target[key] = value
}
return target
}
func setHeaderOverrideEntry(target map[string]interface{}, key string, value interface{}) {
key = strings.TrimSpace(key)
if key == "" {
return
}
for existing := range target {
if strings.EqualFold(strings.TrimSpace(existing), key) {
delete(target, existing)
}
}
target[key] = value
}
func isHeaderDeletedByRuntime(headerName string, deleted map[string]bool) bool {
if len(deleted) == 0 {
return false
}
normalized := normalizeHeaderContextKey(headerName)
if normalized == "" {
return false
}
return deleted[normalized]
}
func mergeHeaderOverrideSource(base, runtime map[string]interface{}, deleted map[string]bool) map[string]interface{} {
merged := make(map[string]interface{}, len(base)+len(runtime))
for key, value := range base {
if isHeaderDeletedByRuntime(key, deleted) {
continue
}
setHeaderOverrideEntry(merged, key, value)
}
for key, value := range runtime {
setHeaderOverrideEntry(merged, key, value)
}
return merged
}
func cloneDeletedHeaderKeys(source map[string]bool) map[string]bool {
if len(source) == 0 {
return map[string]bool{}
}
target := make(map[string]bool, len(source))
for key, value := range source {
if !value {
continue
}
normalized := normalizeHeaderContextKey(key)
if normalized == "" {
continue
}
target[normalized] = true
}
return target
}
func GetEffectiveHeaderOverride(info *RelayInfo) map[string]interface{} {
if info == nil {
return map[string]interface{}{}
}
base := getHeaderOverrideMap(info)
if !info.UseRuntimeHeadersOverride {
return cloneHeaderOverrideMap(base)
}
return mergeHeaderOverrideSource(base, info.RuntimeHeadersOverride, cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized))
}
func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
// 检查是否包含 "operations" 字段
if opsValue, exists := paramOverride["operations"]; exists {
@@ -480,6 +559,9 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
targetHeader = strings.TrimSpace(op.Path)
}
err = copyHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
if errors.Is(err, errSourceHeaderNotFound) {
err = nil
}
if err == nil {
contextJSON, err = marshalContextJSON(context)
}
@@ -493,6 +575,9 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
targetHeader = strings.TrimSpace(op.Path)
}
err = moveHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
if errors.Is(err, errSourceHeaderNotFound) {
err = nil
}
if err == nil {
contextJSON, err = marshalContextJSON(context)
}
@@ -647,8 +732,13 @@ func setHeaderOverrideInContext(context map[string]interface{}, headerName strin
rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
rawHeaders[headerName] = headerValue
normalizedHeaderName := normalizeHeaderContextKey(headerName)
normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
normalizedHeaders[normalizeHeaderContextKey(headerName)] = headerValue
normalizedHeaders[normalizedHeaderName] = headerValue
if normalizedHeaderName != "" {
deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted)
delete(deletedHeaders, normalizedHeaderName)
}
return nil
}
@@ -693,7 +783,12 @@ func deleteHeaderOverrideInContext(context map[string]interface{}, headerName st
}
normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
delete(normalizedHeaders, normalizeHeaderContextKey(headerName))
normalizedHeaderName := normalizeHeaderContextKey(headerName)
delete(normalizedHeaders, normalizedHeaderName)
if normalizedHeaderName != "" {
deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted)
deletedHeaders[normalizedHeaderName] = true
}
return nil
}
@@ -1062,9 +1157,39 @@ func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]in
info.RuntimeHeadersOverride = lo.SliceToMap(sanitized, func(item lo.Entry[string, interface{}]) (string, interface{}) {
return item.Key, item.Value
})
info.RuntimeHeadersDeletedNormalized = sanitizeRuntimeDeletedHeadersFromContext(context)
info.UseRuntimeHeadersOverride = true
}
func sanitizeRuntimeDeletedHeadersFromContext(context map[string]interface{}) map[string]bool {
deletedRaw, exists := context[paramOverrideContextHeaderOverrideDeleted]
if !exists {
return nil
}
deletedMap, ok := deletedRaw.(map[string]interface{})
if !ok || len(deletedMap) == 0 {
return nil
}
entries := lo.Entries(deletedMap)
sanitized := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (string, bool) {
if keep, ok := item.Value.(bool); ok && !keep {
return "", false
}
normalized := normalizeHeaderContextKey(item.Key)
if normalized == "" {
return "", false
}
return normalized, true
})
if len(sanitized) == 0 {
return nil
}
keys := lo.Uniq(sanitized)
return lo.SliceToMap(keys, func(item string) (string, bool) {
return item, true
})
}
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
sourceValue := gjson.Get(jsonStr, fromPath)
if !sourceValue.Exists() {
@@ -1513,13 +1638,13 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
ctx[paramOverrideContextRequestHeaders] = buildNormalizedHeaders(info.RequestHeaders)
ctx[paramOverrideContextRequestHeadersRaw] = buildRawHeaders(info.RequestHeaders)
headerOverrideSource := getHeaderOverrideMap(info)
if info.UseRuntimeHeadersOverride {
headerOverrideSource = info.RuntimeHeadersOverride
}
headerOverrideSource := GetEffectiveHeaderOverride(info)
rawHeaderOverride, normalizedHeaderOverride := buildHeaderOverrideContext(headerOverrideSource)
ctx[paramOverrideContextHeaderOverride] = rawHeaderOverride
ctx[paramOverrideContextHeaderOverrideNormalized] = normalizedHeaderOverride
ctx[paramOverrideContextHeaderOverrideDeleted] = lo.SliceToMap(lo.Keys(cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized)), func(item string) (string, interface{}) {
return item, true
})
ctx["retry_index"] = info.RetryIndex
ctx["is_retry"] = info.RetryIndex > 0