261 lines
8.6 KiB
Go
261 lines
8.6 KiB
Go
package openai
|
||
|
||
import (
|
||
"encoding/json"
|
||
"one-api/common"
|
||
"one-api/dto"
|
||
"one-api/logger"
|
||
relaycommon "one-api/relay/common"
|
||
relayconstant "one-api/relay/constant"
|
||
"one-api/relay/helper"
|
||
"one-api/service"
|
||
"one-api/types"
|
||
"strings"
|
||
|
||
"github.com/samber/lo"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// 辅助函数
|
||
func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||
info.SendResponseCount++
|
||
|
||
switch info.RelayFormat {
|
||
case types.RelayFormatOpenAI:
|
||
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
||
case types.RelayFormatClaude:
|
||
return handleClaudeFormat(c, data, info)
|
||
case types.RelayFormatGemini:
|
||
return handleGeminiFormat(c, data, info)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||
var streamResponse dto.ChatCompletionsStreamResponse
|
||
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||
return err
|
||
}
|
||
|
||
if streamResponse.Usage != nil {
|
||
info.ClaudeConvertInfo.Usage = streamResponse.Usage
|
||
}
|
||
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
||
for _, resp := range claudeResponses {
|
||
helper.ClaudeData(c, *resp)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||
var streamResponse dto.ChatCompletionsStreamResponse
|
||
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||
logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||
return err
|
||
}
|
||
|
||
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||
|
||
// 如果返回 nil,表示没有实际内容,跳过发送
|
||
if geminiResponse == nil {
|
||
return nil
|
||
}
|
||
|
||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||
if err != nil {
|
||
logger.LogError(c, "failed to marshal gemini response: "+err.Error())
|
||
return err
|
||
}
|
||
|
||
// send gemini format response
|
||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||
_ = helper.FlushWriter(c)
|
||
return nil
|
||
}
|
||
|
||
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||
for _, choice := range streamResponse.Choices {
|
||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||
if choice.Delta.ToolCalls != nil {
|
||
if len(choice.Delta.ToolCalls) > *toolCount {
|
||
*toolCount = len(choice.Delta.ToolCalls)
|
||
}
|
||
for _, tool := range choice.Delta.ToolCalls {
|
||
responseTextBuilder.WriteString(tool.Function.Name)
|
||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||
|
||
switch relayMode {
|
||
case relayconstant.RelayModeChatCompletions:
|
||
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
|
||
case relayconstant.RelayModeCompletions:
|
||
return processCompletions(streamResp, streamItems, responseTextBuilder)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||
var streamResponses []dto.ChatCompletionsStreamResponse
|
||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||
// 一次性解析失败,逐个解析
|
||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||
for _, item := range streamItems {
|
||
var streamResponse dto.ChatCompletionsStreamResponse
|
||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||
return err
|
||
}
|
||
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
|
||
common.SysLog("error processing stream response: " + err.Error())
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 批量处理所有响应
|
||
for _, streamResponse := range streamResponses {
|
||
for _, choice := range streamResponse.Choices {
|
||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||
if choice.Delta.ToolCalls != nil {
|
||
if len(choice.Delta.ToolCalls) > *toolCount {
|
||
*toolCount = len(choice.Delta.ToolCalls)
|
||
}
|
||
for _, tool := range choice.Delta.ToolCalls {
|
||
responseTextBuilder.WriteString(tool.Function.Name)
|
||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
|
||
var streamResponses []dto.CompletionsStreamResponse
|
||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||
// 一次性解析失败,逐个解析
|
||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||
for _, item := range streamItems {
|
||
var streamResponse dto.CompletionsStreamResponse
|
||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||
continue
|
||
}
|
||
for _, choice := range streamResponse.Choices {
|
||
responseTextBuilder.WriteString(choice.Text)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 批量处理所有响应
|
||
for _, streamResponse := range streamResponses {
|
||
for _, choice := range streamResponse.Choices {
|
||
responseTextBuilder.WriteString(choice.Text)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
|
||
systemFingerprint *string, model *string, usage **dto.Usage,
|
||
containStreamUsage *bool, info *relaycommon.RelayInfo,
|
||
shouldSendLastResp *bool) error {
|
||
|
||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
|
||
return err
|
||
}
|
||
|
||
*responseId = lastStreamResponse.Id
|
||
*createAt = lastStreamResponse.Created
|
||
*systemFingerprint = lastStreamResponse.GetSystemFingerprint()
|
||
*model = lastStreamResponse.Model
|
||
|
||
if service.ValidUsage(lastStreamResponse.Usage) {
|
||
*containStreamUsage = true
|
||
*usage = lastStreamResponse.Usage
|
||
if !info.ShouldIncludeUsage {
|
||
*shouldSendLastResp = lo.SomeBy(lastStreamResponse.Choices, func(choice dto.ChatCompletionsStreamResponseChoice) bool {
|
||
return choice.Delta.GetContentString() != "" || choice.Delta.GetReasoningContent() != ""
|
||
})
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
|
||
responseId string, createAt int64, model string, systemFingerprint string,
|
||
usage *dto.Usage, containStreamUsage bool) {
|
||
|
||
switch info.RelayFormat {
|
||
case types.RelayFormatOpenAI:
|
||
if info.ShouldIncludeUsage && !containStreamUsage {
|
||
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
||
response.SetSystemFingerprint(systemFingerprint)
|
||
helper.ObjectData(c, response)
|
||
}
|
||
helper.Done(c)
|
||
|
||
case types.RelayFormatClaude:
|
||
info.ClaudeConvertInfo.Done = true
|
||
var streamResponse dto.ChatCompletionsStreamResponse
|
||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||
return
|
||
}
|
||
|
||
info.ClaudeConvertInfo.Usage = usage
|
||
|
||
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
||
for _, resp := range claudeResponses {
|
||
_ = helper.ClaudeData(c, *resp)
|
||
}
|
||
|
||
case types.RelayFormatGemini:
|
||
var streamResponse dto.ChatCompletionsStreamResponse
|
||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||
return
|
||
}
|
||
|
||
// 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段
|
||
// 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空,finishReason 为 STOP 的响应
|
||
// 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null
|
||
// 暂不知是否有程序会不兼容。
|
||
|
||
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||
|
||
// openai 流响应开头的空数据
|
||
if geminiResponse == nil {
|
||
return
|
||
}
|
||
|
||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||
if err != nil {
|
||
common.SysLog("error marshalling gemini response: " + err.Error())
|
||
return
|
||
}
|
||
|
||
// 发送最终的 Gemini 响应
|
||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||
_ = helper.FlushWriter(c)
|
||
}
|
||
}
|
||
|
||
func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
|
||
if data == "" {
|
||
return
|
||
}
|
||
helper.ResponseChunkData(c, streamResponse, data)
|
||
}
|