Files
new-api/relay/channel/openai/helper.go
2025-08-01 22:23:35 +08:00

273 lines
9.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package openai
import (
"encoding/json"
"errors"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
// 辅助函数
func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
info.SendResponseCount++
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
return sendStreamData(c, info, data, forceFormat, thinkToContent)
case relaycommon.RelayFormatClaude:
return handleClaudeFormat(c, data, info)
case relaycommon.RelayFormatGemini:
return handleGeminiFormat(c, data, info)
}
return nil
}
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
return err
}
if streamResponse.Usage != nil {
info.ClaudeConvertInfo.Usage = streamResponse.Usage
}
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
helper.ClaudeData(c, *resp)
}
return nil
}
func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
// 截取前50个字符用于调试
debugData := data
if len(data) > 50 {
debugData = data[:50] + "..."
}
common.LogInfo(c, "handleGeminiFormat called with data: "+debugData)
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
common.LogError(c, "failed to unmarshal stream response: "+err.Error())
return err
}
common.LogInfo(c, "successfully unmarshaled stream response")
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
// 如果返回 nil表示没有实际内容跳过发送
if geminiResponse == nil {
common.LogInfo(c, "handleGeminiFormat: no content to send, skipping")
return nil
}
geminiResponseStr, err := common.Marshal(geminiResponse)
if err != nil {
common.LogError(c, "failed to marshal gemini response: "+err.Error())
return err
}
common.LogInfo(c, "sending gemini format response")
// send gemini format response
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
return nil
}
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > *toolCount {
*toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
return nil
}
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode {
case relayconstant.RelayModeChatCompletions:
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
case relayconstant.RelayModeCompletions:
return processCompletions(streamResp, streamItems, responseTextBuilder)
}
return nil
}
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
var streamResponses []dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
common.SysError("error processing stream response: " + err.Error())
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > *toolCount {
*toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
return nil
}
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
var streamResponses []dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
continue
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
systemFingerprint *string, model *string, usage **dto.Usage,
containStreamUsage *bool, info *relaycommon.RelayInfo,
shouldSendLastResp *bool) error {
var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
return err
}
*responseId = lastStreamResponse.Id
*createAt = lastStreamResponse.Created
*systemFingerprint = lastStreamResponse.GetSystemFingerprint()
*model = lastStreamResponse.Model
if service.ValidUsage(lastStreamResponse.Usage) {
*containStreamUsage = true
*usage = lastStreamResponse.Usage
if !info.ShouldIncludeUsage {
*shouldSendLastResp = false
}
}
return nil
}
func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
responseId string, createAt int64, model string, systemFingerprint string,
usage *dto.Usage, containStreamUsage bool) {
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
if info.ShouldIncludeUsage && !containStreamUsage {
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
helper.ObjectData(c, response)
}
helper.Done(c)
case relaycommon.RelayFormatClaude:
info.ClaudeConvertInfo.Done = true
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
info.ClaudeConvertInfo.Usage = usage
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
_ = helper.ClaudeData(c, *resp)
}
case relaycommon.RelayFormatGemini:
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
// 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段
// 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空finishReason 为 STOP 的响应
// 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null
// 暂不知是否有程序会不兼容。
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
// openai 流响应开头的空数据
if geminiResponse == nil {
return
}
geminiResponseStr, err := common.Marshal(geminiResponse)
if err != nil {
common.SysError("error marshalling gemini response: " + err.Error())
return
}
// 发送最终的 Gemini 响应
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
}
}
func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
if data == "" {
return
}
helper.ResponseChunkData(c, streamResponse, data)
}