fix: incorrect whisper audio usage
This commit is contained in:
@@ -5,7 +5,10 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"io"
|
||||
"math"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -13,6 +16,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -316,6 +320,11 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
// count tokens by audio file duration
|
||||
audioTokens, err := countAudioTokens(c)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -340,70 +349,52 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
var text string
|
||||
switch responseFormat {
|
||||
case "json":
|
||||
text, err = getTextFromJSON(responseBody)
|
||||
case "text":
|
||||
text, err = getTextFromText(responseBody)
|
||||
case "srt":
|
||||
text, err = getTextFromSRT(responseBody)
|
||||
case "verbose_json":
|
||||
text, err = getTextFromVerboseJSON(responseBody)
|
||||
case "vtt":
|
||||
text, err = getTextFromVTT(responseBody)
|
||||
}
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
|
||||
usage.PromptTokens = audioTokens
|
||||
usage.CompletionTokens = 0
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func getTextFromVTT(body []byte) (string, error) {
|
||||
return getTextFromSRT(body)
|
||||
}
|
||||
|
||||
func getTextFromVerboseJSON(body []byte) (string, error) {
|
||||
var whisperResponse dto.WhisperVerboseJSONResponse
|
||||
if err := json.Unmarshal(body, &whisperResponse); err != nil {
|
||||
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
|
||||
func countAudioTokens(c *gin.Context) (int, error) {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
return whisperResponse.Text, nil
|
||||
}
|
||||
|
||||
func getTextFromSRT(body []byte) (string, error) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(string(body)))
|
||||
var builder strings.Builder
|
||||
var textLine bool
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if textLine {
|
||||
builder.WriteString(line)
|
||||
textLine = false
|
||||
continue
|
||||
} else if strings.Contains(line, "-->") {
|
||||
textLine = true
|
||||
continue
|
||||
}
|
||||
var reqBody struct {
|
||||
File *multipart.FileHeader `form:"file" binding:"required"`
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", err
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
||||
if err = c.ShouldBind(&reqBody); err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
func getTextFromText(body []byte) (string, error) {
|
||||
return strings.TrimSuffix(string(body), "\n"), nil
|
||||
}
|
||||
|
||||
func getTextFromJSON(body []byte) (string, error) {
|
||||
var whisperResponse dto.AudioResponse
|
||||
if err := json.Unmarshal(body, &whisperResponse); err != nil {
|
||||
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
|
||||
reqFp, err := reqBody.File.Open()
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
return whisperResponse.Text, nil
|
||||
|
||||
tmpFp, err := os.CreateTemp("", "audio-*")
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
defer os.Remove(tmpFp.Name())
|
||||
|
||||
_, err = io.Copy(tmpFp, reqFp)
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
if err = tmpFp.Close(); err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
|
||||
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name())
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
|
||||
}
|
||||
|
||||
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
|
||||
|
||||
Reference in New Issue
Block a user