diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index c7c2a92b..f46328e3 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -24,7 +24,6 @@ import ( ) const ( - // Context keys for passing data between methods contextKeyTTSRequest = "volcengine_tts_request" contextKeyResponseFormat = "response_format" ) @@ -76,27 +75,23 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf Request: VolcengineTTSReqInfo{ ReqID: generateRequestID(), Text: request.Input, - Operation: "submit", // default WebSocket uses "submit" + Operation: "submit", Model: info.OriginModelName, }, } - // 同步扩展字段的厂商自定义metadata if len(request.Metadata) > 0 { if err = json.Unmarshal(request.Metadata, &volcRequest); err != nil { return nil, fmt.Errorf("error unmarshalling metadata to volcengine request: %w", err) } } - // Store the request in context for WebSocket handler c.Set(contextKeyTTSRequest, volcRequest) - // https://www.volcengine.com/docs/6561/1257584 - // operation需要设置为submit才是流式返回 + if volcRequest.Request.Operation == "submit" { info.IsStream = true } - // Return nil as WebSocket doesn't use traditional request body jsonData, err := json.Marshal(volcRequest) if err != nil { return nil, fmt.Errorf("error marshalling volcengine request: %w", err) @@ -115,9 +110,8 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf writer := multipart.NewWriter(&requestBody) writer.WriteField("model", request.Model) - // 获取所有表单字段 + formData := c.Request.PostForm - // 遍历表单字段并打印输出 for key, values := range formData { if key == "model" { continue @@ -127,21 +121,16 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } } - // Parse the multipart form to handle both single image and multiple images - if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory + if err := c.Request.ParseMultipartForm(32 << 20); err != nil { return nil, errors.New("failed to parse multipart form") } if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil { - // Check if "image" field exists in any form, including array notation var imageFiles []*multipart.FileHeader var exists bool - // First check for standard "image" field if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 { - // If not found, check for "image[]" field if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 { - // If still not found, iterate through all fields to find any that start with "image[" foundArrayImages := false for fieldName, files := range c.Request.MultipartForm.File { if strings.HasPrefix(fieldName, "image[") && len(files) > 0 { @@ -152,14 +141,12 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } } - // If no image fields found at all if !foundArrayImages && (len(imageFiles) == 0) { return nil, errors.New("image is required") } } } - // Process all image files for i, fileHeader := range imageFiles { file, err := fileHeader.Open() if err != nil { @@ -167,16 +154,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } defer file.Close() - // If multiple images, use image[] as the field name fieldName := "image" if len(imageFiles) > 1 { fieldName = "image[]" } - // Determine MIME type based on file extension mimeType := detectImageMimeType(fileHeader.Filename) - // Create a form file with the appropriate content type h := make(textproto.MIMEHeader) h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename)) h.Set("Content-Type", mimeType) @@ -191,7 +175,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } } - // Handle mask file if present if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 { maskFile, err := maskFiles[0].Open() if err != nil { @@ -199,10 +182,8 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } defer maskFile.Close() - // Determine MIME type for mask file mimeType := detectImageMimeType(maskFiles[0].Filename) - // Create a form file with the appropriate content type h := make(textproto.MIMEHeader) h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename)) h.Set("Content-Type", mimeType) @@ -220,7 +201,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf return nil, errors.New("no multipart form data found") } - // 关闭 multipart 编写器以设置分界线 writer.Close() c.Request.Header.Set("Content-Type", writer.FormDataContentType()) return bytes.NewReader(requestBody.Bytes()), nil @@ -230,7 +210,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } } -// detectImageMimeType determines the MIME type based on the file extension func detectImageMimeType(filename string) string { ext := strings.ToLower(filepath.Ext(filename)) switch ext { @@ -241,11 +220,9 @@ func detectImageMimeType(filename string) string { case ".webp": return "image/webp" default: - // Try to detect from extension if possible if strings.HasPrefix(ext, ".jp") { return "image/jpeg" } - // Default to png as a fallback return "image/png" } } @@ -281,7 +258,6 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { case constant.RelayModeRerank: return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil case constant.RelayModeAudioSpeech: - // 只有当 baseUrl 是火山默认的官方Url时才改为官方的的TTS接口,否则走透传的New接口 if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] { return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil } @@ -312,7 +288,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } - // 适配 方舟deepseek混合模型 的 thinking 后缀 + if strings.HasSuffix(info.UpstreamModelName, "-thinking") && strings.HasPrefix(info.UpstreamModelName, "deepseek") { info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") request.Model = info.UpstreamModelName @@ -330,18 +306,16 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela } func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { - // TODO implement me return nil, errors.New("not implemented") } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { - // For TTS with WebSocket, skip traditional HTTP request if info.RelayMode == constant.RelayModeAudioSpeech { baseUrl := info.ChannelBaseUrl if baseUrl == "" { baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] } - // Only use WebSocket for official Volcengine endpoint + if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] { if info.IsStream { return nil, nil diff --git a/relay/channel/volcengine/protocols.go b/relay/channel/volcengine/protocols.go index a41d8756..c978e1c7 100644 --- a/relay/channel/volcengine/protocols.go +++ b/relay/channel/volcengine/protocols.go @@ -11,69 +11,45 @@ import ( ) type ( - // EventType defines the event type which determines the event of the message. - EventType int32 - // MsgType defines message type which determines how the message will be - // serialized with the protocol. - MsgType uint8 - // MsgTypeFlagBits defines the 4-bit message-type specific flags. The specific - // values should be defined in each specific usage scenario. - MsgTypeFlagBits uint8 - // VersionBits defines the 4-bit version type. - VersionBits uint8 - // HeaderSizeBits defines the 4-bit header-size type. - HeaderSizeBits uint8 - // SerializationBits defines the 4-bit serialization method type. + EventType int32 + MsgType uint8 + MsgTypeFlagBits uint8 + VersionBits uint8 + HeaderSizeBits uint8 SerializationBits uint8 - // CompressionBits defines the 4-bit compression method type. - CompressionBits uint8 + CompressionBits uint8 ) const ( - MsgTypeFlagNoSeq MsgTypeFlagBits = 0 // Non-terminal packet with no sequence - MsgTypeFlagPositiveSeq MsgTypeFlagBits = 0b1 // Non-terminal packet with sequence > 0 - MsgTypeFlagLastNoSeq MsgTypeFlagBits = 0b10 // last packet with no sequence - MsgTypeFlagNegativeSeq MsgTypeFlagBits = 0b11 // last packet with sequence < 0 - MsgTypeFlagWithEvent MsgTypeFlagBits = 0b100 // Payload contains event number (int32) + MsgTypeFlagNoSeq MsgTypeFlagBits = 0 + MsgTypeFlagPositiveSeq MsgTypeFlagBits = 0b1 + MsgTypeFlagNegativeSeq MsgTypeFlagBits = 0b11 + MsgTypeFlagWithEvent MsgTypeFlagBits = 0b100 ) const ( Version1 VersionBits = iota + 1 - Version2 - Version3 - Version4 ) const ( HeaderSize4 HeaderSizeBits = iota + 1 - HeaderSize8 - HeaderSize12 - HeaderSize16 ) const ( - SerializationRaw SerializationBits = 0 - SerializationJSON SerializationBits = 0b1 - SerializationThrift SerializationBits = 0b11 - SerializationCustom SerializationBits = 0b1111 + SerializationJSON SerializationBits = 0b1 ) const ( - CompressionNone CompressionBits = 0 - CompressionGzip CompressionBits = 0b1 - CompressionCustom CompressionBits = 0b1111 + CompressionNone CompressionBits = 0 ) const ( - MsgTypeInvalid MsgType = 0 MsgTypeFullClientRequest MsgType = 0b1 MsgTypeAudioOnlyClient MsgType = 0b10 MsgTypeFullServerResponse MsgType = 0b1001 MsgTypeAudioOnlyServer MsgType = 0b1011 MsgTypeFrontEndResultServer MsgType = 0b1100 MsgTypeError MsgType = 0b1111 - - MsgTypeServerACK = MsgTypeAudioOnlyServer ) func (t MsgType) String() string { @@ -85,7 +61,7 @@ func (t MsgType) String() string { case MsgTypeFullServerResponse: return "MsgType_FullServerResponse" case MsgTypeAudioOnlyServer: - return "MsgType_AudioOnlyServer" // MsgTypeServerACK + return "MsgType_AudioOnlyServer" case MsgTypeError: return "MsgType_Error" case MsgTypeFrontEndResultServer: @@ -96,44 +72,33 @@ func (t MsgType) String() string { } const ( - // Default event, applicable for scenarios not using events or not requiring event transmission, - // or for scenarios using events, non-zero values can be used to validate event legitimacy EventType_None EventType = 0 - // 1 ~ 49 for upstream Connection events + EventType_StartConnection EventType = 1 - EventType_StartTask EventType = 1 // Alias of "StartConnection" EventType_FinishConnection EventType = 2 - EventType_FinishTask EventType = 2 // Alias of "FinishConnection" - // 50 ~ 99 for downstream Connection events - // Connection established successfully - EventType_ConnectionStarted EventType = 50 - EventType_TaskStarted EventType = 50 // Alias of "ConnectionStarted" - // Connection failed (possibly due to authentication failure) - EventType_ConnectionFailed EventType = 51 - EventType_TaskFailed EventType = 51 // Alias of "ConnectionFailed" - // Connection ended + + EventType_ConnectionStarted EventType = 50 + EventType_ConnectionFailed EventType = 51 EventType_ConnectionFinished EventType = 52 - EventType_TaskFinished EventType = 52 // Alias of "ConnectionFinished" - // 100 ~ 149 for upstream Session events + EventType_StartSession EventType = 100 EventType_CancelSession EventType = 101 EventType_FinishSession EventType = 102 - // 150 ~ 199 for downstream Session events + EventType_SessionStarted EventType = 150 EventType_SessionCanceled EventType = 151 EventType_SessionFinished EventType = 152 EventType_SessionFailed EventType = 153 - // Usage events + EventType_UsageResponse EventType = 154 - EventType_ChargeData EventType = 154 // Alias of "UsageResponse" - // 200 ~ 249 for upstream general events + EventType_TaskRequest EventType = 200 EventType_UpdateConfig EventType = 201 - // 250 ~ 299 for downstream general events + EventType_AudioMuted EventType = 250 - // 300 ~ 349 for upstream TTS events + EventType_SayHello EventType = 300 - // 350 ~ 399 for downstream TTS events + EventType_TTSSentenceStart EventType = 350 EventType_TTSSentenceEnd EventType = 351 EventType_TTSResponse EventType = 352 @@ -141,22 +106,20 @@ const ( EventType_PodcastRoundStart EventType = 360 EventType_PodcastRoundResponse EventType = 361 EventType_PodcastRoundEnd EventType = 362 - // 450 ~ 499 for downstream ASR events + EventType_ASRInfo EventType = 450 EventType_ASRResponse EventType = 451 EventType_ASREnded EventType = 459 - // 500 ~ 549 for upstream dialogue events - // (Ground-Truth-Alignment) text for speech synthesis + EventType_ChatTTSText EventType = 500 - // 550 ~ 599 for downstream dialogue events + EventType_ChatResponse EventType = 550 EventType_ChatEnded EventType = 559 - // 650 ~ 699 for downstream dialogue events - // Events for source (original) language subtitle. + EventType_SourceSubtitleStart EventType = 650 EventType_SourceSubtitleResponse EventType = 651 EventType_SourceSubtitleEnd EventType = 652 - // Events for target (translation) language subtitle. + EventType_TranslationSubtitleStart EventType = 653 EventType_TranslationSubtitleResponse EventType = 654 EventType_TranslationSubtitleEnd EventType = 655 @@ -243,26 +206,6 @@ func (t EventType) String() string { } } -// 0 1 2 3 -// | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | Version | Header Size | Msg Type | Flags | -// | (4 bits) | (4 bits) | (4 bits) | (4 bits) | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | Serialization | Compression | Reserved | -// | (4 bits) | (4 bits) | (8 bits) | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | | -// | Optional Header Extensions | -// | (if Header Size > 1) | -// | | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | | -// | Payload | -// | (variable length) | -// | | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - type Message struct { Version VersionBits HeaderSize HeaderSizeBits @@ -573,140 +516,15 @@ func ReceiveMessage(conn *websocket.Conn) (*Message, error) { if err != nil { return nil, err } - // Log: receive msg return msg, nil } -func WaitForEvent(conn *websocket.Conn, msgType MsgType, eventType EventType) (*Message, error) { - for { - msg, err := ReceiveMessage(conn) - if err != nil { - return nil, err - } - if msg.MsgType != msgType || msg.EventType != eventType { - return nil, fmt.Errorf("unexpected message: %s", msg) - } - if msg.MsgType == msgType && msg.EventType == eventType { - return msg, nil - } - } -} - func FullClientRequest(conn *websocket.Conn, payload []byte) error { msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagNoSeq) if err != nil { return err } msg.Payload = payload - // Log: send msg - frame, err := msg.Marshal() - if err != nil { - return err - } - return conn.WriteMessage(websocket.BinaryMessage, frame) -} - -func AudioOnlyClient(conn *websocket.Conn, payload []byte, flag MsgTypeFlagBits) error { - msg, err := NewMessage(MsgTypeAudioOnlyClient, flag) - if err != nil { - return err - } - msg.Payload = payload - // Log: send msg - frame, err := msg.Marshal() - if err != nil { - return err - } - return conn.WriteMessage(websocket.BinaryMessage, frame) -} - -func StartConnection(conn *websocket.Conn) error { - msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) - if err != nil { - return err - } - msg.EventType = EventType_StartConnection - msg.Payload = []byte("{}") - // Log: send msg - frame, err := msg.Marshal() - if err != nil { - return err - } - return conn.WriteMessage(websocket.BinaryMessage, frame) -} - -func FinishConnection(conn *websocket.Conn) error { - msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) - if err != nil { - return err - } - msg.EventType = EventType_FinishConnection - msg.Payload = []byte("{}") - // Log: send msg - frame, err := msg.Marshal() - if err != nil { - return err - } - return conn.WriteMessage(websocket.BinaryMessage, frame) -} - -func StartSession(conn *websocket.Conn, payload []byte, sessionID string) error { - msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) - if err != nil { - return err - } - msg.EventType = EventType_StartSession - msg.SessionID = sessionID - msg.Payload = payload - // Log: send msg - frame, err := msg.Marshal() - if err != nil { - return err - } - return conn.WriteMessage(websocket.BinaryMessage, frame) -} - -func FinishSession(conn *websocket.Conn, sessionID string) error { - msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) - if err != nil { - return err - } - msg.EventType = EventType_FinishSession - msg.SessionID = sessionID - msg.Payload = []byte("{}") - // Log: send msg - frame, err := msg.Marshal() - if err != nil { - return err - } - return conn.WriteMessage(websocket.BinaryMessage, frame) -} - -func CancelSession(conn *websocket.Conn, sessionID string) error { - msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) - if err != nil { - return err - } - msg.EventType = EventType_CancelSession - msg.SessionID = sessionID - msg.Payload = []byte("{}") - // Log: send msg - frame, err := msg.Marshal() - if err != nil { - return err - } - return conn.WriteMessage(websocket.BinaryMessage, frame) -} - -func TaskRequest(conn *websocket.Conn, payload []byte, sessionID string) error { - msg, err := NewMessage(MsgTypeFullClientRequest, MsgTypeFlagWithEvent) - if err != nil { - return err - } - msg.EventType = EventType_TaskRequest - msg.SessionID = sessionID - msg.Payload = payload - // Log: send msg frame, err := msg.Marshal() if err != nil { return err diff --git a/relay/channel/volcengine/tts.go b/relay/channel/volcengine/tts.go index 033737a5..166fab8e 100644 --- a/relay/channel/volcengine/tts.go +++ b/relay/channel/volcengine/tts.go @@ -196,9 +196,7 @@ func generateRequestID() string { return uuid.New().String() } -// handleTTSWebSocketResponse handles streaming TTS response via WebSocket func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest VolcengineTTSRequest, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) { - // Parse API key for auth _, token, parseErr := parseVolcengineAuth(info.ApiKey) if parseErr != nil { return nil, types.NewErrorWithStatusCode( @@ -208,11 +206,9 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V ) } - // Setup WebSocket headers header := http.Header{} header.Set("Authorization", fmt.Sprintf("Bearer;%s", token)) - // Dial WebSocket connection conn, resp, dialErr := websocket.DefaultDialer.DialContext(context.Background(), requestURL, header) if dialErr != nil { if resp != nil { @@ -239,7 +235,6 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V ) } - // Send full client request if sendErr := FullClientRequest(conn, payload); sendErr != nil { return nil, types.NewErrorWithStatusCode( fmt.Errorf("failed to send request: %w", sendErr), @@ -248,13 +243,10 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V ) } - // Set response headers contentType := getContentTypeByEncoding(encoding) c.Header("Content-Type", contentType) c.Header("Transfer-Encoding", "chunked") - // Stream audio data - var audioBuffer []byte for { msg, recvErr := ReceiveMessage(conn) if recvErr != nil { @@ -279,7 +271,6 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V continue case MsgTypeAudioOnlyServer: if len(msg.Payload) > 0 { - audioBuffer = append(audioBuffer, msg.Payload...) if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil { return nil, types.NewErrorWithStatusCode( fmt.Errorf("failed to write audio data: %w", writeErr), @@ -287,7 +278,6 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V http.StatusInternalServerError, ) } - //logger.Infof("write audio chunk size: %d", len(msg.Payload)) c.Writer.Flush() }