feat: add doubao tts usage token

This commit is contained in:
feitianbubu
2025-10-17 21:33:08 +08:00
parent 1be4e12ca0
commit 0952631fa2
2 changed files with 3 additions and 6 deletions

View File

@@ -319,7 +319,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeAudioSpeech { if info.RelayMode == constant.RelayModeAudioSpeech {
encoding := mapEncoding(c.GetString("response_format")) encoding := mapEncoding(c.GetString("response_format"))
return handleTTSResponse(c, resp, encoding) return handleTTSResponse(c, resp, info, encoding)
} }
adaptor := openai.Adaptor{} adaptor := openai.Adaptor{}

View File

@@ -119,6 +119,7 @@ func mapVoiceType(openAIVoice string) string {
return openAIVoice return openAIVoice
} }
// [0.1,2],默认为 1通常保留一位小数即可
func mapSpeedRatio(speed float64) float64 { func mapSpeedRatio(speed float64) float64 {
if speed == 0 { if speed == 0 {
return 1.0 return 1.0
@@ -133,9 +134,6 @@ func mapSpeedRatio(speed float64) float64 {
} }
func mapEncoding(responseFormat string) string { func mapEncoding(responseFormat string) string {
if responseFormat == "" {
return "mp3"
}
if encoding, ok := responseFormatToEncodingMap[responseFormat]; ok { if encoding, ok := responseFormatToEncodingMap[responseFormat]; ok {
return encoding return encoding
} }
@@ -155,7 +153,7 @@ func getContentTypeByEncoding(encoding string) string {
return "application/octet-stream" return "application/octet-stream"
} }
func handleTTSResponse(c *gin.Context, resp *http.Response, encoding string) (usage any, err *types.NewAPIError) { func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, encoding string) (usage any, err *types.NewAPIError) {
body, readErr := io.ReadAll(resp.Body) body, readErr := io.ReadAll(resp.Body)
if readErr != nil { if readErr != nil {
return nil, types.NewErrorWithStatusCode( return nil, types.NewErrorWithStatusCode(
@@ -196,7 +194,6 @@ func handleTTSResponse(c *gin.Context, resp *http.Response, encoding string) (us
c.Header("Content-Type", contentType) c.Header("Content-Type", contentType)
c.Data(http.StatusOK, contentType, audioData) c.Data(http.StatusOK, contentType, audioData)
info := c.MustGet("relay_info").(*relaycommon.RelayInfo)
usage = &dto.Usage{ usage = &dto.Usage{
PromptTokens: info.PromptTokens, PromptTokens: info.PromptTokens,
CompletionTokens: 0, CompletionTokens: 0,