refactor: Centralize stream handling and helper functions in relay package

This commit is contained in:
1808837298@qq.com
2025-03-05 19:47:41 +08:00
parent 37bb34b4b0
commit 37a83ecc33
20 changed files with 228 additions and 195 deletions

View File

@@ -12,7 +12,6 @@ var relayGoPool gopool.Pool
func init() { func init() {
relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig()) relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig())
relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) { relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) {
//check ctx.Value("stop_chan").(chan bool)
if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok { if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok {
SafeSendBool(stopChan, true) SafeSendBool(stopChan, true)
} }
@@ -20,6 +19,6 @@ func init() {
}) })
} }
func CtxGo(ctx context.Context, f func()) { func RelayCtxGo(ctx context.Context, f func()) {
relayGoPool.CtxGo(ctx, f) relayGoPool.CtxGo(ctx, f)
} }

View File

@@ -16,6 +16,7 @@ import (
"one-api/relay" "one-api/relay"
"one-api/relay/constant" "one-api/relay/constant"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"strings" "strings"
) )
@@ -41,15 +42,6 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
return err return err
} }
func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
var err *dto.OpenAIErrorWithStatusCode
switch relayMode {
default:
err = relay.TextHelper(c)
}
return err
}
func Relay(c *gin.Context) { func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path) relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey) requestId := c.GetString(common.RequestIdKey)
@@ -110,7 +102,7 @@ func WssRelay(c *gin.Context) {
if err != nil { if err != nil {
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError) openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
service.WssError(c, ws, openaiErr.Error) helper.WssError(c, ws, openaiErr.Error)
return return
} }
@@ -152,7 +144,7 @@ func WssRelay(c *gin.Context) {
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
} }
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId) openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
service.WssError(c, ws, openaiErr.Error) helper.WssError(c, ws, openaiErr.Error)
} }
} }

View File

@@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"strings" "strings"
) )
@@ -153,7 +154,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
} }
stopChan <- true stopChan <- true
}() }()
service.SetEventStreamHeaders(c) helper.SetEventStreamHeaders(c)
lastResponseText := "" lastResponseText := ""
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {

View File

@@ -12,6 +12,7 @@ import (
relaymodel "one-api/dto" relaymodel "one-api/dto"
"one-api/relay/channel/claude" "one-api/relay/channel/claude"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"strings" "strings"
"time" "time"
@@ -203,13 +204,13 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
} }
}) })
if info.ShouldIncludeUsage { if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage) response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
err := service.ObjectData(c, response) err := helper.ObjectData(c, response)
if err != nil { if err != nil {
common.SysError("send final response failed: " + err.Error()) common.SysError("send final response failed: " + err.Error())
} }
} }
service.Done(c) helper.Done(c)
if resp != nil { if resp != nil {
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {

View File

@@ -11,6 +11,7 @@ import (
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"strings" "strings"
"sync" "sync"
@@ -138,7 +139,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
} }
stopChan <- true stopChan <- true
}() }()
service.SetEventStreamHeaders(c) helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:

View File

@@ -1,7 +1,6 @@
package claude package claude
import ( import (
"bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -9,6 +8,7 @@ import (
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting/model_setting" "one-api/setting/model_setting"
"strings" "strings"
@@ -443,28 +443,18 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
usage = &dto.Usage{} usage = &dto.Usage{}
responseText := "" responseText := ""
createdTime := common.GetTimestamp() createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
for scanner.Scan() { helper.StreamScannerHandler(c, resp, info, func(data string) bool {
data := scanner.Text()
info.SetFirstResponseTime()
if len(data) < 6 || !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
data = strings.TrimSpace(data)
var claudeResponse ClaudeResponse var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse) err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
continue return true
} }
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if response == nil { if response == nil {
continue return true
} }
if requestMode == RequestModeCompletion { if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion responseText += claudeResponse.Completion
@@ -481,9 +471,9 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
usage.CompletionTokens = claudeUsage.OutputTokens usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else if claudeResponse.Type == "content_block_start" { } else if claudeResponse.Type == "content_block_start" {
return true
} else { } else {
continue return true
} }
} }
//response.Id = responseId //response.Id = responseId
@@ -491,11 +481,12 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
response.Created = createdTime response.Created = createdTime
response.Model = info.UpstreamModelName response.Model = info.UpstreamModelName
err = service.ObjectData(c, response) err = helper.ObjectData(c, response)
if err != nil { if err != nil {
common.LogError(c, "send_stream_response_failed: "+err.Error()) common.LogError(c, "send_stream_response_failed: "+err.Error())
} }
} return true
})
if requestMode == RequestModeCompletion { if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
@@ -508,13 +499,13 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
} }
} }
if info.ShouldIncludeUsage { if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage) response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response) err := helper.ObjectData(c, response)
if err != nil { if err != nil {
common.SysError("send final response failed: " + err.Error()) common.SysError("send final response failed: " + err.Error())
} }
} }
service.Done(c) helper.Done(c)
resp.Body.Close() resp.Body.Close()
return nil, usage return nil, usage
} }

View File

@@ -9,6 +9,7 @@ import (
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"strings" "strings"
"time" "time"
@@ -28,8 +29,8 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines) scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c) helper.SetEventStreamHeaders(c)
id := service.GetResponseID(c) id := helper.GetResponseID(c)
var responseText string var responseText string
isFirst := true isFirst := true
@@ -57,7 +58,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
} }
response.Id = id response.Id = id
response.Model = info.UpstreamModelName response.Model = info.UpstreamModelName
err = service.ObjectData(c, response) err = helper.ObjectData(c, response)
if isFirst { if isFirst {
isFirst = false isFirst = false
info.FirstResponseTime = time.Now() info.FirstResponseTime = time.Now()
@@ -72,13 +73,13 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
} }
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage { if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := service.ObjectData(c, response) err := helper.ObjectData(c, response)
if err != nil { if err != nil {
common.LogError(c, "error_rendering_final_usage_response: "+err.Error()) common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
} }
} }
service.Done(c) helper.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
@@ -109,7 +110,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
} }
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
response.Usage = *usage response.Usage = *usage
response.Id = service.GetResponseID(c) response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -10,6 +10,7 @@ import (
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"strings" "strings"
"time" "time"
@@ -103,7 +104,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
} }
stopChan <- true stopChan <- true
}() }()
service.SetEventStreamHeaders(c) helper.SetEventStreamHeaders(c)
isFirst := true isFirst := true
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {

View File

@@ -10,6 +10,7 @@ import (
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"strings" "strings"
) )
@@ -66,7 +67,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines) scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c) helper.SetEventStreamHeaders(c)
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
@@ -92,7 +93,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
responseText += openaiResponse.Choices[0].Delta.GetContentString() responseText += openaiResponse.Choices[0].Delta.GetContentString()
} }
} }
err = service.ObjectData(c, openaiResponse) err = helper.ObjectData(c, openaiResponse)
if err != nil { if err != nil {
common.SysError(err.Error()) common.SysError(err.Error())
} }
@@ -100,7 +101,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error()) common.SysError("error reading stream: " + err.Error())
} }
service.Done(c) helper.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil //return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -1,7 +1,6 @@
package gemini package gemini
import ( import (
"bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -10,6 +9,7 @@ import (
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting/model_setting" "one-api/setting/model_setting"
"strings" "strings"
@@ -429,10 +429,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) { func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates)) choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
is_stop := false isStop := false
for _, candidate := range geminiResponse.Candidates { for _, candidate := range geminiResponse.Candidates {
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" { if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
is_stop = true isStop = true
candidate.FinishReason = nil candidate.FinishReason = nil
} }
choice := dto.ChatCompletionsStreamResponseChoice{ choice := dto.ChatCompletionsStreamResponseChoice{
@@ -482,9 +482,8 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
var response dto.ChatCompletionsStreamResponse var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = choices response.Choices = choices
return &response, is_stop return &response, isStop
} }
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -492,27 +491,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp() createAt := common.GetTimestamp()
var usage = &dto.Usage{} var usage = &dto.Usage{}
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c) helper.StreamScannerHandler(c, resp, info, func(data string) bool {
for scanner.Scan() {
data := scanner.Text()
info.SetFirstResponseTime()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\"")
var geminiResponse GeminiChatResponse var geminiResponse GeminiChatResponse
err := json.Unmarshal([]byte(data), &geminiResponse) err := json.Unmarshal([]byte(data), &geminiResponse)
if err != nil { if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error()) common.LogError(c, "error unmarshalling stream response: "+err.Error())
continue return false
} }
response, is_stop := streamResponseGeminiChat2OpenAI(&geminiResponse) response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
response.Id = id response.Id = id
response.Created = createAt response.Created = createAt
response.Model = info.UpstreamModelName response.Model = info.UpstreamModelName
@@ -521,15 +509,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
} }
err = service.ObjectData(c, response) err = helper.ObjectData(c, response)
if err != nil { if err != nil {
common.LogError(c, err.Error()) common.LogError(c, err.Error())
} }
if is_stop { if isStop {
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop) response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
service.ObjectData(c, response) helper.ObjectData(c, response)
} }
} return true
})
var response *dto.ChatCompletionsStreamResponse var response *dto.ChatCompletionsStreamResponse
@@ -538,13 +527,13 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
if info.ShouldIncludeUsage { if info.ShouldIncludeUsage {
response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response) err := helper.ObjectData(c, response)
if err != nil { if err != nil {
common.SysError("send final response failed: " + err.Error()) common.SysError("send final response failed: " + err.Error())
} }
} }
service.Done(c) helper.Done(c)
resp.Body.Close() resp.Body.Close()
return nil, usage return nil, usage
} }

View File

@@ -1,11 +1,13 @@
package openai package openai
import ( import (
"bufio"
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"io" "io"
"math" "math"
"mime/multipart" "mime/multipart"
@@ -15,16 +17,10 @@ import (
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"os" "os"
"strings" "strings"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
) )
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
@@ -33,7 +29,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
} }
if !forceFormat && !thinkToContent { if !forceFormat && !thinkToContent {
return service.StringData(c, data) return helper.StringData(c, data)
} }
var lastStreamResponse dto.ChatCompletionsStreamResponse var lastStreamResponse dto.ChatCompletionsStreamResponse
@@ -42,34 +38,47 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
} }
if !thinkToContent { if !thinkToContent {
return service.ObjectData(c, lastStreamResponse) return helper.ObjectData(c, lastStreamResponse)
}
hasThinkingContent := false
for _, choice := range lastStreamResponse.Choices {
if len(choice.Delta.GetReasoningContent()) > 0 {
hasThinkingContent = true
break
}
} }
// Handle think to content conversion // Handle think to content conversion
if info.IsFirstResponse { if info.ThinkingContentInfo.IsFirstThinkingContent {
response := lastStreamResponse.Copy() if hasThinkingContent {
for i := range response.Choices { response := lastStreamResponse.Copy()
response.Choices[i].Delta.SetContentString("<think>\n") for i := range response.Choices {
response.Choices[i].Delta.SetReasoningContent("") response.Choices[i].Delta.SetContentString("<think>\n")
response.Choices[i].Delta.SetReasoningContent("")
}
info.ThinkingContentInfo.IsFirstThinkingContent = false
return helper.ObjectData(c, response)
} else {
return helper.ObjectData(c, lastStreamResponse)
} }
service.ObjectData(c, response)
} }
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 { if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
return service.ObjectData(c, lastStreamResponse) return helper.ObjectData(c, lastStreamResponse)
} }
// Process each choice // Process each choice
for i, choice := range lastStreamResponse.Choices { for i, choice := range lastStreamResponse.Choices {
// Handle transition from thinking to content // Handle transition from thinking to content
if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse { if len(choice.Delta.GetContentString()) > 0 && !info.ThinkingContentInfo.SendLastThinkingContent {
response := lastStreamResponse.Copy() response := lastStreamResponse.Copy()
for j := range response.Choices { for j := range response.Choices {
response.Choices[j].Delta.SetContentString("\n</think>") response.Choices[j].Delta.SetContentString("\n</think>\n\n")
response.Choices[j].Delta.SetReasoningContent("") response.Choices[j].Delta.SetReasoningContent("")
} }
info.SendLastReasoningResponse = true info.ThinkingContentInfo.SendLastThinkingContent = true
service.ObjectData(c, response) helper.ObjectData(c, response)
} }
// Convert reasoning content to regular content // Convert reasoning content to regular content
@@ -79,7 +88,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
} }
} }
return service.ObjectData(c, lastStreamResponse) return helper.ObjectData(c, lastStreamResponse)
} }
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -109,75 +118,23 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
} }
toolCount := 0 toolCount := 0
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
// twice timeout for o1 model
streamingTimeout *= 2
}
ticker := time.NewTicker(streamingTimeout)
defer ticker.Stop()
stopChan := make(chan bool, 2)
defer close(stopChan)
var ( var (
lastStreamData string lastStreamData string
mu sync.Mutex
) )
ctx := context.WithValue(context.Background(), "stop_chan", stopChan) helper.StreamScannerHandler(c, resp, info, func(data string) bool {
if lastStreamData != "" {
common.CtxGo(ctx, func() { err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
for scanner.Scan() { if err != nil {
//info.SetFirstResponseTime() common.LogError(c, "streaming error: "+err.Error())
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
data := scanner.Text()
if common.DebugEnabled {
println(data)
}
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" && data[:6] != "[DONE]" {
continue
}
mu.Lock()
data = data[5:]
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "[DONE]") {
if lastStreamData != "" {
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
if err != nil {
common.LogError(c, "streaming error: "+err.Error())
}
info.SetFirstResponseTime()
}
lastStreamData = data
streamItems = append(streamItems, data)
}
mu.Unlock()
}
if err := scanner.Err(); err != nil {
if err != io.EOF {
common.LogError(c, "scanner error: "+err.Error())
} }
} }
lastStreamData = data
common.SafeSendBool(stopChan, true) streamItems = append(streamItems, data)
return true
}) })
select {
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
case <-stopChan:
// 正常结束
}
shouldSendLastResp := true shouldSendLastResp := true
var lastStreamResponse dto.ChatCompletionsStreamResponse var lastStreamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse) err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
@@ -285,12 +242,12 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
} }
if info.ShouldIncludeUsage && !containStreamUsage { if info.ShouldIncludeUsage && !containStreamUsage {
response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage) response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint) response.SetSystemFingerprint(systemFingerprint)
service.ObjectData(c, response) helper.ObjectData(c, response)
} }
service.Done(c) helper.Done(c)
resp.Body.Close() resp.Body.Close()
return nil, usage return nil, usage
@@ -523,7 +480,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
localUsage.InputTokenDetails.TextTokens += textToken localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken localUsage.InputTokenDetails.AudioTokens += audioToken
err = service.WssString(c, targetConn, string(message)) err = helper.WssString(c, targetConn, string(message))
if err != nil { if err != nil {
errChan <- fmt.Errorf("error writing to target: %v", err) errChan <- fmt.Errorf("error writing to target: %v", err)
return return
@@ -629,7 +586,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
localUsage.OutputTokenDetails.AudioTokens += audioToken localUsage.OutputTokenDetails.AudioTokens += audioToken
} }
err = service.WssString(c, clientConn, string(message)) err = helper.WssString(c, clientConn, string(message))
if err != nil { if err != nil {
errChan <- fmt.Errorf("error writing to client: %v", err) errChan <- fmt.Errorf("error writing to client: %v", err)
return return

View File

@@ -9,6 +9,7 @@ import (
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/relay/helper"
"one-api/service" "one-api/service"
) )
@@ -112,7 +113,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
dataChan <- string(jsonResponse) dataChan <- string(jsonResponse)
stopChan <- true stopChan <- true
}() }()
service.SetEventStreamHeaders(c) helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:

View File

@@ -14,6 +14,7 @@ import (
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"strconv" "strconv"
"strings" "strings"
@@ -91,7 +92,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines) scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c) helper.SetEventStreamHeaders(c)
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
@@ -112,7 +113,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
responseText += response.Choices[0].Delta.GetContentString() responseText += response.Choices[0].Delta.GetContentString()
} }
err = service.ObjectData(c, response) err = helper.ObjectData(c, response)
if err != nil { if err != nil {
common.SysError(err.Error()) common.SysError(err.Error())
} }
@@ -122,7 +123,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
common.SysError("error reading stream: " + err.Error()) common.SysError("error reading stream: " + err.Error())
} }
service.Done(c) helper.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {

View File

@@ -14,6 +14,7 @@ import (
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"strings" "strings"
"time" "time"
@@ -132,7 +133,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
} }
service.SetEventStreamHeaders(c) helper.SetEventStreamHeaders(c)
var usage dto.Usage var usage dto.Usage
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {

View File

@@ -10,6 +10,7 @@ import (
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"strings" "strings"
"sync" "sync"
@@ -177,7 +178,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
} }
stopChan <- true stopChan <- true
}() }()
service.SetEventStreamHeaders(c) helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:

View File

@@ -10,6 +10,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"strings" "strings"
"sync" "sync"
@@ -197,7 +198,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
} }
stopChan <- true stopChan <- true
}() }()
service.SetEventStreamHeaders(c) helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:

View File

@@ -12,25 +12,30 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
type ThinkingContentInfo struct {
IsFirstThinkingContent bool
SendLastThinkingContent bool
}
type RelayInfo struct { type RelayInfo struct {
ChannelType int ChannelType int
ChannelId int ChannelId int
TokenId int TokenId int
TokenKey string TokenKey string
UserId int UserId int
Group string Group string
TokenUnlimited bool TokenUnlimited bool
StartTime time.Time StartTime time.Time
FirstResponseTime time.Time FirstResponseTime time.Time
IsFirstResponse bool isFirstResponse bool
SendLastReasoningResponse bool //SendLastReasoningResponse bool
ApiType int ApiType int
IsStream bool IsStream bool
IsPlayground bool IsPlayground bool
UsePrice bool UsePrice bool
RelayMode int RelayMode int
UpstreamModelName string UpstreamModelName string
OriginModelName string OriginModelName string
//RecodeModelName string //RecodeModelName string
RequestURLPath string RequestURLPath string
ApiVersion string ApiVersion string
@@ -53,6 +58,7 @@ type RelayInfo struct {
UserSetting map[string]interface{} UserSetting map[string]interface{}
UserEmail string UserEmail string
UserQuota int UserQuota int
ThinkingContentInfo
} }
// 定义支持流式选项的通道类型 // 定义支持流式选项的通道类型
@@ -95,7 +101,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
UserQuota: c.GetInt(constant.ContextKeyUserQuota), UserQuota: c.GetInt(constant.ContextKeyUserQuota),
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting), UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
UserEmail: c.GetString(constant.ContextKeyUserEmail), UserEmail: c.GetString(constant.ContextKeyUserEmail),
IsFirstResponse: true, isFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"), BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(), RequestURLPath: c.Request.URL.String(),
@@ -117,6 +123,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"), Organization: c.GetString("channel_organization"),
ChannelSetting: channelSetting, ChannelSetting: channelSetting,
ThinkingContentInfo: ThinkingContentInfo{
IsFirstThinkingContent: true,
SendLastThinkingContent: false,
},
} }
if strings.HasPrefix(c.Request.URL.Path, "/pg") { if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true info.IsPlayground = true
@@ -147,9 +157,9 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
} }
func (info *RelayInfo) SetFirstResponseTime() { func (info *RelayInfo) SetFirstResponseTime() {
if info.IsFirstResponse { if info.isFirstResponse {
info.FirstResponseTime = time.Now() info.FirstResponseTime = time.Now()
info.IsFirstResponse = false info.isFirstResponse = false
} }
} }

View File

@@ -1,4 +1,4 @@
package service package helper
import ( import (
"encoding/json" "encoding/json"

View File

@@ -0,0 +1,85 @@
package helper
import (
"bufio"
"context"
"io"
"net/http"
"one-api/common"
"one-api/constant"
relaycommon "one-api/relay/common"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
// twice timeout for thinking model
streamingTimeout *= 2
}
var (
stopChan = make(chan bool, 2)
scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout)
)
defer func() {
ticker.Stop()
close(stopChan)
}()
scanner.Split(bufio.ScanLines)
SetEventStreamHeaders(c)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = context.WithValue(ctx, "stop_chan", stopChan)
common.RelayCtxGo(ctx, func() {
for scanner.Scan() {
ticker.Reset(streamingTimeout)
data := scanner.Text()
if common.DebugEnabled {
println(data)
}
if len(data) < 6 {
continue
}
if data[:5] != "data:" && data[:6] != "[DONE]" {
continue
}
data = data[5:]
data = strings.TrimLeft(data, " ")
data = strings.TrimSuffix(data, "\"")
if !strings.HasPrefix(data, "[DONE]") {
info.SetFirstResponseTime()
success := dataHandler(data)
if !success {
break
}
}
}
if err := scanner.Err(); err != nil {
if err != io.EOF {
common.LogError(c, "scanner error: "+err.Error())
}
}
common.SafeSendBool(stopChan, true)
})
select {
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
case <-stopChan:
// 正常结束
}
}

View File

@@ -326,7 +326,6 @@ func GetModelRatio(name string) (float64, bool) {
} }
ratio, ok := modelRatioMap[name] ratio, ok := modelRatioMap[name]
if !ok { if !ok {
common.SysError("model ratio not found: " + name)
return 37.5, operation_setting.SelfUseModeEnabled return 37.5, operation_setting.SelfUseModeEnabled
} }
return ratio, true return ratio, true