From 2abf05b31409fbeeebcda8d33362577565e4827e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B2=88=E6=B5=A9?= Date: Fri, 17 Jan 2025 18:12:05 +0800 Subject: [PATCH] fix: incorrect whisper audio usage --- Dockerfile | 2 +- common/utils.go | 33 ++++++++++ relay/channel/openai/relay-openai.go | 97 +++++++++++++--------------- 3 files changed, 78 insertions(+), 54 deletions(-) diff --git a/Dockerfile b/Dockerfile index eebf982a..44a7837a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,7 +24,7 @@ FROM alpine RUN apk update \ && apk upgrade \ - && apk add --no-cache ca-certificates tzdata \ + && apk add --no-cache ca-certificates tzdata ffmpeg\ && update-ca-certificates 2>/dev/null || true COPY --from=builder2 /build/one-api / diff --git a/common/utils.go b/common/utils.go index 26c8236b..fb769a7c 100644 --- a/common/utils.go +++ b/common/utils.go @@ -1,14 +1,19 @@ package common import ( + "bytes" + "context" crand "crypto/rand" "encoding/base64" "fmt" + "github.com/pkg/errors" "html/template" + "io" "log" "math/big" "math/rand" "net" + "os" "os/exec" "runtime" "strconv" @@ -207,3 +212,31 @@ func RandomSleep() { // Sleep for 0-3000 ms time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond) } + +// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string. +func SaveTmpFile(filename string, data io.Reader) (string, error) { + f, err := os.CreateTemp(os.TempDir(), filename) + if err != nil { + return "", errors.Wrapf(err, "failed to create temporary file %s", filename) + } + defer f.Close() + + _, err = io.Copy(f, data) + if err != nil { + return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename) + } + + return f.Name(), nil +} + +// GetAudioDuration returns the duration of an audio file in seconds. +func GetAudioDuration(ctx context.Context, filename string) (float64, error) { + // ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}} + c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename) + output, err := c.Output() + if err != nil { + return 0, errors.Wrap(err, "failed to get audio duration") + } + + return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64) +} diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index d8b1aef3..537ccb32 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -5,7 +5,10 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/pkg/errors" "io" + "math" + "mime/multipart" "net/http" "one-api/common" "one-api/constant" @@ -13,6 +16,7 @@ import ( relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/service" + "os" "strings" "sync" "time" @@ -316,6 +320,11 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + // count tokens by audio file duration + audioTokens, err := countAudioTokens(c) + if err != nil { + return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil + } responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil @@ -340,70 +349,52 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } resp.Body.Close() - var text string - switch responseFormat { - case "json": - text, err = getTextFromJSON(responseBody) - case "text": - text, err = getTextFromText(responseBody) - case "srt": - text, err = getTextFromSRT(responseBody) - case "verbose_json": - text, err = getTextFromVerboseJSON(responseBody) - case "vtt": - text, err = getTextFromVTT(responseBody) - } - usage := &dto.Usage{} - usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName) + usage.PromptTokens = audioTokens + usage.CompletionTokens = 0 usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return nil, usage } -func getTextFromVTT(body []byte) (string, error) { - return getTextFromSRT(body) -} - -func getTextFromVerboseJSON(body []byte) (string, error) { - var whisperResponse dto.WhisperVerboseJSONResponse - if err := json.Unmarshal(body, &whisperResponse); err != nil { - return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) +func countAudioTokens(c *gin.Context) (int, error) { + body, err := common.GetRequestBody(c) + if err != nil { + return 0, errors.WithStack(err) } - return whisperResponse.Text, nil -} -func getTextFromSRT(body []byte) (string, error) { - scanner := bufio.NewScanner(strings.NewReader(string(body))) - var builder strings.Builder - var textLine bool - for scanner.Scan() { - line := scanner.Text() - if textLine { - builder.WriteString(line) - textLine = false - continue - } else if strings.Contains(line, "-->") { - textLine = true - continue - } + var reqBody struct { + File *multipart.FileHeader `form:"file" binding:"required"` } - if err := scanner.Err(); err != nil { - return "", err + c.Request.Body = io.NopCloser(bytes.NewReader(body)) + if err = c.ShouldBind(&reqBody); err != nil { + return 0, errors.WithStack(err) } - return builder.String(), nil -} -func getTextFromText(body []byte) (string, error) { - return strings.TrimSuffix(string(body), "\n"), nil -} - -func getTextFromJSON(body []byte) (string, error) { - var whisperResponse dto.AudioResponse - if err := json.Unmarshal(body, &whisperResponse); err != nil { - return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + reqFp, err := reqBody.File.Open() + if err != nil { + return 0, errors.WithStack(err) } - return whisperResponse.Text, nil + + tmpFp, err := os.CreateTemp("", "audio-*") + if err != nil { + return 0, errors.WithStack(err) + } + defer os.Remove(tmpFp.Name()) + + _, err = io.Copy(tmpFp, reqFp) + if err != nil { + return 0, errors.WithStack(err) + } + if err = tmpFp.Close(); err != nil { + return 0, errors.WithStack(err) + } + + duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name()) + if err != nil { + return 0, errors.WithStack(err) + } + + return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens } func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {