package openai import ( "encoding/json" "one-api/common" "one-api/dto" "one-api/logger" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" "one-api/types" "strings" "github.com/samber/lo" "github.com/gin-gonic/gin" ) // 辅助函数 func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { info.SendResponseCount++ switch info.RelayFormat { case types.RelayFormatOpenAI: return sendStreamData(c, info, data, forceFormat, thinkToContent) case types.RelayFormatClaude: return handleClaudeFormat(c, data, info) case types.RelayFormatGemini: return handleGeminiFormat(c, data, info) } return nil } func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error { var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil { return err } if streamResponse.Usage != nil { info.ClaudeConvertInfo.Usage = streamResponse.Usage } claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info) for _, resp := range claudeResponses { helper.ClaudeData(c, *resp) } return nil } func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error { var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil { logger.LogError(c, "failed to unmarshal stream response: "+err.Error()) return err } geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info) // 如果返回 nil,表示没有实际内容,跳过发送 if geminiResponse == nil { return nil } geminiResponseStr, err := common.Marshal(geminiResponse) if err != nil { logger.LogError(c, "failed to marshal gemini response: "+err.Error()) return err } // send gemini format response c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)}) _ = helper.FlushWriter(c) return nil } func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error { for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) if choice.Delta.ToolCalls != nil { if len(choice.Delta.ToolCalls) > *toolCount { *toolCount = len(choice.Delta.ToolCalls) } for _, tool := range choice.Delta.ToolCalls { responseTextBuilder.WriteString(tool.Function.Name) responseTextBuilder.WriteString(tool.Function.Arguments) } } } return nil } func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error { streamResp := "[" + strings.Join(streamItems, ",") + "]" switch relayMode { case relayconstant.RelayModeChatCompletions: return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount) case relayconstant.RelayModeCompletions: return processCompletions(streamResp, streamItems, responseTextBuilder) } return nil } func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error { var streamResponses []dto.ChatCompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { // 一次性解析失败,逐个解析 common.SysLog("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { var streamResponse dto.ChatCompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { return err } if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil { common.SysLog("error processing stream response: " + err.Error()) } } return nil } // 批量处理所有响应 for _, streamResponse := range streamResponses { for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) if choice.Delta.ToolCalls != nil { if len(choice.Delta.ToolCalls) > *toolCount { *toolCount = len(choice.Delta.ToolCalls) } for _, tool := range choice.Delta.ToolCalls { responseTextBuilder.WriteString(tool.Function.Name) responseTextBuilder.WriteString(tool.Function.Arguments) } } } } return nil } func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error { var streamResponses []dto.CompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { // 一次性解析失败,逐个解析 common.SysLog("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { var streamResponse dto.CompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { continue } for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Text) } } return nil } // 批量处理所有响应 for _, streamResponse := range streamResponses { for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Text) } } return nil } func handleLastResponse(lastStreamData string, responseId *string, createAt *int64, systemFingerprint *string, model *string, usage **dto.Usage, containStreamUsage *bool, info *relaycommon.RelayInfo, shouldSendLastResp *bool) error { var lastStreamResponse dto.ChatCompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil { return err } *responseId = lastStreamResponse.Id *createAt = lastStreamResponse.Created *systemFingerprint = lastStreamResponse.GetSystemFingerprint() *model = lastStreamResponse.Model if service.ValidUsage(lastStreamResponse.Usage) { *containStreamUsage = true *usage = lastStreamResponse.Usage if !info.ShouldIncludeUsage { *shouldSendLastResp = lo.SomeBy(lastStreamResponse.Choices, func(choice dto.ChatCompletionsStreamResponseChoice) bool { return choice.Delta.GetContentString() != "" || choice.Delta.GetReasoningContent() != "" }) } } return nil } func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string, responseId string, createAt int64, model string, systemFingerprint string, usage *dto.Usage, containStreamUsage bool) { switch info.RelayFormat { case types.RelayFormatOpenAI: if info.ShouldIncludeUsage && !containStreamUsage { response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage) response.SetSystemFingerprint(systemFingerprint) helper.ObjectData(c, response) } helper.Done(c) case types.RelayFormatClaude: info.ClaudeConvertInfo.Done = true var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { common.SysLog("error unmarshalling stream response: " + err.Error()) return } info.ClaudeConvertInfo.Usage = usage claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info) for _, resp := range claudeResponses { _ = helper.ClaudeData(c, *resp) } case types.RelayFormatGemini: var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { common.SysLog("error unmarshalling stream response: " + err.Error()) return } // 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段 // 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空,finishReason 为 STOP 的响应 // 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null // 暂不知是否有程序会不兼容。 geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info) // openai 流响应开头的空数据 if geminiResponse == nil { return } geminiResponseStr, err := common.Marshal(geminiResponse) if err != nil { common.SysLog("error marshalling gemini response: " + err.Error()) return } // 发送最终的 Gemini 响应 c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)}) _ = helper.FlushWriter(c) } } func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) { if data == "" { return } helper.ResponseChunkData(c, streamResponse, data) }