From 5fe1ce89ec5f554f416d9326e91baf6397c50538 Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 15 Aug 2025 13:20:36 +0800 Subject: [PATCH] refactor: improve request type validation and enhance sensitive information masking --- common/str.go | 89 ++++++++++++++-------------- controller/relay.go | 7 ++- logger/logger.go | 9 ++- relay/channel/xunfei/relay-xunfei.go | 6 +- relay/claude_handler.go | 2 +- relay/common/relay_info.go | 38 ++++++------ relay/embedding_handler.go | 2 +- relay/gemini_handler.go | 2 +- relay/helper/common.go | 3 + service/sensitive.go | 22 +------ types/error.go | 9 --- 11 files changed, 87 insertions(+), 102 deletions(-) diff --git a/common/str.go b/common/str.go index a769b8e4..511a0a39 100644 --- a/common/str.go +++ b/common/str.go @@ -117,6 +117,48 @@ func MaskEmail(email string) string { return "***@" + email[atIndex+1:] } +// maskHostTail returns the tail parts of a domain/host that should be preserved. +// It keeps 2 parts for likely country-code TLDs (e.g., co.uk, com.cn), otherwise keeps only the TLD. +func maskHostTail(parts []string) []string { + if len(parts) < 2 { + return parts + } + lastPart := parts[len(parts)-1] + secondLastPart := parts[len(parts)-2] + if len(lastPart) == 2 && len(secondLastPart) <= 3 { + // Likely country code TLD like co.uk, com.cn + return []string{secondLastPart, lastPart} + } + return []string{lastPart} +} + +// maskHostForURL collapses subdomains and keeps only masked prefix + preserved tail. +// Example: api.openai.com -> ***.com, sub.domain.co.uk -> ***.co.uk +func maskHostForURL(host string) string { + parts := strings.Split(host, ".") + if len(parts) < 2 { + return "***" + } + tail := maskHostTail(parts) + return "***." + strings.Join(tail, ".") +} + +// maskHostForPlainDomain masks a plain domain and reflects subdomain depth with multiple ***. +// Example: openai.com -> ***.com, api.openai.com -> ***.***.com, sub.domain.co.uk -> ***.***.co.uk +func maskHostForPlainDomain(domain string) string { + parts := strings.Split(domain, ".") + if len(parts) < 2 { + return domain + } + tail := maskHostTail(parts) + numStars := len(parts) - len(tail) + if numStars < 1 { + numStars = 1 + } + stars := strings.TrimSuffix(strings.Repeat("***.", numStars), ".") + return stars + "." + strings.Join(tail, ".") +} + // MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string // Example: // http://example.com -> http://***.com @@ -140,32 +182,8 @@ func MaskSensitiveInfo(str string) string { return urlStr } - // Split host by dots - parts := strings.Split(host, ".") - if len(parts) < 2 { - // If less than 2 parts, just mask the whole host - return u.Scheme + "://***" + u.Path - } - - // Keep the TLD (Top Level Domain) and mask the rest - var maskedHost string - if len(parts) == 2 { - // example.com -> ***.com - maskedHost = "***." + parts[len(parts)-1] - } else { - // Handle cases like sub.domain.co.uk or api.example.com - // Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.) - lastPart := parts[len(parts)-1] - secondLastPart := parts[len(parts)-2] - - if len(lastPart) == 2 && len(secondLastPart) <= 3 { - // Likely country code TLD like co.uk, com.cn - maskedHost = "***." + secondLastPart + "." + lastPart - } else { - // Regular TLD like .com, .org - maskedHost = "***." + lastPart - } - } + // Mask host with unified logic + maskedHost := maskHostForURL(host) result := u.Scheme + "://" + maskedHost @@ -208,26 +226,11 @@ func MaskSensitiveInfo(str string) string { // Mask domain names without protocol (like openai.com, www.openai.com) domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`) str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string { - // Skip if it's already been processed as part of a URL + // Skip if it's already part of a URL to avoid partial masking if strings.Contains(str, "://"+domain) { return domain } - - parts := strings.Split(domain, ".") - if len(parts) < 2 { - return domain - } - - // Handle different domain patterns - if len(parts) == 2 { - // openai.com -> ***.com - return "***." + parts[1] - } else { - // www.openai.com -> ***.***.com - // api.openai.com -> ***.***.com - lastPart := parts[len(parts)-1] - return "***.***." + lastPart - } + return maskHostForPlainDomain(domain) }) // Mask IP addresses diff --git a/controller/relay.go b/controller/relay.go index 57955a18..b0c995fb 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -113,8 +113,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { meta := request.GetTokenCountMeta() if setting.ShouldCheckPromptSensitive() { - words, err := service.CheckSensitiveText(meta.CombineText) - if err != nil { + contains, words := service.CheckSensitiveText(meta.CombineText) + if contains { logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected) return @@ -139,7 +139,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { } defer func() { - if newAPIError != nil { + // Only return quota if downstream failed and quota was actually pre-consumed + if newAPIError != nil && preConsumedQuota != 0 { service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota) } }() diff --git a/logger/logger.go b/logger/logger.go index ca81d624..d59e51cb 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -4,8 +4,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/bytedance/gopkg/util/gopool" - "github.com/gin-gonic/gin" "io" "log" "one-api/common" @@ -13,6 +11,9 @@ import ( "path/filepath" "sync" "time" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" ) const ( @@ -29,6 +30,9 @@ var setupLogLock sync.Mutex var setupLogWorking bool func SetupLogger() { + defer func() { + setupLogWorking = false + }() if *common.LogDir != "" { ok := setupLogLock.TryLock() if !ok { @@ -37,7 +41,6 @@ func SetupLogger() { } defer func() { setupLogLock.Unlock() - setupLogWorking = false }() logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index 54ed476f..9d5c190f 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -206,6 +206,11 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap if err != nil || resp.StatusCode != 101 { return nil, nil, err } + + defer func() { + conn.Close() + }() + data := requestOpenAI2Xunfei(textRequest, appId, domain) err = conn.WriteJSON(data) if err != nil { @@ -229,7 +234,6 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap } dataChan <- response if response.Payload.Choices.Status == 2 { - err := conn.Close() if err != nil { common.SysLog("error closing websocket connection: " + err.Error()) } diff --git a/relay/claude_handler.go b/relay/claude_handler.go index ddc424b4..8f846f1c 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -24,7 +24,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ textRequest, ok := info.Request.(*dto.ClaudeRequest) if !ok { - common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request)) + common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request)) } err := helper.ModelMappedHelper(c, info, textRequest) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 1ebb0581..51142ff8 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -87,26 +87,24 @@ type RelayInfo struct { UsePrice bool RelayMode int OriginModelName string - //RecodeModelName string - RequestURLPath string - PromptTokens int - //SupportStreamOptions bool - ShouldIncludeUsage bool - DisablePing bool // 是否禁止向下游发送自定义 Ping - ClientWs *websocket.Conn - TargetWs *websocket.Conn - InputAudioFormat string - OutputAudioFormat string - RealtimeTools []dto.RealTimeTool - IsFirstRequest bool - AudioUsage bool - ReasoningEffort string - UserSetting dto.UserSetting - UserEmail string - UserQuota int - RelayFormat types.RelayFormat - SendResponseCount int - FinalPreConsumedQuota int // 最终预消耗的配额 + RequestURLPath string + PromptTokens int + ShouldIncludeUsage bool + DisablePing bool // 是否禁止向下游发送自定义 Ping + ClientWs *websocket.Conn + TargetWs *websocket.Conn + InputAudioFormat string + OutputAudioFormat string + RealtimeTools []dto.RealTimeTool + IsFirstRequest bool + AudioUsage bool + ReasoningEffort string + UserSetting dto.UserSetting + UserEmail string + UserQuota int + RelayFormat types.RelayFormat + SendResponseCount int + FinalPreConsumedQuota int // 最终预消耗的配额 PriceData types.PriceData diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index f7906cf9..99f0d817 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -21,7 +21,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest) if !ok { - common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request)) + common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request)) } err := helper.ModelMappedHelper(c, info, embeddingRequest) diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 3ebe0884..d50fff42 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -55,7 +55,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ request, ok := info.Request.(*dto.GeminiChatRequest) if !ok { - common.FatalLog(fmt.Sprintf("invalid request type, expected dto.GeminiChatRequest, got %T", info.Request)) + common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request)) } // model mapped 模型映射 diff --git a/relay/helper/common.go b/relay/helper/common.go index 5075314d..4b2c51eb 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -122,6 +122,9 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error { } func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) { + if ws == nil { + return + } errorObj := &dto.RealtimeEvent{ Type: "error", EventId: GetLocalRealtimeID(c), diff --git a/service/sensitive.go b/service/sensitive.go index b3e3c4d6..25cfd46f 100644 --- a/service/sensitive.go +++ b/service/sensitive.go @@ -2,7 +2,6 @@ package service import ( "errors" - "fmt" "one-api/dto" "one-api/setting" "strings" @@ -32,25 +31,8 @@ func CheckSensitiveMessages(messages []dto.Message) ([]string, error) { return nil, nil } -func CheckSensitiveText(text string) ([]string, error) { - if ok, words := SensitiveWordContains(text); ok { - return words, errors.New("sensitive words detected") - } - return nil, nil -} - -func CheckSensitiveInput(input any) ([]string, error) { - switch v := input.(type) { - case string: - return CheckSensitiveText(v) - case []string: - var builder strings.Builder - for _, s := range v { - builder.WriteString(s) - } - return CheckSensitiveText(builder.String()) - } - return CheckSensitiveText(fmt.Sprintf("%v", input)) +func CheckSensitiveText(text string) (bool, []string) { + return SensitiveWordContains(text) } // SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表 diff --git a/types/error.go b/types/error.go index 8585461a..07486c27 100644 --- a/types/error.go +++ b/types/error.go @@ -121,15 +121,6 @@ func (e *NewAPIError) MaskSensitiveError() string { return string(e.errorCode) } errStr := e.Err.Error() - if e.StatusCode == http.StatusServiceUnavailable { - if e.errorCode == ErrorCodeModelNotFound { - errStr = "上游分组模型服务不可用,请稍后再试" - } else { - if strings.Contains(errStr, "分组") || strings.Contains(errStr, "渠道") { - errStr = "上游分组模型服务不可用,请稍后再试" - } - } - } return common.MaskSensitiveInfo(errStr) }