refactor: enhance error handling and masking for model not found scenarios

This commit is contained in:
CaIon
2025-08-15 12:41:05 +08:00
parent 7f1f368065
commit 44e9b02b3f
4 changed files with 49 additions and 5 deletions

View File

@@ -99,12 +99,15 @@ func GetJsonString(data any) string {
return string(b)
}
// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string
// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string
// Example:
// http://example.com -> http://***.com
// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
// 192.168.1.1 -> ***.***.***.***
// openai.com -> ***.com
// www.openai.com -> ***.***.com
// api.openai.com -> ***.***.com
func MaskSensitiveInfo(str string) string {
// Mask URLs
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
@@ -184,6 +187,31 @@ func MaskSensitiveInfo(str string) string {
return result
})
// Mask domain names without protocol (like openai.com, www.openai.com)
domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
// Skip if it's already been processed as part of a URL
if strings.Contains(str, "://"+domain) {
return domain
}
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
// Handle different domain patterns
if len(parts) == 2 {
// openai.com -> ***.com
return "***." + parts[1]
} else {
// www.openai.com -> ***.***.com
// api.openai.com -> ***.***.com
lastPart := parts[len(parts)-1]
return "***.***." + lastPart
}
})
// Mask IP addresses
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
str = ipPattern.ReplaceAllString(str, "***.***.***.***")

View File

@@ -107,11 +107,11 @@ func Distribute() func(c *gin.Context) {
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
// message = "数据库一致性已被破坏,请联系管理员"
//}
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, string(types.ErrorCodeModelNotFound))
return
}
if channel == nil {
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道distributor", userGroup, modelRequest.Model))
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道distributor", userGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound))
return
}
}

View File

@@ -7,12 +7,17 @@ import (
"one-api/logger"
)
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string, code ...string) {
codeStr := ""
if len(code) > 0 {
codeStr = code[0]
}
userId := c.GetInt("id")
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
"type": "new_api_error",
"code": codeStr,
},
})
c.Abort()

View File

@@ -67,6 +67,7 @@ const (
ErrorCodeBadResponseBody ErrorCode = "bad_response_body"
ErrorCodeEmptyResponse ErrorCode = "empty_response"
ErrorCodeAwsInvokeError ErrorCode = "aws_invoke_error"
ErrorCodeModelNotFound ErrorCode = "model_not_found"
// sql error
ErrorCodeQueryDataError ErrorCode = "query_data_error"
@@ -119,7 +120,17 @@ func (e *NewAPIError) MaskSensitiveError() string {
if e.Err == nil {
return string(e.errorCode)
}
return common.MaskSensitiveInfo(e.Err.Error())
errStr := e.Err.Error()
if e.StatusCode == http.StatusServiceUnavailable {
if e.errorCode == ErrorCodeModelNotFound {
errStr = "上游分组模型服务不可用,请稍后再试"
} else {
if strings.Contains(errStr, "分组") || strings.Contains(errStr, "渠道") {
errStr = "上游分组模型服务不可用,请稍后再试"
}
}
}
return common.MaskSensitiveInfo(errStr)
}
func (e *NewAPIError) SetMessage(message string) {