refactor: enhance error handling and masking for model not found scenarios
This commit is contained in:
@@ -99,12 +99,15 @@ func GetJsonString(data any) string {
|
|||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string
|
// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string
|
||||||
// Example:
|
// Example:
|
||||||
// http://example.com -> http://***.com
|
// http://example.com -> http://***.com
|
||||||
// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
|
// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
|
||||||
// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
|
// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
|
||||||
// 192.168.1.1 -> ***.***.***.***
|
// 192.168.1.1 -> ***.***.***.***
|
||||||
|
// openai.com -> ***.com
|
||||||
|
// www.openai.com -> ***.***.com
|
||||||
|
// api.openai.com -> ***.***.com
|
||||||
func MaskSensitiveInfo(str string) string {
|
func MaskSensitiveInfo(str string) string {
|
||||||
// Mask URLs
|
// Mask URLs
|
||||||
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
|
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
|
||||||
@@ -184,6 +187,31 @@ func MaskSensitiveInfo(str string) string {
|
|||||||
return result
|
return result
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Mask domain names without protocol (like openai.com, www.openai.com)
|
||||||
|
domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
|
||||||
|
str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
|
||||||
|
// Skip if it's already been processed as part of a URL
|
||||||
|
if strings.Contains(str, "://"+domain) {
|
||||||
|
return domain
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(domain, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return domain
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle different domain patterns
|
||||||
|
if len(parts) == 2 {
|
||||||
|
// openai.com -> ***.com
|
||||||
|
return "***." + parts[1]
|
||||||
|
} else {
|
||||||
|
// www.openai.com -> ***.***.com
|
||||||
|
// api.openai.com -> ***.***.com
|
||||||
|
lastPart := parts[len(parts)-1]
|
||||||
|
return "***.***." + lastPart
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Mask IP addresses
|
// Mask IP addresses
|
||||||
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
|
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
|
||||||
str = ipPattern.ReplaceAllString(str, "***.***.***.***")
|
str = ipPattern.ReplaceAllString(str, "***.***.***.***")
|
||||||
|
|||||||
@@ -107,11 +107,11 @@ func Distribute() func(c *gin.Context) {
|
|||||||
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||||
// message = "数据库一致性已被破坏,请联系管理员"
|
// message = "数据库一致性已被破坏,请联系管理员"
|
||||||
//}
|
//}
|
||||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
|
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, string(types.ErrorCodeModelNotFound))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if channel == nil {
|
if channel == nil {
|
||||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model))
|
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,12 +7,17 @@ import (
|
|||||||
"one-api/logger"
|
"one-api/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
|
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string, code ...string) {
|
||||||
|
codeStr := ""
|
||||||
|
if len(code) > 0 {
|
||||||
|
codeStr = code[0]
|
||||||
|
}
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
c.JSON(statusCode, gin.H{
|
c.JSON(statusCode, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
||||||
"type": "new_api_error",
|
"type": "new_api_error",
|
||||||
|
"code": codeStr,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
c.Abort()
|
c.Abort()
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ const (
|
|||||||
ErrorCodeBadResponseBody ErrorCode = "bad_response_body"
|
ErrorCodeBadResponseBody ErrorCode = "bad_response_body"
|
||||||
ErrorCodeEmptyResponse ErrorCode = "empty_response"
|
ErrorCodeEmptyResponse ErrorCode = "empty_response"
|
||||||
ErrorCodeAwsInvokeError ErrorCode = "aws_invoke_error"
|
ErrorCodeAwsInvokeError ErrorCode = "aws_invoke_error"
|
||||||
|
ErrorCodeModelNotFound ErrorCode = "model_not_found"
|
||||||
|
|
||||||
// sql error
|
// sql error
|
||||||
ErrorCodeQueryDataError ErrorCode = "query_data_error"
|
ErrorCodeQueryDataError ErrorCode = "query_data_error"
|
||||||
@@ -119,7 +120,17 @@ func (e *NewAPIError) MaskSensitiveError() string {
|
|||||||
if e.Err == nil {
|
if e.Err == nil {
|
||||||
return string(e.errorCode)
|
return string(e.errorCode)
|
||||||
}
|
}
|
||||||
return common.MaskSensitiveInfo(e.Err.Error())
|
errStr := e.Err.Error()
|
||||||
|
if e.StatusCode == http.StatusServiceUnavailable {
|
||||||
|
if e.errorCode == ErrorCodeModelNotFound {
|
||||||
|
errStr = "上游分组模型服务不可用,请稍后再试"
|
||||||
|
} else {
|
||||||
|
if strings.Contains(errStr, "分组") || strings.Contains(errStr, "渠道") {
|
||||||
|
errStr = "上游分组模型服务不可用,请稍后再试"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return common.MaskSensitiveInfo(errStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *NewAPIError) SetMessage(message string) {
|
func (e *NewAPIError) SetMessage(message string) {
|
||||||
|
|||||||
Reference in New Issue
Block a user