feat(audio): enhance audio request handling with token type detection and streaming support
This commit is contained in:
@@ -2,6 +2,7 @@ package dto
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/types"
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
|
||||||
@@ -24,11 +25,14 @@ func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
|||||||
CombineText: r.Input,
|
CombineText: r.Input,
|
||||||
TokenType: types.TokenTypeTextNumber,
|
TokenType: types.TokenTypeTextNumber,
|
||||||
}
|
}
|
||||||
|
if strings.Contains(r.Model, "gpt") {
|
||||||
|
meta.TokenType = types.TokenTypeTokenizer
|
||||||
|
}
|
||||||
return meta
|
return meta
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AudioRequest) IsStream(c *gin.Context) bool {
|
func (r *AudioRequest) IsStream(c *gin.Context) bool {
|
||||||
return false
|
return r.StreamFormat == "sse"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AudioRequest) SetModelName(modelName string) {
|
func (r *AudioRequest) SetModelName(modelName string) {
|
||||||
|
|||||||
@@ -67,8 +67,11 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
|||||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return newAPIError
|
return newAPIError
|
||||||
}
|
}
|
||||||
|
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
|
||||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||||
|
} else {
|
||||||
|
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
145
relay/channel/openai/audio.go
Normal file
145
relay/channel/openai/audio.go
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/constant"
|
||||||
|
"github.com/QuantumNous/new-api/dto"
|
||||||
|
"github.com/QuantumNous/new-api/logger"
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
|
"github.com/QuantumNous/new-api/relay/helper"
|
||||||
|
"github.com/QuantumNous/new-api/service"
|
||||||
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
|
||||||
|
// the status code has been judged before, if there is a body reading failure,
|
||||||
|
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
||||||
|
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
||||||
|
// if the upstream returns a specific status code, once the upstream has already written the header,
|
||||||
|
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
||||||
|
// and can be terminated directly.
|
||||||
|
defer service.CloseResponseBodyGracefully(resp)
|
||||||
|
usage := &dto.Usage{}
|
||||||
|
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||||
|
usage.TotalTokens = info.GetEstimatePromptTokens()
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
c.Writer.Header().Set(k, v[0])
|
||||||
|
}
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
|
||||||
|
if info.IsStream {
|
||||||
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
|
if service.SundaySearch(data, "usage") {
|
||||||
|
var simpleResponse dto.SimpleResponse
|
||||||
|
err := common.Unmarshal([]byte(data), &simpleResponse)
|
||||||
|
if err != nil {
|
||||||
|
logger.LogError(c, err.Error())
|
||||||
|
}
|
||||||
|
if simpleResponse.Usage.TotalTokens != 0 {
|
||||||
|
usage.PromptTokens = simpleResponse.Usage.InputTokens
|
||||||
|
usage.CompletionTokens = simpleResponse.OutputTokens
|
||||||
|
usage.TotalTokens = simpleResponse.TotalTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = helper.StringData(c, data)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
||||||
|
// 读取响应体到缓冲区
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
logger.LogError(c, fmt.Sprintf("failed to read TTS response body: %v", err))
|
||||||
|
c.Writer.WriteHeaderNow()
|
||||||
|
return usage
|
||||||
|
}
|
||||||
|
|
||||||
|
// 写入响应到客户端
|
||||||
|
c.Writer.WriteHeaderNow()
|
||||||
|
_, err = c.Writer.Write(bodyBytes)
|
||||||
|
if err != nil {
|
||||||
|
logger.LogError(c, fmt.Sprintf("failed to write TTS response: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算音频时长并更新 usage
|
||||||
|
audioFormat := "mp3" // 默认格式
|
||||||
|
if audioReq, ok := info.Request.(*dto.AudioRequest); ok && audioReq.ResponseFormat != "" {
|
||||||
|
audioFormat = audioReq.ResponseFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
var duration float64
|
||||||
|
var durationErr error
|
||||||
|
|
||||||
|
if audioFormat == "pcm" {
|
||||||
|
// PCM 格式没有文件头,根据 OpenAI TTS 的 PCM 参数计算时长
|
||||||
|
// 采样率: 24000 Hz, 位深度: 16-bit (2 bytes), 声道数: 1
|
||||||
|
const sampleRate = 24000
|
||||||
|
const bytesPerSample = 2
|
||||||
|
const channels = 1
|
||||||
|
duration = float64(len(bodyBytes)) / float64(sampleRate*bytesPerSample*channels)
|
||||||
|
} else {
|
||||||
|
ext := "." + audioFormat
|
||||||
|
reader := bytes.NewReader(bodyBytes)
|
||||||
|
duration, durationErr = common.GetAudioDuration(c.Request.Context(), reader, ext)
|
||||||
|
}
|
||||||
|
|
||||||
|
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||||
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
|
||||||
|
if durationErr != nil {
|
||||||
|
logger.LogWarn(c, fmt.Sprintf("failed to get audio duration: %v", durationErr))
|
||||||
|
// 如果无法获取时长,则设置保底的 CompletionTokens,根据body大小计算
|
||||||
|
sizeInKB := float64(len(bodyBytes)) / 1000.0
|
||||||
|
estimatedTokens := int(math.Ceil(sizeInKB)) // 粗略估算每KB约等于1 token
|
||||||
|
usage.CompletionTokens = estimatedTokens
|
||||||
|
usage.CompletionTokenDetails.AudioTokens = estimatedTokens
|
||||||
|
} else if duration > 0 {
|
||||||
|
// 计算 token: ceil(duration) / 60.0 * 1000,即每分钟 1000 tokens
|
||||||
|
completionTokens := int(math.Round(math.Ceil(duration) / 60.0 * 1000))
|
||||||
|
usage.CompletionTokens = completionTokens
|
||||||
|
usage.CompletionTokenDetails.AudioTokens = completionTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
|
||||||
|
defer service.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
// 写入新的 response body
|
||||||
|
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||||
|
|
||||||
|
var responseData struct {
|
||||||
|
Usage *dto.Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
if err := common.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
|
||||||
|
if responseData.Usage.TotalTokens > 0 {
|
||||||
|
usage := responseData.Usage
|
||||||
|
if usage.PromptTokens == 0 {
|
||||||
|
usage.PromptTokens = usage.InputTokens
|
||||||
|
}
|
||||||
|
if usage.CompletionTokens == 0 {
|
||||||
|
usage.CompletionTokens = usage.OutputTokens
|
||||||
|
}
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &dto.Usage{}
|
||||||
|
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||||
|
usage.CompletionTokens = 0
|
||||||
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -151,7 +150,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
|||||||
var streamResp struct {
|
var streamResp struct {
|
||||||
Usage *dto.Usage `json:"usage"`
|
Usage *dto.Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
err := json.Unmarshal([]byte(secondLastStreamData), &streamResp)
|
err := common.Unmarshal([]byte(secondLastStreamData), &streamResp)
|
||||||
if err == nil && streamResp.Usage != nil && service.ValidUsage(streamResp.Usage) {
|
if err == nil && streamResp.Usage != nil && service.ValidUsage(streamResp.Usage) {
|
||||||
usage = streamResp.Usage
|
usage = streamResp.Usage
|
||||||
containStreamUsage = true
|
containStreamUsage = true
|
||||||
@@ -327,68 +326,6 @@ func streamTTSResponse(c *gin.Context, resp *http.Response) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
|
|
||||||
// the status code has been judged before, if there is a body reading failure,
|
|
||||||
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
|
||||||
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
|
||||||
// if the upstream returns a specific status code, once the upstream has already written the header,
|
|
||||||
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
|
||||||
// and can be terminated directly.
|
|
||||||
defer service.CloseResponseBodyGracefully(resp)
|
|
||||||
usage := &dto.Usage{}
|
|
||||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
|
||||||
usage.TotalTokens = info.GetEstimatePromptTokens()
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
c.Writer.Header().Set(k, v[0])
|
|
||||||
}
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
|
|
||||||
isStreaming := resp.ContentLength == -1 || resp.Header.Get("Content-Length") == ""
|
|
||||||
if isStreaming {
|
|
||||||
streamTTSResponse(c, resp)
|
|
||||||
} else {
|
|
||||||
c.Writer.WriteHeaderNow()
|
|
||||||
_, err := io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
logger.LogError(c, err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
|
|
||||||
defer service.CloseResponseBodyGracefully(resp)
|
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
// 写入新的 response body
|
|
||||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
|
||||||
|
|
||||||
var responseData struct {
|
|
||||||
Usage *dto.Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
|
|
||||||
if responseData.Usage.TotalTokens > 0 {
|
|
||||||
usage := responseData.Usage
|
|
||||||
if usage.PromptTokens == 0 {
|
|
||||||
usage.PromptTokens = usage.InputTokens
|
|
||||||
}
|
|
||||||
if usage.CompletionTokens == 0 {
|
|
||||||
usage.CompletionTokens = usage.OutputTokens
|
|
||||||
}
|
|
||||||
return nil, usage
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
usage := &dto.Usage{}
|
|
||||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
|
||||||
usage.CompletionTokens = 0
|
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
||||||
return nil, usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
|
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
|
||||||
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
|
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
|
||||||
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
|
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
|
||||||
@@ -687,7 +624,7 @@ func extractCachedTokensFromBody(body []byte) (int, bool) {
|
|||||||
} `json:"usage"`
|
} `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(body, &payload); err != nil {
|
if err := common.Unmarshal(body, &payload); err != nil {
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
|||||||
return newApiErr
|
return newApiErr
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
|
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
|
||||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||||
} else {
|
} else {
|
||||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||||
|
|||||||
@@ -536,7 +536,7 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
|
|||||||
if name == "gpt-4o-2024-05-13" {
|
if name == "gpt-4o-2024-05-13" {
|
||||||
return 3, true
|
return 3, true
|
||||||
}
|
}
|
||||||
return 4, true
|
return 4, false
|
||||||
}
|
}
|
||||||
// gpt-5 匹配
|
// gpt-5 匹配
|
||||||
if strings.HasPrefix(name, "gpt-5") {
|
if strings.HasPrefix(name, "gpt-5") {
|
||||||
|
|||||||
Reference in New Issue
Block a user