From f7b284ad73dc69265ef642652a48446836999d31 Mon Sep 17 00:00:00 2001 From: CaIon Date: Wed, 30 Jul 2025 19:08:35 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=94=99=E8=AF=AF=E5=86=85=E5=AE=B9?= =?UTF-8?q?=E8=84=B1=E6=95=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/str.go | 95 +++++++++++++++++++++++++++++++++++++++++++++ controller/relay.go | 3 +- types/error.go | 29 +++++++++++--- 3 files changed, 119 insertions(+), 8 deletions(-) diff --git a/common/str.go b/common/str.go index 88b58c72..f5399eab 100644 --- a/common/str.go +++ b/common/str.go @@ -4,7 +4,10 @@ import ( "encoding/base64" "encoding/json" "math/rand" + "net/url" + "regexp" "strconv" + "strings" "unsafe" ) @@ -95,3 +98,95 @@ func GetJsonString(data any) string { b, _ := json.Marshal(data) return string(b) } + +// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string +// Example: +// http://example.com -> http://***.com +// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=*** +// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/*** +// 192.168.1.1 -> ***.***.***.*** +func MaskSensitiveInfo(str string) string { + // Mask URLs + urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`) + str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string { + u, err := url.Parse(urlStr) + if err != nil { + return urlStr + } + + host := u.Host + if host == "" { + return urlStr + } + + // Split host by dots + parts := strings.Split(host, ".") + if len(parts) < 2 { + // If less than 2 parts, just mask the whole host + return u.Scheme + "://***" + u.Path + } + + // Keep the TLD (Top Level Domain) and mask the rest + var maskedHost string + if len(parts) == 2 { + // example.com -> ***.com + maskedHost = "***." + parts[len(parts)-1] + } else { + // Handle cases like sub.domain.co.uk or api.example.com + // Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.) + lastPart := parts[len(parts)-1] + secondLastPart := parts[len(parts)-2] + + if len(lastPart) == 2 && len(secondLastPart) <= 3 { + // Likely country code TLD like co.uk, com.cn + maskedHost = "***." + secondLastPart + "." + lastPart + } else { + // Regular TLD like .com, .org + maskedHost = "***." + lastPart + } + } + + result := u.Scheme + "://" + maskedHost + + // Mask path + if u.Path != "" && u.Path != "/" { + pathParts := strings.Split(strings.Trim(u.Path, "/"), "/") + maskedPathParts := make([]string, len(pathParts)) + for i := range pathParts { + if pathParts[i] != "" { + maskedPathParts[i] = "***" + } + } + if len(maskedPathParts) > 0 { + result += "/" + strings.Join(maskedPathParts, "/") + } + } else if u.Path == "/" { + result += "/" + } + + // Mask query parameters + if u.RawQuery != "" { + values, err := url.ParseQuery(u.RawQuery) + if err != nil { + // If can't parse query, just mask the whole query string + result += "?***" + } else { + maskedParams := make([]string, 0, len(values)) + for key := range values { + maskedParams = append(maskedParams, key+"=***") + } + if len(maskedParams) > 0 { + result += "?" + strings.Join(maskedParams, "&") + } + } + } + + return result + }) + + // Mask IP addresses + ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`) + str = ipPattern.ReplaceAllString(str, "***.***.***.***") + + return str +} diff --git a/controller/relay.go b/controller/relay.go index d4b5fd18..01081d3d 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -62,8 +62,7 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError { other["channel_id"] = channelId other["channel_name"] = c.GetString("channel_name") other["channel_type"] = c.GetInt("channel_type") - - model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other) + model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other) } return err diff --git a/types/error.go b/types/error.go index c94bd001..2a8105c7 100644 --- a/types/error.go +++ b/types/error.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/http" + "one-api/common" "strings" ) @@ -107,19 +108,30 @@ func (e *NewAPIError) Error() string { return e.Err.Error() } +func (e *NewAPIError) MaskSensitiveError() string { + if e == nil { + return "" + } + if e.Err == nil { + return string(e.errorCode) + } + return common.MaskSensitiveInfo(e.Err.Error()) +} + func (e *NewAPIError) SetMessage(message string) { e.Err = errors.New(message) } func (e *NewAPIError) ToOpenAIError() OpenAIError { + var result OpenAIError switch e.errorType { case ErrorTypeOpenAIError: if openAIError, ok := e.RelayError.(OpenAIError); ok { - return openAIError + result = openAIError } case ErrorTypeClaudeError: if claudeError, ok := e.RelayError.(ClaudeError); ok { - return OpenAIError{ + result = OpenAIError{ Message: e.Error(), Type: claudeError.Type, Param: "", @@ -127,30 +139,35 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError { } } } - return OpenAIError{ + result = OpenAIError{ Message: e.Error(), Type: string(e.errorType), Param: "", Code: e.errorCode, } + result.Message = common.MaskSensitiveInfo(result.Message) + return result } func (e *NewAPIError) ToClaudeError() ClaudeError { + var result ClaudeError switch e.errorType { case ErrorTypeOpenAIError: openAIError := e.RelayError.(OpenAIError) - return ClaudeError{ + result = ClaudeError{ Message: e.Error(), Type: fmt.Sprintf("%v", openAIError.Code), } case ErrorTypeClaudeError: - return e.RelayError.(ClaudeError) + result = e.RelayError.(ClaudeError) default: - return ClaudeError{ + result = ClaudeError{ Message: e.Error(), Type: string(e.errorType), } } + result.Message = common.MaskSensitiveInfo(result.Message) + return result } func NewError(err error, errorCode ErrorCode) *NewAPIError {