feat: 初步兼容流模式下openai渠道类型转为claude格式访问 #862

This commit is contained in:
1808837298@qq.com
2025-03-13 19:32:08 +08:00
parent c25d4d8d23
commit 7e46d4217d
37 changed files with 390 additions and 165 deletions

View File

@@ -107,7 +107,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
adaptor.Init(info)
convertedRequest, err := adaptor.ConvertRequest(c, info, request)
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
if err != nil {
return err, nil
}

View File

@@ -13,7 +13,7 @@ type ClaudeMediaMessage struct {
Source *ClaudeMessageSource `json:"source,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
PartialJson string `json:"partial_json,omitempty"`
PartialJson *string `json:"partial_json,omitempty"`
Role string `json:"role,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
@@ -37,6 +37,32 @@ func (c *ClaudeMediaMessage) GetText() string {
return *c.Text
}
func (c *ClaudeMediaMessage) IsStringContent() bool {
var content string
return json.Unmarshal(c.Content, &content) == nil
}
func (c *ClaudeMediaMessage) GetStringContent() string {
var content string
if err := json.Unmarshal(c.Content, &content); err == nil {
return content
}
return ""
}
func (c *ClaudeMediaMessage) SetContent(content any) {
jsonContent, _ := json.Marshal(content)
c.Content = jsonContent
}
func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
var mediaContent []ClaudeMediaMessage
if err := json.Unmarshal(c.Content, &mediaContent); err == nil {
return mediaContent
}
return make([]ClaudeMediaMessage, 0)
}
type ClaudeMessageSource struct {
Type string `json:"type"`
MediaType string `json:"media_type"`

View File

@@ -13,7 +13,7 @@ type Adaptor interface {
Init(info *relaycommon.RelayInfo)
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)

View File

@@ -50,7 +50,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/gorilla/websocket"
"io"
"net/http"
common2 "one-api/common"
"one-api/relay/common"
"one-api/relay/constant"
"one-api/service"
@@ -31,6 +32,9 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
if common2.DebugEnabled {
println("fullRequestURL:", fullRequestURL)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)

View File

@@ -47,7 +47,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -110,7 +110,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -64,7 +64,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -335,7 +335,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
case "input_json_delta":
tools = append(tools, dto.ToolCallResponse{
Function: dto.FunctionResponse{
Arguments: claudeResponse.Delta.PartialJson,
Arguments: *claudeResponse.Delta.PartialJson,
},
})
case "signature_delta":

View File

@@ -43,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -48,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return requestOpenAI2Cohere(*request), nil
}

View File

@@ -50,7 +50,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -70,7 +70,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -95,7 +95,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -49,7 +49,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}

View File

@@ -43,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -57,7 +57,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -49,7 +49,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -21,6 +21,7 @@ import (
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/service"
"strings"
)
@@ -29,10 +30,20 @@ type Adaptor struct {
ResponseFormat string
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
if !strings.HasPrefix(request.Model, "claude") {
return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
}
aiRequest, err := service.ClaudeToOpenAIRequest(*request)
if err != nil {
return nil, err
}
if info.SupportStreamOptions {
aiRequest.StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
}
return a.ConvertOpenAIRequest(c, info, aiRequest)
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -40,6 +51,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == relaycommon.RelayFormatClaude {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
if info.RelayMode == constant.RelayModeRealtime {
if strings.HasPrefix(info.BaseUrl, "https://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
@@ -115,7 +129,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -0,0 +1,188 @@
package openai
import (
"encoding/json"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
// 辅助函数
func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
info.SendResponseCount++
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
return sendStreamData(c, info, data, forceFormat, thinkToContent)
case relaycommon.RelayFormatClaude:
return handleClaudeFormat(c, data, info)
}
return nil
}
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
return err
}
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
helper.ClaudeData(c, *resp)
}
return nil
}
func processStreamResponse(item string, responseTextBuilder *strings.Builder, toolCount *int) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > *toolCount {
*toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
return nil
}
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode {
case relayconstant.RelayModeChatCompletions:
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
case relayconstant.RelayModeCompletions:
return processCompletions(streamResp, streamItems, responseTextBuilder)
}
return nil
}
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
var streamResponses []dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
if err := processStreamResponse(item, responseTextBuilder, toolCount); err != nil {
common.SysError("error processing stream response: " + err.Error())
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > *toolCount {
*toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
return nil
}
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
var streamResponses []dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
continue
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
systemFingerprint *string, model *string, usage **dto.Usage,
containStreamUsage *bool, info *relaycommon.RelayInfo,
shouldSendLastResp *bool) error {
var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
return err
}
*responseId = lastStreamResponse.Id
*createAt = lastStreamResponse.Created
*systemFingerprint = lastStreamResponse.GetSystemFingerprint()
*model = lastStreamResponse.Model
if service.ValidUsage(lastStreamResponse.Usage) {
*containStreamUsage = true
*usage = lastStreamResponse.Usage
if !info.ShouldIncludeUsage {
*shouldSendLastResp = false
}
}
return nil
}
func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
responseId string, createAt int64, model string, systemFingerprint string,
usage *dto.Usage, containStreamUsage bool) {
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
if info.ShouldIncludeUsage && !containStreamUsage {
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
helper.ObjectData(c, response)
}
helper.Done(c)
case relaycommon.RelayFormatClaude:
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
if !containStreamUsage {
streamResponse.Usage = usage
}
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
helper.ClaudeData(c, *resp)
}
}
}

View File

@@ -12,7 +12,6 @@ import (
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"os"
@@ -137,10 +136,11 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
if lastStreamData != "" {
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
if err != nil {
common.LogError(c, "streaming error: "+err.Error())
common.SysError("error handling stream format: " + err.Error())
}
info.SetFirstResponseTime()
}
lastStreamData = data
streamItems = append(streamItems, data)
@@ -172,83 +172,9 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
}
// 计算token
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
//if service.ValidUsage(streamResponse.Usage) {
// usage = streamResponse.Usage
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
// handle both reasoning_content and reasoning
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
}
} else {
for _, streamResponse := range streamResponses {
//if service.ValidUsage(streamResponse.Usage) {
// usage = streamResponse.Usage
// containStreamUsage = true
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) // This will handle both reasoning_content and reasoning
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
}
case relayconstant.RelayModeCompletions:
var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
} else {
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
// 处理token计算
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
common.SysError("error processing tokens: " + err.Error())
}
if !containStreamUsage {
@@ -262,15 +188,8 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
}
if info.ShouldIncludeUsage && !containStreamUsage {
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
helper.ObjectData(c, response)
}
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
helper.Done(c)
//resp.Body.Close()
return nil, usage
}

View File

@@ -46,7 +46,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}

View File

@@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -54,7 +54,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}

View File

@@ -58,7 +58,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -122,7 +122,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -56,7 +56,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -48,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -45,7 +45,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -114,13 +114,14 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if common.DebugEnabled {
println("requestBody: ", string(jsonData))
}
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
//log.Printf("requestBody: %s", requestBody)
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)

View File

@@ -17,6 +17,16 @@ type ThinkingContentInfo struct {
SendLastThinkingContent bool
}
const (
LastMessageTypeText = "text"
LastMessageTypeTools = "tools"
)
type ClaudeConvertInfo struct {
LastMessagesType string
Index int
}
const (
RelayFormatOpenAI = "openai"
RelayFormatClaude = "claude"
@@ -64,8 +74,9 @@ type RelayInfo struct {
UserEmail string
UserQuota int
RelayFormat string
ResponseTimes int64
SendResponseCount int
ThinkingContentInfo
ClaudeConvertInfo
}
// 定义支持流式选项的通道类型
@@ -93,6 +104,9 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatClaude
info.ShouldIncludeUsage = false
info.ClaudeConvertInfo = ClaudeConvertInfo{
LastMessagesType: LastMessageTypeText,
}
return info
}
@@ -172,7 +186,6 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
}
func (info *RelayInfo) SetFirstResponseTime() {
info.ResponseTimes++
if info.isFirstResponse {
info.FirstResponseTime = time.Now()
info.isFirstResponse = false

View File

@@ -19,6 +19,22 @@ func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("X-Accel-Buffering", "no")
}
func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
jsonData, err := json.Marshal(resp)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
} else {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
}
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
return nil
}
func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})

View File

@@ -160,7 +160,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
@@ -168,6 +168,9 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
if common.DebugEnabled {
println("requestBody: ", string(jsonData))
}
requestBody = bytes.NewBuffer(jsonData)
}

View File

@@ -44,24 +44,26 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
openAIMessages := make([]dto.Message, 0)
// Add system message if present
if claudeRequest.IsStringSystem() {
openAIMessage := dto.Message{
Role: "system",
}
openAIMessage.SetStringContent(claudeRequest.GetStringSystem())
openAIMessages = append(openAIMessages, openAIMessage)
} else {
systems := claudeRequest.ParseSystem()
if len(systems) > 0 {
systemStr := ""
if claudeRequest.System != nil {
if claudeRequest.IsStringSystem() {
openAIMessage := dto.Message{
Role: "system",
}
for _, system := range systems {
systemStr += system.Type
}
openAIMessage.SetStringContent(systemStr)
openAIMessage.SetStringContent(claudeRequest.GetStringSystem())
openAIMessages = append(openAIMessages, openAIMessage)
} else {
systems := claudeRequest.ParseSystem()
if len(systems) > 0 {
systemStr := ""
openAIMessage := dto.Message{
Role: "system",
}
for _, system := range systems {
systemStr += system.Type
}
openAIMessage.SetStringContent(systemStr)
openAIMessages = append(openAIMessages, openAIMessage)
}
}
}
for _, claudeMessage := range claudeRequest.Messages {
@@ -100,7 +102,8 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
mediaMessages = append(mediaMessages, mediaMessage)
case "tool_use":
toolCall := dto.ToolCallRequest{
ID: mediaMsg.Id,
ID: mediaMsg.Id,
Type: "function",
Function: dto.FunctionRequest{
Name: mediaMsg.Name,
Arguments: toJSONString(mediaMsg.Input),
@@ -111,20 +114,33 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
// Add tool result as a separate message
oaiToolMessage := dto.Message{
Role: "tool",
Name: &mediaMsg.Name,
ToolCallId: mediaMsg.ToolUseId,
}
oaiToolMessage.Content = mediaMsg.Content
//oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text)
if mediaMsg.IsStringContent() {
oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
} else {
mediaContents := mediaMsg.ParseMediaContent()
if len(mediaContents) > 0 && mediaContents[0].Text != nil {
oaiToolMessage.SetStringContent(*mediaContents[0].Text)
}
}
openAIMessages = append(openAIMessages, oaiToolMessage)
}
}
openAIMessage.SetMediaContent(mediaMessages)
if len(mediaMessages) > 0 {
openAIMessage.SetMediaContent(mediaMessages)
}
if len(toolCalls) > 0 {
openAIMessage.SetToolCalls(toolCalls)
}
}
openAIMessages = append(openAIMessages, openAIMessage)
if len(openAIMessage.ParseContent()) > 0 {
openAIMessages = append(openAIMessages, openAIMessage)
}
}
openAIRequest.Messages = openAIMessages
@@ -154,22 +170,35 @@ func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.O
}
}
func generateStopBlock(index int) *dto.ClaudeResponse {
return &dto.ClaudeResponse{
Type: "content_block_stop",
Index: common.GetPointer[int](index),
}
}
func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
var claudeResponses []*dto.ClaudeResponse
if info.ResponseTimes == 1 {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "message_start",
Message: &dto.ClaudeMediaMessage{
Id: openAIResponse.Id,
Model: openAIResponse.Model,
Type: "message",
Role: "assistant",
Usage: &dto.ClaudeUsage{
InputTokens: info.PromptTokens,
OutputTokens: 0,
},
if info.SendResponseCount == 1 {
msg := &dto.ClaudeMediaMessage{
Id: openAIResponse.Id,
Model: openAIResponse.Model,
Type: "message",
Role: "assistant",
Usage: &dto.ClaudeUsage{
InputTokens: info.PromptTokens,
OutputTokens: 0,
},
}
msg.SetContent(make([]any, 0))
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "message_start",
Message: msg,
})
claudeResponses = append(claudeResponses)
//claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
// Type: "ping",
//})
if openAIResponse.IsToolCall() {
resp := &dto.ClaudeResponse{
Type: "content_block_start",
@@ -192,23 +221,18 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
resp.SetIndex(0)
claudeResponses = append(claudeResponses, resp)
}
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "ping",
})
return claudeResponses
}
if len(openAIResponse.Choices) == 0 {
// no choices
// TODO: handle this case
return claudeResponses
} else {
chosenChoice := openAIResponse.Choices[0]
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
// should be done
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "content_block_stop",
Index: common.GetPointer[int](0),
})
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
if openAIResponse.Usage != nil {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "message_delta",
@@ -229,18 +253,35 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
claudeResponse.SetIndex(0)
claudeResponse.Type = "content_block_delta"
if len(chosenChoice.Delta.ToolCalls) > 0 {
if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeText {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
info.ClaudeConvertInfo.Index++
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Id: openAIResponse.GetFirstToolCall().ID,
Type: "tool_use",
Name: openAIResponse.GetFirstToolCall().Function.Name,
Input: map[string]interface{}{},
},
})
}
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
// tools delta
claudeResponse.Delta = &dto.ClaudeMediaMessage{
Type: "input_json_delta",
PartialJson: chosenChoice.Delta.ToolCalls[0].Function.Arguments,
PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
}
} else {
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
// text delta
claudeResponse.Delta = &dto.ClaudeMediaMessage{
Type: "text_delta",
Text: common.GetPointer[string](chosenChoice.Delta.GetContentString()),
}
}
claudeResponse.Index = &info.ClaudeConvertInfo.Index
claudeResponses = append(claudeResponses, &claudeResponse)
}
}