diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 6c8d9ebe..70ea51bf 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -209,17 +209,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) { account := selection.Account setOpsSelectedAccount(c, account.ID) - // 检查预热请求拦截(在账号选择后、转发前检查) - if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { - if selection.Acquired && selection.ReleaseFunc != nil { - selection.ReleaseFunc() + // 检查请求拦截(预热请求、SUGGESTION MODE等) + if account.IsInterceptWarmupEnabled() { + interceptType := detectInterceptType(body) + if interceptType != InterceptTypeNone { + if selection.Acquired && selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + if reqStream { + sendMockInterceptStream(c, reqModel, interceptType) + } else { + sendMockInterceptResponse(c, reqModel, interceptType) + } + return } - if reqStream { - sendMockWarmupStream(c, reqModel) - } else { - sendMockWarmupResponse(c, reqModel) - } - return } // 3. 获取账号并发槽位 @@ -344,17 +347,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) { account := selection.Account setOpsSelectedAccount(c, account.ID) - // 检查预热请求拦截(在账号选择后、转发前检查) - if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { - if selection.Acquired && selection.ReleaseFunc != nil { - selection.ReleaseFunc() + // 检查请求拦截(预热请求、SUGGESTION MODE等) + if account.IsInterceptWarmupEnabled() { + interceptType := detectInterceptType(body) + if interceptType != InterceptTypeNone { + if selection.Acquired && selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + if reqStream { + sendMockInterceptStream(c, reqModel, interceptType) + } else { + sendMockInterceptResponse(c, reqModel, interceptType) + } + return } - if reqStream { - sendMockWarmupStream(c, reqModel) - } else { - sendMockWarmupResponse(c, reqModel) - } - return } // 3. 获取账号并发槽位 @@ -765,17 +771,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { } } -// isWarmupRequest 检测是否为预热请求(标题生成、Warmup等) -func isWarmupRequest(body []byte) bool { - // 快速检查:如果body不包含关键字,直接返回false +// InterceptType 表示请求拦截类型 +type InterceptType int + +const ( + InterceptTypeNone InterceptType = iota + InterceptTypeWarmup // 预热请求(返回 "New Conversation") + InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串) +) + +// detectInterceptType 检测请求是否需要拦截,返回拦截类型 +func detectInterceptType(body []byte) InterceptType { + // 快速检查:如果不包含任何关键字,直接返回 bodyStr := string(body) - if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") { - return false + hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:") + hasWarmupKeyword := strings.Contains(bodyStr, "title") || strings.Contains(bodyStr, "Warmup") + + if !hasSuggestionMode && !hasWarmupKeyword { + return InterceptTypeNone } - // 解析完整请求 + // 解析请求(只解析一次) var req struct { Messages []struct { + Role string `json:"role"` Content []struct { Type string `json:"type"` Text string `json:"text"` @@ -786,43 +805,71 @@ func isWarmupRequest(body []byte) bool { } `json:"system"` } if err := json.Unmarshal(body, &req); err != nil { - return false + return InterceptTypeNone } - // 检查 messages 中的标题提示模式 - for _, msg := range req.Messages { - for _, content := range msg.Content { - if content.Type == "text" { - if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") || - content.Text == "Warmup" { - return true + // 检查 SUGGESTION MODE(最后一条 user 消息) + if hasSuggestionMode && len(req.Messages) > 0 { + lastMsg := req.Messages[len(req.Messages)-1] + if lastMsg.Role == "user" && len(lastMsg.Content) > 0 && + lastMsg.Content[0].Type == "text" && + strings.HasPrefix(lastMsg.Content[0].Text, "[SUGGESTION MODE:") { + return InterceptTypeSuggestionMode + } + } + + // 检查 Warmup 请求 + if hasWarmupKeyword { + // 检查 messages 中的标题提示模式 + for _, msg := range req.Messages { + for _, content := range msg.Content { + if content.Type == "text" { + if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") || + content.Text == "Warmup" { + return InterceptTypeWarmup + } } } } + // 检查 system 中的标题提取模式 + for _, sys := range req.System { + if strings.Contains(sys.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") { + return InterceptTypeWarmup + } + } } - // 检查 system 中的标题提取模式 - for _, system := range req.System { - if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") { - return true - } - } - - return false + return InterceptTypeNone } -// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截) -func sendMockWarmupStream(c *gin.Context, model string) { +// sendMockInterceptStream 发送流式 mock 响应(用于请求拦截) +func sendMockInterceptStream(c *gin.Context, model string, interceptType InterceptType) { c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("X-Accel-Buffering", "no") + // 根据拦截类型决定响应内容 + var msgID string + var outputTokens int + var textDeltas []string + + switch interceptType { + case InterceptTypeSuggestionMode: + msgID = "msg_mock_suggestion" + outputTokens = 1 + textDeltas = []string{""} // 空内容 + default: // InterceptTypeWarmup + msgID = "msg_mock_warmup" + outputTokens = 2 + textDeltas = []string{"New", " Conversation"} + } + // Build message_start event with proper JSON marshaling messageStart := map[string]any{ "type": "message_start", "message": map[string]any{ - "id": "msg_mock_warmup", + "id": msgID, "type": "message", "role": "assistant", "model": model, @@ -837,16 +884,46 @@ func sendMockWarmupStream(c *gin.Context, model string) { } messageStartJSON, _ := json.Marshal(messageStart) + // Build events events := []string{ `event: message_start` + "\n" + `data: ` + string(messageStartJSON), `event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`, - `event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`, - `event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`, - `event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`, - `event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`, - `event: message_stop` + "\n" + `data: {"type":"message_stop"}`, } + // Add text deltas + for _, text := range textDeltas { + delta := map[string]any{ + "type": "content_block_delta", + "index": 0, + "delta": map[string]string{ + "type": "text_delta", + "text": text, + }, + } + deltaJSON, _ := json.Marshal(delta) + events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON)) + } + + // Add final events + messageDelta := map[string]any{ + "type": "message_delta", + "delta": map[string]any{ + "stop_reason": "end_turn", + "stop_sequence": nil, + }, + "usage": map[string]int{ + "input_tokens": 10, + "output_tokens": outputTokens, + }, + } + messageDeltaJSON, _ := json.Marshal(messageDelta) + + events = append(events, + `event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`, + `event: message_delta`+"\n"+`data: `+string(messageDeltaJSON), + `event: message_stop`+"\n"+`data: {"type":"message_stop"}`, + ) + for _, event := range events { _, _ = c.Writer.WriteString(event + "\n\n") c.Writer.Flush() @@ -854,18 +931,32 @@ func sendMockWarmupStream(c *gin.Context, model string) { } } -// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截) -func sendMockWarmupResponse(c *gin.Context, model string) { +// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截) +func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) { + var msgID, text string + var outputTokens int + + switch interceptType { + case InterceptTypeSuggestionMode: + msgID = "msg_mock_suggestion" + text = "" + outputTokens = 1 + default: // InterceptTypeWarmup + msgID = "msg_mock_warmup" + text = "New Conversation" + outputTokens = 2 + } + c.JSON(http.StatusOK, gin.H{ - "id": "msg_mock_warmup", + "id": msgID, "type": "message", "role": "assistant", "model": model, - "content": []gin.H{{"type": "text", "text": "New Conversation"}}, + "content": []gin.H{{"type": "text", "text": text}}, "stop_reason": "end_turn", "usage": gin.H{ "input_tokens": 10, - "output_tokens": 2, + "output_tokens": outputTokens, }, }) }