refactor: clean up doubao tts code
This commit is contained in:
@@ -24,7 +24,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Context keys for passing data between methods
|
|
||||||
contextKeyTTSRequest = "volcengine_tts_request"
|
contextKeyTTSRequest = "volcengine_tts_request"
|
||||||
contextKeyResponseFormat = "response_format"
|
contextKeyResponseFormat = "response_format"
|
||||||
)
|
)
|
||||||
@@ -76,27 +75,23 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
Request: VolcengineTTSReqInfo{
|
Request: VolcengineTTSReqInfo{
|
||||||
ReqID: generateRequestID(),
|
ReqID: generateRequestID(),
|
||||||
Text: request.Input,
|
Text: request.Input,
|
||||||
Operation: "submit", // default WebSocket uses "submit"
|
Operation: "submit",
|
||||||
Model: info.OriginModelName,
|
Model: info.OriginModelName,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// 同步扩展字段的厂商自定义metadata
|
|
||||||
if len(request.Metadata) > 0 {
|
if len(request.Metadata) > 0 {
|
||||||
if err = json.Unmarshal(request.Metadata, &volcRequest); err != nil {
|
if err = json.Unmarshal(request.Metadata, &volcRequest); err != nil {
|
||||||
return nil, fmt.Errorf("error unmarshalling metadata to volcengine request: %w", err)
|
return nil, fmt.Errorf("error unmarshalling metadata to volcengine request: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the request in context for WebSocket handler
|
|
||||||
c.Set(contextKeyTTSRequest, volcRequest)
|
c.Set(contextKeyTTSRequest, volcRequest)
|
||||||
// https://www.volcengine.com/docs/6561/1257584
|
|
||||||
// operation需要设置为submit才是流式返回
|
|
||||||
if volcRequest.Request.Operation == "submit" {
|
if volcRequest.Request.Operation == "submit" {
|
||||||
info.IsStream = true
|
info.IsStream = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return nil as WebSocket doesn't use traditional request body
|
|
||||||
jsonData, err := json.Marshal(volcRequest)
|
jsonData, err := json.Marshal(volcRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
|
return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
|
||||||
@@ -115,9 +110,8 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
writer := multipart.NewWriter(&requestBody)
|
writer := multipart.NewWriter(&requestBody)
|
||||||
|
|
||||||
writer.WriteField("model", request.Model)
|
writer.WriteField("model", request.Model)
|
||||||
// 获取所有表单字段
|
|
||||||
formData := c.Request.PostForm
|
formData := c.Request.PostForm
|
||||||
// 遍历表单字段并打印输出
|
|
||||||
for key, values := range formData {
|
for key, values := range formData {
|
||||||
if key == "model" {
|
if key == "model" {
|
||||||
continue
|
continue
|
||||||
@@ -127,21 +121,16 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the multipart form to handle both single image and multiple images
|
if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
|
||||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
|
|
||||||
return nil, errors.New("failed to parse multipart form")
|
return nil, errors.New("failed to parse multipart form")
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
|
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
|
||||||
// Check if "image" field exists in any form, including array notation
|
|
||||||
var imageFiles []*multipart.FileHeader
|
var imageFiles []*multipart.FileHeader
|
||||||
var exists bool
|
var exists bool
|
||||||
|
|
||||||
// First check for standard "image" field
|
|
||||||
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
|
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
|
||||||
// If not found, check for "image[]" field
|
|
||||||
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
|
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
|
||||||
// If still not found, iterate through all fields to find any that start with "image["
|
|
||||||
foundArrayImages := false
|
foundArrayImages := false
|
||||||
for fieldName, files := range c.Request.MultipartForm.File {
|
for fieldName, files := range c.Request.MultipartForm.File {
|
||||||
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
|
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
|
||||||
@@ -152,14 +141,12 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no image fields found at all
|
|
||||||
if !foundArrayImages && (len(imageFiles) == 0) {
|
if !foundArrayImages && (len(imageFiles) == 0) {
|
||||||
return nil, errors.New("image is required")
|
return nil, errors.New("image is required")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process all image files
|
|
||||||
for i, fileHeader := range imageFiles {
|
for i, fileHeader := range imageFiles {
|
||||||
file, err := fileHeader.Open()
|
file, err := fileHeader.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -167,16 +154,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
// If multiple images, use image[] as the field name
|
|
||||||
fieldName := "image"
|
fieldName := "image"
|
||||||
if len(imageFiles) > 1 {
|
if len(imageFiles) > 1 {
|
||||||
fieldName = "image[]"
|
fieldName = "image[]"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine MIME type based on file extension
|
|
||||||
mimeType := detectImageMimeType(fileHeader.Filename)
|
mimeType := detectImageMimeType(fileHeader.Filename)
|
||||||
|
|
||||||
// Create a form file with the appropriate content type
|
|
||||||
h := make(textproto.MIMEHeader)
|
h := make(textproto.MIMEHeader)
|
||||||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
|
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
|
||||||
h.Set("Content-Type", mimeType)
|
h.Set("Content-Type", mimeType)
|
||||||
@@ -191,7 +175,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle mask file if present
|
|
||||||
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
|
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
|
||||||
maskFile, err := maskFiles[0].Open()
|
maskFile, err := maskFiles[0].Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -199,10 +182,8 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
}
|
}
|
||||||
defer maskFile.Close()
|
defer maskFile.Close()
|
||||||
|
|
||||||
// Determine MIME type for mask file
|
|
||||||
mimeType := detectImageMimeType(maskFiles[0].Filename)
|
mimeType := detectImageMimeType(maskFiles[0].Filename)
|
||||||
|
|
||||||
// Create a form file with the appropriate content type
|
|
||||||
h := make(textproto.MIMEHeader)
|
h := make(textproto.MIMEHeader)
|
||||||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
|
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
|
||||||
h.Set("Content-Type", mimeType)
|
h.Set("Content-Type", mimeType)
|
||||||
@@ -220,7 +201,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
return nil, errors.New("no multipart form data found")
|
return nil, errors.New("no multipart form data found")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 关闭 multipart 编写器以设置分界线
|
|
||||||
writer.Close()
|
writer.Close()
|
||||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
return bytes.NewReader(requestBody.Bytes()), nil
|
return bytes.NewReader(requestBody.Bytes()), nil
|
||||||
@@ -230,7 +210,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// detectImageMimeType determines the MIME type based on the file extension
|
|
||||||
func detectImageMimeType(filename string) string {
|
func detectImageMimeType(filename string) string {
|
||||||
ext := strings.ToLower(filepath.Ext(filename))
|
ext := strings.ToLower(filepath.Ext(filename))
|
||||||
switch ext {
|
switch ext {
|
||||||
@@ -241,11 +220,9 @@ func detectImageMimeType(filename string) string {
|
|||||||
case ".webp":
|
case ".webp":
|
||||||
return "image/webp"
|
return "image/webp"
|
||||||
default:
|
default:
|
||||||
// Try to detect from extension if possible
|
|
||||||
if strings.HasPrefix(ext, ".jp") {
|
if strings.HasPrefix(ext, ".jp") {
|
||||||
return "image/jpeg"
|
return "image/jpeg"
|
||||||
}
|
}
|
||||||
// Default to png as a fallback
|
|
||||||
return "image/png"
|
return "image/png"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -281,7 +258,6 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
|
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
|
||||||
case constant.RelayModeAudioSpeech:
|
case constant.RelayModeAudioSpeech:
|
||||||
// 只有当 baseUrl 是火山默认的官方Url时才改为官方的的TTS接口,否则走透传的New接口
|
|
||||||
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
|
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
|
||||||
return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil
|
return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil
|
||||||
}
|
}
|
||||||
@@ -312,7 +288,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
// 适配 方舟deepseek混合模型 的 thinking 后缀
|
|
||||||
if strings.HasSuffix(info.UpstreamModelName, "-thinking") && strings.HasPrefix(info.UpstreamModelName, "deepseek") {
|
if strings.HasSuffix(info.UpstreamModelName, "-thinking") && strings.HasPrefix(info.UpstreamModelName, "deepseek") {
|
||||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||||
request.Model = info.UpstreamModelName
|
request.Model = info.UpstreamModelName
|
||||||
@@ -330,18 +306,16 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||||
// TODO implement me
|
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
// For TTS with WebSocket, skip traditional HTTP request
|
|
||||||
if info.RelayMode == constant.RelayModeAudioSpeech {
|
if info.RelayMode == constant.RelayModeAudioSpeech {
|
||||||
baseUrl := info.ChannelBaseUrl
|
baseUrl := info.ChannelBaseUrl
|
||||||
if baseUrl == "" {
|
if baseUrl == "" {
|
||||||
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
|
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
|
||||||
}
|
}
|
||||||
// Only use WebSocket for official Volcengine endpoint
|
|
||||||
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
|
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
|||||||
@@ -11,69 +11,45 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
// EventType defines the event type which determines the event of the message.
|
EventType int32
|
||||||
EventType int32
|
MsgType uint8
|
||||||
// MsgType defines message type which determines how the message will be
|
MsgTypeFlagBits uint8
|
||||||
// serialized with the protocol.
|
VersionBits uint8
|
||||||
MsgType uint8
|
HeaderSizeBits uint8
|
||||||
// MsgTypeFlagBits defines the 4-bit message-type specific flags. The specific
|
|
||||||
// values should be defined in each specific usage scenario.
|
|
||||||
MsgTypeFlagBits uint8
|
|
||||||
// VersionBits defines the 4-bit version type.
|
|
||||||
VersionBits uint8
|
|
||||||
// HeaderSizeBits defines the 4-bit header-size type.
|
|
||||||
HeaderSizeBits uint8
|
|
||||||
// SerializationBits defines the 4-bit serialization method type.
|
|
||||||
SerializationBits uint8
|
SerializationBits uint8
|
||||||
// CompressionBits defines the 4-bit compression method type.
|
CompressionBits uint8
|
||||||
CompressionBits uint8
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MsgTypeFlagNoSeq MsgTypeFlagBits = 0 // Non-terminal packet with no sequence
|
MsgTypeFlagNoSeq MsgTypeFlagBits = 0
|
||||||
MsgTypeFlagPositiveSeq MsgTypeFlagBits = 0b1 // Non-terminal packet with sequence > 0
|
MsgTypeFlagPositiveSeq MsgTypeFlagBits = 0b1
|
||||||
MsgTypeFlagLastNoSeq MsgTypeFlagBits = 0b10 // last packet with no sequence
|
MsgTypeFlagNegativeSeq MsgTypeFlagBits = 0b11
|
||||||
MsgTypeFlagNegativeSeq MsgTypeFlagBits = 0b11 // last packet with sequence < 0
|
MsgTypeFlagWithEvent MsgTypeFlagBits = 0b100
|
||||||
MsgTypeFlagWithEvent MsgTypeFlagBits = 0b100 // Payload contains event number (int32)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
Version1 VersionBits = iota + 1
|
Version1 VersionBits = iota + 1
|
||||||
Version2
|
|
||||||
Version3
|
|
||||||
Version4
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
HeaderSize4 HeaderSizeBits = iota + 1
|
HeaderSize4 HeaderSizeBits = iota + 1
|
||||||
HeaderSize8
|
|
||||||
HeaderSize12
|
|
||||||
HeaderSize16
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
SerializationRaw SerializationBits = 0
|
SerializationJSON SerializationBits = 0b1
|
||||||
SerializationJSON SerializationBits = 0b1
|
|
||||||
SerializationThrift SerializationBits = 0b11
|
|
||||||
SerializationCustom SerializationBits = 0b1111
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CompressionNone CompressionBits = 0
|
CompressionNone CompressionBits = 0
|
||||||
CompressionGzip CompressionBits = 0b1
|
|
||||||
CompressionCustom CompressionBits = 0b1111
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MsgTypeInvalid MsgType = 0
|
|
||||||
MsgTypeFullClientRequest MsgType = 0b1
|
MsgTypeFullClientRequest MsgType = 0b1
|
||||||
MsgTypeAudioOnlyClient MsgType = 0b10
|
MsgTypeAudioOnlyClient MsgType = 0b10
|
||||||
MsgTypeFullServerResponse MsgType = 0b1001
|
MsgTypeFullServerResponse MsgType = 0b1001
|
||||||
MsgTypeAudioOnlyServer MsgType = 0b1011
|
MsgTypeAudioOnlyServer MsgType = 0b1011
|
||||||
MsgTypeFrontEndResultServer MsgType = 0b1100
|
MsgTypeFrontEndResultServer MsgType = 0b1100
|
||||||
MsgTypeError MsgType = 0b1111
|
MsgTypeError MsgType = 0b1111
|
||||||
|
|
||||||
MsgTypeServerACK = MsgTypeAudioOnlyServer
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t MsgType) String() string {
|
func (t MsgType) String() string {
|
||||||
@@ -85,7 +61,7 @@ func (t MsgType) String() string {
|
|||||||
case MsgTypeFullServerResponse:
|
case MsgTypeFullServerResponse:
|
||||||
return "MsgType_FullServerResponse"
|
return "MsgType_FullServerResponse"
|
||||||
case MsgTypeAudioOnlyServer:
|
case MsgTypeAudioOnlyServer:
|
||||||
return "MsgType_AudioOnlyServer" // MsgTypeServerACK
|
return "MsgType_AudioOnlyServer"
|
||||||
case MsgTypeError:
|
case MsgTypeError:
|
||||||
return "MsgType_Error"
|
return "MsgType_Error"
|
||||||
case MsgTypeFrontEndResultServer:
|
case MsgTypeFrontEndResultServer:
|
||||||
@@ -96,44 +72,33 @@ func (t MsgType) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Default event, applicable for scenarios not using events or not requiring event transmission,
|
|
||||||
// or for scenarios using events, non-zero values can be used to validate event legitimacy
|
|
||||||
EventType_None EventType = 0
|
EventType_None EventType = 0
|
||||||
// 1 ~ 49 for upstream Connection events
|
|
||||||
EventType_StartConnection EventType = 1
|
EventType_StartConnection EventType = 1
|
||||||
EventType_StartTask EventType = 1 // Alias of "StartConnection"
|
|
||||||
EventType_FinishConnection EventType = 2
|
EventType_FinishConnection EventType = 2
|
||||||
EventType_FinishTask EventType = 2 // Alias of "FinishConnection"
|
|
||||||
// 50 ~ 99 for downstream Connection events
|
EventType_ConnectionStarted EventType = 50
|
||||||
// Connection established successfully
|
EventType_ConnectionFailed EventType = 51
|
||||||
EventType_ConnectionStarted EventType = 50
|
|
||||||
EventType_TaskStarted EventType = 50 // Alias of "ConnectionStarted"
|
|
||||||
// Connection failed (possibly due to authentication failure)
|
|
||||||
EventType_ConnectionFailed EventType = 51
|
|
||||||
EventType_TaskFailed EventType = 51 // Alias of "ConnectionFailed"
|
|
||||||
// Connection ended
|
|
||||||
EventType_ConnectionFinished EventType = 52
|
EventType_ConnectionFinished EventType = 52
|
||||||
EventType_TaskFinished EventType = 52 // Alias of "ConnectionFinished"
|
|
||||||
// 100 ~ 149 for upstream Session events
|
|
||||||
EventType_StartSession EventType = 100
|
EventType_StartSession EventType = 100
|
||||||
EventType_CancelSession EventType = 101
|
EventType_CancelSession EventType = 101
|
||||||
EventType_FinishSession EventType = 102
|
EventType_FinishSession EventType = 102
|
||||||
// 150 ~ 199 for downstream Session events
|
|
||||||
EventType_SessionStarted EventType = 150
|
EventType_SessionStarted EventType = 150
|
||||||
EventType_SessionCanceled EventType = 151
|
EventType_SessionCanceled EventType = 151
|
||||||
EventType_SessionFinished EventType = 152
|
EventType_SessionFinished EventType = 152
|
||||||
EventType_SessionFailed EventType = 153
|
EventType_SessionFailed EventType = 153
|
||||||
// Usage events
|
|
||||||
EventType_UsageResponse EventType = 154
|
EventType_UsageResponse EventType = 154
|
||||||
EventType_ChargeData EventType = 154 // Alias of "UsageResponse"
|
|
||||||
// 200 ~ 249 for upstream general events
|
|
||||||
EventType_TaskRequest EventType = 200
|
EventType_TaskRequest EventType = 200
|
||||||
EventType_UpdateConfig EventType = 201
|
EventType_UpdateConfig EventType = 201
|
||||||
// 250 ~ 299 for downstream general events
|
|
||||||
EventType_AudioMuted EventType = 250
|
EventType_AudioMuted EventType = 250
|
||||||
// 300 ~ 349 for upstream TTS events
|
|
||||||
EventType_SayHello EventType = 300
|
EventType_SayHello EventType = 300
|
||||||
// 350 ~ 399 for downstream TTS events
|
|
||||||
EventType_TTSSentenceStart EventType = 350
|
EventType_TTSSentenceStart EventType = 350
|
||||||
EventType_TTSSentenceEnd EventType = 351
|
EventType_TTSSentenceEnd EventType = 351
|
||||||
EventType_TTSResponse EventType = 352
|
EventType_TTSResponse EventType = 352
|
||||||
@@ -141,22 +106,20 @@ const (
|
|||||||
EventType_PodcastRoundStart EventType = 360
|
EventType_PodcastRoundStart EventType = 360
|
||||||
EventType_PodcastRoundResponse EventType = 361
|
EventType_PodcastRoundResponse EventType = 361
|
||||||
EventType_PodcastRoundEnd EventType = 362
|
EventType_PodcastRoundEnd EventType = 362
|
||||||
// 450 ~ 499 for downstream ASR events
|
|
||||||
EventType_ASRInfo EventType = 450
|
EventType_ASRInfo EventType = 450
|
||||||
EventType_ASRResponse EventType = 451
|
EventType_ASRResponse EventType = 451
|
||||||
EventType_ASREnded EventType = 459
|
EventType_ASREnded EventType = 459
|
||||||
// 500 ~ 549 for upstream dialogue events
|
|
||||||
// (Ground-Truth-Alignment) text for speech synthesis
|
|
||||||
EventType_ChatTTSText EventType = 500
|
EventType_ChatTTSText EventType = 500
|
||||||
// 550 ~ 599 for downstream dialogue events
|
|
||||||
EventType_ChatResponse EventType = 550
|
EventType_ChatResponse EventType = 550
|
||||||
EventType_ChatEnded EventType = 559
|
EventType_ChatEnded EventType = 559
|
||||||
// 650 ~ 699 for downstream dialogue events
|
|
||||||
// Events for source (original) language subtitle.
|
|
||||||
EventType_SourceSubtitleStart EventType = 650
|
EventType_SourceSubtitleStart EventType = 650
|
||||||
EventType_SourceSubtitleResponse EventType = 651
|
EventType_SourceSubtitleResponse EventType = 651
|
||||||
EventType_SourceSubtitleEnd EventType = 652
|
EventType_SourceSubtitleEnd EventType = 652
|
||||||
// Events for target (translation) language subtitle.
|
|
||||||
EventType_TranslationSubtitleStart EventType = 653
|
EventType_TranslationSubtitleStart EventType = 653
|
||||||
EventType_TranslationSubtitleResponse EventType = 654
|
EventType_TranslationSubtitleResponse EventType = 654
|
||||||
EventType_TranslationSubtitleEnd EventType = 655
|
EventType_TranslationSubtitleEnd EventType = 655
|
||||||
@@ -243,26 +206,6 @@ func (t EventType) String() string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 0 1 2 3
|
|
||||||
// | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 |
|
|
||||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
|
||||||
// | Version | Header Size | Msg Type | Flags |
|
|
||||||
// | (4 bits) | (4 bits) | (4 bits) | (4 bits) |
|
|
||||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
|
||||||
// | Serialization | Compression | Reserved |
|
|
||||||
// | (4 bits) | (4 bits) | (8 bits) |
|
|
||||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
|
||||||
// | |
|
|
||||||
// | Optional Header Extensions |
|
|
||||||
// | (if Header Size > 1) |
|
|
||||||
// | |
|
|
||||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
|
||||||
// | |
|
|
||||||
// | Payload |
|
|
||||||
// | (variable length) |
|
|
||||||
// | |
|
|
||||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Version VersionBits
|
Version VersionBits
|
||||||
HeaderSize HeaderSizeBits
|
HeaderSize HeaderSizeBits
|
||||||
@@ -573,140 +516,15 @@ func ReceiveMessage(conn *websocket.Conn) (*Message, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// Log: receive msg
|
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func WaitForEvent(conn *websocket.Conn, msgType MsgType, eventType EventType) (*Message, error) {
|
|
||||||
for {
|
|
||||||
msg, err := ReceiveMessage(conn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if msg.MsgType != msgType || msg.EventType != eventType {
|
|
||||||
return nil, fmt.Errorf("unexpected message: %s", msg)
|
|
||||||
}
|
|
||||||
if msg.MsgType == msgType && msg.EventType == eventType {
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func FullClientRequest(conn *websocket.Conn, payload []byte) error {
|
func FullClientRequest(conn *websocket.Conn, payload []byte) error {
|
||||||
msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagNoSeq)
|
msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagNoSeq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
msg.Payload = payload
|
msg.Payload = payload
|
||||||
// Log: send msg
|
|
||||||
frame, err := msg.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return conn.WriteMessage(websocket.BinaryMessage, frame)
|
|
||||||
}
|
|
||||||
|
|
||||||
func AudioOnlyClient(conn *websocket.Conn, payload []byte, flag MsgTypeFlagBits) error {
|
|
||||||
msg, err := NewMessage(MsgTypeAudioOnlyClient, flag)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
msg.Payload = payload
|
|
||||||
// Log: send msg
|
|
||||||
frame, err := msg.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return conn.WriteMessage(websocket.BinaryMessage, frame)
|
|
||||||
}
|
|
||||||
|
|
||||||
func StartConnection(conn *websocket.Conn) error {
|
|
||||||
msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
msg.EventType = EventType_StartConnection
|
|
||||||
msg.Payload = []byte("{}")
|
|
||||||
// Log: send msg
|
|
||||||
frame, err := msg.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return conn.WriteMessage(websocket.BinaryMessage, frame)
|
|
||||||
}
|
|
||||||
|
|
||||||
func FinishConnection(conn *websocket.Conn) error {
|
|
||||||
msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
msg.EventType = EventType_FinishConnection
|
|
||||||
msg.Payload = []byte("{}")
|
|
||||||
// Log: send msg
|
|
||||||
frame, err := msg.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return conn.WriteMessage(websocket.BinaryMessage, frame)
|
|
||||||
}
|
|
||||||
|
|
||||||
func StartSession(conn *websocket.Conn, payload []byte, sessionID string) error {
|
|
||||||
msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
msg.EventType = EventType_StartSession
|
|
||||||
msg.SessionID = sessionID
|
|
||||||
msg.Payload = payload
|
|
||||||
// Log: send msg
|
|
||||||
frame, err := msg.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return conn.WriteMessage(websocket.BinaryMessage, frame)
|
|
||||||
}
|
|
||||||
|
|
||||||
func FinishSession(conn *websocket.Conn, sessionID string) error {
|
|
||||||
msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
msg.EventType = EventType_FinishSession
|
|
||||||
msg.SessionID = sessionID
|
|
||||||
msg.Payload = []byte("{}")
|
|
||||||
// Log: send msg
|
|
||||||
frame, err := msg.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return conn.WriteMessage(websocket.BinaryMessage, frame)
|
|
||||||
}
|
|
||||||
|
|
||||||
func CancelSession(conn *websocket.Conn, sessionID string) error {
|
|
||||||
msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
msg.EventType = EventType_CancelSession
|
|
||||||
msg.SessionID = sessionID
|
|
||||||
msg.Payload = []byte("{}")
|
|
||||||
// Log: send msg
|
|
||||||
frame, err := msg.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return conn.WriteMessage(websocket.BinaryMessage, frame)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TaskRequest(conn *websocket.Conn, payload []byte, sessionID string) error {
|
|
||||||
msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
msg.EventType = EventType_TaskRequest
|
|
||||||
msg.SessionID = sessionID
|
|
||||||
msg.Payload = payload
|
|
||||||
// Log: send msg
|
|
||||||
frame, err := msg.Marshal()
|
frame, err := msg.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -196,9 +196,7 @@ func generateRequestID() string {
|
|||||||
return uuid.New().String()
|
return uuid.New().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleTTSWebSocketResponse handles streaming TTS response via WebSocket
|
|
||||||
func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest VolcengineTTSRequest, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
|
func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest VolcengineTTSRequest, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
|
||||||
// Parse API key for auth
|
|
||||||
_, token, parseErr := parseVolcengineAuth(info.ApiKey)
|
_, token, parseErr := parseVolcengineAuth(info.ApiKey)
|
||||||
if parseErr != nil {
|
if parseErr != nil {
|
||||||
return nil, types.NewErrorWithStatusCode(
|
return nil, types.NewErrorWithStatusCode(
|
||||||
@@ -208,11 +206,9 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup WebSocket headers
|
|
||||||
header := http.Header{}
|
header := http.Header{}
|
||||||
header.Set("Authorization", fmt.Sprintf("Bearer;%s", token))
|
header.Set("Authorization", fmt.Sprintf("Bearer;%s", token))
|
||||||
|
|
||||||
// Dial WebSocket connection
|
|
||||||
conn, resp, dialErr := websocket.DefaultDialer.DialContext(context.Background(), requestURL, header)
|
conn, resp, dialErr := websocket.DefaultDialer.DialContext(context.Background(), requestURL, header)
|
||||||
if dialErr != nil {
|
if dialErr != nil {
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
@@ -239,7 +235,6 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send full client request
|
|
||||||
if sendErr := FullClientRequest(conn, payload); sendErr != nil {
|
if sendErr := FullClientRequest(conn, payload); sendErr != nil {
|
||||||
return nil, types.NewErrorWithStatusCode(
|
return nil, types.NewErrorWithStatusCode(
|
||||||
fmt.Errorf("failed to send request: %w", sendErr),
|
fmt.Errorf("failed to send request: %w", sendErr),
|
||||||
@@ -248,13 +243,10 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set response headers
|
|
||||||
contentType := getContentTypeByEncoding(encoding)
|
contentType := getContentTypeByEncoding(encoding)
|
||||||
c.Header("Content-Type", contentType)
|
c.Header("Content-Type", contentType)
|
||||||
c.Header("Transfer-Encoding", "chunked")
|
c.Header("Transfer-Encoding", "chunked")
|
||||||
|
|
||||||
// Stream audio data
|
|
||||||
var audioBuffer []byte
|
|
||||||
for {
|
for {
|
||||||
msg, recvErr := ReceiveMessage(conn)
|
msg, recvErr := ReceiveMessage(conn)
|
||||||
if recvErr != nil {
|
if recvErr != nil {
|
||||||
@@ -279,7 +271,6 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
|
|||||||
continue
|
continue
|
||||||
case MsgTypeAudioOnlyServer:
|
case MsgTypeAudioOnlyServer:
|
||||||
if len(msg.Payload) > 0 {
|
if len(msg.Payload) > 0 {
|
||||||
audioBuffer = append(audioBuffer, msg.Payload...)
|
|
||||||
if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil {
|
if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil {
|
||||||
return nil, types.NewErrorWithStatusCode(
|
return nil, types.NewErrorWithStatusCode(
|
||||||
fmt.Errorf("failed to write audio data: %w", writeErr),
|
fmt.Errorf("failed to write audio data: %w", writeErr),
|
||||||
@@ -287,7 +278,6 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
|
|||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
//logger.Infof("write audio chunk size: %d", len(msg.Payload))
|
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user