feat: 初步兼容流模式下openai渠道类型转为claude格式访问 #862
This commit is contained in:
188
relay/channel/openai/helper.go
Normal file
188
relay/channel/openai/helper.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 辅助函数
|
||||
func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
info.SendResponseCount++
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
||||
case relaycommon.RelayFormatClaude:
|
||||
return handleClaudeFormat(c, data, info)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
||||
for _, resp := range claudeResponses {
|
||||
helper.ClaudeData(c, *resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processStreamResponse(item string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > *toolCount {
|
||||
*toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
return processCompletions(streamResp, streamItems, responseTextBuilder)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
var streamResponses []dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
if err := processStreamResponse(item, responseTextBuilder, toolCount); err != nil {
|
||||
common.SysError("error processing stream response: " + err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 批量处理所有响应
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > *toolCount {
|
||||
*toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
|
||||
var streamResponses []dto.CompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||
continue
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 批量处理所有响应
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
|
||||
systemFingerprint *string, model *string, usage **dto.Usage,
|
||||
containStreamUsage *bool, info *relaycommon.RelayInfo,
|
||||
shouldSendLastResp *bool) error {
|
||||
|
||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*responseId = lastStreamResponse.Id
|
||||
*createAt = lastStreamResponse.Created
|
||||
*systemFingerprint = lastStreamResponse.GetSystemFingerprint()
|
||||
*model = lastStreamResponse.Model
|
||||
|
||||
if service.ValidUsage(lastStreamResponse.Usage) {
|
||||
*containStreamUsage = true
|
||||
*usage = lastStreamResponse.Usage
|
||||
if !info.ShouldIncludeUsage {
|
||||
*shouldSendLastResp = false
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
|
||||
responseId string, createAt int64, model string, systemFingerprint string,
|
||||
usage *dto.Usage, containStreamUsage bool) {
|
||||
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
if info.ShouldIncludeUsage && !containStreamUsage {
|
||||
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
||||
response.SetSystemFingerprint(systemFingerprint)
|
||||
helper.ObjectData(c, response)
|
||||
}
|
||||
helper.Done(c)
|
||||
|
||||
case relaycommon.RelayFormatClaude:
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
streamResponse.Usage = usage
|
||||
}
|
||||
|
||||
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
||||
for _, resp := range claudeResponses {
|
||||
helper.ClaudeData(c, *resp)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user