refactor: clean up doubao tts code

This commit is contained in:
feitianbubu
2025-10-22 16:48:00 +08:00
parent 828bb17d2c
commit fd6cd838f7
3 changed files with 35 additions and 253 deletions

View File

@@ -24,7 +24,6 @@ import (
)
const (
// Context keys for passing data between methods
contextKeyTTSRequest = "volcengine_tts_request"
contextKeyResponseFormat = "response_format"
)
@@ -76,27 +75,23 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
Request: VolcengineTTSReqInfo{
ReqID: generateRequestID(),
Text: request.Input,
Operation: "submit", // default WebSocket uses "submit"
Operation: "submit",
Model: info.OriginModelName,
},
}
// 同步扩展字段的厂商自定义metadata
if len(request.Metadata) > 0 {
if err = json.Unmarshal(request.Metadata, &volcRequest); err != nil {
return nil, fmt.Errorf("error unmarshalling metadata to volcengine request: %w", err)
}
}
// Store the request in context for WebSocket handler
c.Set(contextKeyTTSRequest, volcRequest)
// https://www.volcengine.com/docs/6561/1257584
// operation需要设置为submit才是流式返回
if volcRequest.Request.Operation == "submit" {
info.IsStream = true
}
// Return nil as WebSocket doesn't use traditional request body
jsonData, err := json.Marshal(volcRequest)
if err != nil {
return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
@@ -115,9 +110,8 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
writer := multipart.NewWriter(&requestBody)
writer.WriteField("model", request.Model)
// 获取所有表单字段
formData := c.Request.PostForm
// 遍历表单字段并打印输出
for key, values := range formData {
if key == "model" {
continue
@@ -127,21 +121,16 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
}
// Parse the multipart form to handle both single image and multiple images
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
return nil, errors.New("failed to parse multipart form")
}
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
// Check if "image" field exists in any form, including array notation
var imageFiles []*multipart.FileHeader
var exists bool
// First check for standard "image" field
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
// If not found, check for "image[]" field
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
// If still not found, iterate through all fields to find any that start with "image["
foundArrayImages := false
for fieldName, files := range c.Request.MultipartForm.File {
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
@@ -152,14 +141,12 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
}
// If no image fields found at all
if !foundArrayImages && (len(imageFiles) == 0) {
return nil, errors.New("image is required")
}
}
}
// Process all image files
for i, fileHeader := range imageFiles {
file, err := fileHeader.Open()
if err != nil {
@@ -167,16 +154,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
defer file.Close()
// If multiple images, use image[] as the field name
fieldName := "image"
if len(imageFiles) > 1 {
fieldName = "image[]"
}
// Determine MIME type based on file extension
mimeType := detectImageMimeType(fileHeader.Filename)
// Create a form file with the appropriate content type
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
h.Set("Content-Type", mimeType)
@@ -191,7 +175,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
}
// Handle mask file if present
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
maskFile, err := maskFiles[0].Open()
if err != nil {
@@ -199,10 +182,8 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
defer maskFile.Close()
// Determine MIME type for mask file
mimeType := detectImageMimeType(maskFiles[0].Filename)
// Create a form file with the appropriate content type
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
h.Set("Content-Type", mimeType)
@@ -220,7 +201,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
return nil, errors.New("no multipart form data found")
}
// 关闭 multipart 编写器以设置分界线
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return bytes.NewReader(requestBody.Bytes()), nil
@@ -230,7 +210,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
}
// detectImageMimeType determines the MIME type based on the file extension
func detectImageMimeType(filename string) string {
ext := strings.ToLower(filepath.Ext(filename))
switch ext {
@@ -241,11 +220,9 @@ func detectImageMimeType(filename string) string {
case ".webp":
return "image/webp"
default:
// Try to detect from extension if possible
if strings.HasPrefix(ext, ".jp") {
return "image/jpeg"
}
// Default to png as a fallback
return "image/png"
}
}
@@ -281,7 +258,6 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
case constant.RelayModeRerank:
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
case constant.RelayModeAudioSpeech:
// 只有当 baseUrl 是火山默认的官方Url时才改为官方的的TTS接口否则走透传的New接口
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil
}
@@ -312,7 +288,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
// 适配 方舟deepseek混合模型 的 thinking 后缀
if strings.HasSuffix(info.UpstreamModelName, "-thinking") && strings.HasPrefix(info.UpstreamModelName, "deepseek") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
request.Model = info.UpstreamModelName
@@ -330,18 +306,16 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
// For TTS with WebSocket, skip traditional HTTP request
if info.RelayMode == constant.RelayModeAudioSpeech {
baseUrl := info.ChannelBaseUrl
if baseUrl == "" {
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
}
// Only use WebSocket for official Volcengine endpoint
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
if info.IsStream {
return nil, nil