refactor: improve request type validation and enhance sensitive information masking
This commit is contained in:
@@ -117,6 +117,48 @@ func MaskEmail(email string) string {
|
|||||||
return "***@" + email[atIndex+1:]
|
return "***@" + email[atIndex+1:]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// maskHostTail returns the tail parts of a domain/host that should be preserved.
|
||||||
|
// It keeps 2 parts for likely country-code TLDs (e.g., co.uk, com.cn), otherwise keeps only the TLD.
|
||||||
|
func maskHostTail(parts []string) []string {
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
lastPart := parts[len(parts)-1]
|
||||||
|
secondLastPart := parts[len(parts)-2]
|
||||||
|
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
|
||||||
|
// Likely country code TLD like co.uk, com.cn
|
||||||
|
return []string{secondLastPart, lastPart}
|
||||||
|
}
|
||||||
|
return []string{lastPart}
|
||||||
|
}
|
||||||
|
|
||||||
|
// maskHostForURL collapses subdomains and keeps only masked prefix + preserved tail.
|
||||||
|
// Example: api.openai.com -> ***.com, sub.domain.co.uk -> ***.co.uk
|
||||||
|
func maskHostForURL(host string) string {
|
||||||
|
parts := strings.Split(host, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return "***"
|
||||||
|
}
|
||||||
|
tail := maskHostTail(parts)
|
||||||
|
return "***." + strings.Join(tail, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
// maskHostForPlainDomain masks a plain domain and reflects subdomain depth with multiple ***.
|
||||||
|
// Example: openai.com -> ***.com, api.openai.com -> ***.***.com, sub.domain.co.uk -> ***.***.co.uk
|
||||||
|
func maskHostForPlainDomain(domain string) string {
|
||||||
|
parts := strings.Split(domain, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return domain
|
||||||
|
}
|
||||||
|
tail := maskHostTail(parts)
|
||||||
|
numStars := len(parts) - len(tail)
|
||||||
|
if numStars < 1 {
|
||||||
|
numStars = 1
|
||||||
|
}
|
||||||
|
stars := strings.TrimSuffix(strings.Repeat("***.", numStars), ".")
|
||||||
|
return stars + "." + strings.Join(tail, ".")
|
||||||
|
}
|
||||||
|
|
||||||
// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string
|
// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string
|
||||||
// Example:
|
// Example:
|
||||||
// http://example.com -> http://***.com
|
// http://example.com -> http://***.com
|
||||||
@@ -140,32 +182,8 @@ func MaskSensitiveInfo(str string) string {
|
|||||||
return urlStr
|
return urlStr
|
||||||
}
|
}
|
||||||
|
|
||||||
// Split host by dots
|
// Mask host with unified logic
|
||||||
parts := strings.Split(host, ".")
|
maskedHost := maskHostForURL(host)
|
||||||
if len(parts) < 2 {
|
|
||||||
// If less than 2 parts, just mask the whole host
|
|
||||||
return u.Scheme + "://***" + u.Path
|
|
||||||
}
|
|
||||||
|
|
||||||
// Keep the TLD (Top Level Domain) and mask the rest
|
|
||||||
var maskedHost string
|
|
||||||
if len(parts) == 2 {
|
|
||||||
// example.com -> ***.com
|
|
||||||
maskedHost = "***." + parts[len(parts)-1]
|
|
||||||
} else {
|
|
||||||
// Handle cases like sub.domain.co.uk or api.example.com
|
|
||||||
// Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.)
|
|
||||||
lastPart := parts[len(parts)-1]
|
|
||||||
secondLastPart := parts[len(parts)-2]
|
|
||||||
|
|
||||||
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
|
|
||||||
// Likely country code TLD like co.uk, com.cn
|
|
||||||
maskedHost = "***." + secondLastPart + "." + lastPart
|
|
||||||
} else {
|
|
||||||
// Regular TLD like .com, .org
|
|
||||||
maskedHost = "***." + lastPart
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result := u.Scheme + "://" + maskedHost
|
result := u.Scheme + "://" + maskedHost
|
||||||
|
|
||||||
@@ -208,26 +226,11 @@ func MaskSensitiveInfo(str string) string {
|
|||||||
// Mask domain names without protocol (like openai.com, www.openai.com)
|
// Mask domain names without protocol (like openai.com, www.openai.com)
|
||||||
domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
|
domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
|
||||||
str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
|
str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
|
||||||
// Skip if it's already been processed as part of a URL
|
// Skip if it's already part of a URL to avoid partial masking
|
||||||
if strings.Contains(str, "://"+domain) {
|
if strings.Contains(str, "://"+domain) {
|
||||||
return domain
|
return domain
|
||||||
}
|
}
|
||||||
|
return maskHostForPlainDomain(domain)
|
||||||
parts := strings.Split(domain, ".")
|
|
||||||
if len(parts) < 2 {
|
|
||||||
return domain
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle different domain patterns
|
|
||||||
if len(parts) == 2 {
|
|
||||||
// openai.com -> ***.com
|
|
||||||
return "***." + parts[1]
|
|
||||||
} else {
|
|
||||||
// www.openai.com -> ***.***.com
|
|
||||||
// api.openai.com -> ***.***.com
|
|
||||||
lastPart := parts[len(parts)-1]
|
|
||||||
return "***.***." + lastPart
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// Mask IP addresses
|
// Mask IP addresses
|
||||||
|
|||||||
@@ -113,8 +113,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
|||||||
meta := request.GetTokenCountMeta()
|
meta := request.GetTokenCountMeta()
|
||||||
|
|
||||||
if setting.ShouldCheckPromptSensitive() {
|
if setting.ShouldCheckPromptSensitive() {
|
||||||
words, err := service.CheckSensitiveText(meta.CombineText)
|
contains, words := service.CheckSensitiveText(meta.CombineText)
|
||||||
if err != nil {
|
if contains {
|
||||||
logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
||||||
newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||||
return
|
return
|
||||||
@@ -139,7 +139,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if newAPIError != nil {
|
// Only return quota if downstream failed and quota was actually pre-consumed
|
||||||
|
if newAPIError != nil && preConsumedQuota != 0 {
|
||||||
service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
|
service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/bytedance/gopkg/util/gopool"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@@ -13,6 +11,9 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -29,6 +30,9 @@ var setupLogLock sync.Mutex
|
|||||||
var setupLogWorking bool
|
var setupLogWorking bool
|
||||||
|
|
||||||
func SetupLogger() {
|
func SetupLogger() {
|
||||||
|
defer func() {
|
||||||
|
setupLogWorking = false
|
||||||
|
}()
|
||||||
if *common.LogDir != "" {
|
if *common.LogDir != "" {
|
||||||
ok := setupLogLock.TryLock()
|
ok := setupLogLock.TryLock()
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -37,7 +41,6 @@ func SetupLogger() {
|
|||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
setupLogLock.Unlock()
|
setupLogLock.Unlock()
|
||||||
setupLogWorking = false
|
|
||||||
}()
|
}()
|
||||||
logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
|
logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
|
||||||
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||||
|
|||||||
@@ -206,6 +206,11 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
|
|||||||
if err != nil || resp.StatusCode != 101 {
|
if err != nil || resp.StatusCode != 101 {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
conn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
||||||
err = conn.WriteJSON(data)
|
err = conn.WriteJSON(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -229,7 +234,6 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
|
|||||||
}
|
}
|
||||||
dataChan <- response
|
dataChan <- response
|
||||||
if response.Payload.Choices.Status == 2 {
|
if response.Payload.Choices.Status == 2 {
|
||||||
err := conn.Close()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog("error closing websocket connection: " + err.Error())
|
common.SysLog("error closing websocket connection: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
|||||||
textRequest, ok := info.Request.(*dto.ClaudeRequest)
|
textRequest, ok := info.Request.(*dto.ClaudeRequest)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request))
|
common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request))
|
||||||
}
|
}
|
||||||
|
|
||||||
err := helper.ModelMappedHelper(c, info, textRequest)
|
err := helper.ModelMappedHelper(c, info, textRequest)
|
||||||
|
|||||||
@@ -87,10 +87,8 @@ type RelayInfo struct {
|
|||||||
UsePrice bool
|
UsePrice bool
|
||||||
RelayMode int
|
RelayMode int
|
||||||
OriginModelName string
|
OriginModelName string
|
||||||
//RecodeModelName string
|
|
||||||
RequestURLPath string
|
RequestURLPath string
|
||||||
PromptTokens int
|
PromptTokens int
|
||||||
//SupportStreamOptions bool
|
|
||||||
ShouldIncludeUsage bool
|
ShouldIncludeUsage bool
|
||||||
DisablePing bool // 是否禁止向下游发送自定义 Ping
|
DisablePing bool // 是否禁止向下游发送自定义 Ping
|
||||||
ClientWs *websocket.Conn
|
ClientWs *websocket.Conn
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
|||||||
|
|
||||||
embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest)
|
embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest)
|
||||||
if !ok {
|
if !ok {
|
||||||
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request))
|
common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request))
|
||||||
}
|
}
|
||||||
|
|
||||||
err := helper.ModelMappedHelper(c, info, embeddingRequest)
|
err := helper.ModelMappedHelper(c, info, embeddingRequest)
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
|||||||
|
|
||||||
request, ok := info.Request.(*dto.GeminiChatRequest)
|
request, ok := info.Request.(*dto.GeminiChatRequest)
|
||||||
if !ok {
|
if !ok {
|
||||||
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.GeminiChatRequest, got %T", info.Request))
|
common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request))
|
||||||
}
|
}
|
||||||
|
|
||||||
// model mapped 模型映射
|
// model mapped 模型映射
|
||||||
|
|||||||
@@ -122,6 +122,9 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) {
|
func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) {
|
||||||
|
if ws == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
errorObj := &dto.RealtimeEvent{
|
errorObj := &dto.RealtimeEvent{
|
||||||
Type: "error",
|
Type: "error",
|
||||||
EventId: GetLocalRealtimeID(c),
|
EventId: GetLocalRealtimeID(c),
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -32,25 +31,8 @@ func CheckSensitiveMessages(messages []dto.Message) ([]string, error) {
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckSensitiveText(text string) ([]string, error) {
|
func CheckSensitiveText(text string) (bool, []string) {
|
||||||
if ok, words := SensitiveWordContains(text); ok {
|
return SensitiveWordContains(text)
|
||||||
return words, errors.New("sensitive words detected")
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CheckSensitiveInput(input any) ([]string, error) {
|
|
||||||
switch v := input.(type) {
|
|
||||||
case string:
|
|
||||||
return CheckSensitiveText(v)
|
|
||||||
case []string:
|
|
||||||
var builder strings.Builder
|
|
||||||
for _, s := range v {
|
|
||||||
builder.WriteString(s)
|
|
||||||
}
|
|
||||||
return CheckSensitiveText(builder.String())
|
|
||||||
}
|
|
||||||
return CheckSensitiveText(fmt.Sprintf("%v", input))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
|
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
|
||||||
|
|||||||
@@ -121,15 +121,6 @@ func (e *NewAPIError) MaskSensitiveError() string {
|
|||||||
return string(e.errorCode)
|
return string(e.errorCode)
|
||||||
}
|
}
|
||||||
errStr := e.Err.Error()
|
errStr := e.Err.Error()
|
||||||
if e.StatusCode == http.StatusServiceUnavailable {
|
|
||||||
if e.errorCode == ErrorCodeModelNotFound {
|
|
||||||
errStr = "上游分组模型服务不可用,请稍后再试"
|
|
||||||
} else {
|
|
||||||
if strings.Contains(errStr, "分组") || strings.Contains(errStr, "渠道") {
|
|
||||||
errStr = "上游分组模型服务不可用,请稍后再试"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return common.MaskSensitiveInfo(errStr)
|
return common.MaskSensitiveInfo(errStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user