refactor: improve request type validation and enhance sensitive information masking

This commit is contained in:
CaIon
2025-08-15 13:20:36 +08:00
parent 03fc89da00
commit 5fe1ce89ec
11 changed files with 87 additions and 102 deletions

View File

@@ -117,6 +117,48 @@ func MaskEmail(email string) string {
return "***@" + email[atIndex+1:] return "***@" + email[atIndex+1:]
} }
// maskHostTail returns the tail parts of a domain/host that should be preserved.
// It keeps 2 parts for likely country-code TLDs (e.g., co.uk, com.cn), otherwise keeps only the TLD.
func maskHostTail(parts []string) []string {
if len(parts) < 2 {
return parts
}
lastPart := parts[len(parts)-1]
secondLastPart := parts[len(parts)-2]
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
// Likely country code TLD like co.uk, com.cn
return []string{secondLastPart, lastPart}
}
return []string{lastPart}
}
// maskHostForURL collapses subdomains and keeps only masked prefix + preserved tail.
// Example: api.openai.com -> ***.com, sub.domain.co.uk -> ***.co.uk
func maskHostForURL(host string) string {
parts := strings.Split(host, ".")
if len(parts) < 2 {
return "***"
}
tail := maskHostTail(parts)
return "***." + strings.Join(tail, ".")
}
// maskHostForPlainDomain masks a plain domain and reflects subdomain depth with multiple ***.
// Example: openai.com -> ***.com, api.openai.com -> ***.***.com, sub.domain.co.uk -> ***.***.co.uk
func maskHostForPlainDomain(domain string) string {
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
tail := maskHostTail(parts)
numStars := len(parts) - len(tail)
if numStars < 1 {
numStars = 1
}
stars := strings.TrimSuffix(strings.Repeat("***.", numStars), ".")
return stars + "." + strings.Join(tail, ".")
}
// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string // MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string
// Example: // Example:
// http://example.com -> http://***.com // http://example.com -> http://***.com
@@ -140,32 +182,8 @@ func MaskSensitiveInfo(str string) string {
return urlStr return urlStr
} }
// Split host by dots // Mask host with unified logic
parts := strings.Split(host, ".") maskedHost := maskHostForURL(host)
if len(parts) < 2 {
// If less than 2 parts, just mask the whole host
return u.Scheme + "://***" + u.Path
}
// Keep the TLD (Top Level Domain) and mask the rest
var maskedHost string
if len(parts) == 2 {
// example.com -> ***.com
maskedHost = "***." + parts[len(parts)-1]
} else {
// Handle cases like sub.domain.co.uk or api.example.com
// Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.)
lastPart := parts[len(parts)-1]
secondLastPart := parts[len(parts)-2]
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
// Likely country code TLD like co.uk, com.cn
maskedHost = "***." + secondLastPart + "." + lastPart
} else {
// Regular TLD like .com, .org
maskedHost = "***." + lastPart
}
}
result := u.Scheme + "://" + maskedHost result := u.Scheme + "://" + maskedHost
@@ -208,26 +226,11 @@ func MaskSensitiveInfo(str string) string {
// Mask domain names without protocol (like openai.com, www.openai.com) // Mask domain names without protocol (like openai.com, www.openai.com)
domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`) domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string { str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
// Skip if it's already been processed as part of a URL // Skip if it's already part of a URL to avoid partial masking
if strings.Contains(str, "://"+domain) { if strings.Contains(str, "://"+domain) {
return domain return domain
} }
return maskHostForPlainDomain(domain)
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
// Handle different domain patterns
if len(parts) == 2 {
// openai.com -> ***.com
return "***." + parts[1]
} else {
// www.openai.com -> ***.***.com
// api.openai.com -> ***.***.com
lastPart := parts[len(parts)-1]
return "***.***." + lastPart
}
}) })
// Mask IP addresses // Mask IP addresses

View File

@@ -113,8 +113,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
meta := request.GetTokenCountMeta() meta := request.GetTokenCountMeta()
if setting.ShouldCheckPromptSensitive() { if setting.ShouldCheckPromptSensitive() {
words, err := service.CheckSensitiveText(meta.CombineText) contains, words := service.CheckSensitiveText(meta.CombineText)
if err != nil { if contains {
logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected) newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
return return
@@ -139,7 +139,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
} }
defer func() { defer func() {
if newAPIError != nil { // Only return quota if downstream failed and quota was actually pre-consumed
if newAPIError != nil && preConsumedQuota != 0 {
service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota) service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
} }
}() }()

View File

@@ -4,8 +4,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"io" "io"
"log" "log"
"one-api/common" "one-api/common"
@@ -13,6 +11,9 @@ import (
"path/filepath" "path/filepath"
"sync" "sync"
"time" "time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
) )
const ( const (
@@ -29,6 +30,9 @@ var setupLogLock sync.Mutex
var setupLogWorking bool var setupLogWorking bool
func SetupLogger() { func SetupLogger() {
defer func() {
setupLogWorking = false
}()
if *common.LogDir != "" { if *common.LogDir != "" {
ok := setupLogLock.TryLock() ok := setupLogLock.TryLock()
if !ok { if !ok {
@@ -37,7 +41,6 @@ func SetupLogger() {
} }
defer func() { defer func() {
setupLogLock.Unlock() setupLogLock.Unlock()
setupLogWorking = false
}() }()
logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)

View File

@@ -206,6 +206,11 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
if err != nil || resp.StatusCode != 101 { if err != nil || resp.StatusCode != 101 {
return nil, nil, err return nil, nil, err
} }
defer func() {
conn.Close()
}()
data := requestOpenAI2Xunfei(textRequest, appId, domain) data := requestOpenAI2Xunfei(textRequest, appId, domain)
err = conn.WriteJSON(data) err = conn.WriteJSON(data)
if err != nil { if err != nil {
@@ -229,7 +234,6 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
} }
dataChan <- response dataChan <- response
if response.Payload.Choices.Status == 2 { if response.Payload.Choices.Status == 2 {
err := conn.Close()
if err != nil { if err != nil {
common.SysLog("error closing websocket connection: " + err.Error()) common.SysLog("error closing websocket connection: " + err.Error())
} }

View File

@@ -24,7 +24,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
textRequest, ok := info.Request.(*dto.ClaudeRequest) textRequest, ok := info.Request.(*dto.ClaudeRequest)
if !ok { if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request)) common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request))
} }
err := helper.ModelMappedHelper(c, info, textRequest) err := helper.ModelMappedHelper(c, info, textRequest)

View File

@@ -87,26 +87,24 @@ type RelayInfo struct {
UsePrice bool UsePrice bool
RelayMode int RelayMode int
OriginModelName string OriginModelName string
//RecodeModelName string RequestURLPath string
RequestURLPath string PromptTokens int
PromptTokens int ShouldIncludeUsage bool
//SupportStreamOptions bool DisablePing bool // 是否禁止向下游发送自定义 Ping
ShouldIncludeUsage bool ClientWs *websocket.Conn
DisablePing bool // 是否禁止向下游发送自定义 Ping TargetWs *websocket.Conn
ClientWs *websocket.Conn InputAudioFormat string
TargetWs *websocket.Conn OutputAudioFormat string
InputAudioFormat string RealtimeTools []dto.RealTimeTool
OutputAudioFormat string IsFirstRequest bool
RealtimeTools []dto.RealTimeTool AudioUsage bool
IsFirstRequest bool ReasoningEffort string
AudioUsage bool UserSetting dto.UserSetting
ReasoningEffort string UserEmail string
UserSetting dto.UserSetting UserQuota int
UserEmail string RelayFormat types.RelayFormat
UserQuota int SendResponseCount int
RelayFormat types.RelayFormat FinalPreConsumedQuota int // 最终预消耗的配额
SendResponseCount int
FinalPreConsumedQuota int // 最终预消耗的配额
PriceData types.PriceData PriceData types.PriceData

View File

@@ -21,7 +21,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest) embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest)
if !ok { if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request)) common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request))
} }
err := helper.ModelMappedHelper(c, info, embeddingRequest) err := helper.ModelMappedHelper(c, info, embeddingRequest)

View File

@@ -55,7 +55,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
request, ok := info.Request.(*dto.GeminiChatRequest) request, ok := info.Request.(*dto.GeminiChatRequest)
if !ok { if !ok {
common.FatalLog(fmt.Sprintf("invalid request type, expected dto.GeminiChatRequest, got %T", info.Request)) common.FatalLog(fmt.Sprintf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request))
} }
// model mapped 模型映射 // model mapped 模型映射

View File

@@ -122,6 +122,9 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
} }
func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) { func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) {
if ws == nil {
return
}
errorObj := &dto.RealtimeEvent{ errorObj := &dto.RealtimeEvent{
Type: "error", Type: "error",
EventId: GetLocalRealtimeID(c), EventId: GetLocalRealtimeID(c),

View File

@@ -2,7 +2,6 @@ package service
import ( import (
"errors" "errors"
"fmt"
"one-api/dto" "one-api/dto"
"one-api/setting" "one-api/setting"
"strings" "strings"
@@ -32,25 +31,8 @@ func CheckSensitiveMessages(messages []dto.Message) ([]string, error) {
return nil, nil return nil, nil
} }
func CheckSensitiveText(text string) ([]string, error) { func CheckSensitiveText(text string) (bool, []string) {
if ok, words := SensitiveWordContains(text); ok { return SensitiveWordContains(text)
return words, errors.New("sensitive words detected")
}
return nil, nil
}
func CheckSensitiveInput(input any) ([]string, error) {
switch v := input.(type) {
case string:
return CheckSensitiveText(v)
case []string:
var builder strings.Builder
for _, s := range v {
builder.WriteString(s)
}
return CheckSensitiveText(builder.String())
}
return CheckSensitiveText(fmt.Sprintf("%v", input))
} }
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表 // SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表

View File

@@ -121,15 +121,6 @@ func (e *NewAPIError) MaskSensitiveError() string {
return string(e.errorCode) return string(e.errorCode)
} }
errStr := e.Err.Error() errStr := e.Err.Error()
if e.StatusCode == http.StatusServiceUnavailable {
if e.errorCode == ErrorCodeModelNotFound {
errStr = "上游分组模型服务不可用,请稍后再试"
} else {
if strings.Contains(errStr, "分组") || strings.Contains(errStr, "渠道") {
errStr = "上游分组模型服务不可用,请稍后再试"
}
}
}
return common.MaskSensitiveInfo(errStr) return common.MaskSensitiveInfo(errStr)
} }