feat: 初步兼容流模式下openai渠道类型转为claude格式访问 #862

This commit is contained in:
1808837298@qq.com
2025-03-13 19:32:08 +08:00
parent c25d4d8d23
commit 7e46d4217d
37 changed files with 390 additions and 165 deletions

View File

@@ -107,7 +107,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
adaptor.Init(info) adaptor.Init(info)
convertedRequest, err := adaptor.ConvertRequest(c, info, request) convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
if err != nil { if err != nil {
return err, nil return err, nil
} }

View File

@@ -13,7 +13,7 @@ type ClaudeMediaMessage struct {
Source *ClaudeMessageSource `json:"source,omitempty"` Source *ClaudeMessageSource `json:"source,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"` Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"` StopReason *string `json:"stop_reason,omitempty"`
PartialJson string `json:"partial_json,omitempty"` PartialJson *string `json:"partial_json,omitempty"`
Role string `json:"role,omitempty"` Role string `json:"role,omitempty"`
Thinking string `json:"thinking,omitempty"` Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"` Signature string `json:"signature,omitempty"`
@@ -37,6 +37,32 @@ func (c *ClaudeMediaMessage) GetText() string {
return *c.Text return *c.Text
} }
func (c *ClaudeMediaMessage) IsStringContent() bool {
var content string
return json.Unmarshal(c.Content, &content) == nil
}
func (c *ClaudeMediaMessage) GetStringContent() string {
var content string
if err := json.Unmarshal(c.Content, &content); err == nil {
return content
}
return ""
}
func (c *ClaudeMediaMessage) SetContent(content any) {
jsonContent, _ := json.Marshal(content)
c.Content = jsonContent
}
func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
var mediaContent []ClaudeMediaMessage
if err := json.Unmarshal(c.Content, &mediaContent); err == nil {
return mediaContent
}
return make([]ClaudeMediaMessage, 0)
}
type ClaudeMessageSource struct { type ClaudeMessageSource struct {
Type string `json:"type"` Type string `json:"type"`
MediaType string `json:"media_type"` MediaType string `json:"media_type"`

View File

@@ -13,7 +13,7 @@ type Adaptor interface {
Init(info *relaycommon.RelayInfo) Init(info *relaycommon.RelayInfo)
GetRequestURL(info *relaycommon.RelayInfo) (string, error) GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)

View File

@@ -50,7 +50,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -7,6 +7,7 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"io" "io"
"net/http" "net/http"
common2 "one-api/common"
"one-api/relay/common" "one-api/relay/common"
"one-api/relay/constant" "one-api/relay/constant"
"one-api/service" "one-api/service"
@@ -31,6 +32,9 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
if err != nil { if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err) return nil, fmt.Errorf("get request url failed: %w", err)
} }
if common2.DebugEnabled {
println("fullRequestURL:", fullRequestURL)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil { if err != nil {
return nil, fmt.Errorf("new request failed: %w", err) return nil, fmt.Errorf("new request failed: %w", err)

View File

@@ -47,7 +47,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -110,7 +110,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -64,7 +64,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -335,7 +335,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
case "input_json_delta": case "input_json_delta":
tools = append(tools, dto.ToolCallResponse{ tools = append(tools, dto.ToolCallResponse{
Function: dto.FunctionResponse{ Function: dto.FunctionResponse{
Arguments: claudeResponse.Delta.PartialJson, Arguments: *claudeResponse.Delta.PartialJson,
}, },
}) })
case "signature_delta": case "signature_delta":

View File

@@ -43,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -48,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return requestOpenAI2Cohere(*request), nil return requestOpenAI2Cohere(*request), nil
} }

View File

@@ -50,7 +50,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -70,7 +70,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -95,7 +95,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -49,7 +49,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil return request, nil
} }

View File

@@ -43,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -57,7 +57,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -49,7 +49,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -21,6 +21,7 @@ import (
"one-api/relay/channel/xinference" "one-api/relay/channel/xinference"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/constant" "one-api/relay/constant"
"one-api/service"
"strings" "strings"
) )
@@ -29,10 +30,20 @@ type Adaptor struct {
ResponseFormat string ResponseFormat string
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
//TODO implement me if !strings.HasPrefix(request.Model, "claude") {
panic("implement me") return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
return nil, nil }
aiRequest, err := service.ClaudeToOpenAIRequest(*request)
if err != nil {
return nil, err
}
if info.SupportStreamOptions {
aiRequest.StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
}
return a.ConvertOpenAIRequest(c, info, aiRequest)
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -40,6 +51,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == relaycommon.RelayFormatClaude {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
if info.RelayMode == constant.RelayModeRealtime { if info.RelayMode == constant.RelayModeRealtime {
if strings.HasPrefix(info.BaseUrl, "https://") { if strings.HasPrefix(info.BaseUrl, "https://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
@@ -115,7 +129,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -0,0 +1,188 @@
package openai
import (
"encoding/json"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
// 辅助函数
func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
info.SendResponseCount++
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
return sendStreamData(c, info, data, forceFormat, thinkToContent)
case relaycommon.RelayFormatClaude:
return handleClaudeFormat(c, data, info)
}
return nil
}
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
return err
}
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
helper.ClaudeData(c, *resp)
}
return nil
}
func processStreamResponse(item string, responseTextBuilder *strings.Builder, toolCount *int) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > *toolCount {
*toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
return nil
}
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode {
case relayconstant.RelayModeChatCompletions:
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
case relayconstant.RelayModeCompletions:
return processCompletions(streamResp, streamItems, responseTextBuilder)
}
return nil
}
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
var streamResponses []dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
if err := processStreamResponse(item, responseTextBuilder, toolCount); err != nil {
common.SysError("error processing stream response: " + err.Error())
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > *toolCount {
*toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
return nil
}
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
var streamResponses []dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
continue
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
systemFingerprint *string, model *string, usage **dto.Usage,
containStreamUsage *bool, info *relaycommon.RelayInfo,
shouldSendLastResp *bool) error {
var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
return err
}
*responseId = lastStreamResponse.Id
*createAt = lastStreamResponse.Created
*systemFingerprint = lastStreamResponse.GetSystemFingerprint()
*model = lastStreamResponse.Model
if service.ValidUsage(lastStreamResponse.Usage) {
*containStreamUsage = true
*usage = lastStreamResponse.Usage
if !info.ShouldIncludeUsage {
*shouldSendLastResp = false
}
}
return nil
}
func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
responseId string, createAt int64, model string, systemFingerprint string,
usage *dto.Usage, containStreamUsage bool) {
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
if info.ShouldIncludeUsage && !containStreamUsage {
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
helper.ObjectData(c, response)
}
helper.Done(c)
case relaycommon.RelayFormatClaude:
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
if !containStreamUsage {
streamResponse.Usage = usage
}
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
helper.ClaudeData(c, *resp)
}
}
}

View File

@@ -12,7 +12,6 @@ import (
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
"os" "os"
@@ -137,10 +136,11 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
helper.StreamScannerHandler(c, resp, info, func(data string) bool { helper.StreamScannerHandler(c, resp, info, func(data string) bool {
if lastStreamData != "" { if lastStreamData != "" {
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
if err != nil { if err != nil {
common.LogError(c, "streaming error: "+err.Error()) common.SysError("error handling stream format: " + err.Error())
} }
info.SetFirstResponseTime()
} }
lastStreamData = data lastStreamData = data
streamItems = append(streamItems, data) streamItems = append(streamItems, data)
@@ -172,83 +172,9 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
} }
// 计算token // 处理token计算
streamResp := "[" + strings.Join(streamItems, ",") + "]" if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
switch info.RelayMode { common.SysError("error processing tokens: " + err.Error())
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
//if service.ValidUsage(streamResponse.Usage) {
// usage = streamResponse.Usage
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
// handle both reasoning_content and reasoning
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
}
} else {
for _, streamResponse := range streamResponses {
//if service.ValidUsage(streamResponse.Usage) {
// usage = streamResponse.Usage
// containStreamUsage = true
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) // This will handle both reasoning_content and reasoning
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
}
case relayconstant.RelayModeCompletions:
var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
} else {
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
} }
if !containStreamUsage { if !containStreamUsage {
@@ -262,15 +188,8 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
} }
} }
if info.ShouldIncludeUsage && !containStreamUsage { handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
helper.ObjectData(c, response)
}
helper.Done(c)
//resp.Body.Close()
return nil, usage return nil, usage
} }

View File

@@ -46,7 +46,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil return request, nil
} }

View File

@@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -54,7 +54,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil return request, nil
} }

View File

@@ -58,7 +58,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -122,7 +122,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -56,7 +56,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -48,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -45,7 +45,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -114,13 +114,14 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
} }
jsonData, err := json.Marshal(convertedRequest) jsonData, err := json.Marshal(convertedRequest)
if common.DebugEnabled {
println("requestBody: ", string(jsonData))
}
if err != nil { if err != nil {
return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
} }
requestBody = bytes.NewBuffer(jsonData) requestBody = bytes.NewBuffer(jsonData)
//log.Printf("requestBody: %s", requestBody)
statusCodeMappingStr := c.GetString("status_code_mapping") statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody) resp, err := adaptor.DoRequest(c, relayInfo, requestBody)

View File

@@ -17,6 +17,16 @@ type ThinkingContentInfo struct {
SendLastThinkingContent bool SendLastThinkingContent bool
} }
const (
LastMessageTypeText = "text"
LastMessageTypeTools = "tools"
)
type ClaudeConvertInfo struct {
LastMessagesType string
Index int
}
const ( const (
RelayFormatOpenAI = "openai" RelayFormatOpenAI = "openai"
RelayFormatClaude = "claude" RelayFormatClaude = "claude"
@@ -64,8 +74,9 @@ type RelayInfo struct {
UserEmail string UserEmail string
UserQuota int UserQuota int
RelayFormat string RelayFormat string
ResponseTimes int64 SendResponseCount int
ThinkingContentInfo ThinkingContentInfo
ClaudeConvertInfo
} }
// 定义支持流式选项的通道类型 // 定义支持流式选项的通道类型
@@ -93,6 +104,9 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c) info := GenRelayInfo(c)
info.RelayFormat = RelayFormatClaude info.RelayFormat = RelayFormatClaude
info.ShouldIncludeUsage = false info.ShouldIncludeUsage = false
info.ClaudeConvertInfo = ClaudeConvertInfo{
LastMessagesType: LastMessageTypeText,
}
return info return info
} }
@@ -172,7 +186,6 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
} }
func (info *RelayInfo) SetFirstResponseTime() { func (info *RelayInfo) SetFirstResponseTime() {
info.ResponseTimes++
if info.isFirstResponse { if info.isFirstResponse {
info.FirstResponseTime = time.Now() info.FirstResponseTime = time.Now()
info.isFirstResponse = false info.isFirstResponse = false

View File

@@ -19,6 +19,22 @@ func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("X-Accel-Buffering", "no") c.Writer.Header().Set("X-Accel-Buffering", "no")
} }
func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
jsonData, err := json.Marshal(resp)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
} else {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
}
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
return nil
}
func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) { func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)}) c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})

View File

@@ -160,7 +160,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
} }
requestBody = bytes.NewBuffer(body) requestBody = bytes.NewBuffer(body)
} else { } else {
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest) convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
} }
@@ -168,6 +168,9 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
} }
if common.DebugEnabled {
println("requestBody: ", string(jsonData))
}
requestBody = bytes.NewBuffer(jsonData) requestBody = bytes.NewBuffer(jsonData)
} }

View File

@@ -44,24 +44,26 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
openAIMessages := make([]dto.Message, 0) openAIMessages := make([]dto.Message, 0)
// Add system message if present // Add system message if present
if claudeRequest.IsStringSystem() { if claudeRequest.System != nil {
openAIMessage := dto.Message{ if claudeRequest.IsStringSystem() {
Role: "system",
}
openAIMessage.SetStringContent(claudeRequest.GetStringSystem())
openAIMessages = append(openAIMessages, openAIMessage)
} else {
systems := claudeRequest.ParseSystem()
if len(systems) > 0 {
systemStr := ""
openAIMessage := dto.Message{ openAIMessage := dto.Message{
Role: "system", Role: "system",
} }
for _, system := range systems { openAIMessage.SetStringContent(claudeRequest.GetStringSystem())
systemStr += system.Type
}
openAIMessage.SetStringContent(systemStr)
openAIMessages = append(openAIMessages, openAIMessage) openAIMessages = append(openAIMessages, openAIMessage)
} else {
systems := claudeRequest.ParseSystem()
if len(systems) > 0 {
systemStr := ""
openAIMessage := dto.Message{
Role: "system",
}
for _, system := range systems {
systemStr += system.Type
}
openAIMessage.SetStringContent(systemStr)
openAIMessages = append(openAIMessages, openAIMessage)
}
} }
} }
for _, claudeMessage := range claudeRequest.Messages { for _, claudeMessage := range claudeRequest.Messages {
@@ -100,7 +102,8 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
mediaMessages = append(mediaMessages, mediaMessage) mediaMessages = append(mediaMessages, mediaMessage)
case "tool_use": case "tool_use":
toolCall := dto.ToolCallRequest{ toolCall := dto.ToolCallRequest{
ID: mediaMsg.Id, ID: mediaMsg.Id,
Type: "function",
Function: dto.FunctionRequest{ Function: dto.FunctionRequest{
Name: mediaMsg.Name, Name: mediaMsg.Name,
Arguments: toJSONString(mediaMsg.Input), Arguments: toJSONString(mediaMsg.Input),
@@ -111,20 +114,33 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
// Add tool result as a separate message // Add tool result as a separate message
oaiToolMessage := dto.Message{ oaiToolMessage := dto.Message{
Role: "tool", Role: "tool",
Name: &mediaMsg.Name,
ToolCallId: mediaMsg.ToolUseId, ToolCallId: mediaMsg.ToolUseId,
} }
oaiToolMessage.Content = mediaMsg.Content //oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text)
if mediaMsg.IsStringContent() {
oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
} else {
mediaContents := mediaMsg.ParseMediaContent()
if len(mediaContents) > 0 && mediaContents[0].Text != nil {
oaiToolMessage.SetStringContent(*mediaContents[0].Text)
}
}
openAIMessages = append(openAIMessages, oaiToolMessage)
} }
} }
openAIMessage.SetMediaContent(mediaMessages) if len(mediaMessages) > 0 {
openAIMessage.SetMediaContent(mediaMessages)
}
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
openAIMessage.SetToolCalls(toolCalls) openAIMessage.SetToolCalls(toolCalls)
} }
} }
if len(openAIMessage.ParseContent()) > 0 {
openAIMessages = append(openAIMessages, openAIMessage) openAIMessages = append(openAIMessages, openAIMessage)
}
} }
openAIRequest.Messages = openAIMessages openAIRequest.Messages = openAIMessages
@@ -154,22 +170,35 @@ func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.O
} }
} }
func generateStopBlock(index int) *dto.ClaudeResponse {
return &dto.ClaudeResponse{
Type: "content_block_stop",
Index: common.GetPointer[int](index),
}
}
func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse { func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
var claudeResponses []*dto.ClaudeResponse var claudeResponses []*dto.ClaudeResponse
if info.ResponseTimes == 1 { if info.SendResponseCount == 1 {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ msg := &dto.ClaudeMediaMessage{
Type: "message_start", Id: openAIResponse.Id,
Message: &dto.ClaudeMediaMessage{ Model: openAIResponse.Model,
Id: openAIResponse.Id, Type: "message",
Model: openAIResponse.Model, Role: "assistant",
Type: "message", Usage: &dto.ClaudeUsage{
Role: "assistant", InputTokens: info.PromptTokens,
Usage: &dto.ClaudeUsage{ OutputTokens: 0,
InputTokens: info.PromptTokens,
OutputTokens: 0,
},
}, },
}
msg.SetContent(make([]any, 0))
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "message_start",
Message: msg,
}) })
claudeResponses = append(claudeResponses)
//claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
// Type: "ping",
//})
if openAIResponse.IsToolCall() { if openAIResponse.IsToolCall() {
resp := &dto.ClaudeResponse{ resp := &dto.ClaudeResponse{
Type: "content_block_start", Type: "content_block_start",
@@ -192,23 +221,18 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
resp.SetIndex(0) resp.SetIndex(0)
claudeResponses = append(claudeResponses, resp) claudeResponses = append(claudeResponses, resp)
} }
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "ping",
})
return claudeResponses return claudeResponses
} }
if len(openAIResponse.Choices) == 0 { if len(openAIResponse.Choices) == 0 {
// no choices // no choices
// TODO: handle this case // TODO: handle this case
return claudeResponses
} else { } else {
chosenChoice := openAIResponse.Choices[0] chosenChoice := openAIResponse.Choices[0]
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" { if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
// should be done // should be done
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
Type: "content_block_stop",
Index: common.GetPointer[int](0),
})
if openAIResponse.Usage != nil { if openAIResponse.Usage != nil {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "message_delta", Type: "message_delta",
@@ -229,18 +253,35 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
claudeResponse.SetIndex(0) claudeResponse.SetIndex(0)
claudeResponse.Type = "content_block_delta" claudeResponse.Type = "content_block_delta"
if len(chosenChoice.Delta.ToolCalls) > 0 { if len(chosenChoice.Delta.ToolCalls) > 0 {
if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeText {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
info.ClaudeConvertInfo.Index++
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Id: openAIResponse.GetFirstToolCall().ID,
Type: "tool_use",
Name: openAIResponse.GetFirstToolCall().Function.Name,
Input: map[string]interface{}{},
},
})
}
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
// tools delta // tools delta
claudeResponse.Delta = &dto.ClaudeMediaMessage{ claudeResponse.Delta = &dto.ClaudeMediaMessage{
Type: "input_json_delta", Type: "input_json_delta",
PartialJson: chosenChoice.Delta.ToolCalls[0].Function.Arguments, PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
} }
} else { } else {
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
// text delta // text delta
claudeResponse.Delta = &dto.ClaudeMediaMessage{ claudeResponse.Delta = &dto.ClaudeMediaMessage{
Type: "text_delta", Type: "text_delta",
Text: common.GetPointer[string](chosenChoice.Delta.GetContentString()), Text: common.GetPointer[string](chosenChoice.Delta.GetContentString()),
} }
} }
claudeResponse.Index = &info.ClaudeConvertInfo.Index
claudeResponses = append(claudeResponses, &claudeResponse) claudeResponses = append(claudeResponses, &claudeResponse)
} }
} }