feat: Add thinking-to-content conversion for stream responses
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package constant
|
||||
|
||||
var (
|
||||
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
|
||||
ChanelSettingProxy = "proxy" // Proxy 代理
|
||||
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
|
||||
ChanelSettingProxy = "proxy" // Proxy 代理
|
||||
ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
|
||||
)
|
||||
|
||||
@@ -10,6 +10,10 @@
|
||||
- 用于配置网络代理
|
||||
- 类型为字符串,填写代理地址(例如 socks5 协议的代理地址)
|
||||
|
||||
3. thinking_to_content
|
||||
- 用于标识是否将思考内容`reasoning_conetnt`转换为`<think>`标签拼接到内容中返回
|
||||
- 类型为布尔值,设置为 true 时启用思考内容转换
|
||||
|
||||
--------------------------------------------------------------
|
||||
|
||||
## JSON 格式示例
|
||||
@@ -19,6 +23,7 @@
|
||||
```json
|
||||
{
|
||||
"force_format": true,
|
||||
"thinking_to_content": true,
|
||||
"proxy": "socks5://xxxxxxx"
|
||||
}
|
||||
```
|
||||
|
||||
@@ -86,6 +86,10 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string
|
||||
return *c.ReasoningContent
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
|
||||
c.ReasoningContent = &s
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
// Index is not nil only in chat completion chunk object
|
||||
Index *int `json:"index,omitempty"`
|
||||
@@ -116,6 +120,20 @@ type ChatCompletionsStreamResponse struct {
|
||||
Usage *Usage `json:"usage"`
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
|
||||
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
|
||||
copy(choices, c.Choices)
|
||||
return &ChatCompletionsStreamResponse{
|
||||
Id: c.Id,
|
||||
Object: c.Object,
|
||||
Created: c.Created,
|
||||
Model: c.Model,
|
||||
SystemFingerprint: c.SystemFingerprint,
|
||||
Choices: choices,
|
||||
Usage: c.Usage,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
|
||||
if c.SystemFingerprint == nil {
|
||||
return ""
|
||||
|
||||
@@ -5,10 +5,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
"io"
|
||||
"math"
|
||||
"mime/multipart"
|
||||
@@ -23,21 +19,66 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func sendStreamData(c *gin.Context, data string, forceFormat bool) error {
|
||||
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
if data == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if forceFormat {
|
||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
if !forceFormat && !thinkToContent {
|
||||
return service.StringData(c, data)
|
||||
}
|
||||
|
||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !thinkToContent {
|
||||
return service.ObjectData(c, lastStreamResponse)
|
||||
}
|
||||
return service.StringData(c, data)
|
||||
|
||||
// Handle think to content conversion
|
||||
if info.IsFirstResponse {
|
||||
response := lastStreamResponse.Copy()
|
||||
for i := range response.Choices {
|
||||
response.Choices[i].Delta.SetContentString("<think>\n")
|
||||
response.Choices[i].Delta.SetReasoningContent("")
|
||||
}
|
||||
service.ObjectData(c, response)
|
||||
}
|
||||
|
||||
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
|
||||
return service.ObjectData(c, lastStreamResponse)
|
||||
}
|
||||
|
||||
// Process each choice
|
||||
for i, choice := range lastStreamResponse.Choices {
|
||||
// Handle transition from thinking to content
|
||||
if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse {
|
||||
response := lastStreamResponse.Copy()
|
||||
for j := range response.Choices {
|
||||
response.Choices[j].Delta.SetContentString("\n</think>")
|
||||
response.Choices[j].Delta.SetReasoningContent("")
|
||||
}
|
||||
info.SendLastReasoningResponse = true
|
||||
service.ObjectData(c, response)
|
||||
}
|
||||
|
||||
// Convert reasoning content to regular content
|
||||
if len(choice.Delta.GetReasoningContent()) > 0 {
|
||||
lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
|
||||
lastStreamResponse.Choices[i].Delta.SetReasoningContent("")
|
||||
}
|
||||
}
|
||||
|
||||
return service.ObjectData(c, lastStreamResponse)
|
||||
}
|
||||
|
||||
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
@@ -56,11 +97,14 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
var usage = &dto.Usage{}
|
||||
var streamItems []string // store stream items
|
||||
var forceFormat bool
|
||||
var thinkToContent bool
|
||||
|
||||
if info.ChannelType == common.ChannelTypeCustom {
|
||||
if forceFmt, ok := info.ChannelSetting["force_format"].(bool); ok {
|
||||
forceFormat = forceFmt
|
||||
}
|
||||
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
|
||||
forceFormat = forceFmt
|
||||
}
|
||||
|
||||
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
|
||||
thinkToContent = think2Content
|
||||
}
|
||||
|
||||
toolCount := 0
|
||||
@@ -84,7 +128,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
)
|
||||
gopool.Go(func() {
|
||||
for scanner.Scan() {
|
||||
info.SetFirstResponseTime()
|
||||
//info.SetFirstResponseTime()
|
||||
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||
data := scanner.Text()
|
||||
if common.DebugEnabled {
|
||||
@@ -101,10 +145,11 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
data = strings.TrimSpace(data)
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
if lastStreamData != "" {
|
||||
err := sendStreamData(c, lastStreamData, forceFormat)
|
||||
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
if err != nil {
|
||||
common.LogError(c, "streaming error: "+err.Error())
|
||||
}
|
||||
info.SetFirstResponseTime()
|
||||
}
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
@@ -144,7 +189,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
}
|
||||
if shouldSendLastResp {
|
||||
sendStreamData(c, lastStreamData, forceFormat)
|
||||
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
}
|
||||
|
||||
// 计算token
|
||||
|
||||
@@ -13,23 +13,24 @@ import (
|
||||
)
|
||||
|
||||
type RelayInfo struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
TokenKey string
|
||||
UserId int
|
||||
Group string
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
setFirstResponse bool
|
||||
ApiType int
|
||||
IsStream bool
|
||||
IsPlayground bool
|
||||
UsePrice bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
OriginModelName string
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
TokenKey string
|
||||
UserId int
|
||||
Group string
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
IsFirstResponse bool
|
||||
SendLastReasoningResponse bool
|
||||
ApiType int
|
||||
IsStream bool
|
||||
IsPlayground bool
|
||||
UsePrice bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
OriginModelName string
|
||||
//RecodeModelName string
|
||||
RequestURLPath string
|
||||
ApiVersion string
|
||||
@@ -88,6 +89,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
||||
|
||||
info := &RelayInfo{
|
||||
IsFirstResponse: true,
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
@@ -139,9 +141,9 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
|
||||
}
|
||||
|
||||
func (info *RelayInfo) SetFirstResponseTime() {
|
||||
if !info.setFirstResponse {
|
||||
if info.IsFirstResponse {
|
||||
info.FirstResponseTime = time.Now()
|
||||
info.setFirstResponse = true
|
||||
info.IsFirstResponse = false
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user