diff --git a/constant/channel_setting.go b/constant/channel_setting.go index 6eccfb84..e06e7eb1 100644 --- a/constant/channel_setting.go +++ b/constant/channel_setting.go @@ -1,6 +1,7 @@ package constant var ( - ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式 - ChanelSettingProxy = "proxy" // Proxy 代理 + ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式 + ChanelSettingProxy = "proxy" // Proxy 代理 + ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent ) diff --git a/docs/channel/other_setting.md b/docs/channel/other_setting.md index 775da557..b3f4f969 100644 --- a/docs/channel/other_setting.md +++ b/docs/channel/other_setting.md @@ -10,6 +10,10 @@ - 用于配置网络代理 - 类型为字符串,填写代理地址(例如 socks5 协议的代理地址) +3. thinking_to_content + - 用于标识是否将思考内容`reasoning_conetnt`转换为``标签拼接到内容中返回 + - 类型为布尔值,设置为 true 时启用思考内容转换 + -------------------------------------------------------------- ## JSON 格式示例 @@ -19,6 +23,7 @@ ```json { "force_format": true, + "thinking_to_content": true, "proxy": "socks5://xxxxxxx" } ``` diff --git a/dto/openai_response.go b/dto/openai_response.go index febf01ff..3a4c971b 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -86,6 +86,10 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string return *c.ReasoningContent } +func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) { + c.ReasoningContent = &s +} + type ToolCall struct { // Index is not nil only in chat completion chunk object Index *int `json:"index,omitempty"` @@ -116,6 +120,20 @@ type ChatCompletionsStreamResponse struct { Usage *Usage `json:"usage"` } +func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse { + choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices)) + copy(choices, c.Choices) + return &ChatCompletionsStreamResponse{ + Id: c.Id, + Object: c.Object, + Created: c.Created, + Model: c.Model, + SystemFingerprint: c.SystemFingerprint, + Choices: choices, + Usage: c.Usage, + } +} + func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string { if c.SystemFingerprint == nil { return "" diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 33cdea48..a5bd0e33 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -5,10 +5,6 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/bytedance/gopkg/util/gopool" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" - "github.com/pkg/errors" "io" "math" "mime/multipart" @@ -23,21 +19,66 @@ import ( "strings" "sync" "time" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/pkg/errors" ) -func sendStreamData(c *gin.Context, data string, forceFormat bool) error { +func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { if data == "" { return nil } - if forceFormat { - var lastStreamResponse dto.ChatCompletionsStreamResponse - if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil { - return err - } + if !forceFormat && !thinkToContent { + return service.StringData(c, data) + } + + var lastStreamResponse dto.ChatCompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil { + return err + } + + if !thinkToContent { return service.ObjectData(c, lastStreamResponse) } - return service.StringData(c, data) + + // Handle think to content conversion + if info.IsFirstResponse { + response := lastStreamResponse.Copy() + for i := range response.Choices { + response.Choices[i].Delta.SetContentString("\n") + response.Choices[i].Delta.SetReasoningContent("") + } + service.ObjectData(c, response) + } + + if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 { + return service.ObjectData(c, lastStreamResponse) + } + + // Process each choice + for i, choice := range lastStreamResponse.Choices { + // Handle transition from thinking to content + if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse { + response := lastStreamResponse.Copy() + for j := range response.Choices { + response.Choices[j].Delta.SetContentString("\n") + response.Choices[j].Delta.SetReasoningContent("") + } + info.SendLastReasoningResponse = true + service.ObjectData(c, response) + } + + // Convert reasoning content to regular content + if len(choice.Delta.GetReasoningContent()) > 0 { + lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent()) + lastStreamResponse.Choices[i].Delta.SetReasoningContent("") + } + } + + return service.ObjectData(c, lastStreamResponse) } func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { @@ -56,11 +97,14 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel var usage = &dto.Usage{} var streamItems []string // store stream items var forceFormat bool + var thinkToContent bool - if info.ChannelType == common.ChannelTypeCustom { - if forceFmt, ok := info.ChannelSetting["force_format"].(bool); ok { - forceFormat = forceFmt - } + if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { + forceFormat = forceFmt + } + + if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok { + thinkToContent = think2Content } toolCount := 0 @@ -84,7 +128,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel ) gopool.Go(func() { for scanner.Scan() { - info.SetFirstResponseTime() + //info.SetFirstResponseTime() ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) data := scanner.Text() if common.DebugEnabled { @@ -101,10 +145,11 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel data = strings.TrimSpace(data) if !strings.HasPrefix(data, "[DONE]") { if lastStreamData != "" { - err := sendStreamData(c, lastStreamData, forceFormat) + err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) if err != nil { common.LogError(c, "streaming error: "+err.Error()) } + info.SetFirstResponseTime() } lastStreamData = data streamItems = append(streamItems, data) @@ -144,7 +189,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } } if shouldSendLastResp { - sendStreamData(c, lastStreamData, forceFormat) + sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) } // 计算token diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 1f4a3a42..e1ecd83a 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -13,23 +13,24 @@ import ( ) type RelayInfo struct { - ChannelType int - ChannelId int - TokenId int - TokenKey string - UserId int - Group string - TokenUnlimited bool - StartTime time.Time - FirstResponseTime time.Time - setFirstResponse bool - ApiType int - IsStream bool - IsPlayground bool - UsePrice bool - RelayMode int - UpstreamModelName string - OriginModelName string + ChannelType int + ChannelId int + TokenId int + TokenKey string + UserId int + Group string + TokenUnlimited bool + StartTime time.Time + FirstResponseTime time.Time + IsFirstResponse bool + SendLastReasoningResponse bool + ApiType int + IsStream bool + IsPlayground bool + UsePrice bool + RelayMode int + UpstreamModelName string + OriginModelName string //RecodeModelName string RequestURLPath string ApiVersion string @@ -88,6 +89,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { apiType, _ := relayconstant.ChannelType2APIType(channelType) info := &RelayInfo{ + IsFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), BaseUrl: c.GetString("base_url"), RequestURLPath: c.Request.URL.String(), @@ -139,9 +141,9 @@ func (info *RelayInfo) SetIsStream(isStream bool) { } func (info *RelayInfo) SetFirstResponseTime() { - if !info.setFirstResponse { + if info.IsFirstResponse { info.FirstResponseTime = time.Now() - info.setFirstResponse = true + info.IsFirstResponse = false } }