diff --git a/common/str.go b/common/str.go index f5399eab..7d4cdaf0 100644 --- a/common/str.go +++ b/common/str.go @@ -99,12 +99,15 @@ func GetJsonString(data any) string { return string(b) } -// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string +// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string // Example: // http://example.com -> http://***.com // https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=*** // https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/*** // 192.168.1.1 -> ***.***.***.*** +// openai.com -> ***.com +// www.openai.com -> ***.***.com +// api.openai.com -> ***.***.com func MaskSensitiveInfo(str string) string { // Mask URLs urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`) @@ -184,6 +187,31 @@ func MaskSensitiveInfo(str string) string { return result }) + // Mask domain names without protocol (like openai.com, www.openai.com) + domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`) + str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string { + // Skip if it's already been processed as part of a URL + if strings.Contains(str, "://"+domain) { + return domain + } + + parts := strings.Split(domain, ".") + if len(parts) < 2 { + return domain + } + + // Handle different domain patterns + if len(parts) == 2 { + // openai.com -> ***.com + return "***." + parts[1] + } else { + // www.openai.com -> ***.***.com + // api.openai.com -> ***.***.com + lastPart := parts[len(parts)-1] + return "***.***." + lastPart + } + }) + // Mask IP addresses ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`) str = ipPattern.ReplaceAllString(str, "***.***.***.***") diff --git a/middleware/distributor.go b/middleware/distributor.go index 286a4d1f..28b66a3a 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -107,11 +107,11 @@ func Distribute() func(c *gin.Context) { // common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) // message = "数据库一致性已被破坏,请联系管理员" //} - abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message) + abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, string(types.ErrorCodeModelNotFound)) return } if channel == nil { - abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model)) + abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound)) return } } diff --git a/middleware/utils.go b/middleware/utils.go index e23bbff7..77d1eb80 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -7,12 +7,17 @@ import ( "one-api/logger" ) -func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { +func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string, code ...string) { + codeStr := "" + if len(code) > 0 { + codeStr = code[0] + } userId := c.GetInt("id") c.JSON(statusCode, gin.H{ "error": gin.H{ "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), "type": "new_api_error", + "code": codeStr, }, }) c.Abort() diff --git a/types/error.go b/types/error.go index 2cfeb541..8585461a 100644 --- a/types/error.go +++ b/types/error.go @@ -67,6 +67,7 @@ const ( ErrorCodeBadResponseBody ErrorCode = "bad_response_body" ErrorCodeEmptyResponse ErrorCode = "empty_response" ErrorCodeAwsInvokeError ErrorCode = "aws_invoke_error" + ErrorCodeModelNotFound ErrorCode = "model_not_found" // sql error ErrorCodeQueryDataError ErrorCode = "query_data_error" @@ -119,7 +120,17 @@ func (e *NewAPIError) MaskSensitiveError() string { if e.Err == nil { return string(e.errorCode) } - return common.MaskSensitiveInfo(e.Err.Error()) + errStr := e.Err.Error() + if e.StatusCode == http.StatusServiceUnavailable { + if e.errorCode == ErrorCodeModelNotFound { + errStr = "上游分组模型服务不可用,请稍后再试" + } else { + if strings.Contains(errStr, "分组") || strings.Contains(errStr, "渠道") { + errStr = "上游分组模型服务不可用,请稍后再试" + } + } + } + return common.MaskSensitiveInfo(errStr) } func (e *NewAPIError) SetMessage(message string) {