feat: add doubao audio tts
This commit is contained in:
@@ -37,8 +37,49 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
//TODO implement me
|
if info.RelayMode != constant.RelayModeAudioSpeech {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("unsupported audio relay mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
appID, token, err := parseVolcengineAuth(info.ApiKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
voiceType := mapVoiceType(request.Voice)
|
||||||
|
speedRatio := mapSpeedRatio(request.Speed)
|
||||||
|
encoding := mapEncoding(request.ResponseFormat)
|
||||||
|
|
||||||
|
c.Set("response_format", encoding)
|
||||||
|
|
||||||
|
volcRequest := VolcengineTTSRequest{
|
||||||
|
App: VolcengineTTSApp{
|
||||||
|
AppID: appID,
|
||||||
|
Token: token,
|
||||||
|
Cluster: "volcano_tts",
|
||||||
|
},
|
||||||
|
User: VolcengineTTSUser{
|
||||||
|
UID: "openai_relay_user",
|
||||||
|
},
|
||||||
|
Audio: VolcengineTTSAudio{
|
||||||
|
VoiceType: voiceType,
|
||||||
|
Encoding: encoding,
|
||||||
|
SpeedRatio: speedRatio,
|
||||||
|
Rate: 24000,
|
||||||
|
},
|
||||||
|
Request: VolcengineTTSReqInfo{
|
||||||
|
ReqID: generateRequestID(),
|
||||||
|
Text: request.Input,
|
||||||
|
Operation: "query",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(volcRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error marshalling volcengine request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.NewReader(jsonData), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
@@ -190,7 +231,6 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
// 支持自定义域名,如果未设置则使用默认域名
|
|
||||||
baseUrl := info.ChannelBaseUrl
|
baseUrl := info.ChannelBaseUrl
|
||||||
if baseUrl == "" {
|
if baseUrl == "" {
|
||||||
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
|
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
|
||||||
@@ -217,6 +257,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
|
return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
|
||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
|
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
|
||||||
|
case constant.RelayModeAudioSpeech:
|
||||||
|
return "https://openspeech.bytedance.com/api/v1/tts", nil
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -225,6 +267,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
|
|
||||||
|
if info.RelayMode == constant.RelayModeAudioSpeech {
|
||||||
|
parts := strings.Split(info.ApiKey, "|")
|
||||||
|
if len(parts) == 2 {
|
||||||
|
req.Set("Authorization", "Bearer;"+parts[1])
|
||||||
|
}
|
||||||
|
req.Set("Content-Type", "application/json")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -260,6 +312,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
|
if info.RelayMode == constant.RelayModeAudioSpeech {
|
||||||
|
encoding := mapEncoding(c.GetString("response_format"))
|
||||||
|
return handleTTSResponse(c, resp, encoding)
|
||||||
|
}
|
||||||
|
|
||||||
adaptor := openai.Adaptor{}
|
adaptor := openai.Adaptor{}
|
||||||
usage, err = adaptor.DoResponse(c, resp, info)
|
usage, err = adaptor.DoResponse(c, resp, info)
|
||||||
return
|
return
|
||||||
|
|||||||
184
relay/channel/volcengine/tts.go
Normal file
184
relay/channel/volcengine/tts.go
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
package volcengine
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/dto"
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type VolcengineTTSRequest struct {
|
||||||
|
App VolcengineTTSApp `json:"app"`
|
||||||
|
User VolcengineTTSUser `json:"user"`
|
||||||
|
Audio VolcengineTTSAudio `json:"audio"`
|
||||||
|
Request VolcengineTTSReqInfo `json:"request"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type VolcengineTTSApp struct {
|
||||||
|
AppID string `json:"appid"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
Cluster string `json:"cluster"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type VolcengineTTSUser struct {
|
||||||
|
UID string `json:"uid"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type VolcengineTTSAudio struct {
|
||||||
|
VoiceType string `json:"voice_type"`
|
||||||
|
Encoding string `json:"encoding"`
|
||||||
|
SpeedRatio float64 `json:"speed_ratio"`
|
||||||
|
Rate int `json:"rate"`
|
||||||
|
LoudnessRatio float64 `json:"loudness_ratio,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type VolcengineTTSReqInfo struct {
|
||||||
|
ReqID string `json:"reqid"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
Operation string `json:"operation"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type VolcengineTTSResponse struct {
|
||||||
|
ReqID string `json:"reqid"`
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Sequence int `json:"sequence"`
|
||||||
|
Data string `json:"data"`
|
||||||
|
Addition *VolcengineTTSAdditionInfo `json:"addition,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type VolcengineTTSAdditionInfo struct {
|
||||||
|
Duration string `json:"duration"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var openAIToVolcengineVoiceMap = map[string]string{
|
||||||
|
"alloy": "zh_male_M392_conversation_wvae_bigtts",
|
||||||
|
"echo": "zh_male_wenhao_mars_bigtts",
|
||||||
|
"fable": "zh_female_tianmei_mars_bigtts",
|
||||||
|
"onyx": "zh_male_zhibei_mars_bigtts",
|
||||||
|
"nova": "zh_female_shuangkuaisisi_mars_bigtts",
|
||||||
|
"shimmer": "zh_female_cancan_mars_bigtts",
|
||||||
|
}
|
||||||
|
|
||||||
|
var responseFormatToEncodingMap = map[string]string{
|
||||||
|
"mp3": "mp3",
|
||||||
|
"opus": "ogg_opus",
|
||||||
|
"aac": "mp3",
|
||||||
|
"flac": "mp3",
|
||||||
|
"wav": "wav",
|
||||||
|
"pcm": "pcm",
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseVolcengineAuth(apiKey string) (appID, token string, err error) {
|
||||||
|
parts := strings.Split(apiKey, "|")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return "", "", errors.New("invalid api key format, expected: appid:access_token")
|
||||||
|
}
|
||||||
|
return parts[0], parts[1], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapVoiceType(openAIVoice string) string {
|
||||||
|
if voice, ok := openAIToVolcengineVoiceMap[openAIVoice]; ok {
|
||||||
|
return voice
|
||||||
|
}
|
||||||
|
return "zh_male_M392_conversation_wvae_bigtts"
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapSpeedRatio(speed float64) float64 {
|
||||||
|
if speed == 0 {
|
||||||
|
return 1.0
|
||||||
|
}
|
||||||
|
if speed < 0.1 {
|
||||||
|
return 0.1
|
||||||
|
}
|
||||||
|
if speed > 2.0 {
|
||||||
|
return 2.0
|
||||||
|
}
|
||||||
|
return speed
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapEncoding(responseFormat string) string {
|
||||||
|
if responseFormat == "" {
|
||||||
|
return "mp3"
|
||||||
|
}
|
||||||
|
if encoding, ok := responseFormatToEncodingMap[responseFormat]; ok {
|
||||||
|
return encoding
|
||||||
|
}
|
||||||
|
return "mp3"
|
||||||
|
}
|
||||||
|
|
||||||
|
func getContentTypeByEncoding(encoding string) string {
|
||||||
|
contentTypeMap := map[string]string{
|
||||||
|
"mp3": "audio/mpeg",
|
||||||
|
"ogg_opus": "audio/ogg",
|
||||||
|
"wav": "audio/wav",
|
||||||
|
"pcm": "audio/pcm",
|
||||||
|
}
|
||||||
|
if ct, ok := contentTypeMap[encoding]; ok {
|
||||||
|
return ct
|
||||||
|
}
|
||||||
|
return "application/octet-stream"
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleTTSResponse(c *gin.Context, resp *http.Response, encoding string) (usage any, err *types.NewAPIError) {
|
||||||
|
body, readErr := io.ReadAll(resp.Body)
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, types.NewErrorWithStatusCode(
|
||||||
|
errors.New("failed to read volcengine response"),
|
||||||
|
types.ErrorCodeReadResponseBodyFailed,
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var volcResp VolcengineTTSResponse
|
||||||
|
if unmarshalErr := json.Unmarshal(body, &volcResp); unmarshalErr != nil {
|
||||||
|
return nil, types.NewErrorWithStatusCode(
|
||||||
|
errors.New("failed to parse volcengine response"),
|
||||||
|
types.ErrorCodeBadResponseBody,
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if volcResp.Code != 3000 {
|
||||||
|
return nil, types.NewErrorWithStatusCode(
|
||||||
|
errors.New(volcResp.Message),
|
||||||
|
types.ErrorCodeBadResponse,
|
||||||
|
http.StatusBadRequest,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
audioData, decodeErr := base64.StdEncoding.DecodeString(volcResp.Data)
|
||||||
|
if decodeErr != nil {
|
||||||
|
return nil, types.NewErrorWithStatusCode(
|
||||||
|
errors.New("failed to decode audio data"),
|
||||||
|
types.ErrorCodeBadResponseBody,
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := getContentTypeByEncoding(encoding)
|
||||||
|
c.Header("Content-Type", contentType)
|
||||||
|
c.Data(http.StatusOK, contentType, audioData)
|
||||||
|
|
||||||
|
info := c.MustGet("relay_info").(*relaycommon.RelayInfo)
|
||||||
|
usage = &dto.Usage{
|
||||||
|
PromptTokens: info.PromptTokens,
|
||||||
|
CompletionTokens: 0,
|
||||||
|
TotalTokens: info.PromptTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateRequestID() string {
|
||||||
|
return uuid.New().String()
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user