197 lines
6.5 KiB
Go
197 lines
6.5 KiB
Go
package openai
|
|
|
|
import (
|
|
"encoding/json"
|
|
"one-api/common"
|
|
"one-api/dto"
|
|
relaycommon "one-api/relay/common"
|
|
relayconstant "one-api/relay/constant"
|
|
"one-api/relay/helper"
|
|
"one-api/service"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// 辅助函数
|
|
func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
|
info.SendResponseCount++
|
|
switch info.RelayFormat {
|
|
case relaycommon.RelayFormatOpenAI:
|
|
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
|
case relaycommon.RelayFormatClaude:
|
|
return handleClaudeFormat(c, data, info)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
|
var streamResponse dto.ChatCompletionsStreamResponse
|
|
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
|
return err
|
|
}
|
|
|
|
if streamResponse.Usage != nil {
|
|
info.ClaudeConvertInfo.Usage = streamResponse.Usage
|
|
}
|
|
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
|
for _, resp := range claudeResponses {
|
|
helper.ClaudeData(c, *resp)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
|
|
for _, choice := range streamResponse.Choices {
|
|
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
|
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
|
if choice.Delta.ToolCalls != nil {
|
|
if len(choice.Delta.ToolCalls) > *toolCount {
|
|
*toolCount = len(choice.Delta.ToolCalls)
|
|
}
|
|
for _, tool := range choice.Delta.ToolCalls {
|
|
responseTextBuilder.WriteString(tool.Function.Name)
|
|
responseTextBuilder.WriteString(tool.Function.Arguments)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
|
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
|
|
|
switch relayMode {
|
|
case relayconstant.RelayModeChatCompletions:
|
|
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
|
|
case relayconstant.RelayModeCompletions:
|
|
return processCompletions(streamResp, streamItems, responseTextBuilder)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
|
var streamResponses []dto.ChatCompletionsStreamResponse
|
|
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
|
// 一次性解析失败,逐个解析
|
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
for _, item := range streamItems {
|
|
var streamResponse dto.ChatCompletionsStreamResponse
|
|
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
|
return err
|
|
}
|
|
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
|
|
common.SysError("error processing stream response: " + err.Error())
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// 批量处理所有响应
|
|
for _, streamResponse := range streamResponses {
|
|
for _, choice := range streamResponse.Choices {
|
|
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
|
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
|
if choice.Delta.ToolCalls != nil {
|
|
if len(choice.Delta.ToolCalls) > *toolCount {
|
|
*toolCount = len(choice.Delta.ToolCalls)
|
|
}
|
|
for _, tool := range choice.Delta.ToolCalls {
|
|
responseTextBuilder.WriteString(tool.Function.Name)
|
|
responseTextBuilder.WriteString(tool.Function.Arguments)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
|
|
var streamResponses []dto.CompletionsStreamResponse
|
|
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
|
// 一次性解析失败,逐个解析
|
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
for _, item := range streamItems {
|
|
var streamResponse dto.CompletionsStreamResponse
|
|
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
|
continue
|
|
}
|
|
for _, choice := range streamResponse.Choices {
|
|
responseTextBuilder.WriteString(choice.Text)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// 批量处理所有响应
|
|
for _, streamResponse := range streamResponses {
|
|
for _, choice := range streamResponse.Choices {
|
|
responseTextBuilder.WriteString(choice.Text)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
|
|
systemFingerprint *string, model *string, usage **dto.Usage,
|
|
containStreamUsage *bool, info *relaycommon.RelayInfo,
|
|
shouldSendLastResp *bool) error {
|
|
|
|
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
|
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
|
|
return err
|
|
}
|
|
|
|
*responseId = lastStreamResponse.Id
|
|
*createAt = lastStreamResponse.Created
|
|
*systemFingerprint = lastStreamResponse.GetSystemFingerprint()
|
|
*model = lastStreamResponse.Model
|
|
|
|
if service.ValidUsage(lastStreamResponse.Usage) {
|
|
*containStreamUsage = true
|
|
*usage = lastStreamResponse.Usage
|
|
if !info.ShouldIncludeUsage {
|
|
*shouldSendLastResp = false
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
|
|
responseId string, createAt int64, model string, systemFingerprint string,
|
|
usage *dto.Usage, containStreamUsage bool) {
|
|
|
|
switch info.RelayFormat {
|
|
case relaycommon.RelayFormatOpenAI:
|
|
if info.ShouldIncludeUsage && !containStreamUsage {
|
|
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
|
response.SetSystemFingerprint(systemFingerprint)
|
|
helper.ObjectData(c, response)
|
|
}
|
|
helper.Done(c)
|
|
|
|
case relaycommon.RelayFormatClaude:
|
|
info.ClaudeConvertInfo.Done = true
|
|
var streamResponse dto.ChatCompletionsStreamResponse
|
|
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
return
|
|
}
|
|
|
|
info.ClaudeConvertInfo.Usage = usage
|
|
|
|
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
|
for _, resp := range claudeResponses {
|
|
_ = helper.ClaudeData(c, *resp)
|
|
}
|
|
}
|
|
}
|
|
|
|
func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
|
|
if data == "" {
|
|
return
|
|
}
|
|
helper.ResponseChunkData(c, streamResponse, data)
|
|
}
|