Files
sub2api/backend/internal/service/openai_client_transport.go
2026-02-28 15:01:20 +08:00

72 lines
2.0 KiB
Go

package service
import (
"strings"
"github.com/gin-gonic/gin"
)
// OpenAIClientTransport 表示客户端入站协议类型。
type OpenAIClientTransport string
const (
OpenAIClientTransportUnknown OpenAIClientTransport = ""
OpenAIClientTransportHTTP OpenAIClientTransport = "http"
OpenAIClientTransportWS OpenAIClientTransport = "ws"
)
const openAIClientTransportContextKey = "openai_client_transport"
// SetOpenAIClientTransport 标记当前请求的客户端入站协议。
func SetOpenAIClientTransport(c *gin.Context, transport OpenAIClientTransport) {
if c == nil {
return
}
normalized := normalizeOpenAIClientTransport(transport)
if normalized == OpenAIClientTransportUnknown {
return
}
c.Set(openAIClientTransportContextKey, string(normalized))
}
// GetOpenAIClientTransport 读取当前请求的客户端入站协议。
func GetOpenAIClientTransport(c *gin.Context) OpenAIClientTransport {
if c == nil {
return OpenAIClientTransportUnknown
}
raw, ok := c.Get(openAIClientTransportContextKey)
if !ok || raw == nil {
return OpenAIClientTransportUnknown
}
switch v := raw.(type) {
case OpenAIClientTransport:
return normalizeOpenAIClientTransport(v)
case string:
return normalizeOpenAIClientTransport(OpenAIClientTransport(v))
default:
return OpenAIClientTransportUnknown
}
}
func normalizeOpenAIClientTransport(transport OpenAIClientTransport) OpenAIClientTransport {
switch strings.ToLower(strings.TrimSpace(string(transport))) {
case string(OpenAIClientTransportHTTP), "http_sse", "sse":
return OpenAIClientTransportHTTP
case string(OpenAIClientTransportWS), "websocket":
return OpenAIClientTransportWS
default:
return OpenAIClientTransportUnknown
}
}
func resolveOpenAIWSDecisionByClientTransport(
decision OpenAIWSProtocolDecision,
clientTransport OpenAIClientTransport,
) OpenAIWSProtocolDecision {
if clientTransport == OpenAIClientTransportHTTP {
return openAIWSHTTPDecision("client_protocol_http")
}
return decision
}