refactor: Centralize stream handling and helper functions in relay package
This commit is contained in:
@@ -12,7 +12,6 @@ var relayGoPool gopool.Pool
|
|||||||
func init() {
|
func init() {
|
||||||
relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig())
|
relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig())
|
||||||
relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) {
|
relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) {
|
||||||
//check ctx.Value("stop_chan").(chan bool)
|
|
||||||
if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok {
|
if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok {
|
||||||
SafeSendBool(stopChan, true)
|
SafeSendBool(stopChan, true)
|
||||||
}
|
}
|
||||||
@@ -20,6 +19,6 @@ func init() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func CtxGo(ctx context.Context, f func()) {
|
func RelayCtxGo(ctx context.Context, f func()) {
|
||||||
relayGoPool.CtxGo(ctx, f)
|
relayGoPool.CtxGo(ctx, f)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -41,15 +42,6 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|
||||||
var err *dto.OpenAIErrorWithStatusCode
|
|
||||||
switch relayMode {
|
|
||||||
default:
|
|
||||||
err = relay.TextHelper(c)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func Relay(c *gin.Context) {
|
||||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
@@ -110,7 +102,7 @@ func WssRelay(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
|
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
|
||||||
service.WssError(c, ws, openaiErr.Error)
|
helper.WssError(c, ws, openaiErr.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,7 +144,7 @@ func WssRelay(c *gin.Context) {
|
|||||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||||
}
|
}
|
||||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
||||||
service.WssError(c, ws, openaiErr.Error)
|
helper.WssError(c, ws, openaiErr.Error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -153,7 +154,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
service.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
lastResponseText := ""
|
lastResponseText := ""
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
relaymodel "one-api/dto"
|
relaymodel "one-api/dto"
|
||||||
"one-api/relay/channel/claude"
|
"one-api/relay/channel/claude"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -203,13 +204,13 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
if info.ShouldIncludeUsage {
|
if info.ShouldIncludeUsage {
|
||||||
response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
|
response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
|
||||||
err := service.ObjectData(c, response)
|
err := helper.ObjectData(c, response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("send final response failed: " + err.Error())
|
common.SysError("send final response failed: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
service.Done(c)
|
helper.Done(c)
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -138,7 +139,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
service.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package claude
|
package claude
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -9,6 +8,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting/model_setting"
|
"one-api/setting/model_setting"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -443,28 +443,18 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
usage = &dto.Usage{}
|
usage = &dto.Usage{}
|
||||||
responseText := ""
|
responseText := ""
|
||||||
createdTime := common.GetTimestamp()
|
createdTime := common.GetTimestamp()
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(bufio.ScanLines)
|
|
||||||
service.SetEventStreamHeaders(c)
|
|
||||||
|
|
||||||
for scanner.Scan() {
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
data := scanner.Text()
|
|
||||||
info.SetFirstResponseTime()
|
|
||||||
if len(data) < 6 || !strings.HasPrefix(data, "data:") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = strings.TrimPrefix(data, "data:")
|
|
||||||
data = strings.TrimSpace(data)
|
|
||||||
var claudeResponse ClaudeResponse
|
var claudeResponse ClaudeResponse
|
||||||
err := json.Unmarshal([]byte(data), &claudeResponse)
|
err := json.Unmarshal([]byte(data), &claudeResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
continue
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||||
if response == nil {
|
if response == nil {
|
||||||
continue
|
return true
|
||||||
}
|
}
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
responseText += claudeResponse.Completion
|
responseText += claudeResponse.Completion
|
||||||
@@ -481,9 +471,9 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
usage.CompletionTokens = claudeUsage.OutputTokens
|
usage.CompletionTokens = claudeUsage.OutputTokens
|
||||||
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
|
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
|
||||||
} else if claudeResponse.Type == "content_block_start" {
|
} else if claudeResponse.Type == "content_block_start" {
|
||||||
|
return true
|
||||||
} else {
|
} else {
|
||||||
continue
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//response.Id = responseId
|
//response.Id = responseId
|
||||||
@@ -491,11 +481,12 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
response.Created = createdTime
|
response.Created = createdTime
|
||||||
response.Model = info.UpstreamModelName
|
response.Model = info.UpstreamModelName
|
||||||
|
|
||||||
err = service.ObjectData(c, response)
|
err = helper.ObjectData(c, response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, "send_stream_response_failed: "+err.Error())
|
common.LogError(c, "send_stream_response_failed: "+err.Error())
|
||||||
}
|
}
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
@@ -508,13 +499,13 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if info.ShouldIncludeUsage {
|
if info.ShouldIncludeUsage {
|
||||||
response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
|
response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
|
||||||
err := service.ObjectData(c, response)
|
err := helper.ObjectData(c, response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("send final response failed: " + err.Error())
|
common.SysError("send final response failed: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
service.Done(c)
|
helper.Done(c)
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -28,8 +29,8 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(bufio.ScanLines)
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
service.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
id := service.GetResponseID(c)
|
id := helper.GetResponseID(c)
|
||||||
var responseText string
|
var responseText string
|
||||||
isFirst := true
|
isFirst := true
|
||||||
|
|
||||||
@@ -57,7 +58,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
}
|
}
|
||||||
response.Id = id
|
response.Id = id
|
||||||
response.Model = info.UpstreamModelName
|
response.Model = info.UpstreamModelName
|
||||||
err = service.ObjectData(c, response)
|
err = helper.ObjectData(c, response)
|
||||||
if isFirst {
|
if isFirst {
|
||||||
isFirst = false
|
isFirst = false
|
||||||
info.FirstResponseTime = time.Now()
|
info.FirstResponseTime = time.Now()
|
||||||
@@ -72,13 +73,13 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
}
|
}
|
||||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
if info.ShouldIncludeUsage {
|
if info.ShouldIncludeUsage {
|
||||||
response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||||
err := service.ObjectData(c, response)
|
err := helper.ObjectData(c, response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
|
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
service.Done(c)
|
helper.Done(c)
|
||||||
|
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -109,7 +110,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
|||||||
}
|
}
|
||||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
response.Usage = *usage
|
response.Usage = *usage
|
||||||
response.Id = service.GetResponseID(c)
|
response.Id = helper.GetResponseID(c)
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -103,7 +104,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
service.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
isFirst := true
|
isFirst := true
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -66,7 +67,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(bufio.ScanLines)
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
service.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
@@ -92,7 +93,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|||||||
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = service.ObjectData(c, openaiResponse)
|
err = helper.ObjectData(c, openaiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(err.Error())
|
common.SysError(err.Error())
|
||||||
}
|
}
|
||||||
@@ -100,7 +101,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
common.SysError("error reading stream: " + err.Error())
|
common.SysError("error reading stream: " + err.Error())
|
||||||
}
|
}
|
||||||
service.Done(c)
|
helper.Done(c)
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package gemini
|
package gemini
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -10,6 +9,7 @@ import (
|
|||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting/model_setting"
|
"one-api/setting/model_setting"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -429,10 +429,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
|||||||
|
|
||||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
|
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
|
||||||
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
|
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
|
||||||
is_stop := false
|
isStop := false
|
||||||
for _, candidate := range geminiResponse.Candidates {
|
for _, candidate := range geminiResponse.Candidates {
|
||||||
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
|
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
|
||||||
is_stop = true
|
isStop = true
|
||||||
candidate.FinishReason = nil
|
candidate.FinishReason = nil
|
||||||
}
|
}
|
||||||
choice := dto.ChatCompletionsStreamResponseChoice{
|
choice := dto.ChatCompletionsStreamResponseChoice{
|
||||||
@@ -482,9 +482,8 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
|||||||
|
|
||||||
var response dto.ChatCompletionsStreamResponse
|
var response dto.ChatCompletionsStreamResponse
|
||||||
response.Object = "chat.completion.chunk"
|
response.Object = "chat.completion.chunk"
|
||||||
response.Model = "gemini"
|
|
||||||
response.Choices = choices
|
response.Choices = choices
|
||||||
return &response, is_stop
|
return &response, isStop
|
||||||
}
|
}
|
||||||
|
|
||||||
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
@@ -492,27 +491,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
createAt := common.GetTimestamp()
|
createAt := common.GetTimestamp()
|
||||||
var usage = &dto.Usage{}
|
var usage = &dto.Usage{}
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(bufio.ScanLines)
|
|
||||||
|
|
||||||
service.SetEventStreamHeaders(c)
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
info.SetFirstResponseTime()
|
|
||||||
data = strings.TrimSpace(data)
|
|
||||||
if !strings.HasPrefix(data, "data: ") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = strings.TrimPrefix(data, "data: ")
|
|
||||||
data = strings.TrimSuffix(data, "\"")
|
|
||||||
var geminiResponse GeminiChatResponse
|
var geminiResponse GeminiChatResponse
|
||||||
err := json.Unmarshal([]byte(data), &geminiResponse)
|
err := json.Unmarshal([]byte(data), &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||||
continue
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
response, is_stop := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||||
response.Id = id
|
response.Id = id
|
||||||
response.Created = createAt
|
response.Created = createAt
|
||||||
response.Model = info.UpstreamModelName
|
response.Model = info.UpstreamModelName
|
||||||
@@ -521,15 +509,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
||||||
}
|
}
|
||||||
err = service.ObjectData(c, response)
|
err = helper.ObjectData(c, response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, err.Error())
|
common.LogError(c, err.Error())
|
||||||
}
|
}
|
||||||
if is_stop {
|
if isStop {
|
||||||
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
|
response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
|
||||||
service.ObjectData(c, response)
|
helper.ObjectData(c, response)
|
||||||
}
|
}
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
var response *dto.ChatCompletionsStreamResponse
|
var response *dto.ChatCompletionsStreamResponse
|
||||||
|
|
||||||
@@ -538,13 +527,13 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
|
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
|
||||||
|
|
||||||
if info.ShouldIncludeUsage {
|
if info.ShouldIncludeUsage {
|
||||||
response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||||
err := service.ObjectData(c, response)
|
err := helper.ObjectData(c, response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("send final response failed: " + err.Error())
|
common.SysError("send final response failed: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
service.Done(c)
|
helper.Done(c)
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
@@ -15,16 +17,10 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/bytedance/gopkg/util/gopool"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||||
@@ -33,7 +29,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !forceFormat && !thinkToContent {
|
if !forceFormat && !thinkToContent {
|
||||||
return service.StringData(c, data)
|
return helper.StringData(c, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||||
@@ -42,34 +38,47 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !thinkToContent {
|
if !thinkToContent {
|
||||||
return service.ObjectData(c, lastStreamResponse)
|
return helper.ObjectData(c, lastStreamResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasThinkingContent := false
|
||||||
|
for _, choice := range lastStreamResponse.Choices {
|
||||||
|
if len(choice.Delta.GetReasoningContent()) > 0 {
|
||||||
|
hasThinkingContent = true
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle think to content conversion
|
// Handle think to content conversion
|
||||||
if info.IsFirstResponse {
|
if info.ThinkingContentInfo.IsFirstThinkingContent {
|
||||||
response := lastStreamResponse.Copy()
|
if hasThinkingContent {
|
||||||
for i := range response.Choices {
|
response := lastStreamResponse.Copy()
|
||||||
response.Choices[i].Delta.SetContentString("<think>\n")
|
for i := range response.Choices {
|
||||||
response.Choices[i].Delta.SetReasoningContent("")
|
response.Choices[i].Delta.SetContentString("<think>\n")
|
||||||
|
response.Choices[i].Delta.SetReasoningContent("")
|
||||||
|
}
|
||||||
|
info.ThinkingContentInfo.IsFirstThinkingContent = false
|
||||||
|
return helper.ObjectData(c, response)
|
||||||
|
} else {
|
||||||
|
return helper.ObjectData(c, lastStreamResponse)
|
||||||
}
|
}
|
||||||
service.ObjectData(c, response)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
|
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
|
||||||
return service.ObjectData(c, lastStreamResponse)
|
return helper.ObjectData(c, lastStreamResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process each choice
|
// Process each choice
|
||||||
for i, choice := range lastStreamResponse.Choices {
|
for i, choice := range lastStreamResponse.Choices {
|
||||||
// Handle transition from thinking to content
|
// Handle transition from thinking to content
|
||||||
if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse {
|
if len(choice.Delta.GetContentString()) > 0 && !info.ThinkingContentInfo.SendLastThinkingContent {
|
||||||
response := lastStreamResponse.Copy()
|
response := lastStreamResponse.Copy()
|
||||||
for j := range response.Choices {
|
for j := range response.Choices {
|
||||||
response.Choices[j].Delta.SetContentString("\n</think>")
|
response.Choices[j].Delta.SetContentString("\n</think>\n\n")
|
||||||
response.Choices[j].Delta.SetReasoningContent("")
|
response.Choices[j].Delta.SetReasoningContent("")
|
||||||
}
|
}
|
||||||
info.SendLastReasoningResponse = true
|
info.ThinkingContentInfo.SendLastThinkingContent = true
|
||||||
service.ObjectData(c, response)
|
helper.ObjectData(c, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert reasoning content to regular content
|
// Convert reasoning content to regular content
|
||||||
@@ -79,7 +88,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return service.ObjectData(c, lastStreamResponse)
|
return helper.ObjectData(c, lastStreamResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
@@ -109,75 +118,23 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
|
|
||||||
toolCount := 0
|
toolCount := 0
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(bufio.ScanLines)
|
|
||||||
|
|
||||||
service.SetEventStreamHeaders(c)
|
|
||||||
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
|
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
|
|
||||||
// twice timeout for o1 model
|
|
||||||
streamingTimeout *= 2
|
|
||||||
}
|
|
||||||
ticker := time.NewTicker(streamingTimeout)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
stopChan := make(chan bool, 2)
|
|
||||||
defer close(stopChan)
|
|
||||||
var (
|
var (
|
||||||
lastStreamData string
|
lastStreamData string
|
||||||
mu sync.Mutex
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx := context.WithValue(context.Background(), "stop_chan", stopChan)
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
|
if lastStreamData != "" {
|
||||||
common.CtxGo(ctx, func() {
|
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||||
for scanner.Scan() {
|
if err != nil {
|
||||||
//info.SetFirstResponseTime()
|
common.LogError(c, "streaming error: "+err.Error())
|
||||||
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
|
|
||||||
data := scanner.Text()
|
|
||||||
if common.DebugEnabled {
|
|
||||||
println(data)
|
|
||||||
}
|
|
||||||
if len(data) < 6 { // ignore blank line or wrong format
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if data[:5] != "data:" && data[:6] != "[DONE]" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
mu.Lock()
|
|
||||||
data = data[5:]
|
|
||||||
data = strings.TrimSpace(data)
|
|
||||||
if !strings.HasPrefix(data, "[DONE]") {
|
|
||||||
if lastStreamData != "" {
|
|
||||||
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(c, "streaming error: "+err.Error())
|
|
||||||
}
|
|
||||||
info.SetFirstResponseTime()
|
|
||||||
}
|
|
||||||
lastStreamData = data
|
|
||||||
streamItems = append(streamItems, data)
|
|
||||||
}
|
|
||||||
mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
if err != io.EOF {
|
|
||||||
common.LogError(c, "scanner error: "+err.Error())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
lastStreamData = data
|
||||||
common.SafeSendBool(stopChan, true)
|
streamItems = append(streamItems, data)
|
||||||
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
// 超时处理逻辑
|
|
||||||
common.LogError(c, "streaming timeout")
|
|
||||||
case <-stopChan:
|
|
||||||
// 正常结束
|
|
||||||
}
|
|
||||||
|
|
||||||
shouldSendLastResp := true
|
shouldSendLastResp := true
|
||||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||||
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
|
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
|
||||||
@@ -285,12 +242,12 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
|
|
||||||
if info.ShouldIncludeUsage && !containStreamUsage {
|
if info.ShouldIncludeUsage && !containStreamUsage {
|
||||||
response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
||||||
response.SetSystemFingerprint(systemFingerprint)
|
response.SetSystemFingerprint(systemFingerprint)
|
||||||
service.ObjectData(c, response)
|
helper.ObjectData(c, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
service.Done(c)
|
helper.Done(c)
|
||||||
|
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
return nil, usage
|
return nil, usage
|
||||||
@@ -523,7 +480,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
|||||||
localUsage.InputTokenDetails.TextTokens += textToken
|
localUsage.InputTokenDetails.TextTokens += textToken
|
||||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||||
|
|
||||||
err = service.WssString(c, targetConn, string(message))
|
err = helper.WssString(c, targetConn, string(message))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("error writing to target: %v", err)
|
errChan <- fmt.Errorf("error writing to target: %v", err)
|
||||||
return
|
return
|
||||||
@@ -629,7 +586,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
|||||||
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.WssString(c, clientConn, string(message))
|
err = helper.WssString(c, clientConn, string(message))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("error writing to client: %v", err)
|
errChan <- fmt.Errorf("error writing to client: %v", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -112,7 +113,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
|
|||||||
dataChan <- string(jsonResponse)
|
dataChan <- string(jsonResponse)
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
service.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -91,7 +92,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
|||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(bufio.ScanLines)
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
service.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
@@ -112,7 +113,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
|||||||
responseText += response.Choices[0].Delta.GetContentString()
|
responseText += response.Choices[0].Delta.GetContentString()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.ObjectData(c, response)
|
err = helper.ObjectData(c, response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(err.Error())
|
common.SysError(err.Error())
|
||||||
}
|
}
|
||||||
@@ -122,7 +123,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
|||||||
common.SysError("error reading stream: " + err.Error())
|
common.SysError("error reading stream: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
service.Done(c)
|
helper.Done(c)
|
||||||
|
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -132,7 +133,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
service.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
var usage dto.Usage
|
var usage dto.Usage
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -177,7 +178,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
service.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -197,7 +198,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
service.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
|
|||||||
@@ -12,25 +12,30 @@ import (
|
|||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ThinkingContentInfo struct {
|
||||||
|
IsFirstThinkingContent bool
|
||||||
|
SendLastThinkingContent bool
|
||||||
|
}
|
||||||
|
|
||||||
type RelayInfo struct {
|
type RelayInfo struct {
|
||||||
ChannelType int
|
ChannelType int
|
||||||
ChannelId int
|
ChannelId int
|
||||||
TokenId int
|
TokenId int
|
||||||
TokenKey string
|
TokenKey string
|
||||||
UserId int
|
UserId int
|
||||||
Group string
|
Group string
|
||||||
TokenUnlimited bool
|
TokenUnlimited bool
|
||||||
StartTime time.Time
|
StartTime time.Time
|
||||||
FirstResponseTime time.Time
|
FirstResponseTime time.Time
|
||||||
IsFirstResponse bool
|
isFirstResponse bool
|
||||||
SendLastReasoningResponse bool
|
//SendLastReasoningResponse bool
|
||||||
ApiType int
|
ApiType int
|
||||||
IsStream bool
|
IsStream bool
|
||||||
IsPlayground bool
|
IsPlayground bool
|
||||||
UsePrice bool
|
UsePrice bool
|
||||||
RelayMode int
|
RelayMode int
|
||||||
UpstreamModelName string
|
UpstreamModelName string
|
||||||
OriginModelName string
|
OriginModelName string
|
||||||
//RecodeModelName string
|
//RecodeModelName string
|
||||||
RequestURLPath string
|
RequestURLPath string
|
||||||
ApiVersion string
|
ApiVersion string
|
||||||
@@ -53,6 +58,7 @@ type RelayInfo struct {
|
|||||||
UserSetting map[string]interface{}
|
UserSetting map[string]interface{}
|
||||||
UserEmail string
|
UserEmail string
|
||||||
UserQuota int
|
UserQuota int
|
||||||
|
ThinkingContentInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// 定义支持流式选项的通道类型
|
// 定义支持流式选项的通道类型
|
||||||
@@ -95,7 +101,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
|
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
|
||||||
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
|
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
|
||||||
UserEmail: c.GetString(constant.ContextKeyUserEmail),
|
UserEmail: c.GetString(constant.ContextKeyUserEmail),
|
||||||
IsFirstResponse: true,
|
isFirstResponse: true,
|
||||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||||
BaseUrl: c.GetString("base_url"),
|
BaseUrl: c.GetString("base_url"),
|
||||||
RequestURLPath: c.Request.URL.String(),
|
RequestURLPath: c.Request.URL.String(),
|
||||||
@@ -117,6 +123,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||||
Organization: c.GetString("channel_organization"),
|
Organization: c.GetString("channel_organization"),
|
||||||
ChannelSetting: channelSetting,
|
ChannelSetting: channelSetting,
|
||||||
|
ThinkingContentInfo: ThinkingContentInfo{
|
||||||
|
IsFirstThinkingContent: true,
|
||||||
|
SendLastThinkingContent: false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
||||||
info.IsPlayground = true
|
info.IsPlayground = true
|
||||||
@@ -147,9 +157,9 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (info *RelayInfo) SetFirstResponseTime() {
|
func (info *RelayInfo) SetFirstResponseTime() {
|
||||||
if info.IsFirstResponse {
|
if info.isFirstResponse {
|
||||||
info.FirstResponseTime = time.Now()
|
info.FirstResponseTime = time.Now()
|
||||||
info.IsFirstResponse = false
|
info.isFirstResponse = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package service
|
package helper
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
85
relay/helper/stream_scanner.go
Normal file
85
relay/helper/stream_scanner.go
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
package helper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
|
||||||
|
|
||||||
|
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
|
||||||
|
if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
|
||||||
|
// twice timeout for thinking model
|
||||||
|
streamingTimeout *= 2
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
stopChan = make(chan bool, 2)
|
||||||
|
scanner = bufio.NewScanner(resp.Body)
|
||||||
|
ticker = time.NewTicker(streamingTimeout)
|
||||||
|
)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
ticker.Stop()
|
||||||
|
close(stopChan)
|
||||||
|
}()
|
||||||
|
|
||||||
|
scanner.Split(bufio.ScanLines)
|
||||||
|
SetEventStreamHeaders(c)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
ctx = context.WithValue(ctx, "stop_chan", stopChan)
|
||||||
|
common.RelayCtxGo(ctx, func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
ticker.Reset(streamingTimeout)
|
||||||
|
data := scanner.Text()
|
||||||
|
if common.DebugEnabled {
|
||||||
|
println(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data) < 6 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if data[:5] != "data:" && data[:6] != "[DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = data[5:]
|
||||||
|
data = strings.TrimLeft(data, " ")
|
||||||
|
data = strings.TrimSuffix(data, "\"")
|
||||||
|
if !strings.HasPrefix(data, "[DONE]") {
|
||||||
|
info.SetFirstResponseTime()
|
||||||
|
success := dataHandler(data)
|
||||||
|
if !success {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
common.LogError(c, "scanner error: "+err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
common.SafeSendBool(stopChan, true)
|
||||||
|
})
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
// 超时处理逻辑
|
||||||
|
common.LogError(c, "streaming timeout")
|
||||||
|
case <-stopChan:
|
||||||
|
// 正常结束
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -326,7 +326,6 @@ func GetModelRatio(name string) (float64, bool) {
|
|||||||
}
|
}
|
||||||
ratio, ok := modelRatioMap[name]
|
ratio, ok := modelRatioMap[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
common.SysError("model ratio not found: " + name)
|
|
||||||
return 37.5, operation_setting.SelfUseModeEnabled
|
return 37.5, operation_setting.SelfUseModeEnabled
|
||||||
}
|
}
|
||||||
return ratio, true
|
return ratio, true
|
||||||
|
|||||||
Reference in New Issue
Block a user