refactor: improve request type validation and enhance sensitive information masking

This commit is contained in:
CaIon
2025-08-15 13:20:36 +08:00
parent 03fc89da00
commit 5fe1ce89ec
11 changed files with 87 additions and 102 deletions

View File

@@ -117,6 +117,48 @@ func MaskEmail(email string) string {
return "***@" + email[atIndex+1:] return "***@" + email[atIndex+1:]
} }
// maskHostTail returns the tail parts of a domain/host that should be preserved.
// It keeps 2 parts for likely country-code TLDs (e.g., co.uk, com.cn), otherwise keeps only the TLD.
func maskHostTail(parts []string) []string {
if len(parts) < 2 {
return parts
}
lastPart := parts[len(parts)-1]
secondLastPart := parts[len(parts)-2]
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
// Likely country code TLD like co.uk, com.cn
return []string{secondLastPart, lastPart}
}
return []string{lastPart}
}
// maskHostForURL collapses subdomains and keeps only masked prefix + preserved tail.
// Example: api.openai.com -> ***.com, sub.domain.co.uk -> ***.co.uk
func maskHostForURL(host string) string {
parts := strings.Split(host, ".")
if len(parts) < 2 {
return "***"
}
tail := maskHostTail(parts)
return "***." + strings.Join(tail, ".")
}
// maskHostForPlainDomain masks a plain domain and reflects subdomain depth with multiple ***.
// Example: openai.com -> ***.com, api.openai.com -> ***.***.com, sub.domain.co.uk -> ***.***.co.uk
func maskHostForPlainDomain(domain string) string {
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
tail := maskHostTail(parts)
numStars := len(parts) - len(tail)
if numStars < 1 {
numStars = 1
}
stars := strings.TrimSuffix(strings.Repeat("***.", numStars), ".")
return stars + "." + strings.Join(tail, ".")
}
// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string // MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string
// Example: // Example:
// http://example.com -> http://***.com // http://example.com -> http://***.com
@@ -140,32 +182,8 @@ func MaskSensitiveInfo(str string) string {
return urlStr return urlStr
} }
// Split host by dots // Mask host with unified logic
parts := strings.Split(host, ".") maskedHost := maskHostForURL(host)
if len(parts) < 2 {
// If less than 2 parts, just mask the whole host
return u.Scheme + "://***" + u.Path
}
// Keep the TLD (Top Level Domain) and mask the rest
var maskedHost string
if len(parts) == 2 {
// example.com -> ***.com
maskedHost = "***." + parts[len(parts)-1]
} else {
// Handle cases like sub.domain.co.uk or api.example.com
// Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.)
lastPart := parts[len(parts)-1]
secondLastPart := parts[len(parts)-2]
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
// Likely country code TLD like co.uk, com.cn
maskedHost = "***." + secondLastPart + "." + lastPart
} else {
// Regular TLD like .com, .org
maskedHost = "***." + lastPart
}
}
result := u.Scheme + "://" + maskedHost result := u.Scheme + "://" + maskedHost
@@ -208,26 +226,11 @@ func MaskSensitiveInfo(str string) string {
// Mask domain names without protocol (like openai.com, www.openai.com) // Mask domain names without protocol (like openai.com, www.openai.com)
domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`) domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string { str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
// Skip if it's already been processed as part of a URL // Skip if it's already part of a URL to avoid partial masking
if strings.Contains(str, "://"+domain) { if strings.Contains(str, "://"+domain) {
return domain return domain
} }
return maskHostForPlainDomain(domain)
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
// Handle different domain patterns
if len(parts) == 2 {
// openai.com -> ***.com
return "***." + parts[1]
} else {
// www.openai.com -> ***.***.com
// api.openai.com -> ***.***.com
lastPart := parts[len(parts)-1]
return "***.***." + lastPart
}
}) })
// Mask IP addresses // Mask IP addresses

View File

@@ -113,8 +113,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
meta := request.GetTokenCountMeta() meta := request.GetTokenCountMeta()
if setting.ShouldCheckPromptSensitive() { if setting.ShouldCheckPromptSensitive() {
words, err := service.CheckSensitiveText(meta.CombineText) contains, words := service.CheckSensitiveText(meta.CombineText)
if err != nil { if contains {
logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected) newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
return return
@@ -139,7 +139,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
} }
defer func() { defer func() {
if newAPIError != nil { // Only return quota if downstream failed and quota was actually pre-consumed
if newAPIError != nil && preConsumedQuota != 0 {
service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota) service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
} }
}() }()

View File

@@ -4,8 +4,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"io" "io"
"log" "log"
"one-api/common" "one-api/common"
@@ -13,6 +11,9 @@ import (
"path/filepath" "path/filepath"
"sync" "sync"
"time" "time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
) )
const ( const (
@@ -29,6 +30,9 @@ var setupLogLock sync.Mutex
var setupLogWorking bool var setupLogWorking bool
func SetupLogger() { func SetupLogger() {
defer func() {
setupLogWorking = false
}()
if *common.LogDir != "" { if *common.LogDir != "" {
ok := setupLogLock.TryLock() ok := setupLogLock.TryLock()
if !ok { if !ok {
@@ -37,7 +41,6 @@ func SetupLogger() {
} }
defer func() { defer func() {
setupLogLock.Unlock() setupLogLock.Unlock()
setupLogWorking = false
}() }()
logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)

View File

@@ -206,6 +206,11 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
if err != nil || resp.StatusCode != 101 { if err != nil || resp.StatusCode != 101 {
return nil, nil, err return nil, nil, err
} }
defer func() {
conn.Close()
}()
data := requestOpenAI2Xunfei(textRequest, appId, domain) data := requestOpenAI2Xunfei(textRequest, appId, domain)
err = conn.WriteJSON(data) err = conn.WriteJSON(data)
if err != nil { if err != nil {
@@ -229,7 +234,6 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
} }
dataChan <- response dataChan <- response
if response.Payload.Choices.Status == 2 { if response.Payload.Choices.Status == 2 {
err := conn.Close()
if err != nil { if err != nil {
common.SysLog("error closing websocket connection: " + err.Error()) common.SysLog("error closing websocket connection: " + err.Error())
} }

View File

@@ -24,7 +24,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
textRequest, ok := info.Request.(*dto.ClaudeRequest) textRequest, ok := info.Request.(*dto.ClaudeRequest)
if !ok { if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request)) common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request))
} }
err := helper.ModelMappedHelper(c, info, textRequest) err := helper.ModelMappedHelper(c, info, textRequest)

View File

@@ -87,10 +87,8 @@ type RelayInfo struct {
UsePrice bool UsePrice bool
RelayMode int RelayMode int
OriginModelName string OriginModelName string
//RecodeModelName string
RequestURLPath string RequestURLPath string
PromptTokens int PromptTokens int
//SupportStreamOptions bool
ShouldIncludeUsage bool ShouldIncludeUsage bool
DisablePing bool // 是否禁止向下游发送自定义 Ping DisablePing bool // 是否禁止向下游发送自定义 Ping
ClientWs *websocket.Conn ClientWs *websocket.Conn

View File

@@ -21,7 +21,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest) embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest)
if !ok { if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request)) common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request))
} }
err := helper.ModelMappedHelper(c, info, embeddingRequest) err := helper.ModelMappedHelper(c, info, embeddingRequest)

View File

@@ -55,7 +55,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
request, ok := info.Request.(*dto.GeminiChatRequest) request, ok := info.Request.(*dto.GeminiChatRequest)
if !ok { if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.GeminiChatRequest, got %T", info.Request)) common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request))
} }
// model mapped 模型映射 // model mapped 模型映射

View File

@@ -122,6 +122,9 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
} }
func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) { func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) {
if ws == nil {
return
}
errorObj := &dto.RealtimeEvent{ errorObj := &dto.RealtimeEvent{
Type: "error", Type: "error",
EventId: GetLocalRealtimeID(c), EventId: GetLocalRealtimeID(c),

View File

@@ -2,7 +2,6 @@ package service
import ( import (
"errors" "errors"
"fmt"
"one-api/dto" "one-api/dto"
"one-api/setting" "one-api/setting"
"strings" "strings"
@@ -32,25 +31,8 @@ func CheckSensitiveMessages(messages []dto.Message) ([]string, error) {
return nil, nil return nil, nil
} }
func CheckSensitiveText(text string) ([]string, error) { func CheckSensitiveText(text string) (bool, []string) {
if ok, words := SensitiveWordContains(text); ok { return SensitiveWordContains(text)
return words, errors.New("sensitive words detected")
}
return nil, nil
}
func CheckSensitiveInput(input any) ([]string, error) {
switch v := input.(type) {
case string:
return CheckSensitiveText(v)
case []string:
var builder strings.Builder
for _, s := range v {
builder.WriteString(s)
}
return CheckSensitiveText(builder.String())
}
return CheckSensitiveText(fmt.Sprintf("%v", input))
} }
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表 // SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表

View File

@@ -121,15 +121,6 @@ func (e *NewAPIError) MaskSensitiveError() string {
return string(e.errorCode) return string(e.errorCode)
} }
errStr := e.Err.Error() errStr := e.Err.Error()
if e.StatusCode == http.StatusServiceUnavailable {
if e.errorCode == ErrorCodeModelNotFound {
errStr = "上游分组模型服务不可用,请稍后再试"
} else {
if strings.Contains(errStr, "分组") || strings.Contains(errStr, "渠道") {
errStr = "上游分组模型服务不可用,请稍后再试"
}
}
}
return common.MaskSensitiveInfo(errStr) return common.MaskSensitiveInfo(errStr)
} }