refactor: enhance error handling and masking for model not found scenarios

This commit is contained in:
CaIon
2025-08-15 12:41:05 +08:00
parent 7f1f368065
commit 44e9b02b3f
4 changed files with 49 additions and 5 deletions

View File

@@ -99,12 +99,15 @@ func GetJsonString(data any) string {
return string(b) return string(b)
} }
// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string // MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string
// Example: // Example:
// http://example.com -> http://***.com // http://example.com -> http://***.com
// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=*** // https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/*** // https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
// 192.168.1.1 -> ***.***.***.*** // 192.168.1.1 -> ***.***.***.***
// openai.com -> ***.com
// www.openai.com -> ***.***.com
// api.openai.com -> ***.***.com
func MaskSensitiveInfo(str string) string { func MaskSensitiveInfo(str string) string {
// Mask URLs // Mask URLs
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`) urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
@@ -184,6 +187,31 @@ func MaskSensitiveInfo(str string) string {
return result return result
}) })
// Mask domain names without protocol (like openai.com, www.openai.com)
domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
// Skip if it's already been processed as part of a URL
if strings.Contains(str, "://"+domain) {
return domain
}
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
// Handle different domain patterns
if len(parts) == 2 {
// openai.com -> ***.com
return "***." + parts[1]
} else {
// www.openai.com -> ***.***.com
// api.openai.com -> ***.***.com
lastPart := parts[len(parts)-1]
return "***.***." + lastPart
}
})
// Mask IP addresses // Mask IP addresses
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`) ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
str = ipPattern.ReplaceAllString(str, "***.***.***.***") str = ipPattern.ReplaceAllString(str, "***.***.***.***")

View File

@@ -107,11 +107,11 @@ func Distribute() func(c *gin.Context) {
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) // common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
// message = "数据库一致性已被破坏,请联系管理员" // message = "数据库一致性已被破坏,请联系管理员"
//} //}
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message) abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, string(types.ErrorCodeModelNotFound))
return return
} }
if channel == nil { if channel == nil {
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道distributor", userGroup, modelRequest.Model)) abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道distributor", userGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound))
return return
} }
} }

View File

@@ -7,12 +7,17 @@ import (
"one-api/logger" "one-api/logger"
) )
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string, code ...string) {
codeStr := ""
if len(code) > 0 {
codeStr = code[0]
}
userId := c.GetInt("id") userId := c.GetInt("id")
c.JSON(statusCode, gin.H{ c.JSON(statusCode, gin.H{
"error": gin.H{ "error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
"type": "new_api_error", "type": "new_api_error",
"code": codeStr,
}, },
}) })
c.Abort() c.Abort()

View File

@@ -67,6 +67,7 @@ const (
ErrorCodeBadResponseBody ErrorCode = "bad_response_body" ErrorCodeBadResponseBody ErrorCode = "bad_response_body"
ErrorCodeEmptyResponse ErrorCode = "empty_response" ErrorCodeEmptyResponse ErrorCode = "empty_response"
ErrorCodeAwsInvokeError ErrorCode = "aws_invoke_error" ErrorCodeAwsInvokeError ErrorCode = "aws_invoke_error"
ErrorCodeModelNotFound ErrorCode = "model_not_found"
// sql error // sql error
ErrorCodeQueryDataError ErrorCode = "query_data_error" ErrorCodeQueryDataError ErrorCode = "query_data_error"
@@ -119,7 +120,17 @@ func (e *NewAPIError) MaskSensitiveError() string {
if e.Err == nil { if e.Err == nil {
return string(e.errorCode) return string(e.errorCode)
} }
return common.MaskSensitiveInfo(e.Err.Error()) errStr := e.Err.Error()
if e.StatusCode == http.StatusServiceUnavailable {
if e.errorCode == ErrorCodeModelNotFound {
errStr = "上游分组模型服务不可用,请稍后再试"
} else {
if strings.Contains(errStr, "分组") || strings.Contains(errStr, "渠道") {
errStr = "上游分组模型服务不可用,请稍后再试"
}
}
}
return common.MaskSensitiveInfo(errStr)
} }
func (e *NewAPIError) SetMessage(message string) { func (e *NewAPIError) SetMessage(message string) {