package openai import ( "bytes" "fmt" "io" "math" "mime/multipart" "net/http" "one-api/common" "one-api/constant" "one-api/dto" "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" "os" "path/filepath" "strings" "one-api/types" "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/pkg/errors" ) func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { if data == "" { return nil } if !forceFormat && !thinkToContent { return helper.StringData(c, data) } var lastStreamResponse dto.ChatCompletionsStreamResponse if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil { return err } if !thinkToContent { return helper.ObjectData(c, lastStreamResponse) } hasThinkingContent := false hasContent := false var thinkingContent strings.Builder for _, choice := range lastStreamResponse.Choices { if len(choice.Delta.GetReasoningContent()) > 0 { hasThinkingContent = true thinkingContent.WriteString(choice.Delta.GetReasoningContent()) } if len(choice.Delta.GetContentString()) > 0 { hasContent = true } } // Handle think to content conversion if info.ThinkingContentInfo.IsFirstThinkingContent { if hasThinkingContent { response := lastStreamResponse.Copy() for i := range response.Choices { // send `think` tag with thinking content response.Choices[i].Delta.SetContentString("\n" + thinkingContent.String()) response.Choices[i].Delta.ReasoningContent = nil response.Choices[i].Delta.Reasoning = nil } info.ThinkingContentInfo.IsFirstThinkingContent = false info.ThinkingContentInfo.HasSentThinkingContent = true return helper.ObjectData(c, response) } } if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 { return helper.ObjectData(c, lastStreamResponse) } // Process each choice for i, choice := range lastStreamResponse.Choices { // Handle transition from thinking to content // only send `` tag when previous thinking content has been sent if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent { response := lastStreamResponse.Copy() for j := range response.Choices { response.Choices[j].Delta.SetContentString("\n\n") response.Choices[j].Delta.ReasoningContent = nil response.Choices[j].Delta.Reasoning = nil } info.ThinkingContentInfo.SendLastThinkingContent = true helper.ObjectData(c, response) } // Convert reasoning content to regular content if any if len(choice.Delta.GetReasoningContent()) > 0 { lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent()) lastStreamResponse.Choices[i].Delta.ReasoningContent = nil lastStreamResponse.Choices[i].Delta.Reasoning = nil } else if !hasThinkingContent && !hasContent { // flush thinking content lastStreamResponse.Choices[i].Delta.ReasoningContent = nil lastStreamResponse.Choices[i].Delta.Reasoning = nil } } return helper.ObjectData(c, lastStreamResponse) } func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { if resp == nil || resp.Body == nil { logger.LogError(c, "invalid response or response body") return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError) } defer service.CloseResponseBodyGracefully(resp) model := info.UpstreamModelName var responseId string var createAt int64 = 0 var systemFingerprint string var containStreamUsage bool var responseTextBuilder strings.Builder var toolCount int var usage = &dto.Usage{} var streamItems []string // store stream items var lastStreamData string helper.StreamScannerHandler(c, resp, info, func(data string) bool { if lastStreamData != "" { err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) if err != nil { common.SysLog("error handling stream format: " + err.Error()) } } if len(data) > 0 { lastStreamData = data streamItems = append(streamItems, data) } return true }) // 处理最后的响应 shouldSendLastResp := true if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage, &containStreamUsage, info, &shouldSendLastResp); err != nil { logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData)) } if info.RelayFormat == types.RelayFormatOpenAI { if shouldSendLastResp { _ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) } } // 处理token计算 if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil { logger.LogError(c, "error processing tokens: "+err.Error()) } if !containStreamUsage { usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage.CompletionTokens += toolCount * 7 } else { if info.ChannelType == constant.ChannelTypeDeepSeek { if usage.PromptCacheHitTokens != 0 { usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens } } } HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage) return usage, nil } func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { defer service.CloseResponseBodyGracefully(resp) var simpleResponse dto.OpenAITextResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } if common.DebugEnabled { println("upstream response body:", string(responseBody)) } err = common.Unmarshal(responseBody, &simpleResponse) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" { return nil, types.WithOpenAIError(*oaiError, resp.StatusCode) } forceFormat := false if info.ChannelSetting.ForceFormat { forceFormat = true } usageModified := false if simpleResponse.Usage.PromptTokens == 0 { completionTokens := simpleResponse.Usage.CompletionTokens if completionTokens == 0 { for _, choice := range simpleResponse.Choices { ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) completionTokens += ctkm } } simpleResponse.Usage = dto.Usage{ PromptTokens: info.PromptTokens, CompletionTokens: completionTokens, TotalTokens: info.PromptTokens + completionTokens, } usageModified = true } switch info.RelayFormat { case types.RelayFormatOpenAI: if usageModified { var bodyMap map[string]interface{} err = common.Unmarshal(responseBody, &bodyMap) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } bodyMap["usage"] = simpleResponse.Usage responseBody, _ = common.Marshal(bodyMap) } if forceFormat { responseBody, err = common.Marshal(simpleResponse) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } } else { break } case types.RelayFormatClaude: claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info) claudeRespStr, err := common.Marshal(claudeResp) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } responseBody = claudeRespStr case types.RelayFormatGemini: geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info) geminiRespStr, err := common.Marshal(geminiResp) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } responseBody = geminiRespStr } service.IOCopyBytesGracefully(c, resp, responseBody) return &simpleResponse.Usage, nil } func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage { // the status code has been judged before, if there is a body reading failure, // it should be regarded as a non-recoverable error, so it should not return err for external retry. // Analogous to nginx's load balancing, it will only retry if it can't be requested or // if the upstream returns a specific status code, once the upstream has already written the header, // the subsequent failure of the response body should be regarded as a non-recoverable error, // and can be terminated directly. defer service.CloseResponseBodyGracefully(resp) usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens usage.TotalTokens = info.PromptTokens for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeaderNow() _, err := io.Copy(c.Writer, resp.Body) if err != nil { logger.LogError(c, err.Error()) } return usage } func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) { defer service.CloseResponseBodyGracefully(resp) // count tokens by audio file duration audioTokens, err := countAudioTokens(c) if err != nil { return types.NewError(err, types.ErrorCodeCountTokenFailed), nil } responseBody, err := io.ReadAll(resp.Body) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } // 写入新的 response body service.IOCopyBytesGracefully(c, resp, responseBody) usage := &dto.Usage{} usage.PromptTokens = audioTokens usage.CompletionTokens = 0 usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return nil, usage } func countAudioTokens(c *gin.Context) (int, error) { body, err := common.GetRequestBody(c) if err != nil { return 0, errors.WithStack(err) } var reqBody struct { File *multipart.FileHeader `form:"file" binding:"required"` } c.Request.Body = io.NopCloser(bytes.NewReader(body)) if err = c.ShouldBind(&reqBody); err != nil { return 0, errors.WithStack(err) } ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名 reqFp, err := reqBody.File.Open() if err != nil { return 0, errors.WithStack(err) } defer reqFp.Close() tmpFp, err := os.CreateTemp("", "audio-*"+ext) if err != nil { return 0, errors.WithStack(err) } defer os.Remove(tmpFp.Name()) _, err = io.Copy(tmpFp, reqFp) if err != nil { return 0, errors.WithStack(err) } if err = tmpFp.Close(); err != nil { return 0, errors.WithStack(err) } duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext) if err != nil { return 0, errors.WithStack(err) } return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens } func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) { if info == nil || info.ClientWs == nil || info.TargetWs == nil { return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil } info.IsStream = true clientConn := info.ClientWs targetConn := info.TargetWs clientClosed := make(chan struct{}) targetClosed := make(chan struct{}) sendChan := make(chan []byte, 100) receiveChan := make(chan []byte, 100) errChan := make(chan error, 2) usage := &dto.RealtimeUsage{} localUsage := &dto.RealtimeUsage{} sumUsage := &dto.RealtimeUsage{} gopool.Go(func() { defer func() { if r := recover(); r != nil { errChan <- fmt.Errorf("panic in client reader: %v", r) } }() for { select { case <-c.Done(): return default: _, message, err := clientConn.ReadMessage() if err != nil { if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { errChan <- fmt.Errorf("error reading from client: %v", err) } close(clientClosed) return } realtimeEvent := &dto.RealtimeEvent{} err = common.Unmarshal(message, realtimeEvent) if err != nil { errChan <- fmt.Errorf("error unmarshalling message: %v", err) return } if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate { if realtimeEvent.Session != nil { if realtimeEvent.Session.Tools != nil { info.RealtimeTools = realtimeEvent.Session.Tools } } } textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) if err != nil { errChan <- fmt.Errorf("error counting text token: %v", err) return } logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken localUsage.InputTokens += textToken + audioToken localUsage.InputTokenDetails.TextTokens += textToken localUsage.InputTokenDetails.AudioTokens += audioToken err = helper.WssString(c, targetConn, string(message)) if err != nil { errChan <- fmt.Errorf("error writing to target: %v", err) return } select { case sendChan <- message: default: } } } }) gopool.Go(func() { defer func() { if r := recover(); r != nil { errChan <- fmt.Errorf("panic in target reader: %v", r) } }() for { select { case <-c.Done(): return default: _, message, err := targetConn.ReadMessage() if err != nil { if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { errChan <- fmt.Errorf("error reading from target: %v", err) } close(targetClosed) return } info.SetFirstResponseTime() realtimeEvent := &dto.RealtimeEvent{} err = common.Unmarshal(message, realtimeEvent) if err != nil { errChan <- fmt.Errorf("error unmarshalling message: %v", err) return } if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone { realtimeUsage := realtimeEvent.Response.Usage if realtimeUsage != nil { usage.TotalTokens += realtimeUsage.TotalTokens usage.InputTokens += realtimeUsage.InputTokens usage.OutputTokens += realtimeUsage.OutputTokens usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens err := preConsumeUsage(c, info, usage, sumUsage) if err != nil { errChan <- fmt.Errorf("error consume usage: %v", err) return } // 本次计费完成,清除 usage = &dto.RealtimeUsage{} localUsage = &dto.RealtimeUsage{} } else { textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) if err != nil { errChan <- fmt.Errorf("error counting text token: %v", err) return } logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken info.IsFirstRequest = false localUsage.InputTokens += textToken + audioToken localUsage.InputTokenDetails.TextTokens += textToken localUsage.InputTokenDetails.AudioTokens += audioToken err = preConsumeUsage(c, info, localUsage, sumUsage) if err != nil { errChan <- fmt.Errorf("error consume usage: %v", err) return } // 本次计费完成,清除 localUsage = &dto.RealtimeUsage{} // print now usage } logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { realtimeSession := realtimeEvent.Session if realtimeSession != nil { // update audio format info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat) info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat) } } else { textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) if err != nil { errChan <- fmt.Errorf("error counting text token: %v", err) return } logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken localUsage.OutputTokens += textToken + audioToken localUsage.OutputTokenDetails.TextTokens += textToken localUsage.OutputTokenDetails.AudioTokens += audioToken } err = helper.WssString(c, clientConn, string(message)) if err != nil { errChan <- fmt.Errorf("error writing to client: %v", err) return } select { case receiveChan <- message: default: } } } }) select { case <-clientClosed: case <-targetClosed: case err := <-errChan: //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil logger.LogError(c, "realtime error: "+err.Error()) case <-c.Done(): } if usage.TotalTokens != 0 { _ = preConsumeUsage(c, info, usage, sumUsage) } if localUsage.TotalTokens != 0 { _ = preConsumeUsage(c, info, localUsage, sumUsage) } // check usage total tokens, if 0, use local usage return nil, sumUsage } func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error { if usage == nil || totalUsage == nil { return fmt.Errorf("invalid usage pointer") } totalUsage.TotalTokens += usage.TotalTokens totalUsage.InputTokens += usage.InputTokens totalUsage.OutputTokens += usage.OutputTokens totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens // clear usage err := service.PreWssConsumeQuota(ctx, info, usage) return err } func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { defer service.CloseResponseBodyGracefully(resp) responseBody, err := io.ReadAll(resp.Body) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } var usageResp dto.SimpleResponse err = common.Unmarshal(responseBody, &usageResp) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } // 写入新的 response body service.IOCopyBytesGracefully(c, resp, responseBody) // Once we've written to the client, we should not return errors anymore // because the upstream has already consumed resources and returned content // We should still perform billing even if parsing fails // format if usageResp.InputTokens > 0 { usageResp.PromptTokens += usageResp.InputTokens } if usageResp.OutputTokens > 0 { usageResp.CompletionTokens += usageResp.OutputTokens } if usageResp.InputTokensDetails != nil { usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens } return &usageResp.Usage, nil }