feat: 初步兼容流模式下openai渠道类型转为claude格式访问 #862
This commit is contained in:
@@ -107,7 +107,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
|
|
||||||
adaptor.Init(info)
|
adaptor.Init(info)
|
||||||
|
|
||||||
convertedRequest, err := adaptor.ConvertRequest(c, info, request)
|
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ type ClaudeMediaMessage struct {
|
|||||||
Source *ClaudeMessageSource `json:"source,omitempty"`
|
Source *ClaudeMessageSource `json:"source,omitempty"`
|
||||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||||
StopReason *string `json:"stop_reason,omitempty"`
|
StopReason *string `json:"stop_reason,omitempty"`
|
||||||
PartialJson string `json:"partial_json,omitempty"`
|
PartialJson *string `json:"partial_json,omitempty"`
|
||||||
Role string `json:"role,omitempty"`
|
Role string `json:"role,omitempty"`
|
||||||
Thinking string `json:"thinking,omitempty"`
|
Thinking string `json:"thinking,omitempty"`
|
||||||
Signature string `json:"signature,omitempty"`
|
Signature string `json:"signature,omitempty"`
|
||||||
@@ -37,6 +37,32 @@ func (c *ClaudeMediaMessage) GetText() string {
|
|||||||
return *c.Text
|
return *c.Text
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ClaudeMediaMessage) IsStringContent() bool {
|
||||||
|
var content string
|
||||||
|
return json.Unmarshal(c.Content, &content) == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClaudeMediaMessage) GetStringContent() string {
|
||||||
|
var content string
|
||||||
|
if err := json.Unmarshal(c.Content, &content); err == nil {
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClaudeMediaMessage) SetContent(content any) {
|
||||||
|
jsonContent, _ := json.Marshal(content)
|
||||||
|
c.Content = jsonContent
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
|
||||||
|
var mediaContent []ClaudeMediaMessage
|
||||||
|
if err := json.Unmarshal(c.Content, &mediaContent); err == nil {
|
||||||
|
return mediaContent
|
||||||
|
}
|
||||||
|
return make([]ClaudeMediaMessage, 0)
|
||||||
|
}
|
||||||
|
|
||||||
type ClaudeMessageSource struct {
|
type ClaudeMessageSource struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
MediaType string `json:"media_type"`
|
MediaType string `json:"media_type"`
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ type Adaptor interface {
|
|||||||
Init(info *relaycommon.RelayInfo)
|
Init(info *relaycommon.RelayInfo)
|
||||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||||
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
|
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
|
||||||
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
||||||
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
||||||
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
|
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
|
||||||
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
|
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
common2 "one-api/common"
|
||||||
"one-api/relay/common"
|
"one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
@@ -31,6 +32,9 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get request url failed: %w", err)
|
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||||
}
|
}
|
||||||
|
if common2.DebugEnabled {
|
||||||
|
println("fullRequestURL:", fullRequestURL)
|
||||||
|
}
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("new request failed: %w", err)
|
return nil, fmt.Errorf("new request failed: %w", err)
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -335,7 +335,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
|
|||||||
case "input_json_delta":
|
case "input_json_delta":
|
||||||
tools = append(tools, dto.ToolCallResponse{
|
tools = append(tools, dto.ToolCallResponse{
|
||||||
Function: dto.FunctionResponse{
|
Function: dto.FunctionResponse{
|
||||||
Arguments: claudeResponse.Delta.PartialJson,
|
Arguments: *claudeResponse.Delta.PartialJson,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
case "signature_delta":
|
case "signature_delta":
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
return requestOpenAI2Cohere(*request), nil
|
return requestOpenAI2Cohere(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"one-api/relay/channel/xinference"
|
"one-api/relay/channel/xinference"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,10 +30,20 @@ type Adaptor struct {
|
|||||||
ResponseFormat string
|
ResponseFormat string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
if !strings.HasPrefix(request.Model, "claude") {
|
||||||
panic("implement me")
|
return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
|
||||||
return nil, nil
|
}
|
||||||
|
aiRequest, err := service.ClaudeToOpenAIRequest(*request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if info.SupportStreamOptions {
|
||||||
|
aiRequest.StreamOptions = &dto.StreamOptions{
|
||||||
|
IncludeUsage: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return a.ConvertOpenAIRequest(c, info, aiRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
@@ -40,6 +51,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||||
|
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||||
|
}
|
||||||
if info.RelayMode == constant.RelayModeRealtime {
|
if info.RelayMode == constant.RelayModeRealtime {
|
||||||
if strings.HasPrefix(info.BaseUrl, "https://") {
|
if strings.HasPrefix(info.BaseUrl, "https://") {
|
||||||
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
||||||
@@ -115,7 +129,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
188
relay/channel/openai/helper.go
Normal file
188
relay/channel/openai/helper.go
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
|
"one-api/relay/helper"
|
||||||
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 辅助函数
|
||||||
|
func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||||
|
info.SendResponseCount++
|
||||||
|
switch info.RelayFormat {
|
||||||
|
case relaycommon.RelayFormatOpenAI:
|
||||||
|
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
||||||
|
case relaycommon.RelayFormatClaude:
|
||||||
|
return handleClaudeFormat(c, data, info)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||||||
|
var streamResponse dto.ChatCompletionsStreamResponse
|
||||||
|
if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
||||||
|
for _, resp := range claudeResponses {
|
||||||
|
helper.ClaudeData(c, *resp)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func processStreamResponse(item string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||||
|
var streamResponse dto.ChatCompletionsStreamResponse
|
||||||
|
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, choice := range streamResponse.Choices {
|
||||||
|
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||||
|
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||||
|
if choice.Delta.ToolCalls != nil {
|
||||||
|
if len(choice.Delta.ToolCalls) > *toolCount {
|
||||||
|
*toolCount = len(choice.Delta.ToolCalls)
|
||||||
|
}
|
||||||
|
for _, tool := range choice.Delta.ToolCalls {
|
||||||
|
responseTextBuilder.WriteString(tool.Function.Name)
|
||||||
|
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||||
|
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||||
|
|
||||||
|
switch relayMode {
|
||||||
|
case relayconstant.RelayModeChatCompletions:
|
||||||
|
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
|
||||||
|
case relayconstant.RelayModeCompletions:
|
||||||
|
return processCompletions(streamResp, streamItems, responseTextBuilder)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||||
|
var streamResponses []dto.ChatCompletionsStreamResponse
|
||||||
|
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||||
|
// 一次性解析失败,逐个解析
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
for _, item := range streamItems {
|
||||||
|
if err := processStreamResponse(item, responseTextBuilder, toolCount); err != nil {
|
||||||
|
common.SysError("error processing stream response: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 批量处理所有响应
|
||||||
|
for _, streamResponse := range streamResponses {
|
||||||
|
for _, choice := range streamResponse.Choices {
|
||||||
|
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||||
|
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||||
|
if choice.Delta.ToolCalls != nil {
|
||||||
|
if len(choice.Delta.ToolCalls) > *toolCount {
|
||||||
|
*toolCount = len(choice.Delta.ToolCalls)
|
||||||
|
}
|
||||||
|
for _, tool := range choice.Delta.ToolCalls {
|
||||||
|
responseTextBuilder.WriteString(tool.Function.Name)
|
||||||
|
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
|
||||||
|
var streamResponses []dto.CompletionsStreamResponse
|
||||||
|
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||||
|
// 一次性解析失败,逐个解析
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
for _, item := range streamItems {
|
||||||
|
var streamResponse dto.CompletionsStreamResponse
|
||||||
|
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, choice := range streamResponse.Choices {
|
||||||
|
responseTextBuilder.WriteString(choice.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 批量处理所有响应
|
||||||
|
for _, streamResponse := range streamResponses {
|
||||||
|
for _, choice := range streamResponse.Choices {
|
||||||
|
responseTextBuilder.WriteString(choice.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
|
||||||
|
systemFingerprint *string, model *string, usage **dto.Usage,
|
||||||
|
containStreamUsage *bool, info *relaycommon.RelayInfo,
|
||||||
|
shouldSendLastResp *bool) error {
|
||||||
|
|
||||||
|
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||||
|
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*responseId = lastStreamResponse.Id
|
||||||
|
*createAt = lastStreamResponse.Created
|
||||||
|
*systemFingerprint = lastStreamResponse.GetSystemFingerprint()
|
||||||
|
*model = lastStreamResponse.Model
|
||||||
|
|
||||||
|
if service.ValidUsage(lastStreamResponse.Usage) {
|
||||||
|
*containStreamUsage = true
|
||||||
|
*usage = lastStreamResponse.Usage
|
||||||
|
if !info.ShouldIncludeUsage {
|
||||||
|
*shouldSendLastResp = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
|
||||||
|
responseId string, createAt int64, model string, systemFingerprint string,
|
||||||
|
usage *dto.Usage, containStreamUsage bool) {
|
||||||
|
|
||||||
|
switch info.RelayFormat {
|
||||||
|
case relaycommon.RelayFormatOpenAI:
|
||||||
|
if info.ShouldIncludeUsage && !containStreamUsage {
|
||||||
|
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
||||||
|
response.SetSystemFingerprint(systemFingerprint)
|
||||||
|
helper.ObjectData(c, response)
|
||||||
|
}
|
||||||
|
helper.Done(c)
|
||||||
|
|
||||||
|
case relaycommon.RelayFormatClaude:
|
||||||
|
var streamResponse dto.ChatCompletionsStreamResponse
|
||||||
|
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !containStreamUsage {
|
||||||
|
streamResponse.Usage = usage
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
||||||
|
for _, resp := range claudeResponses {
|
||||||
|
helper.ClaudeData(c, *resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"os"
|
"os"
|
||||||
@@ -137,10 +136,11 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
|
|
||||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
if lastStreamData != "" {
|
if lastStreamData != "" {
|
||||||
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, "streaming error: "+err.Error())
|
common.SysError("error handling stream format: " + err.Error())
|
||||||
}
|
}
|
||||||
|
info.SetFirstResponseTime()
|
||||||
}
|
}
|
||||||
lastStreamData = data
|
lastStreamData = data
|
||||||
streamItems = append(streamItems, data)
|
streamItems = append(streamItems, data)
|
||||||
@@ -172,83 +172,9 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算token
|
// 处理token计算
|
||||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
|
||||||
switch info.RelayMode {
|
common.SysError("error processing tokens: " + err.Error())
|
||||||
case relayconstant.RelayModeChatCompletions:
|
|
||||||
var streamResponses []dto.ChatCompletionsStreamResponse
|
|
||||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
|
||||||
if err != nil {
|
|
||||||
// 一次性解析失败,逐个解析
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
for _, item := range streamItems {
|
|
||||||
var streamResponse dto.ChatCompletionsStreamResponse
|
|
||||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
|
||||||
if err == nil {
|
|
||||||
//if service.ValidUsage(streamResponse.Usage) {
|
|
||||||
// usage = streamResponse.Usage
|
|
||||||
//}
|
|
||||||
for _, choice := range streamResponse.Choices {
|
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
|
||||||
|
|
||||||
// handle both reasoning_content and reasoning
|
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
|
||||||
|
|
||||||
if choice.Delta.ToolCalls != nil {
|
|
||||||
if len(choice.Delta.ToolCalls) > toolCount {
|
|
||||||
toolCount = len(choice.Delta.ToolCalls)
|
|
||||||
}
|
|
||||||
for _, tool := range choice.Delta.ToolCalls {
|
|
||||||
responseTextBuilder.WriteString(tool.Function.Name)
|
|
||||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for _, streamResponse := range streamResponses {
|
|
||||||
//if service.ValidUsage(streamResponse.Usage) {
|
|
||||||
// usage = streamResponse.Usage
|
|
||||||
// containStreamUsage = true
|
|
||||||
//}
|
|
||||||
for _, choice := range streamResponse.Choices {
|
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) // This will handle both reasoning_content and reasoning
|
|
||||||
if choice.Delta.ToolCalls != nil {
|
|
||||||
if len(choice.Delta.ToolCalls) > toolCount {
|
|
||||||
toolCount = len(choice.Delta.ToolCalls)
|
|
||||||
}
|
|
||||||
for _, tool := range choice.Delta.ToolCalls {
|
|
||||||
responseTextBuilder.WriteString(tool.Function.Name)
|
|
||||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case relayconstant.RelayModeCompletions:
|
|
||||||
var streamResponses []dto.CompletionsStreamResponse
|
|
||||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
|
||||||
if err != nil {
|
|
||||||
// 一次性解析失败,逐个解析
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
for _, item := range streamItems {
|
|
||||||
var streamResponse dto.CompletionsStreamResponse
|
|
||||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
|
||||||
if err == nil {
|
|
||||||
for _, choice := range streamResponse.Choices {
|
|
||||||
responseTextBuilder.WriteString(choice.Text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for _, streamResponse := range streamResponses {
|
|
||||||
for _, choice := range streamResponse.Choices {
|
|
||||||
responseTextBuilder.WriteString(choice.Text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !containStreamUsage {
|
if !containStreamUsage {
|
||||||
@@ -262,15 +188,8 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if info.ShouldIncludeUsage && !containStreamUsage {
|
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
||||||
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
|
||||||
response.SetSystemFingerprint(systemFingerprint)
|
|
||||||
helper.ObjectData(c, response)
|
|
||||||
}
|
|
||||||
|
|
||||||
helper.Done(c)
|
|
||||||
|
|
||||||
//resp.Body.Close()
|
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -114,13 +114,14 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
|||||||
return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
|
if common.DebugEnabled {
|
||||||
|
println("requestBody: ", string(jsonData))
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(jsonData)
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
|
|
||||||
//log.Printf("requestBody: %s", requestBody)
|
|
||||||
|
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
var httpResp *http.Response
|
var httpResp *http.Response
|
||||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
|
|||||||
@@ -17,6 +17,16 @@ type ThinkingContentInfo struct {
|
|||||||
SendLastThinkingContent bool
|
SendLastThinkingContent bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
LastMessageTypeText = "text"
|
||||||
|
LastMessageTypeTools = "tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ClaudeConvertInfo struct {
|
||||||
|
LastMessagesType string
|
||||||
|
Index int
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RelayFormatOpenAI = "openai"
|
RelayFormatOpenAI = "openai"
|
||||||
RelayFormatClaude = "claude"
|
RelayFormatClaude = "claude"
|
||||||
@@ -64,8 +74,9 @@ type RelayInfo struct {
|
|||||||
UserEmail string
|
UserEmail string
|
||||||
UserQuota int
|
UserQuota int
|
||||||
RelayFormat string
|
RelayFormat string
|
||||||
ResponseTimes int64
|
SendResponseCount int
|
||||||
ThinkingContentInfo
|
ThinkingContentInfo
|
||||||
|
ClaudeConvertInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// 定义支持流式选项的通道类型
|
// 定义支持流式选项的通道类型
|
||||||
@@ -93,6 +104,9 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
|
|||||||
info := GenRelayInfo(c)
|
info := GenRelayInfo(c)
|
||||||
info.RelayFormat = RelayFormatClaude
|
info.RelayFormat = RelayFormatClaude
|
||||||
info.ShouldIncludeUsage = false
|
info.ShouldIncludeUsage = false
|
||||||
|
info.ClaudeConvertInfo = ClaudeConvertInfo{
|
||||||
|
LastMessagesType: LastMessageTypeText,
|
||||||
|
}
|
||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,7 +186,6 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (info *RelayInfo) SetFirstResponseTime() {
|
func (info *RelayInfo) SetFirstResponseTime() {
|
||||||
info.ResponseTimes++
|
|
||||||
if info.isFirstResponse {
|
if info.isFirstResponse {
|
||||||
info.FirstResponseTime = time.Now()
|
info.FirstResponseTime = time.Now()
|
||||||
info.isFirstResponse = false
|
info.isFirstResponse = false
|
||||||
|
|||||||
@@ -19,6 +19,22 @@ func SetEventStreamHeaders(c *gin.Context) {
|
|||||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
|
||||||
|
jsonData, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
} else {
|
||||||
|
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
|
||||||
|
}
|
||||||
|
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
} else {
|
||||||
|
return errors.New("streaming error: flusher not found")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
|
func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
|
||||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
|
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
|
||||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
|
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(body)
|
requestBody = bytes.NewBuffer(body)
|
||||||
} else {
|
} else {
|
||||||
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
|
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@@ -168,6 +168,9 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
if common.DebugEnabled {
|
||||||
|
println("requestBody: ", string(jsonData))
|
||||||
|
}
|
||||||
requestBody = bytes.NewBuffer(jsonData)
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -44,24 +44,26 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
|
|||||||
openAIMessages := make([]dto.Message, 0)
|
openAIMessages := make([]dto.Message, 0)
|
||||||
|
|
||||||
// Add system message if present
|
// Add system message if present
|
||||||
if claudeRequest.IsStringSystem() {
|
if claudeRequest.System != nil {
|
||||||
openAIMessage := dto.Message{
|
if claudeRequest.IsStringSystem() {
|
||||||
Role: "system",
|
|
||||||
}
|
|
||||||
openAIMessage.SetStringContent(claudeRequest.GetStringSystem())
|
|
||||||
openAIMessages = append(openAIMessages, openAIMessage)
|
|
||||||
} else {
|
|
||||||
systems := claudeRequest.ParseSystem()
|
|
||||||
if len(systems) > 0 {
|
|
||||||
systemStr := ""
|
|
||||||
openAIMessage := dto.Message{
|
openAIMessage := dto.Message{
|
||||||
Role: "system",
|
Role: "system",
|
||||||
}
|
}
|
||||||
for _, system := range systems {
|
openAIMessage.SetStringContent(claudeRequest.GetStringSystem())
|
||||||
systemStr += system.Type
|
|
||||||
}
|
|
||||||
openAIMessage.SetStringContent(systemStr)
|
|
||||||
openAIMessages = append(openAIMessages, openAIMessage)
|
openAIMessages = append(openAIMessages, openAIMessage)
|
||||||
|
} else {
|
||||||
|
systems := claudeRequest.ParseSystem()
|
||||||
|
if len(systems) > 0 {
|
||||||
|
systemStr := ""
|
||||||
|
openAIMessage := dto.Message{
|
||||||
|
Role: "system",
|
||||||
|
}
|
||||||
|
for _, system := range systems {
|
||||||
|
systemStr += system.Type
|
||||||
|
}
|
||||||
|
openAIMessage.SetStringContent(systemStr)
|
||||||
|
openAIMessages = append(openAIMessages, openAIMessage)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, claudeMessage := range claudeRequest.Messages {
|
for _, claudeMessage := range claudeRequest.Messages {
|
||||||
@@ -100,7 +102,8 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
|
|||||||
mediaMessages = append(mediaMessages, mediaMessage)
|
mediaMessages = append(mediaMessages, mediaMessage)
|
||||||
case "tool_use":
|
case "tool_use":
|
||||||
toolCall := dto.ToolCallRequest{
|
toolCall := dto.ToolCallRequest{
|
||||||
ID: mediaMsg.Id,
|
ID: mediaMsg.Id,
|
||||||
|
Type: "function",
|
||||||
Function: dto.FunctionRequest{
|
Function: dto.FunctionRequest{
|
||||||
Name: mediaMsg.Name,
|
Name: mediaMsg.Name,
|
||||||
Arguments: toJSONString(mediaMsg.Input),
|
Arguments: toJSONString(mediaMsg.Input),
|
||||||
@@ -111,20 +114,33 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
|
|||||||
// Add tool result as a separate message
|
// Add tool result as a separate message
|
||||||
oaiToolMessage := dto.Message{
|
oaiToolMessage := dto.Message{
|
||||||
Role: "tool",
|
Role: "tool",
|
||||||
|
Name: &mediaMsg.Name,
|
||||||
ToolCallId: mediaMsg.ToolUseId,
|
ToolCallId: mediaMsg.ToolUseId,
|
||||||
}
|
}
|
||||||
oaiToolMessage.Content = mediaMsg.Content
|
//oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text)
|
||||||
|
if mediaMsg.IsStringContent() {
|
||||||
|
oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
|
||||||
|
} else {
|
||||||
|
mediaContents := mediaMsg.ParseMediaContent()
|
||||||
|
if len(mediaContents) > 0 && mediaContents[0].Text != nil {
|
||||||
|
oaiToolMessage.SetStringContent(*mediaContents[0].Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
openAIMessages = append(openAIMessages, oaiToolMessage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
openAIMessage.SetMediaContent(mediaMessages)
|
if len(mediaMessages) > 0 {
|
||||||
|
openAIMessage.SetMediaContent(mediaMessages)
|
||||||
|
}
|
||||||
|
|
||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
openAIMessage.SetToolCalls(toolCalls)
|
openAIMessage.SetToolCalls(toolCalls)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(openAIMessage.ParseContent()) > 0 {
|
||||||
openAIMessages = append(openAIMessages, openAIMessage)
|
openAIMessages = append(openAIMessages, openAIMessage)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
openAIRequest.Messages = openAIMessages
|
openAIRequest.Messages = openAIMessages
|
||||||
@@ -154,22 +170,35 @@ func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.O
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateStopBlock(index int) *dto.ClaudeResponse {
|
||||||
|
return &dto.ClaudeResponse{
|
||||||
|
Type: "content_block_stop",
|
||||||
|
Index: common.GetPointer[int](index),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
|
func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
|
||||||
var claudeResponses []*dto.ClaudeResponse
|
var claudeResponses []*dto.ClaudeResponse
|
||||||
if info.ResponseTimes == 1 {
|
if info.SendResponseCount == 1 {
|
||||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
msg := &dto.ClaudeMediaMessage{
|
||||||
Type: "message_start",
|
Id: openAIResponse.Id,
|
||||||
Message: &dto.ClaudeMediaMessage{
|
Model: openAIResponse.Model,
|
||||||
Id: openAIResponse.Id,
|
Type: "message",
|
||||||
Model: openAIResponse.Model,
|
Role: "assistant",
|
||||||
Type: "message",
|
Usage: &dto.ClaudeUsage{
|
||||||
Role: "assistant",
|
InputTokens: info.PromptTokens,
|
||||||
Usage: &dto.ClaudeUsage{
|
OutputTokens: 0,
|
||||||
InputTokens: info.PromptTokens,
|
|
||||||
OutputTokens: 0,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
|
}
|
||||||
|
msg.SetContent(make([]any, 0))
|
||||||
|
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||||
|
Type: "message_start",
|
||||||
|
Message: msg,
|
||||||
})
|
})
|
||||||
|
claudeResponses = append(claudeResponses)
|
||||||
|
//claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||||
|
// Type: "ping",
|
||||||
|
//})
|
||||||
if openAIResponse.IsToolCall() {
|
if openAIResponse.IsToolCall() {
|
||||||
resp := &dto.ClaudeResponse{
|
resp := &dto.ClaudeResponse{
|
||||||
Type: "content_block_start",
|
Type: "content_block_start",
|
||||||
@@ -192,23 +221,18 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
|||||||
resp.SetIndex(0)
|
resp.SetIndex(0)
|
||||||
claudeResponses = append(claudeResponses, resp)
|
claudeResponses = append(claudeResponses, resp)
|
||||||
}
|
}
|
||||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
|
||||||
Type: "ping",
|
|
||||||
})
|
|
||||||
return claudeResponses
|
return claudeResponses
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(openAIResponse.Choices) == 0 {
|
if len(openAIResponse.Choices) == 0 {
|
||||||
// no choices
|
// no choices
|
||||||
// TODO: handle this case
|
// TODO: handle this case
|
||||||
|
return claudeResponses
|
||||||
} else {
|
} else {
|
||||||
chosenChoice := openAIResponse.Choices[0]
|
chosenChoice := openAIResponse.Choices[0]
|
||||||
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
|
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
|
||||||
// should be done
|
// should be done
|
||||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||||
Type: "content_block_stop",
|
|
||||||
Index: common.GetPointer[int](0),
|
|
||||||
})
|
|
||||||
if openAIResponse.Usage != nil {
|
if openAIResponse.Usage != nil {
|
||||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||||
Type: "message_delta",
|
Type: "message_delta",
|
||||||
@@ -229,18 +253,35 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
|||||||
claudeResponse.SetIndex(0)
|
claudeResponse.SetIndex(0)
|
||||||
claudeResponse.Type = "content_block_delta"
|
claudeResponse.Type = "content_block_delta"
|
||||||
if len(chosenChoice.Delta.ToolCalls) > 0 {
|
if len(chosenChoice.Delta.ToolCalls) > 0 {
|
||||||
|
if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeText {
|
||||||
|
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||||
|
info.ClaudeConvertInfo.Index++
|
||||||
|
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||||
|
Index: &info.ClaudeConvertInfo.Index,
|
||||||
|
Type: "content_block_start",
|
||||||
|
ContentBlock: &dto.ClaudeMediaMessage{
|
||||||
|
Id: openAIResponse.GetFirstToolCall().ID,
|
||||||
|
Type: "tool_use",
|
||||||
|
Name: openAIResponse.GetFirstToolCall().Function.Name,
|
||||||
|
Input: map[string]interface{}{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
|
||||||
// tools delta
|
// tools delta
|
||||||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||||||
Type: "input_json_delta",
|
Type: "input_json_delta",
|
||||||
PartialJson: chosenChoice.Delta.ToolCalls[0].Function.Arguments,
|
PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
|
||||||
// text delta
|
// text delta
|
||||||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||||||
Type: "text_delta",
|
Type: "text_delta",
|
||||||
Text: common.GetPointer[string](chosenChoice.Delta.GetContentString()),
|
Text: common.GetPointer[string](chosenChoice.Delta.GetContentString()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
claudeResponse.Index = &info.ClaudeConvertInfo.Index
|
||||||
claudeResponses = append(claudeResponses, &claudeResponse)
|
claudeResponses = append(claudeResponses, &claudeResponse)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user