feat: Add thinking-to-content conversion for stream responses

This commit is contained in:
1808837298@qq.com
2025-02-23 17:05:57 +08:00
parent 88a2fec190
commit 115a181db3
5 changed files with 110 additions and 39 deletions

View File

@@ -1,6 +1,7 @@
package constant package constant
var ( var (
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式 ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
ChanelSettingProxy = "proxy" // Proxy 代理 ChanelSettingProxy = "proxy" // Proxy 代理
ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
) )

View File

@@ -10,6 +10,10 @@
- 用于配置网络代理 - 用于配置网络代理
- 类型为字符串,填写代理地址(例如 socks5 协议的代理地址) - 类型为字符串,填写代理地址(例如 socks5 协议的代理地址)
3. thinking_to_content
- 用于标识是否将思考内容`reasoning_conetnt`转换为`<think>`标签拼接到内容中返回
- 类型为布尔值,设置为 true 时启用思考内容转换
-------------------------------------------------------------- --------------------------------------------------------------
## JSON 格式示例 ## JSON 格式示例
@@ -19,6 +23,7 @@
```json ```json
{ {
"force_format": true, "force_format": true,
"thinking_to_content": true,
"proxy": "socks5://xxxxxxx" "proxy": "socks5://xxxxxxx"
} }
``` ```

View File

@@ -86,6 +86,10 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string
return *c.ReasoningContent return *c.ReasoningContent
} }
func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
c.ReasoningContent = &s
}
type ToolCall struct { type ToolCall struct {
// Index is not nil only in chat completion chunk object // Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"` Index *int `json:"index,omitempty"`
@@ -116,6 +120,20 @@ type ChatCompletionsStreamResponse struct {
Usage *Usage `json:"usage"` Usage *Usage `json:"usage"`
} }
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
copy(choices, c.Choices)
return &ChatCompletionsStreamResponse{
Id: c.Id,
Object: c.Object,
Created: c.Created,
Model: c.Model,
SystemFingerprint: c.SystemFingerprint,
Choices: choices,
Usage: c.Usage,
}
}
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string { func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
if c.SystemFingerprint == nil { if c.SystemFingerprint == nil {
return "" return ""

View File

@@ -5,10 +5,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"io" "io"
"math" "math"
"mime/multipart" "mime/multipart"
@@ -23,21 +19,66 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
) )
func sendStreamData(c *gin.Context, data string, forceFormat bool) error { func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
if data == "" { if data == "" {
return nil return nil
} }
if forceFormat { if !forceFormat && !thinkToContent {
var lastStreamResponse dto.ChatCompletionsStreamResponse return service.StringData(c, data)
if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil { }
return err
} var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil {
return err
}
if !thinkToContent {
return service.ObjectData(c, lastStreamResponse) return service.ObjectData(c, lastStreamResponse)
} }
return service.StringData(c, data)
// Handle think to content conversion
if info.IsFirstResponse {
response := lastStreamResponse.Copy()
for i := range response.Choices {
response.Choices[i].Delta.SetContentString("<think>\n")
response.Choices[i].Delta.SetReasoningContent("")
}
service.ObjectData(c, response)
}
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
return service.ObjectData(c, lastStreamResponse)
}
// Process each choice
for i, choice := range lastStreamResponse.Choices {
// Handle transition from thinking to content
if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse {
response := lastStreamResponse.Copy()
for j := range response.Choices {
response.Choices[j].Delta.SetContentString("\n</think>")
response.Choices[j].Delta.SetReasoningContent("")
}
info.SendLastReasoningResponse = true
service.ObjectData(c, response)
}
// Convert reasoning content to regular content
if len(choice.Delta.GetReasoningContent()) > 0 {
lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
lastStreamResponse.Choices[i].Delta.SetReasoningContent("")
}
}
return service.ObjectData(c, lastStreamResponse)
} }
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -56,11 +97,14 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
var usage = &dto.Usage{} var usage = &dto.Usage{}
var streamItems []string // store stream items var streamItems []string // store stream items
var forceFormat bool var forceFormat bool
var thinkToContent bool
if info.ChannelType == common.ChannelTypeCustom { if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
if forceFmt, ok := info.ChannelSetting["force_format"].(bool); ok { forceFormat = forceFmt
forceFormat = forceFmt }
}
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
thinkToContent = think2Content
} }
toolCount := 0 toolCount := 0
@@ -84,7 +128,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
) )
gopool.Go(func() { gopool.Go(func() {
for scanner.Scan() { for scanner.Scan() {
info.SetFirstResponseTime() //info.SetFirstResponseTime()
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
data := scanner.Text() data := scanner.Text()
if common.DebugEnabled { if common.DebugEnabled {
@@ -101,10 +145,11 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
data = strings.TrimSpace(data) data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "[DONE]") { if !strings.HasPrefix(data, "[DONE]") {
if lastStreamData != "" { if lastStreamData != "" {
err := sendStreamData(c, lastStreamData, forceFormat) err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
if err != nil { if err != nil {
common.LogError(c, "streaming error: "+err.Error()) common.LogError(c, "streaming error: "+err.Error())
} }
info.SetFirstResponseTime()
} }
lastStreamData = data lastStreamData = data
streamItems = append(streamItems, data) streamItems = append(streamItems, data)
@@ -144,7 +189,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
} }
} }
if shouldSendLastResp { if shouldSendLastResp {
sendStreamData(c, lastStreamData, forceFormat) sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
} }
// 计算token // 计算token

View File

@@ -13,23 +13,24 @@ import (
) )
type RelayInfo struct { type RelayInfo struct {
ChannelType int ChannelType int
ChannelId int ChannelId int
TokenId int TokenId int
TokenKey string TokenKey string
UserId int UserId int
Group string Group string
TokenUnlimited bool TokenUnlimited bool
StartTime time.Time StartTime time.Time
FirstResponseTime time.Time FirstResponseTime time.Time
setFirstResponse bool IsFirstResponse bool
ApiType int SendLastReasoningResponse bool
IsStream bool ApiType int
IsPlayground bool IsStream bool
UsePrice bool IsPlayground bool
RelayMode int UsePrice bool
UpstreamModelName string RelayMode int
OriginModelName string UpstreamModelName string
OriginModelName string
//RecodeModelName string //RecodeModelName string
RequestURLPath string RequestURLPath string
ApiVersion string ApiVersion string
@@ -88,6 +89,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
apiType, _ := relayconstant.ChannelType2APIType(channelType) apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &RelayInfo{ info := &RelayInfo{
IsFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"), BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(), RequestURLPath: c.Request.URL.String(),
@@ -139,9 +141,9 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
} }
func (info *RelayInfo) SetFirstResponseTime() { func (info *RelayInfo) SetFirstResponseTime() {
if !info.setFirstResponse { if info.IsFirstResponse {
info.FirstResponseTime = time.Now() info.FirstResponseTime = time.Now()
info.setFirstResponse = true info.IsFirstResponse = false
} }
} }