Changes: - Replaced error returns with logging for response body copy failures to prevent early termination of the request. - Ensured that the response body is closed properly after writing to the client. - Added comments to clarify the handling of billing and error reporting after the response has been sent. This update improves error handling and maintains resource management in the OpenAI handler.
651 lines
22 KiB
Go
651 lines
22 KiB
Go
package openai
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/constant"
|
|
"one-api/dto"
|
|
relaycommon "one-api/relay/common"
|
|
"one-api/relay/helper"
|
|
"one-api/service"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/bytedance/gopkg/util/gopool"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
|
if data == "" {
|
|
return nil
|
|
}
|
|
|
|
if !forceFormat && !thinkToContent {
|
|
return helper.StringData(c, data)
|
|
}
|
|
|
|
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
|
if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil {
|
|
return err
|
|
}
|
|
|
|
if !thinkToContent {
|
|
return helper.ObjectData(c, lastStreamResponse)
|
|
}
|
|
|
|
hasThinkingContent := false
|
|
hasContent := false
|
|
var thinkingContent strings.Builder
|
|
for _, choice := range lastStreamResponse.Choices {
|
|
if len(choice.Delta.GetReasoningContent()) > 0 {
|
|
hasThinkingContent = true
|
|
thinkingContent.WriteString(choice.Delta.GetReasoningContent())
|
|
}
|
|
if len(choice.Delta.GetContentString()) > 0 {
|
|
hasContent = true
|
|
}
|
|
}
|
|
|
|
// Handle think to content conversion
|
|
if info.ThinkingContentInfo.IsFirstThinkingContent {
|
|
if hasThinkingContent {
|
|
response := lastStreamResponse.Copy()
|
|
for i := range response.Choices {
|
|
// send `think` tag with thinking content
|
|
response.Choices[i].Delta.SetContentString("<think>\n" + thinkingContent.String())
|
|
response.Choices[i].Delta.ReasoningContent = nil
|
|
response.Choices[i].Delta.Reasoning = nil
|
|
}
|
|
info.ThinkingContentInfo.IsFirstThinkingContent = false
|
|
info.ThinkingContentInfo.HasSentThinkingContent = true
|
|
return helper.ObjectData(c, response)
|
|
}
|
|
}
|
|
|
|
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
|
|
return helper.ObjectData(c, lastStreamResponse)
|
|
}
|
|
|
|
// Process each choice
|
|
for i, choice := range lastStreamResponse.Choices {
|
|
// Handle transition from thinking to content
|
|
// only send `</think>` tag when previous thinking content has been sent
|
|
if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
|
|
response := lastStreamResponse.Copy()
|
|
for j := range response.Choices {
|
|
response.Choices[j].Delta.SetContentString("\n</think>\n")
|
|
response.Choices[j].Delta.ReasoningContent = nil
|
|
response.Choices[j].Delta.Reasoning = nil
|
|
}
|
|
info.ThinkingContentInfo.SendLastThinkingContent = true
|
|
helper.ObjectData(c, response)
|
|
}
|
|
|
|
// Convert reasoning content to regular content if any
|
|
if len(choice.Delta.GetReasoningContent()) > 0 {
|
|
lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
|
|
lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
|
|
lastStreamResponse.Choices[i].Delta.Reasoning = nil
|
|
} else if !hasThinkingContent && !hasContent {
|
|
// flush thinking content
|
|
lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
|
|
lastStreamResponse.Choices[i].Delta.Reasoning = nil
|
|
}
|
|
}
|
|
|
|
return helper.ObjectData(c, lastStreamResponse)
|
|
}
|
|
|
|
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
if resp == nil || resp.Body == nil {
|
|
common.LogError(c, "invalid response or response body")
|
|
return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
|
|
}
|
|
|
|
containStreamUsage := false
|
|
var responseId string
|
|
var createAt int64 = 0
|
|
var systemFingerprint string
|
|
model := info.UpstreamModelName
|
|
|
|
var responseTextBuilder strings.Builder
|
|
var toolCount int
|
|
var usage = &dto.Usage{}
|
|
var streamItems []string // store stream items
|
|
var forceFormat bool
|
|
var thinkToContent bool
|
|
|
|
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
|
|
forceFormat = forceFmt
|
|
}
|
|
|
|
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
|
|
thinkToContent = think2Content
|
|
}
|
|
|
|
var (
|
|
lastStreamData string
|
|
)
|
|
|
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
if lastStreamData != "" {
|
|
err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
|
|
if err != nil {
|
|
common.SysError("error handling stream format: " + err.Error())
|
|
}
|
|
}
|
|
lastStreamData = data
|
|
streamItems = append(streamItems, data)
|
|
return true
|
|
})
|
|
|
|
shouldSendLastResp := true
|
|
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
|
err := common.DecodeJsonStr(lastStreamData, &lastStreamResponse)
|
|
if err == nil {
|
|
responseId = lastStreamResponse.Id
|
|
createAt = lastStreamResponse.Created
|
|
systemFingerprint = lastStreamResponse.GetSystemFingerprint()
|
|
model = lastStreamResponse.Model
|
|
if service.ValidUsage(lastStreamResponse.Usage) {
|
|
containStreamUsage = true
|
|
usage = lastStreamResponse.Usage
|
|
if !info.ShouldIncludeUsage {
|
|
shouldSendLastResp = false
|
|
}
|
|
}
|
|
for _, choice := range lastStreamResponse.Choices {
|
|
if choice.FinishReason != nil {
|
|
shouldSendLastResp = true
|
|
}
|
|
}
|
|
}
|
|
|
|
if shouldSendLastResp {
|
|
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
|
//err = handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
|
|
}
|
|
|
|
// 处理token计算
|
|
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
|
|
common.SysError("error processing tokens: " + err.Error())
|
|
}
|
|
|
|
if !containStreamUsage {
|
|
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
|
usage.CompletionTokens += toolCount * 7
|
|
} else {
|
|
if info.ChannelType == common.ChannelTypeDeepSeek {
|
|
if usage.PromptCacheHitTokens != 0 {
|
|
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
|
}
|
|
}
|
|
}
|
|
|
|
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
|
|
|
return nil, usage
|
|
}
|
|
|
|
func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
var simpleResponse dto.OpenAITextResponse
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
err = resp.Body.Close()
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
err = common.DecodeJson(responseBody, &simpleResponse)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
|
|
return &dto.OpenAIErrorWithStatusCode{
|
|
Error: *simpleResponse.Error,
|
|
StatusCode: resp.StatusCode,
|
|
}, nil
|
|
}
|
|
|
|
forceFormat := false
|
|
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
|
|
forceFormat = forceFmt
|
|
}
|
|
|
|
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
|
completionTokens := 0
|
|
for _, choice := range simpleResponse.Choices {
|
|
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
|
completionTokens += ctkm
|
|
}
|
|
simpleResponse.Usage = dto.Usage{
|
|
PromptTokens: info.PromptTokens,
|
|
CompletionTokens: completionTokens,
|
|
TotalTokens: info.PromptTokens + completionTokens,
|
|
}
|
|
}
|
|
|
|
switch info.RelayFormat {
|
|
case relaycommon.RelayFormatOpenAI:
|
|
if forceFormat {
|
|
responseBody, err = json.Marshal(simpleResponse)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
} else {
|
|
break
|
|
}
|
|
case relaycommon.RelayFormatClaude:
|
|
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
|
|
claudeRespStr, err := json.Marshal(claudeResp)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
responseBody = claudeRespStr
|
|
}
|
|
|
|
// Reset response body
|
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
|
// And then we will have to send an error response, but in this case, the header has already been set.
|
|
// So the httpClient will be confused by the response.
|
|
// For example, Postman will report error, and we cannot check the response at all.
|
|
for k, v := range resp.Header {
|
|
c.Writer.Header().Set(k, v[0])
|
|
}
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
_, err = io.Copy(c.Writer, resp.Body)
|
|
if err != nil {
|
|
//return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
|
common.SysError("error copying response body: " + err.Error())
|
|
}
|
|
resp.Body.Close()
|
|
return nil, &simpleResponse.Usage
|
|
}
|
|
|
|
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
// the status code has been judged before, if there is a body reading failure,
|
|
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
|
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
|
// if the upstream returns a specific status code, once the upstream has already written the header,
|
|
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
|
// and can be terminated directly.
|
|
defer resp.Body.Close()
|
|
usage := &dto.Usage{}
|
|
usage.PromptTokens = info.PromptTokens
|
|
usage.TotalTokens = info.PromptTokens
|
|
for k, v := range resp.Header {
|
|
c.Writer.Header().Set(k, v[0])
|
|
}
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
c.Writer.WriteHeaderNow()
|
|
_, err := io.Copy(c.Writer, resp.Body)
|
|
if err != nil {
|
|
common.LogError(c, err.Error())
|
|
}
|
|
return nil, usage
|
|
}
|
|
|
|
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
// count tokens by audio file duration
|
|
audioTokens, err := countAudioTokens(c)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil
|
|
}
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
err = resp.Body.Close()
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
// Reset response body
|
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
|
// And then we will have to send an error response, but in this case, the header has already been set.
|
|
// So the httpClient will be confused by the response.
|
|
// For example, Postman will report error, and we cannot check the response at all.
|
|
for k, v := range resp.Header {
|
|
c.Writer.Header().Set(k, v[0])
|
|
}
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
_, err = io.Copy(c.Writer, resp.Body)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
resp.Body.Close()
|
|
|
|
usage := &dto.Usage{}
|
|
usage.PromptTokens = audioTokens
|
|
usage.CompletionTokens = 0
|
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
return nil, usage
|
|
}
|
|
|
|
func countAudioTokens(c *gin.Context) (int, error) {
|
|
body, err := common.GetRequestBody(c)
|
|
if err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
|
|
var reqBody struct {
|
|
File *multipart.FileHeader `form:"file" binding:"required"`
|
|
}
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
|
if err = c.ShouldBind(&reqBody); err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
|
|
reqFp, err := reqBody.File.Open()
|
|
if err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
defer reqFp.Close()
|
|
|
|
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
|
if err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
defer os.Remove(tmpFp.Name())
|
|
|
|
_, err = io.Copy(tmpFp, reqFp)
|
|
if err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
if err = tmpFp.Close(); err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
|
|
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
|
|
if err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
|
|
return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
|
|
}
|
|
|
|
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
|
|
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
|
|
return service.OpenAIErrorWrapper(fmt.Errorf("invalid websocket connection"), "invalid_connection", http.StatusBadRequest), nil
|
|
}
|
|
|
|
info.IsStream = true
|
|
clientConn := info.ClientWs
|
|
targetConn := info.TargetWs
|
|
|
|
clientClosed := make(chan struct{})
|
|
targetClosed := make(chan struct{})
|
|
sendChan := make(chan []byte, 100)
|
|
receiveChan := make(chan []byte, 100)
|
|
errChan := make(chan error, 2)
|
|
|
|
usage := &dto.RealtimeUsage{}
|
|
localUsage := &dto.RealtimeUsage{}
|
|
sumUsage := &dto.RealtimeUsage{}
|
|
|
|
gopool.Go(func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
errChan <- fmt.Errorf("panic in client reader: %v", r)
|
|
}
|
|
}()
|
|
for {
|
|
select {
|
|
case <-c.Done():
|
|
return
|
|
default:
|
|
_, message, err := clientConn.ReadMessage()
|
|
if err != nil {
|
|
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
|
errChan <- fmt.Errorf("error reading from client: %v", err)
|
|
}
|
|
close(clientClosed)
|
|
return
|
|
}
|
|
|
|
realtimeEvent := &dto.RealtimeEvent{}
|
|
err = json.Unmarshal(message, realtimeEvent)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
|
return
|
|
}
|
|
|
|
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
|
|
if realtimeEvent.Session != nil {
|
|
if realtimeEvent.Session.Tools != nil {
|
|
info.RealtimeTools = realtimeEvent.Session.Tools
|
|
}
|
|
}
|
|
}
|
|
|
|
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error counting text token: %v", err)
|
|
return
|
|
}
|
|
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
|
localUsage.TotalTokens += textToken + audioToken
|
|
localUsage.InputTokens += textToken + audioToken
|
|
localUsage.InputTokenDetails.TextTokens += textToken
|
|
localUsage.InputTokenDetails.AudioTokens += audioToken
|
|
|
|
err = helper.WssString(c, targetConn, string(message))
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error writing to target: %v", err)
|
|
return
|
|
}
|
|
|
|
select {
|
|
case sendChan <- message:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
gopool.Go(func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
errChan <- fmt.Errorf("panic in target reader: %v", r)
|
|
}
|
|
}()
|
|
for {
|
|
select {
|
|
case <-c.Done():
|
|
return
|
|
default:
|
|
_, message, err := targetConn.ReadMessage()
|
|
if err != nil {
|
|
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
|
errChan <- fmt.Errorf("error reading from target: %v", err)
|
|
}
|
|
close(targetClosed)
|
|
return
|
|
}
|
|
info.SetFirstResponseTime()
|
|
realtimeEvent := &dto.RealtimeEvent{}
|
|
err = json.Unmarshal(message, realtimeEvent)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
|
return
|
|
}
|
|
|
|
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
|
|
realtimeUsage := realtimeEvent.Response.Usage
|
|
if realtimeUsage != nil {
|
|
usage.TotalTokens += realtimeUsage.TotalTokens
|
|
usage.InputTokens += realtimeUsage.InputTokens
|
|
usage.OutputTokens += realtimeUsage.OutputTokens
|
|
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
|
|
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
|
|
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
|
|
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
|
|
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
|
|
err := preConsumeUsage(c, info, usage, sumUsage)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error consume usage: %v", err)
|
|
return
|
|
}
|
|
// 本次计费完成,清除
|
|
usage = &dto.RealtimeUsage{}
|
|
|
|
localUsage = &dto.RealtimeUsage{}
|
|
} else {
|
|
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error counting text token: %v", err)
|
|
return
|
|
}
|
|
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
|
localUsage.TotalTokens += textToken + audioToken
|
|
info.IsFirstRequest = false
|
|
localUsage.InputTokens += textToken + audioToken
|
|
localUsage.InputTokenDetails.TextTokens += textToken
|
|
localUsage.InputTokenDetails.AudioTokens += audioToken
|
|
err = preConsumeUsage(c, info, localUsage, sumUsage)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error consume usage: %v", err)
|
|
return
|
|
}
|
|
// 本次计费完成,清除
|
|
localUsage = &dto.RealtimeUsage{}
|
|
// print now usage
|
|
}
|
|
//common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
|
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
|
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
|
|
|
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
|
|
realtimeSession := realtimeEvent.Session
|
|
if realtimeSession != nil {
|
|
// update audio format
|
|
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
|
|
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
|
|
}
|
|
} else {
|
|
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error counting text token: %v", err)
|
|
return
|
|
}
|
|
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
|
localUsage.TotalTokens += textToken + audioToken
|
|
localUsage.OutputTokens += textToken + audioToken
|
|
localUsage.OutputTokenDetails.TextTokens += textToken
|
|
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
|
}
|
|
|
|
err = helper.WssString(c, clientConn, string(message))
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error writing to client: %v", err)
|
|
return
|
|
}
|
|
|
|
select {
|
|
case receiveChan <- message:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
select {
|
|
case <-clientClosed:
|
|
case <-targetClosed:
|
|
case err := <-errChan:
|
|
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
|
|
common.LogError(c, "realtime error: "+err.Error())
|
|
case <-c.Done():
|
|
}
|
|
|
|
if usage.TotalTokens != 0 {
|
|
_ = preConsumeUsage(c, info, usage, sumUsage)
|
|
}
|
|
|
|
if localUsage.TotalTokens != 0 {
|
|
_ = preConsumeUsage(c, info, localUsage, sumUsage)
|
|
}
|
|
|
|
// check usage total tokens, if 0, use local usage
|
|
|
|
return nil, sumUsage
|
|
}
|
|
|
|
func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
|
|
if usage == nil || totalUsage == nil {
|
|
return fmt.Errorf("invalid usage pointer")
|
|
}
|
|
|
|
totalUsage.TotalTokens += usage.TotalTokens
|
|
totalUsage.InputTokens += usage.InputTokens
|
|
totalUsage.OutputTokens += usage.OutputTokens
|
|
totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
|
|
totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
|
|
totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
|
|
totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
|
|
totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
|
|
// clear usage
|
|
err := service.PreWssConsumeQuota(ctx, info, usage)
|
|
return err
|
|
}
|
|
|
|
func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
err = resp.Body.Close()
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
// Reset response body
|
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
|
// And then we will have to send an error response, but in this case, the header has already been set.
|
|
// So the httpClient will be confused by the response.
|
|
// For example, Postman will report error, and we cannot check the response at all.
|
|
for k, v := range resp.Header {
|
|
c.Writer.Header().Set(k, v[0])
|
|
}
|
|
// reset content length
|
|
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(responseBody)))
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
_, err = io.Copy(c.Writer, resp.Body)
|
|
if err != nil {
|
|
common.SysError("error copying response body: " + err.Error())
|
|
}
|
|
_ = resp.Body.Close()
|
|
|
|
// Once we've written to the client, we should not return errors anymore
|
|
// because the upstream has already consumed resources and returned content
|
|
// We should still perform billing even if parsing fails
|
|
var usageResp dto.SimpleResponse
|
|
err = json.Unmarshal(responseBody, &usageResp)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
// format
|
|
if usageResp.InputTokens > 0 {
|
|
usageResp.PromptTokens += usageResp.InputTokens
|
|
}
|
|
if usageResp.OutputTokens > 0 {
|
|
usageResp.CompletionTokens += usageResp.OutputTokens
|
|
}
|
|
if usageResp.InputTokensDetails != nil {
|
|
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
|
|
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
|
|
}
|
|
return nil, &usageResp.Usage
|
|
}
|