feat(audio): enhance audio request handling with token type detection and streaming support

This commit is contained in:
CaIon
2025-12-13 17:24:23 +08:00
parent b602843ce1
commit e36e2e1b69
6 changed files with 159 additions and 70 deletions

View File

@@ -1,7 +1,6 @@
package openai
import (
"encoding/json"
"fmt"
"io"
"net/http"
@@ -151,7 +150,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
var streamResp struct {
Usage *dto.Usage `json:"usage"`
}
err := json.Unmarshal([]byte(secondLastStreamData), &streamResp)
err := common.Unmarshal([]byte(secondLastStreamData), &streamResp)
if err == nil && streamResp.Usage != nil && service.ValidUsage(streamResp.Usage) {
usage = streamResp.Usage
containStreamUsage = true
@@ -327,68 +326,6 @@ func streamTTSResponse(c *gin.Context, resp *http.Response) {
}
}
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
// the status code has been judged before, if there is a body reading failure,
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
// if the upstream returns a specific status code, once the upstream has already written the header,
// the subsequent failure of the response body should be regarded as a non-recoverable error,
// and can be terminated directly.
defer service.CloseResponseBodyGracefully(resp)
usage := &dto.Usage{}
usage.PromptTokens = info.GetEstimatePromptTokens()
usage.TotalTokens = info.GetEstimatePromptTokens()
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
isStreaming := resp.ContentLength == -1 || resp.Header.Get("Content-Length") == ""
if isStreaming {
streamTTSResponse(c, resp)
} else {
c.Writer.WriteHeaderNow()
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
logger.LogError(c, err.Error())
}
}
return usage
}
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
// 写入新的 response body
service.IOCopyBytesGracefully(c, resp, responseBody)
var responseData struct {
Usage *dto.Usage `json:"usage"`
}
if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
if responseData.Usage.TotalTokens > 0 {
usage := responseData.Usage
if usage.PromptTokens == 0 {
usage.PromptTokens = usage.InputTokens
}
if usage.CompletionTokens == 0 {
usage.CompletionTokens = usage.OutputTokens
}
return nil, usage
}
}
usage := &dto.Usage{}
usage.PromptTokens = info.GetEstimatePromptTokens()
usage.CompletionTokens = 0
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
}
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
@@ -687,7 +624,7 @@ func extractCachedTokensFromBody(body []byte) (int, bool) {
} `json:"usage"`
}
if err := json.Unmarshal(body, &payload); err != nil {
if err := common.Unmarshal(body, &payload); err != nil {
return 0, false
}