fix: incorrect whisper audio usage
This commit is contained in:
@@ -24,7 +24,7 @@ FROM alpine
|
|||||||
|
|
||||||
RUN apk update \
|
RUN apk update \
|
||||||
&& apk upgrade \
|
&& apk upgrade \
|
||||||
&& apk add --no-cache ca-certificates tzdata \
|
&& apk add --no-cache ca-certificates tzdata ffmpeg\
|
||||||
&& update-ca-certificates 2>/dev/null || true
|
&& update-ca-certificates 2>/dev/null || true
|
||||||
|
|
||||||
COPY --from=builder2 /build/one-api /
|
COPY --from=builder2 /build/one-api /
|
||||||
|
|||||||
@@ -1,14 +1,19 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
crand "crypto/rand"
|
crand "crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"html/template"
|
"html/template"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"math/big"
|
"math/big"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -207,3 +212,31 @@ func RandomSleep() {
|
|||||||
// Sleep for 0-3000 ms
|
// Sleep for 0-3000 ms
|
||||||
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
|
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
|
||||||
|
func SaveTmpFile(filename string, data io.Reader) (string, error) {
|
||||||
|
f, err := os.CreateTemp(os.TempDir(), filename)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrapf(err, "failed to create temporary file %s", filename)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(f, data)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.Name(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAudioDuration returns the duration of an audio file in seconds.
|
||||||
|
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
|
||||||
|
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
|
||||||
|
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
|
||||||
|
output, err := c.Output()
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.Wrap(err, "failed to get audio duration")
|
||||||
|
}
|
||||||
|
|
||||||
|
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,7 +5,10 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
@@ -13,6 +16,7 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -316,6 +320,11 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
// count tokens by audio file duration
|
||||||
|
audioTokens, err := countAudioTokens(c)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
@@ -340,70 +349,52 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
|
|
||||||
var text string
|
|
||||||
switch responseFormat {
|
|
||||||
case "json":
|
|
||||||
text, err = getTextFromJSON(responseBody)
|
|
||||||
case "text":
|
|
||||||
text, err = getTextFromText(responseBody)
|
|
||||||
case "srt":
|
|
||||||
text, err = getTextFromSRT(responseBody)
|
|
||||||
case "verbose_json":
|
|
||||||
text, err = getTextFromVerboseJSON(responseBody)
|
|
||||||
case "vtt":
|
|
||||||
text, err = getTextFromVTT(responseBody)
|
|
||||||
}
|
|
||||||
|
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage.PromptTokens = audioTokens
|
||||||
usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
|
usage.CompletionTokens = 0
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTextFromVTT(body []byte) (string, error) {
|
func countAudioTokens(c *gin.Context) (int, error) {
|
||||||
return getTextFromSRT(body)
|
body, err := common.GetRequestBody(c)
|
||||||
}
|
if err != nil {
|
||||||
|
return 0, errors.WithStack(err)
|
||||||
func getTextFromVerboseJSON(body []byte) (string, error) {
|
|
||||||
var whisperResponse dto.WhisperVerboseJSONResponse
|
|
||||||
if err := json.Unmarshal(body, &whisperResponse); err != nil {
|
|
||||||
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
|
|
||||||
}
|
}
|
||||||
return whisperResponse.Text, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTextFromSRT(body []byte) (string, error) {
|
var reqBody struct {
|
||||||
scanner := bufio.NewScanner(strings.NewReader(string(body)))
|
File *multipart.FileHeader `form:"file" binding:"required"`
|
||||||
var builder strings.Builder
|
|
||||||
var textLine bool
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
if textLine {
|
|
||||||
builder.WriteString(line)
|
|
||||||
textLine = false
|
|
||||||
continue
|
|
||||||
} else if strings.Contains(line, "-->") {
|
|
||||||
textLine = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if err := scanner.Err(); err != nil {
|
c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
return "", err
|
if err = c.ShouldBind(&reqBody); err != nil {
|
||||||
|
return 0, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
return builder.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTextFromText(body []byte) (string, error) {
|
reqFp, err := reqBody.File.Open()
|
||||||
return strings.TrimSuffix(string(body), "\n"), nil
|
if err != nil {
|
||||||
}
|
return 0, errors.WithStack(err)
|
||||||
|
|
||||||
func getTextFromJSON(body []byte) (string, error) {
|
|
||||||
var whisperResponse dto.AudioResponse
|
|
||||||
if err := json.Unmarshal(body, &whisperResponse); err != nil {
|
|
||||||
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
|
|
||||||
}
|
}
|
||||||
return whisperResponse.Text, nil
|
|
||||||
|
tmpFp, err := os.CreateTemp("", "audio-*")
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
defer os.Remove(tmpFp.Name())
|
||||||
|
|
||||||
|
_, err = io.Copy(tmpFp, reqFp)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
if err = tmpFp.Close(); err != nil {
|
||||||
|
return 0, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name())
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
|
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
|
||||||
|
|||||||
Reference in New Issue
Block a user