diff --git a/common/gopool.go b/common/gopool.go index 91fc62cd..bf5df311 100644 --- a/common/gopool.go +++ b/common/gopool.go @@ -12,7 +12,6 @@ var relayGoPool gopool.Pool func init() { relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig()) relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) { - //check ctx.Value("stop_chan").(chan bool) if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok { SafeSendBool(stopChan, true) } @@ -20,6 +19,6 @@ func init() { }) } -func CtxGo(ctx context.Context, f func()) { +func RelayCtxGo(ctx context.Context, f func()) { relayGoPool.CtxGo(ctx, f) } diff --git a/controller/relay.go b/controller/relay.go index e27ebb80..460599b5 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -16,6 +16,7 @@ import ( "one-api/relay" "one-api/relay/constant" relayconstant "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" "strings" ) @@ -41,15 +42,6 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode return err } -func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode { - var err *dto.OpenAIErrorWithStatusCode - switch relayMode { - default: - err = relay.TextHelper(c) - } - return err -} - func Relay(c *gin.Context) { relayMode := constant.Path2RelayMode(c.Request.URL.Path) requestId := c.GetString(common.RequestIdKey) @@ -110,7 +102,7 @@ func WssRelay(c *gin.Context) { if err != nil { openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError) - service.WssError(c, ws, openaiErr.Error) + helper.WssError(c, ws, openaiErr.Error) return } @@ -152,7 +144,7 @@ func WssRelay(c *gin.Context) { openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" } openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId) - service.WssError(c, ws, openaiErr.Error) + helper.WssError(c, ws, openaiErr.Error) } } diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go index db4df0a9..3fe893b3 100644 --- a/relay/channel/ali/text.go +++ b/relay/channel/ali/text.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/relay/helper" "one-api/service" "strings" ) @@ -153,7 +154,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith } stopChan <- true }() - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) lastResponseText := "" c.Stream(func(w io.Writer) bool { select { diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 1b0882b3..976f97ce 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -12,6 +12,7 @@ import ( relaymodel "one-api/dto" "one-api/relay/channel/claude" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "strings" "time" @@ -203,13 +204,13 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } }) if info.ShouldIncludeUsage { - response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage) - err := service.ObjectData(c, response) + response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage) + err := helper.ObjectData(c, response) if err != nil { common.SysError("send final response failed: " + err.Error()) } } - service.Done(c) + helper.Done(c) if resp != nil { err = resp.Body.Close() if err != nil { diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index d88f5212..62b06413 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -11,6 +11,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/relay/helper" "one-api/service" "strings" "sync" @@ -138,7 +139,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi } stopChan <- true }() - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index e32ee817..5c47cbc4 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -1,7 +1,6 @@ package claude import ( - "bufio" "encoding/json" "fmt" "io" @@ -9,6 +8,7 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "one-api/setting/model_setting" "strings" @@ -443,28 +443,18 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. usage = &dto.Usage{} responseText := "" createdTime := common.GetTimestamp() - scanner := bufio.NewScanner(resp.Body) - scanner.Split(bufio.ScanLines) - service.SetEventStreamHeaders(c) - for scanner.Scan() { - data := scanner.Text() - info.SetFirstResponseTime() - if len(data) < 6 || !strings.HasPrefix(data, "data:") { - continue - } - data = strings.TrimPrefix(data, "data:") - data = strings.TrimSpace(data) + helper.StreamScannerHandler(c, resp, info, func(data string) bool { var claudeResponse ClaudeResponse err := json.Unmarshal([]byte(data), &claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) - continue + return true } response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) if response == nil { - continue + return true } if requestMode == RequestModeCompletion { responseText += claudeResponse.Completion @@ -481,9 +471,9 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. usage.CompletionTokens = claudeUsage.OutputTokens usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens } else if claudeResponse.Type == "content_block_start" { - + return true } else { - continue + return true } } //response.Id = responseId @@ -491,11 +481,12 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. response.Created = createdTime response.Model = info.UpstreamModelName - err = service.ObjectData(c, response) + err = helper.ObjectData(c, response) if err != nil { common.LogError(c, "send_stream_response_failed: "+err.Error()) } - } + return true + }) if requestMode == RequestModeCompletion { usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) @@ -508,13 +499,13 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } } if info.ShouldIncludeUsage { - response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage) - err := service.ObjectData(c, response) + response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage) + err := helper.ObjectData(c, response) if err != nil { common.SysError("send final response failed: " + err.Error()) } } - service.Done(c) + helper.Done(c) resp.Body.Close() return nil, usage } diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index d21e524d..a487429c 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -9,6 +9,7 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "strings" "time" @@ -28,8 +29,8 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) - service.SetEventStreamHeaders(c) - id := service.GetResponseID(c) + helper.SetEventStreamHeaders(c) + id := helper.GetResponseID(c) var responseText string isFirst := true @@ -57,7 +58,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela } response.Id = id response.Model = info.UpstreamModelName - err = service.ObjectData(c, response) + err = helper.ObjectData(c, response) if isFirst { isFirst = false info.FirstResponseTime = time.Now() @@ -72,13 +73,13 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela } usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) if info.ShouldIncludeUsage { - response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) - err := service.ObjectData(c, response) + response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) + err := helper.ObjectData(c, response) if err != nil { common.LogError(c, "error_rendering_final_usage_response: "+err.Error()) } } - service.Done(c) + helper.Done(c) err := resp.Body.Close() if err != nil { @@ -109,7 +110,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) } usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) response.Usage = *usage - response.Id = service.GetResponseID(c) + response.Id = helper.GetResponseID(c) jsonResponse, err := json.Marshal(response) if err != nil { return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index 132039b3..17b58dbc 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -10,6 +10,7 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "strings" "time" @@ -103,7 +104,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } stopChan <- true }() - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) isFirst := true c.Stream(func(w io.Writer) bool { select { diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index 5df34d35..3e62d41c 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -10,6 +10,7 @@ import ( "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "strings" ) @@ -66,7 +67,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) for scanner.Scan() { data := scanner.Text() @@ -92,7 +93,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re responseText += openaiResponse.Choices[0].Delta.GetContentString() } } - err = service.ObjectData(c, openaiResponse) + err = helper.ObjectData(c, openaiResponse) if err != nil { common.SysError(err.Error()) } @@ -100,7 +101,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re if err := scanner.Err(); err != nil { common.SysError("error reading stream: " + err.Error()) } - service.Done(c) + helper.Done(c) err := resp.Body.Close() if err != nil { //return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index d5103124..f6e5df1e 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -1,7 +1,6 @@ package gemini import ( - "bufio" "encoding/json" "fmt" "io" @@ -10,6 +9,7 @@ import ( "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "one-api/setting/model_setting" "strings" @@ -429,10 +429,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) { choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates)) - is_stop := false + isStop := false for _, candidate := range geminiResponse.Candidates { if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" { - is_stop = true + isStop = true candidate.FinishReason = nil } choice := dto.ChatCompletionsStreamResponseChoice{ @@ -482,9 +482,8 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" - response.Model = "gemini" response.Choices = choices - return &response, is_stop + return &response, isStop } func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { @@ -492,27 +491,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) createAt := common.GetTimestamp() var usage = &dto.Usage{} - scanner := bufio.NewScanner(resp.Body) - scanner.Split(bufio.ScanLines) - service.SetEventStreamHeaders(c) - for scanner.Scan() { - data := scanner.Text() - info.SetFirstResponseTime() - data = strings.TrimSpace(data) - if !strings.HasPrefix(data, "data: ") { - continue - } - data = strings.TrimPrefix(data, "data: ") - data = strings.TrimSuffix(data, "\"") + helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse GeminiChatResponse err := json.Unmarshal([]byte(data), &geminiResponse) if err != nil { common.LogError(c, "error unmarshalling stream response: "+err.Error()) - continue + return false } - response, is_stop := streamResponseGeminiChat2OpenAI(&geminiResponse) + response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse) response.Id = id response.Created = createAt response.Model = info.UpstreamModelName @@ -521,15 +509,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount } - err = service.ObjectData(c, response) + err = helper.ObjectData(c, response) if err != nil { common.LogError(c, err.Error()) } - if is_stop { - response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop) - service.ObjectData(c, response) + if isStop { + response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop) + helper.ObjectData(c, response) } - } + return true + }) var response *dto.ChatCompletionsStreamResponse @@ -538,13 +527,13 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens if info.ShouldIncludeUsage { - response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) - err := service.ObjectData(c, response) + response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) + err := helper.ObjectData(c, response) if err != nil { common.SysError("send final response failed: " + err.Error()) } } - service.Done(c) + helper.Done(c) resp.Body.Close() return nil, usage } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 696fa3a0..fd5e3d74 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -1,11 +1,13 @@ package openai import ( - "bufio" "bytes" - "context" "encoding/json" "fmt" + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/pkg/errors" "io" "math" "mime/multipart" @@ -15,16 +17,10 @@ import ( "one-api/dto" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" "os" "strings" - "sync" - "time" - - "github.com/bytedance/gopkg/util/gopool" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" - "github.com/pkg/errors" ) func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { @@ -33,7 +29,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo } if !forceFormat && !thinkToContent { - return service.StringData(c, data) + return helper.StringData(c, data) } var lastStreamResponse dto.ChatCompletionsStreamResponse @@ -42,34 +38,47 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo } if !thinkToContent { - return service.ObjectData(c, lastStreamResponse) + return helper.ObjectData(c, lastStreamResponse) + } + + hasThinkingContent := false + for _, choice := range lastStreamResponse.Choices { + if len(choice.Delta.GetReasoningContent()) > 0 { + hasThinkingContent = true + break + } } // Handle think to content conversion - if info.IsFirstResponse { - response := lastStreamResponse.Copy() - for i := range response.Choices { - response.Choices[i].Delta.SetContentString("\n") - response.Choices[i].Delta.SetReasoningContent("") + if info.ThinkingContentInfo.IsFirstThinkingContent { + if hasThinkingContent { + response := lastStreamResponse.Copy() + for i := range response.Choices { + response.Choices[i].Delta.SetContentString("\n") + response.Choices[i].Delta.SetReasoningContent("") + } + info.ThinkingContentInfo.IsFirstThinkingContent = false + return helper.ObjectData(c, response) + } else { + return helper.ObjectData(c, lastStreamResponse) } - service.ObjectData(c, response) } if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 { - return service.ObjectData(c, lastStreamResponse) + return helper.ObjectData(c, lastStreamResponse) } // Process each choice for i, choice := range lastStreamResponse.Choices { // Handle transition from thinking to content - if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse { + if len(choice.Delta.GetContentString()) > 0 && !info.ThinkingContentInfo.SendLastThinkingContent { response := lastStreamResponse.Copy() for j := range response.Choices { - response.Choices[j].Delta.SetContentString("\n") + response.Choices[j].Delta.SetContentString("\n\n\n") response.Choices[j].Delta.SetReasoningContent("") } - info.SendLastReasoningResponse = true - service.ObjectData(c, response) + info.ThinkingContentInfo.SendLastThinkingContent = true + helper.ObjectData(c, response) } // Convert reasoning content to regular content @@ -79,7 +88,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo } } - return service.ObjectData(c, lastStreamResponse) + return helper.ObjectData(c, lastStreamResponse) } func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { @@ -109,75 +118,23 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } toolCount := 0 - scanner := bufio.NewScanner(resp.Body) - scanner.Split(bufio.ScanLines) - service.SetEventStreamHeaders(c) - streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second - if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") { - // twice timeout for o1 model - streamingTimeout *= 2 - } - ticker := time.NewTicker(streamingTimeout) - defer ticker.Stop() - - stopChan := make(chan bool, 2) - defer close(stopChan) var ( lastStreamData string - mu sync.Mutex ) - ctx := context.WithValue(context.Background(), "stop_chan", stopChan) - - common.CtxGo(ctx, func() { - for scanner.Scan() { - //info.SetFirstResponseTime() - ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) - data := scanner.Text() - if common.DebugEnabled { - println(data) - } - if len(data) < 6 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" && data[:6] != "[DONE]" { - continue - } - mu.Lock() - data = data[5:] - data = strings.TrimSpace(data) - if !strings.HasPrefix(data, "[DONE]") { - if lastStreamData != "" { - err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) - if err != nil { - common.LogError(c, "streaming error: "+err.Error()) - } - info.SetFirstResponseTime() - } - lastStreamData = data - streamItems = append(streamItems, data) - } - mu.Unlock() - } - - if err := scanner.Err(); err != nil { - if err != io.EOF { - common.LogError(c, "scanner error: "+err.Error()) + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + if lastStreamData != "" { + err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) + if err != nil { + common.LogError(c, "streaming error: "+err.Error()) } } - - common.SafeSendBool(stopChan, true) + lastStreamData = data + streamItems = append(streamItems, data) + return true }) - select { - case <-ticker.C: - // 超时处理逻辑 - common.LogError(c, "streaming timeout") - case <-stopChan: - // 正常结束 - } - shouldSendLastResp := true var lastStreamResponse dto.ChatCompletionsStreamResponse err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse) @@ -285,12 +242,12 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } if info.ShouldIncludeUsage && !containStreamUsage { - response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage) + response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage) response.SetSystemFingerprint(systemFingerprint) - service.ObjectData(c, response) + helper.ObjectData(c, response) } - service.Done(c) + helper.Done(c) resp.Body.Close() return nil, usage @@ -523,7 +480,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op localUsage.InputTokenDetails.TextTokens += textToken localUsage.InputTokenDetails.AudioTokens += audioToken - err = service.WssString(c, targetConn, string(message)) + err = helper.WssString(c, targetConn, string(message)) if err != nil { errChan <- fmt.Errorf("error writing to target: %v", err) return @@ -629,7 +586,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op localUsage.OutputTokenDetails.AudioTokens += audioToken } - err = service.WssString(c, clientConn, string(message)) + err = helper.WssString(c, clientConn, string(message)) if err != nil { errChan <- fmt.Errorf("error writing to client: %v", err) return diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 02a3e382..c8e337de 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -9,6 +9,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/relay/helper" "one-api/service" ) @@ -112,7 +113,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit dataChan <- string(jsonResponse) stopChan <- true }() - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index dd3ac93f..5630650f 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -14,6 +14,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/relay/helper" "one-api/service" "strconv" "strings" @@ -91,7 +92,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) for scanner.Scan() { data := scanner.Text() @@ -112,7 +113,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError responseText += response.Choices[0].Delta.GetContentString() } - err = service.ObjectData(c, response) + err = helper.ObjectData(c, response) if err != nil { common.SysError(err.Error()) } @@ -122,7 +123,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError common.SysError("error reading stream: " + err.Error()) } - service.Done(c) + helper.Done(c) err := resp.Body.Close() if err != nil { diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index 067ff6e4..15d33510 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -14,6 +14,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/relay/helper" "one-api/service" "strings" "time" @@ -132,7 +133,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a if err != nil { return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) var usage dto.Usage c.Stream(func(w io.Writer) bool { select { diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 6bdd1c2a..b0cac858 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -10,6 +10,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/relay/helper" "one-api/service" "strings" "sync" @@ -177,7 +178,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi } stopChan <- true }() - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/relay/channel/zhipu_4v/relay-zhipu_v4.go b/relay/channel/zhipu_4v/relay-zhipu_v4.go index 97d82c71..faffec6f 100644 --- a/relay/channel/zhipu_4v/relay-zhipu_v4.go +++ b/relay/channel/zhipu_4v/relay-zhipu_v4.go @@ -10,6 +10,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/relay/helper" "one-api/service" "strings" "sync" @@ -197,7 +198,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi } stopChan <- true }() - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 022ab628..c1d3f4a4 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -12,25 +12,30 @@ import ( "github.com/gorilla/websocket" ) +type ThinkingContentInfo struct { + IsFirstThinkingContent bool + SendLastThinkingContent bool +} + type RelayInfo struct { - ChannelType int - ChannelId int - TokenId int - TokenKey string - UserId int - Group string - TokenUnlimited bool - StartTime time.Time - FirstResponseTime time.Time - IsFirstResponse bool - SendLastReasoningResponse bool - ApiType int - IsStream bool - IsPlayground bool - UsePrice bool - RelayMode int - UpstreamModelName string - OriginModelName string + ChannelType int + ChannelId int + TokenId int + TokenKey string + UserId int + Group string + TokenUnlimited bool + StartTime time.Time + FirstResponseTime time.Time + isFirstResponse bool + //SendLastReasoningResponse bool + ApiType int + IsStream bool + IsPlayground bool + UsePrice bool + RelayMode int + UpstreamModelName string + OriginModelName string //RecodeModelName string RequestURLPath string ApiVersion string @@ -53,6 +58,7 @@ type RelayInfo struct { UserSetting map[string]interface{} UserEmail string UserQuota int + ThinkingContentInfo } // 定义支持流式选项的通道类型 @@ -95,7 +101,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { UserQuota: c.GetInt(constant.ContextKeyUserQuota), UserSetting: c.GetStringMap(constant.ContextKeyUserSetting), UserEmail: c.GetString(constant.ContextKeyUserEmail), - IsFirstResponse: true, + isFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), BaseUrl: c.GetString("base_url"), RequestURLPath: c.Request.URL.String(), @@ -117,6 +123,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), Organization: c.GetString("channel_organization"), ChannelSetting: channelSetting, + ThinkingContentInfo: ThinkingContentInfo{ + IsFirstThinkingContent: true, + SendLastThinkingContent: false, + }, } if strings.HasPrefix(c.Request.URL.Path, "/pg") { info.IsPlayground = true @@ -147,9 +157,9 @@ func (info *RelayInfo) SetIsStream(isStream bool) { } func (info *RelayInfo) SetFirstResponseTime() { - if info.IsFirstResponse { + if info.isFirstResponse { info.FirstResponseTime = time.Now() - info.IsFirstResponse = false + info.isFirstResponse = false } } diff --git a/service/relay.go b/relay/helper/common.go similarity index 99% rename from service/relay.go rename to relay/helper/common.go index 6ffed1e2..2a72d30a 100644 --- a/service/relay.go +++ b/relay/helper/common.go @@ -1,4 +1,4 @@ -package service +package helper import ( "encoding/json" diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go new file mode 100644 index 00000000..07462aa7 --- /dev/null +++ b/relay/helper/stream_scanner.go @@ -0,0 +1,85 @@ +package helper + +import ( + "bufio" + "context" + "io" + "net/http" + "one-api/common" + "one-api/constant" + relaycommon "one-api/relay/common" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) { + + streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second + if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") { + // twice timeout for thinking model + streamingTimeout *= 2 + } + + var ( + stopChan = make(chan bool, 2) + scanner = bufio.NewScanner(resp.Body) + ticker = time.NewTicker(streamingTimeout) + ) + + defer func() { + ticker.Stop() + close(stopChan) + }() + + scanner.Split(bufio.ScanLines) + SetEventStreamHeaders(c) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ctx = context.WithValue(ctx, "stop_chan", stopChan) + common.RelayCtxGo(ctx, func() { + for scanner.Scan() { + ticker.Reset(streamingTimeout) + data := scanner.Text() + if common.DebugEnabled { + println(data) + } + + if len(data) < 6 { + continue + } + if data[:5] != "data:" && data[:6] != "[DONE]" { + continue + } + data = data[5:] + data = strings.TrimLeft(data, " ") + data = strings.TrimSuffix(data, "\"") + if !strings.HasPrefix(data, "[DONE]") { + info.SetFirstResponseTime() + success := dataHandler(data) + if !success { + break + } + } + } + + if err := scanner.Err(); err != nil { + if err != io.EOF { + common.LogError(c, "scanner error: "+err.Error()) + } + } + + common.SafeSendBool(stopChan, true) + }) + + select { + case <-ticker.C: + // 超时处理逻辑 + common.LogError(c, "streaming timeout") + case <-stopChan: + // 正常结束 + } +} diff --git a/setting/model-ratio.go b/setting/model-ratio.go index 711dbeec..54b214f9 100644 --- a/setting/model-ratio.go +++ b/setting/model-ratio.go @@ -326,7 +326,6 @@ func GetModelRatio(name string) (float64, bool) { } ratio, ok := modelRatioMap[name] if !ok { - common.SysError("model ratio not found: " + name) return 37.5, operation_setting.SelfUseModeEnabled } return ratio, true