feat: doubao tts support streaming realtime audio
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
package volcengine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -13,6 +15,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type VolcengineTTSRequest struct {
|
||||
@@ -192,3 +195,129 @@ func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
func generateRequestID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// handleTTSWebSocketResponse handles streaming TTS response via WebSocket
|
||||
func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest VolcengineTTSRequest, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
|
||||
// Parse API key for auth
|
||||
_, token, parseErr := parseVolcengineAuth(info.ApiKey)
|
||||
if parseErr != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
parseErr,
|
||||
types.ErrorCodeChannelInvalidKey,
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
}
|
||||
|
||||
// Setup WebSocket headers
|
||||
header := http.Header{}
|
||||
header.Set("Authorization", fmt.Sprintf("Bearer;%s", token))
|
||||
|
||||
// Dial WebSocket connection
|
||||
conn, resp, dialErr := websocket.DefaultDialer.DialContext(context.Background(), requestURL, header)
|
||||
if dialErr != nil {
|
||||
if resp != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("failed to connect to websocket: %w, status: %d", dialErr, resp.StatusCode),
|
||||
types.ErrorCodeBadResponseStatusCode,
|
||||
http.StatusBadGateway,
|
||||
)
|
||||
}
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("failed to connect to websocket: %w", dialErr),
|
||||
types.ErrorCodeBadResponseStatusCode,
|
||||
http.StatusBadGateway,
|
||||
)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Update request operation to "submit" for WebSocket
|
||||
volcRequest.Request.Operation = "submit"
|
||||
|
||||
// Marshal request payload
|
||||
payload, marshalErr := json.Marshal(volcRequest)
|
||||
if marshalErr != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("failed to marshal request: %w", marshalErr),
|
||||
types.ErrorCodeBadRequestBody,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
// Send full client request
|
||||
if sendErr := FullClientRequest(conn, payload); sendErr != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("failed to send request: %w", sendErr),
|
||||
types.ErrorCodeBadRequestBody,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
// Set response headers
|
||||
contentType := getContentTypeByEncoding(encoding)
|
||||
c.Header("Content-Type", contentType)
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
|
||||
// Stream audio data
|
||||
var audioBuffer []byte
|
||||
for {
|
||||
msg, recvErr := ReceiveMessage(conn)
|
||||
if recvErr != nil {
|
||||
if websocket.IsCloseError(recvErr, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
break
|
||||
}
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("failed to receive message: %w", recvErr),
|
||||
types.ErrorCodeBadResponse,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
switch msg.MsgType {
|
||||
case MsgTypeError:
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("received error from server: code=%d, %s", msg.ErrorCode, string(msg.Payload)),
|
||||
types.ErrorCodeBadResponse,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
case MsgTypeFrontEndResultServer:
|
||||
// Metadata response, can be logged or processed
|
||||
continue
|
||||
case MsgTypeAudioOnlyServer:
|
||||
// Stream audio chunk to client
|
||||
if len(msg.Payload) > 0 {
|
||||
audioBuffer = append(audioBuffer, msg.Payload...)
|
||||
if _, writeErr := c.Writer.Write(msg.Payload); writeErr != nil {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("failed to write audio data: %w", writeErr),
|
||||
types.ErrorCodeBadResponse,
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
// Check if this is the last packet (negative sequence)
|
||||
if msg.Sequence < 0 {
|
||||
c.Status(http.StatusOK)
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
default:
|
||||
// Unknown message type, log and continue
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// If we reach here, connection closed without final packet
|
||||
c.Status(http.StatusOK)
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user