This update replaces instances of DecodeJson and DecodeJsonStr with UnmarshalJson and UnmarshalJsonStr in various relay handlers, enhancing code consistency and clarity in JSON processing. The changes improve maintainability and align with recent refactoring efforts in the codebase.
589 lines
20 KiB
Go
589 lines
20 KiB
Go
package openai
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/constant"
|
|
"one-api/dto"
|
|
relaycommon "one-api/relay/common"
|
|
"one-api/relay/helper"
|
|
"one-api/service"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/bytedance/gopkg/util/gopool"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
|
if data == "" {
|
|
return nil
|
|
}
|
|
|
|
if !forceFormat && !thinkToContent {
|
|
return helper.StringData(c, data)
|
|
}
|
|
|
|
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
|
if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
|
|
return err
|
|
}
|
|
|
|
if !thinkToContent {
|
|
return helper.ObjectData(c, lastStreamResponse)
|
|
}
|
|
|
|
hasThinkingContent := false
|
|
hasContent := false
|
|
var thinkingContent strings.Builder
|
|
for _, choice := range lastStreamResponse.Choices {
|
|
if len(choice.Delta.GetReasoningContent()) > 0 {
|
|
hasThinkingContent = true
|
|
thinkingContent.WriteString(choice.Delta.GetReasoningContent())
|
|
}
|
|
if len(choice.Delta.GetContentString()) > 0 {
|
|
hasContent = true
|
|
}
|
|
}
|
|
|
|
// Handle think to content conversion
|
|
if info.ThinkingContentInfo.IsFirstThinkingContent {
|
|
if hasThinkingContent {
|
|
response := lastStreamResponse.Copy()
|
|
for i := range response.Choices {
|
|
// send `think` tag with thinking content
|
|
response.Choices[i].Delta.SetContentString("<think>\n" + thinkingContent.String())
|
|
response.Choices[i].Delta.ReasoningContent = nil
|
|
response.Choices[i].Delta.Reasoning = nil
|
|
}
|
|
info.ThinkingContentInfo.IsFirstThinkingContent = false
|
|
info.ThinkingContentInfo.HasSentThinkingContent = true
|
|
return helper.ObjectData(c, response)
|
|
}
|
|
}
|
|
|
|
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
|
|
return helper.ObjectData(c, lastStreamResponse)
|
|
}
|
|
|
|
// Process each choice
|
|
for i, choice := range lastStreamResponse.Choices {
|
|
// Handle transition from thinking to content
|
|
// only send `</think>` tag when previous thinking content has been sent
|
|
if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
|
|
response := lastStreamResponse.Copy()
|
|
for j := range response.Choices {
|
|
response.Choices[j].Delta.SetContentString("\n</think>\n")
|
|
response.Choices[j].Delta.ReasoningContent = nil
|
|
response.Choices[j].Delta.Reasoning = nil
|
|
}
|
|
info.ThinkingContentInfo.SendLastThinkingContent = true
|
|
helper.ObjectData(c, response)
|
|
}
|
|
|
|
// Convert reasoning content to regular content if any
|
|
if len(choice.Delta.GetReasoningContent()) > 0 {
|
|
lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
|
|
lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
|
|
lastStreamResponse.Choices[i].Delta.Reasoning = nil
|
|
} else if !hasThinkingContent && !hasContent {
|
|
// flush thinking content
|
|
lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
|
|
lastStreamResponse.Choices[i].Delta.Reasoning = nil
|
|
}
|
|
}
|
|
|
|
return helper.ObjectData(c, lastStreamResponse)
|
|
}
|
|
|
|
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
if resp == nil || resp.Body == nil {
|
|
common.LogError(c, "invalid response or response body")
|
|
return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
|
|
}
|
|
|
|
defer common.CloseResponseBodyGracefully(resp)
|
|
|
|
model := info.UpstreamModelName
|
|
var responseId string
|
|
var createAt int64 = 0
|
|
var systemFingerprint string
|
|
var containStreamUsage bool
|
|
var responseTextBuilder strings.Builder
|
|
var toolCount int
|
|
var usage = &dto.Usage{}
|
|
var streamItems []string // store stream items
|
|
var forceFormat bool
|
|
var thinkToContent bool
|
|
|
|
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
|
|
forceFormat = forceFmt
|
|
}
|
|
|
|
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
|
|
thinkToContent = think2Content
|
|
}
|
|
|
|
var (
|
|
lastStreamData string
|
|
)
|
|
|
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
if lastStreamData != "" {
|
|
err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
|
|
if err != nil {
|
|
common.SysError("error handling stream format: " + err.Error())
|
|
}
|
|
}
|
|
lastStreamData = data
|
|
streamItems = append(streamItems, data)
|
|
return true
|
|
})
|
|
|
|
// 处理最后的响应
|
|
shouldSendLastResp := true
|
|
if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
|
|
&containStreamUsage, info, &shouldSendLastResp); err != nil {
|
|
common.SysError("error handling last response: " + err.Error())
|
|
}
|
|
|
|
if shouldSendLastResp && info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
|
_ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
|
}
|
|
|
|
// 处理token计算
|
|
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
|
|
common.SysError("error processing tokens: " + err.Error())
|
|
}
|
|
|
|
if !containStreamUsage {
|
|
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
|
usage.CompletionTokens += toolCount * 7
|
|
} else {
|
|
if info.ChannelType == common.ChannelTypeDeepSeek {
|
|
if usage.PromptCacheHitTokens != 0 {
|
|
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
|
}
|
|
}
|
|
}
|
|
|
|
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
|
|
|
return nil, usage
|
|
}
|
|
|
|
func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
defer common.CloseResponseBodyGracefully(resp)
|
|
|
|
var simpleResponse dto.OpenAITextResponse
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
err = common.UnmarshalJson(responseBody, &simpleResponse)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
|
|
return &dto.OpenAIErrorWithStatusCode{
|
|
Error: *simpleResponse.Error,
|
|
StatusCode: resp.StatusCode,
|
|
}, nil
|
|
}
|
|
|
|
forceFormat := false
|
|
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
|
|
forceFormat = forceFmt
|
|
}
|
|
|
|
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
|
completionTokens := 0
|
|
for _, choice := range simpleResponse.Choices {
|
|
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
|
completionTokens += ctkm
|
|
}
|
|
simpleResponse.Usage = dto.Usage{
|
|
PromptTokens: info.PromptTokens,
|
|
CompletionTokens: completionTokens,
|
|
TotalTokens: info.PromptTokens + completionTokens,
|
|
}
|
|
}
|
|
|
|
switch info.RelayFormat {
|
|
case relaycommon.RelayFormatOpenAI:
|
|
if forceFormat {
|
|
responseBody, err = common.EncodeJson(simpleResponse)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
} else {
|
|
break
|
|
}
|
|
case relaycommon.RelayFormatClaude:
|
|
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
|
|
claudeRespStr, err := common.EncodeJson(claudeResp)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
responseBody = claudeRespStr
|
|
}
|
|
|
|
common.IOCopyBytesGracefully(c, resp, responseBody)
|
|
|
|
return nil, &simpleResponse.Usage
|
|
}
|
|
|
|
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
// the status code has been judged before, if there is a body reading failure,
|
|
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
|
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
|
// if the upstream returns a specific status code, once the upstream has already written the header,
|
|
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
|
// and can be terminated directly.
|
|
defer common.CloseResponseBodyGracefully(resp)
|
|
usage := &dto.Usage{}
|
|
usage.PromptTokens = info.PromptTokens
|
|
usage.TotalTokens = info.PromptTokens
|
|
for k, v := range resp.Header {
|
|
c.Writer.Header().Set(k, v[0])
|
|
}
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
c.Writer.WriteHeaderNow()
|
|
_, err := io.Copy(c.Writer, resp.Body)
|
|
if err != nil {
|
|
common.LogError(c, err.Error())
|
|
}
|
|
return nil, usage
|
|
}
|
|
|
|
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
defer common.CloseResponseBodyGracefully(resp)
|
|
|
|
// count tokens by audio file duration
|
|
audioTokens, err := countAudioTokens(c)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil
|
|
}
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
// 写入新的 response body
|
|
common.IOCopyBytesGracefully(c, resp, responseBody)
|
|
|
|
usage := &dto.Usage{}
|
|
usage.PromptTokens = audioTokens
|
|
usage.CompletionTokens = 0
|
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
return nil, usage
|
|
}
|
|
|
|
func countAudioTokens(c *gin.Context) (int, error) {
|
|
body, err := common.GetRequestBody(c)
|
|
if err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
|
|
var reqBody struct {
|
|
File *multipart.FileHeader `form:"file" binding:"required"`
|
|
}
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
|
if err = c.ShouldBind(&reqBody); err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
|
|
reqFp, err := reqBody.File.Open()
|
|
if err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
defer reqFp.Close()
|
|
|
|
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
|
if err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
defer os.Remove(tmpFp.Name())
|
|
|
|
_, err = io.Copy(tmpFp, reqFp)
|
|
if err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
if err = tmpFp.Close(); err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
|
|
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
|
|
if err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
|
|
return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
|
|
}
|
|
|
|
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
|
|
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
|
|
return service.OpenAIErrorWrapper(fmt.Errorf("invalid websocket connection"), "invalid_connection", http.StatusBadRequest), nil
|
|
}
|
|
|
|
info.IsStream = true
|
|
clientConn := info.ClientWs
|
|
targetConn := info.TargetWs
|
|
|
|
clientClosed := make(chan struct{})
|
|
targetClosed := make(chan struct{})
|
|
sendChan := make(chan []byte, 100)
|
|
receiveChan := make(chan []byte, 100)
|
|
errChan := make(chan error, 2)
|
|
|
|
usage := &dto.RealtimeUsage{}
|
|
localUsage := &dto.RealtimeUsage{}
|
|
sumUsage := &dto.RealtimeUsage{}
|
|
|
|
gopool.Go(func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
errChan <- fmt.Errorf("panic in client reader: %v", r)
|
|
}
|
|
}()
|
|
for {
|
|
select {
|
|
case <-c.Done():
|
|
return
|
|
default:
|
|
_, message, err := clientConn.ReadMessage()
|
|
if err != nil {
|
|
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
|
errChan <- fmt.Errorf("error reading from client: %v", err)
|
|
}
|
|
close(clientClosed)
|
|
return
|
|
}
|
|
|
|
realtimeEvent := &dto.RealtimeEvent{}
|
|
err = common.UnmarshalJson(message, realtimeEvent)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
|
return
|
|
}
|
|
|
|
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
|
|
if realtimeEvent.Session != nil {
|
|
if realtimeEvent.Session.Tools != nil {
|
|
info.RealtimeTools = realtimeEvent.Session.Tools
|
|
}
|
|
}
|
|
}
|
|
|
|
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error counting text token: %v", err)
|
|
return
|
|
}
|
|
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
|
localUsage.TotalTokens += textToken + audioToken
|
|
localUsage.InputTokens += textToken + audioToken
|
|
localUsage.InputTokenDetails.TextTokens += textToken
|
|
localUsage.InputTokenDetails.AudioTokens += audioToken
|
|
|
|
err = helper.WssString(c, targetConn, string(message))
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error writing to target: %v", err)
|
|
return
|
|
}
|
|
|
|
select {
|
|
case sendChan <- message:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
gopool.Go(func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
errChan <- fmt.Errorf("panic in target reader: %v", r)
|
|
}
|
|
}()
|
|
for {
|
|
select {
|
|
case <-c.Done():
|
|
return
|
|
default:
|
|
_, message, err := targetConn.ReadMessage()
|
|
if err != nil {
|
|
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
|
errChan <- fmt.Errorf("error reading from target: %v", err)
|
|
}
|
|
close(targetClosed)
|
|
return
|
|
}
|
|
info.SetFirstResponseTime()
|
|
realtimeEvent := &dto.RealtimeEvent{}
|
|
err = common.UnmarshalJson(message, realtimeEvent)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
|
return
|
|
}
|
|
|
|
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
|
|
realtimeUsage := realtimeEvent.Response.Usage
|
|
if realtimeUsage != nil {
|
|
usage.TotalTokens += realtimeUsage.TotalTokens
|
|
usage.InputTokens += realtimeUsage.InputTokens
|
|
usage.OutputTokens += realtimeUsage.OutputTokens
|
|
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
|
|
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
|
|
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
|
|
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
|
|
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
|
|
err := preConsumeUsage(c, info, usage, sumUsage)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error consume usage: %v", err)
|
|
return
|
|
}
|
|
// 本次计费完成,清除
|
|
usage = &dto.RealtimeUsage{}
|
|
|
|
localUsage = &dto.RealtimeUsage{}
|
|
} else {
|
|
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error counting text token: %v", err)
|
|
return
|
|
}
|
|
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
|
localUsage.TotalTokens += textToken + audioToken
|
|
info.IsFirstRequest = false
|
|
localUsage.InputTokens += textToken + audioToken
|
|
localUsage.InputTokenDetails.TextTokens += textToken
|
|
localUsage.InputTokenDetails.AudioTokens += audioToken
|
|
err = preConsumeUsage(c, info, localUsage, sumUsage)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error consume usage: %v", err)
|
|
return
|
|
}
|
|
// 本次计费完成,清除
|
|
localUsage = &dto.RealtimeUsage{}
|
|
// print now usage
|
|
}
|
|
common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
|
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
|
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
|
|
|
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
|
|
realtimeSession := realtimeEvent.Session
|
|
if realtimeSession != nil {
|
|
// update audio format
|
|
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
|
|
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
|
|
}
|
|
} else {
|
|
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error counting text token: %v", err)
|
|
return
|
|
}
|
|
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
|
localUsage.TotalTokens += textToken + audioToken
|
|
localUsage.OutputTokens += textToken + audioToken
|
|
localUsage.OutputTokenDetails.TextTokens += textToken
|
|
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
|
}
|
|
|
|
err = helper.WssString(c, clientConn, string(message))
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error writing to client: %v", err)
|
|
return
|
|
}
|
|
|
|
select {
|
|
case receiveChan <- message:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
select {
|
|
case <-clientClosed:
|
|
case <-targetClosed:
|
|
case err := <-errChan:
|
|
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
|
|
common.LogError(c, "realtime error: "+err.Error())
|
|
case <-c.Done():
|
|
}
|
|
|
|
if usage.TotalTokens != 0 {
|
|
_ = preConsumeUsage(c, info, usage, sumUsage)
|
|
}
|
|
|
|
if localUsage.TotalTokens != 0 {
|
|
_ = preConsumeUsage(c, info, localUsage, sumUsage)
|
|
}
|
|
|
|
// check usage total tokens, if 0, use local usage
|
|
|
|
return nil, sumUsage
|
|
}
|
|
|
|
func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
|
|
if usage == nil || totalUsage == nil {
|
|
return fmt.Errorf("invalid usage pointer")
|
|
}
|
|
|
|
totalUsage.TotalTokens += usage.TotalTokens
|
|
totalUsage.InputTokens += usage.InputTokens
|
|
totalUsage.OutputTokens += usage.OutputTokens
|
|
totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
|
|
totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
|
|
totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
|
|
totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
|
|
totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
|
|
// clear usage
|
|
err := service.PreWssConsumeQuota(ctx, info, usage)
|
|
return err
|
|
}
|
|
|
|
func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
defer common.CloseResponseBodyGracefully(resp)
|
|
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
|
|
var usageResp dto.SimpleResponse
|
|
err = common.UnmarshalJson(responseBody, &usageResp)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
|
|
// 写入新的 response body
|
|
common.IOCopyBytesGracefully(c, resp, responseBody)
|
|
|
|
// Once we've written to the client, we should not return errors anymore
|
|
// because the upstream has already consumed resources and returned content
|
|
// We should still perform billing even if parsing fails
|
|
// format
|
|
if usageResp.InputTokens > 0 {
|
|
usageResp.PromptTokens += usageResp.InputTokens
|
|
}
|
|
if usageResp.OutputTokens > 0 {
|
|
usageResp.CompletionTokens += usageResp.OutputTokens
|
|
}
|
|
if usageResp.InputTokensDetails != nil {
|
|
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
|
|
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
|
|
}
|
|
return nil, &usageResp.Usage
|
|
}
|