fix: stabilize thinking streams, multimodal parsing, and token accounting (#20)
* fix: stabilize multimodal image compatibility across OpenCode flows Advertise vision-capable metadata in /v1/models and make model matching deterministic so OpenCode does not downgrade image support or route 4.6 models incorrectly. Expand request translation to accept OpenCode/OpenAI attachment shapes, sanitize [Image N] placeholders safely, keep image-only follow-up turns non-empty, and improve token accounting so base64 image bytes no longer inflate prompt token usage and trigger premature compaction. * fix: deduplicate thinking streams and trim injected prompt noise * fix: align /v1/messages thinking blocks and message_start usage * fix: reduce repetitive thinking across tool turns Select a single reasoning stream source, prevent chunk replay, and preserve structured tool-loop context so the model keeps continuity instead of re-planning each turn. * fix: unify token counting on existing API endpoints Compute usage deterministically on /v1/messages and /v1/chat/completions even when upstream omits tokenUsage. - remove roo-only token path and keep behavior on existing endpoints - add proxy/token_estimator.go with shared Claude/OpenAI estimators (input/system/messages/tools + output/thinking/tool calls) - wire stream/non-stream handlers to use estimator-derived input/output usage - update /v1/messages/count_tokens to reuse the same estimator - keep robust upstream usage parsing/normalization in proxy/kiro.go while dropping parser-level estimate fallback Why: direct upstream tests show metering/context events frequently arrive without tokenUsage in this environment; this made usage zero or inconsistent. Local deterministic accounting keeps reported usage stable and explicit.
This commit is contained in:
568
proxy/handler.go
568
proxy/handler.go
@@ -35,6 +35,32 @@ type Handler struct {
|
||||
modelsCacheTime int64
|
||||
}
|
||||
|
||||
type thinkingStreamSource int
|
||||
|
||||
const (
|
||||
thinkingSourceUnknown thinkingStreamSource = iota
|
||||
thinkingSourceReasoningEvent
|
||||
thinkingSourceTagBlock
|
||||
)
|
||||
|
||||
func allowReasoningSource(source *thinkingStreamSource) bool {
|
||||
if *source == thinkingSourceTagBlock {
|
||||
return false
|
||||
}
|
||||
*source = thinkingSourceReasoningEvent
|
||||
return true
|
||||
}
|
||||
|
||||
func allowTagSource(source *thinkingStreamSource) bool {
|
||||
if *source == thinkingSourceReasoningEvent {
|
||||
return false
|
||||
}
|
||||
if *source == thinkingSourceUnknown {
|
||||
*source = thinkingSourceTagBlock
|
||||
}
|
||||
return *source == thinkingSourceTagBlock
|
||||
}
|
||||
|
||||
func NewHandler() *Handler {
|
||||
totalReq, successReq, failedReq, totalTokens, totalCredits := config.GetStats()
|
||||
h := &Handler{
|
||||
@@ -248,36 +274,33 @@ func (h *Handler) handleModels(w http.ResponseWriter, r *http.Request) {
|
||||
var models []map[string]interface{}
|
||||
if len(cached) > 0 {
|
||||
for _, m := range cached {
|
||||
models = append(models, map[string]interface{}{
|
||||
"id": m.ModelId, "object": "model", "owned_by": "anthropic",
|
||||
})
|
||||
supportsImage := modelSupportsImage(m.InputTypes)
|
||||
models = append(models, buildModelInfo(m.ModelId, "anthropic", supportsImage))
|
||||
// 自动生成 thinking 变体
|
||||
models = append(models, map[string]interface{}{
|
||||
"id": m.ModelId + thinkingSuffix, "object": "model", "owned_by": "anthropic",
|
||||
})
|
||||
models = append(models, buildModelInfo(m.ModelId+thinkingSuffix, "anthropic", supportsImage))
|
||||
}
|
||||
} else {
|
||||
// fallback 静态列表
|
||||
models = []map[string]interface{}{
|
||||
{"id": "claude-sonnet-4.6", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-sonnet-4.6" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-opus-4.6", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-opus-4.6" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-sonnet-4.5", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-sonnet-4.5" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-sonnet-4", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-sonnet-4" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-haiku-4.5", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-haiku-4.5" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-opus-4.5", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-opus-4.5" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
|
||||
buildModelInfo("claude-sonnet-4.6", "anthropic", true),
|
||||
buildModelInfo("claude-sonnet-4.6"+thinkingSuffix, "anthropic", true),
|
||||
buildModelInfo("claude-opus-4.6", "anthropic", true),
|
||||
buildModelInfo("claude-opus-4.6"+thinkingSuffix, "anthropic", true),
|
||||
buildModelInfo("claude-sonnet-4.5", "anthropic", true),
|
||||
buildModelInfo("claude-sonnet-4.5"+thinkingSuffix, "anthropic", true),
|
||||
buildModelInfo("claude-sonnet-4", "anthropic", true),
|
||||
buildModelInfo("claude-sonnet-4"+thinkingSuffix, "anthropic", true),
|
||||
buildModelInfo("claude-haiku-4.5", "anthropic", true),
|
||||
buildModelInfo("claude-haiku-4.5"+thinkingSuffix, "anthropic", true),
|
||||
buildModelInfo("claude-opus-4.5", "anthropic", true),
|
||||
buildModelInfo("claude-opus-4.5"+thinkingSuffix, "anthropic", true),
|
||||
}
|
||||
}
|
||||
// 添加别名模型
|
||||
models = append(models,
|
||||
map[string]interface{}{"id": "auto", "object": "model", "owned_by": "kiro-proxy"},
|
||||
map[string]interface{}{"id": "gpt-4o", "object": "model", "owned_by": "kiro-proxy"},
|
||||
map[string]interface{}{"id": "gpt-4", "object": "model", "owned_by": "kiro-proxy"},
|
||||
buildModelInfo("auto", "kiro-proxy", true),
|
||||
buildModelInfo("gpt-4o", "kiro-proxy", true),
|
||||
buildModelInfo("gpt-4", "kiro-proxy", true),
|
||||
)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
@@ -287,6 +310,49 @@ func (h *Handler) handleModels(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
func modelSupportsImage(inputTypes []string) bool {
|
||||
for _, t := range inputTypes {
|
||||
lt := strings.ToLower(t)
|
||||
if strings.Contains(lt, "image") || strings.Contains(lt, "vision") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildModelInfo(id, ownedBy string, supportsImage bool) map[string]interface{} {
|
||||
modalities := []string{"text"}
|
||||
if supportsImage {
|
||||
modalities = append(modalities, "image")
|
||||
}
|
||||
modalitiesMap := map[string][]string{
|
||||
"input": modalities,
|
||||
"output": []string{"text"},
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "model",
|
||||
"owned_by": ownedBy,
|
||||
"supports_image": supportsImage,
|
||||
"input_modalities": modalities,
|
||||
"modalities": modalitiesMap,
|
||||
"capabilities": map[string]bool{
|
||||
"vision": supportsImage,
|
||||
"image": supportsImage,
|
||||
"image_vision": supportsImage,
|
||||
},
|
||||
"info": map[string]interface{}{
|
||||
"meta": map[string]interface{}{
|
||||
"capabilities": map[string]bool{
|
||||
"vision": supportsImage,
|
||||
"image_vision": supportsImage,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// refreshModelsCache 从 Kiro API 拉取模型列表并缓存
|
||||
func (h *Handler) refreshModelsCache() {
|
||||
account := h.pool.GetNext()
|
||||
@@ -327,50 +393,13 @@ func (h *Handler) handleCountTokens(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content interface{} `json:"content"`
|
||||
} `json:"messages"`
|
||||
System interface{} `json:"system"`
|
||||
}
|
||||
var req ClaudeRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
h.sendClaudeError(w, 400, "invalid_request_error", "Invalid JSON")
|
||||
return
|
||||
}
|
||||
|
||||
// 简单估算 token 数量(每 4 个字符约 1 个 token)
|
||||
var totalChars int
|
||||
for _, msg := range req.Messages {
|
||||
switch content := msg.Content.(type) {
|
||||
case string:
|
||||
totalChars += len(content)
|
||||
case []interface{}:
|
||||
for _, part := range content {
|
||||
if p, ok := part.(map[string]interface{}); ok {
|
||||
if text, ok := p["text"].(string); ok {
|
||||
totalChars += len(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 系统提示
|
||||
switch system := req.System.(type) {
|
||||
case string:
|
||||
totalChars += len(system)
|
||||
case []interface{}:
|
||||
for _, part := range system {
|
||||
if p, ok := part.(map[string]interface{}); ok {
|
||||
if text, ok := p["text"].(string); ok {
|
||||
totalChars += len(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
estimatedTokens := (totalChars + 3) / 4 // 向上取整
|
||||
estimatedTokens := estimateClaudeRequestInputTokens(&req)
|
||||
if estimatedTokens < 1 {
|
||||
estimatedTokens = 1
|
||||
}
|
||||
@@ -381,6 +410,10 @@ func (h *Handler) handleCountTokens(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// handleClaudeMessages Claude API 处理
|
||||
func (h *Handler) handleClaudeMessages(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleClaudeMessagesInternal(w, r)
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeMessagesInternal(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method Not Allowed", 405)
|
||||
return
|
||||
@@ -416,20 +449,21 @@ func (h *Handler) handleClaudeMessages(w http.ResponseWriter, r *http.Request) {
|
||||
thinkingCfg := config.GetThinkingConfig()
|
||||
actualModel, thinking := ParseModelAndThinking(req.Model, thinkingCfg.Suffix)
|
||||
req.Model = actualModel
|
||||
estimatedInputTokens := estimateClaudeRequestInputTokens(&req)
|
||||
|
||||
// 转换请求
|
||||
kiroPayload := ClaudeToKiro(&req, thinking)
|
||||
|
||||
// 流式或非流式
|
||||
if req.Stream {
|
||||
h.handleClaudeStream(w, account, kiroPayload, req.Model)
|
||||
h.handleClaudeStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens)
|
||||
} else {
|
||||
h.handleClaudeNonStream(w, account, kiroPayload, req.Model)
|
||||
h.handleClaudeNonStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// handleClaudeStream Claude 流式响应
|
||||
func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string) {
|
||||
func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int) {
|
||||
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
@@ -444,91 +478,169 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
|
||||
thinkingFormat := config.GetThinkingConfig().ClaudeFormat
|
||||
|
||||
msgID := "msg_" + uuid.New().String()
|
||||
var contentStarted bool
|
||||
var toolUseIndex int
|
||||
var inputTokens, outputTokens int
|
||||
var credits float64
|
||||
var toolUses []KiroToolUse
|
||||
var nextContentIndex int
|
||||
var rawContentBuilder strings.Builder
|
||||
var rawThinkingBuilder strings.Builder
|
||||
activeBlockIndex := -1
|
||||
activeBlockType := ""
|
||||
startInputTokens := estimatedInputTokens
|
||||
|
||||
closeActiveBlock := func() {
|
||||
if activeBlockIndex < 0 {
|
||||
return
|
||||
}
|
||||
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": activeBlockIndex,
|
||||
})
|
||||
activeBlockIndex = -1
|
||||
activeBlockType = ""
|
||||
}
|
||||
|
||||
startContentBlock := func(blockType string) {
|
||||
if activeBlockType == blockType {
|
||||
return
|
||||
}
|
||||
closeActiveBlock()
|
||||
|
||||
idx := nextContentIndex
|
||||
nextContentIndex++
|
||||
|
||||
if blockType == "thinking" {
|
||||
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": idx,
|
||||
"content_block": map[string]string{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
},
|
||||
})
|
||||
} else {
|
||||
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": idx,
|
||||
"content_block": map[string]string{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
activeBlockIndex = idx
|
||||
activeBlockType = blockType
|
||||
}
|
||||
|
||||
// Thinking 标签解析状态
|
||||
var textBuffer string
|
||||
var inThinkingBlock bool
|
||||
var dropTagThinking bool
|
||||
var thinkingSource thinkingStreamSource
|
||||
|
||||
// 发送文本的辅助函数
|
||||
// thinkingState: 0=普通内容, 1=thinking开始, 2=thinking中间, 3=thinking结束
|
||||
sendText := func(text string, thinkingState int) {
|
||||
// 确保 content_block 已开始
|
||||
if !contentStarted {
|
||||
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": map[string]string{"type": "text", "text": ""},
|
||||
})
|
||||
contentStarted = true
|
||||
}
|
||||
|
||||
if thinkingState == 0 {
|
||||
// 普通内容
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
startContentBlock("text")
|
||||
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"index": activeBlockIndex,
|
||||
"delta": map[string]string{"type": "text_delta", "text": text},
|
||||
})
|
||||
} else {
|
||||
// thinking 内容
|
||||
return
|
||||
}
|
||||
|
||||
if !thinking {
|
||||
return
|
||||
}
|
||||
|
||||
switch thinkingFormat {
|
||||
case "think":
|
||||
var outputText string
|
||||
switch thinkingFormat {
|
||||
case "think":
|
||||
switch thinkingState {
|
||||
case 1:
|
||||
outputText = "<think>" + text
|
||||
case 2:
|
||||
outputText = text
|
||||
case 3:
|
||||
outputText = text + "</think>"
|
||||
}
|
||||
case "reasoning_content":
|
||||
// Claude 格式不支持 reasoning_content,直接输出内容
|
||||
switch thinkingState {
|
||||
case 1:
|
||||
outputText = "<think>" + text
|
||||
case 2:
|
||||
outputText = text
|
||||
default: // "thinking"
|
||||
switch thinkingState {
|
||||
case 1:
|
||||
outputText = "<thinking>" + text
|
||||
case 2:
|
||||
outputText = text
|
||||
case 3:
|
||||
outputText = text + "</thinking>"
|
||||
}
|
||||
case 3:
|
||||
outputText = text + "</think>"
|
||||
}
|
||||
if outputText == "" {
|
||||
return
|
||||
}
|
||||
startContentBlock("text")
|
||||
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"index": activeBlockIndex,
|
||||
"delta": map[string]string{"type": "text_delta", "text": outputText},
|
||||
})
|
||||
case "reasoning_content":
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
startContentBlock("text")
|
||||
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": activeBlockIndex,
|
||||
"delta": map[string]string{"type": "text_delta", "text": text},
|
||||
})
|
||||
default:
|
||||
if thinkingState == 3 && text == "" {
|
||||
if activeBlockType == "thinking" {
|
||||
closeActiveBlock()
|
||||
}
|
||||
return
|
||||
}
|
||||
if text != "" {
|
||||
startContentBlock("thinking")
|
||||
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": activeBlockIndex,
|
||||
"delta": map[string]string{"type": "thinking_delta", "thinking": text},
|
||||
})
|
||||
}
|
||||
if thinkingState == 3 && activeBlockType == "thinking" {
|
||||
closeActiveBlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理文本,解析 <thinking> 标签
|
||||
var thinkingStarted bool
|
||||
var eventThinkingOpen bool
|
||||
|
||||
processClaudeText := func(text string, isThinking bool, forceFlush bool) {
|
||||
if isThinking && !thinking {
|
||||
return
|
||||
}
|
||||
|
||||
// 如果是 reasoningContentEvent,直接输出
|
||||
if isThinking {
|
||||
if !allowReasoningSource(&thinkingSource) {
|
||||
return
|
||||
}
|
||||
if !thinkingStarted {
|
||||
sendText(text, 1)
|
||||
thinkingStarted = true
|
||||
eventThinkingOpen = true
|
||||
} else {
|
||||
sendText(text, 2)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if eventThinkingOpen {
|
||||
sendText("", 3)
|
||||
eventThinkingOpen = false
|
||||
thinkingStarted = false
|
||||
}
|
||||
|
||||
textBuffer += text
|
||||
|
||||
for {
|
||||
@@ -540,6 +652,7 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
|
||||
}
|
||||
textBuffer = textBuffer[thinkingStart+10:]
|
||||
inThinkingBlock = true
|
||||
dropTagThinking = !allowTagSource(&thinkingSource)
|
||||
thinkingStarted = false
|
||||
} else if forceFlush || len([]rune(textBuffer)) > 50 {
|
||||
// 使用 rune 切片来正确处理 Unicode 字符
|
||||
@@ -560,25 +673,33 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
|
||||
thinkingEnd := strings.Index(textBuffer, "</thinking>")
|
||||
if thinkingEnd != -1 {
|
||||
content := textBuffer[:thinkingEnd]
|
||||
if !thinkingStarted {
|
||||
sendText(content, 1)
|
||||
sendText("", 3)
|
||||
} else {
|
||||
sendText(content, 3)
|
||||
if !dropTagThinking {
|
||||
if !thinkingStarted {
|
||||
sendText(content, 1)
|
||||
sendText("", 3)
|
||||
} else {
|
||||
sendText(content, 3)
|
||||
}
|
||||
}
|
||||
textBuffer = textBuffer[thinkingEnd+11:]
|
||||
inThinkingBlock = false
|
||||
dropTagThinking = false
|
||||
thinkingStarted = false
|
||||
} else if forceFlush {
|
||||
if textBuffer != "" {
|
||||
if !thinkingStarted {
|
||||
sendText(textBuffer, 1)
|
||||
sendText("", 3)
|
||||
} else {
|
||||
sendText(textBuffer, 3)
|
||||
if !dropTagThinking {
|
||||
if !thinkingStarted {
|
||||
sendText(textBuffer, 1)
|
||||
sendText("", 3)
|
||||
} else {
|
||||
sendText(textBuffer, 3)
|
||||
}
|
||||
}
|
||||
textBuffer = ""
|
||||
}
|
||||
inThinkingBlock = false
|
||||
dropTagThinking = false
|
||||
thinkingStarted = false
|
||||
break
|
||||
} else {
|
||||
// 流式输出 thinking 块内的内容
|
||||
@@ -586,11 +707,13 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
|
||||
if len(runes) > 20 {
|
||||
safeLen := len(runes) - 15
|
||||
if safeLen > 0 {
|
||||
if !thinkingStarted {
|
||||
sendText(string(runes[:safeLen]), 1)
|
||||
thinkingStarted = true
|
||||
} else {
|
||||
sendText(string(runes[:safeLen]), 2)
|
||||
if !dropTagThinking {
|
||||
if !thinkingStarted {
|
||||
sendText(string(runes[:safeLen]), 1)
|
||||
thinkingStarted = true
|
||||
} else {
|
||||
sendText(string(runes[:safeLen]), 2)
|
||||
}
|
||||
}
|
||||
textBuffer = string(runes[safeLen:])
|
||||
}
|
||||
@@ -605,11 +728,17 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
|
||||
h.sendSSE(w, flusher, "message_start", map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"message": map[string]interface{}{
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []interface{}{},
|
||||
"model": model,
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []interface{}{},
|
||||
"model": model,
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]int{
|
||||
"input_tokens": startInputTokens,
|
||||
"output_tokens": 0,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
@@ -618,27 +747,26 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
if isThinking {
|
||||
rawThinkingBuilder.WriteString(text)
|
||||
} else {
|
||||
rawContentBuilder.WriteString(text)
|
||||
}
|
||||
processClaudeText(text, isThinking, false)
|
||||
},
|
||||
OnToolUse: func(tu KiroToolUse) {
|
||||
// 先刷新缓冲区
|
||||
processClaudeText("", false, true)
|
||||
rawContentBuilder.WriteString(tu.Name)
|
||||
if b, err := json.Marshal(tu.Input); err == nil {
|
||||
rawContentBuilder.Write(b)
|
||||
}
|
||||
|
||||
toolUses = append(toolUses, tu)
|
||||
closeActiveBlock()
|
||||
|
||||
// 关闭文本块
|
||||
if contentStarted && toolUseIndex == 0 {
|
||||
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": 0,
|
||||
})
|
||||
}
|
||||
|
||||
idx := toolUseIndex
|
||||
if contentStarted {
|
||||
idx = toolUseIndex + 1
|
||||
}
|
||||
toolUseIndex++
|
||||
idx := nextContentIndex
|
||||
nextContentIndex++
|
||||
|
||||
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
@@ -691,19 +819,27 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
|
||||
|
||||
// 刷新剩余缓冲区
|
||||
processClaudeText("", false, true)
|
||||
if eventThinkingOpen {
|
||||
sendText("", 3)
|
||||
eventThinkingOpen = false
|
||||
}
|
||||
closeActiveBlock()
|
||||
|
||||
inputTokens = estimatedInputTokens
|
||||
outputContent, extractedReasoning := extractThinkingFromContent(rawContentBuilder.String())
|
||||
thinkingOutput := rawThinkingBuilder.String()
|
||||
if thinking && thinkingOutput == "" && extractedReasoning != "" {
|
||||
thinkingOutput = extractedReasoning
|
||||
}
|
||||
if !thinking {
|
||||
thinkingOutput = ""
|
||||
}
|
||||
outputTokens = estimateClaudeOutputTokens(outputContent, thinkingOutput, toolUses)
|
||||
|
||||
h.recordSuccess(inputTokens, outputTokens, credits)
|
||||
h.pool.RecordSuccess(account.ID)
|
||||
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
|
||||
|
||||
// 关闭最后的内容块
|
||||
if contentStarted && toolUseIndex == 0 {
|
||||
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": 0,
|
||||
})
|
||||
}
|
||||
|
||||
// 发送 message_delta
|
||||
stopReason := "end_turn"
|
||||
if len(toolUses) > 0 {
|
||||
@@ -787,7 +923,7 @@ func (h *Handler) recordFailure() {
|
||||
}
|
||||
|
||||
// handleClaudeNonStream Claude 非流式响应
|
||||
func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string) {
|
||||
func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int) {
|
||||
var content string
|
||||
var thinkingContent string
|
||||
var toolUses []KiroToolUse
|
||||
@@ -825,25 +961,36 @@ func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.A
|
||||
return
|
||||
}
|
||||
|
||||
// 合并 thinking 内容(如果有 reasoningContentEvent 的内容)
|
||||
thinkingFormat := config.GetThinkingConfig().ClaudeFormat
|
||||
finalContent, extractedReasoning := extractThinkingFromContent(content)
|
||||
if thinking && thinkingContent == "" && extractedReasoning != "" {
|
||||
thinkingContent = extractedReasoning
|
||||
}
|
||||
if !thinking {
|
||||
thinkingContent = ""
|
||||
}
|
||||
|
||||
inputTokens = estimatedInputTokens
|
||||
outputTokens = estimateClaudeOutputTokens(finalContent, thinkingContent, toolUses)
|
||||
|
||||
h.recordSuccess(inputTokens, outputTokens, credits)
|
||||
h.pool.RecordSuccess(account.ID)
|
||||
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
|
||||
|
||||
// 合并 thinking 内容(如果有 reasoningContentEvent 的内容)
|
||||
thinkingFormat := config.GetThinkingConfig().ClaudeFormat
|
||||
finalContent := content
|
||||
if thinkingContent != "" {
|
||||
if thinking && thinkingContent != "" {
|
||||
switch thinkingFormat {
|
||||
case "think":
|
||||
finalContent = "<think>" + thinkingContent + "</think>" + content
|
||||
finalContent = "<think>" + thinkingContent + "</think>" + finalContent
|
||||
thinkingContent = ""
|
||||
case "reasoning_content":
|
||||
finalContent = thinkingContent + content // Claude 格式不支持 reasoning_content,直接拼接
|
||||
default: // "thinking"
|
||||
finalContent = "<thinking>" + thinkingContent + "</thinking>" + content
|
||||
finalContent = thinkingContent + finalContent // Claude 格式不支持 reasoning_content,直接拼接
|
||||
thinkingContent = ""
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
resp := KiroToClaudeResponse(finalContent, toolUses, inputTokens, outputTokens, model)
|
||||
resp := KiroToClaudeResponse(finalContent, thinkingContent, toolUses, inputTokens, outputTokens, model)
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
@@ -894,18 +1041,19 @@ func (h *Handler) handleOpenAIChat(w http.ResponseWriter, r *http.Request) {
|
||||
thinkingCfg := config.GetThinkingConfig()
|
||||
actualModel, thinking := ParseModelAndThinking(req.Model, thinkingCfg.Suffix)
|
||||
req.Model = actualModel
|
||||
estimatedInputTokens := estimateOpenAIRequestInputTokens(&req)
|
||||
|
||||
kiroPayload := OpenAIToKiro(&req, thinking)
|
||||
|
||||
if req.Stream {
|
||||
h.handleOpenAIStream(w, account, kiroPayload, req.Model)
|
||||
h.handleOpenAIStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens)
|
||||
} else {
|
||||
h.handleOpenAINonStream(w, account, kiroPayload, req.Model)
|
||||
h.handleOpenAINonStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// handleOpenAIStream OpenAI 流式响应
|
||||
func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string) {
|
||||
func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int) {
|
||||
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
@@ -924,10 +1072,14 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
|
||||
var toolCallIndex int
|
||||
var inputTokens, outputTokens int
|
||||
var credits float64
|
||||
var rawContentBuilder strings.Builder
|
||||
var rawReasoningBuilder strings.Builder
|
||||
|
||||
// Thinking 标签解析状态
|
||||
var textBuffer string
|
||||
var inThinkingBlock bool
|
||||
var dropTagThinking bool
|
||||
var thinkingSource thinkingStreamSource
|
||||
|
||||
// 发送 chunk 的辅助函数
|
||||
// thinkingState: 0=普通内容, 1=thinking开始, 2=thinking中间, 3=thinking结束
|
||||
@@ -939,6 +1091,9 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
|
||||
var chunk map[string]interface{}
|
||||
|
||||
if thinkingState > 0 {
|
||||
if !thinking {
|
||||
return
|
||||
}
|
||||
// thinking 内容
|
||||
switch thinkingFormat {
|
||||
case "thinking":
|
||||
@@ -1031,19 +1186,34 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
|
||||
// 处理文本,解析 <thinking> 标签
|
||||
// thinkingStarted 用于跟踪是否已发送开始标签
|
||||
var thinkingStarted bool
|
||||
var eventThinkingOpen bool
|
||||
|
||||
processText := func(text string, isThinking bool, forceFlush bool) {
|
||||
if isThinking && !thinking {
|
||||
return
|
||||
}
|
||||
|
||||
// 如果是 reasoningContentEvent,直接输出
|
||||
if isThinking {
|
||||
if !allowReasoningSource(&thinkingSource) {
|
||||
return
|
||||
}
|
||||
if !thinkingStarted {
|
||||
sendChunk(text, 1) // 开始
|
||||
thinkingStarted = true
|
||||
eventThinkingOpen = true
|
||||
} else {
|
||||
sendChunk(text, 2) // 中间
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if eventThinkingOpen {
|
||||
sendChunk("", 3)
|
||||
eventThinkingOpen = false
|
||||
thinkingStarted = false
|
||||
}
|
||||
|
||||
textBuffer += text
|
||||
|
||||
for {
|
||||
@@ -1057,6 +1227,7 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
|
||||
}
|
||||
textBuffer = textBuffer[thinkingStart+10:] // 移除 <thinking>
|
||||
inThinkingBlock = true
|
||||
dropTagThinking = !allowTagSource(&thinkingSource)
|
||||
thinkingStarted = false // 重置,准备发送新的开始标签
|
||||
} else if forceFlush || len([]rune(textBuffer)) > 50 {
|
||||
// 没有找到标签,安全输出(保留可能的部分标签)
|
||||
@@ -1079,28 +1250,36 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
|
||||
if thinkingEnd != -1 {
|
||||
// 输出 thinking 内容
|
||||
content := textBuffer[:thinkingEnd]
|
||||
if !thinkingStarted {
|
||||
// 一次性输出完整内容(开始+内容+结束)
|
||||
sendChunk(content, 1) // 开始
|
||||
sendChunk("", 3) // 结束(空内容,只发结束标签)
|
||||
} else {
|
||||
// 已经开始了,发送剩余内容和结束
|
||||
sendChunk(content, 3) // 结束
|
||||
if !dropTagThinking {
|
||||
if !thinkingStarted {
|
||||
// 一次性输出完整内容(开始+内容+结束)
|
||||
sendChunk(content, 1) // 开始
|
||||
sendChunk("", 3) // 结束(空内容,只发结束标签)
|
||||
} else {
|
||||
// 已经开始了,发送剩余内容和结束
|
||||
sendChunk(content, 3) // 结束
|
||||
}
|
||||
}
|
||||
textBuffer = textBuffer[thinkingEnd+11:] // 移除 </thinking>
|
||||
inThinkingBlock = false
|
||||
dropTagThinking = false
|
||||
thinkingStarted = false
|
||||
} else if forceFlush {
|
||||
// 强制刷新:输出剩余内容
|
||||
if textBuffer != "" {
|
||||
if !thinkingStarted {
|
||||
sendChunk(textBuffer, 1) // 开始
|
||||
sendChunk("", 3) // 结束
|
||||
} else {
|
||||
sendChunk(textBuffer, 3) // 结束
|
||||
if !dropTagThinking {
|
||||
if !thinkingStarted {
|
||||
sendChunk(textBuffer, 1) // 开始
|
||||
sendChunk("", 3) // 结束
|
||||
} else {
|
||||
sendChunk(textBuffer, 3) // 结束
|
||||
}
|
||||
}
|
||||
textBuffer = ""
|
||||
}
|
||||
inThinkingBlock = false
|
||||
dropTagThinking = false
|
||||
thinkingStarted = false
|
||||
break
|
||||
} else {
|
||||
// 流式输出 thinking 块内的内容
|
||||
@@ -1108,11 +1287,13 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
|
||||
if len(runes) > 20 {
|
||||
safeLen := len(runes) - 15 // 保留可能的 </thinking> 部分
|
||||
if safeLen > 0 {
|
||||
if !thinkingStarted {
|
||||
sendChunk(string(runes[:safeLen]), 1) // 开始
|
||||
thinkingStarted = true
|
||||
} else {
|
||||
sendChunk(string(runes[:safeLen]), 2) // 中间
|
||||
if !dropTagThinking {
|
||||
if !thinkingStarted {
|
||||
sendChunk(string(runes[:safeLen]), 1) // 开始
|
||||
thinkingStarted = true
|
||||
} else {
|
||||
sendChunk(string(runes[:safeLen]), 2) // 中间
|
||||
}
|
||||
}
|
||||
textBuffer = string(runes[safeLen:])
|
||||
}
|
||||
@@ -1128,6 +1309,11 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
if isThinking {
|
||||
rawReasoningBuilder.WriteString(text)
|
||||
} else {
|
||||
rawContentBuilder.WriteString(text)
|
||||
}
|
||||
processText(text, isThinking, false)
|
||||
},
|
||||
OnToolUse: func(tu KiroToolUse) {
|
||||
@@ -1135,6 +1321,8 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
|
||||
processText("", false, true)
|
||||
|
||||
args, _ := json.Marshal(tu.Input)
|
||||
rawContentBuilder.WriteString(tu.Name)
|
||||
rawContentBuilder.Write(args)
|
||||
tc := ToolCall{ID: tu.ToolUseID, Type: "function"}
|
||||
tc.Function.Name = tu.Name
|
||||
tc.Function.Arguments = string(args)
|
||||
@@ -1187,6 +1375,25 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
|
||||
|
||||
// 刷新剩余缓冲区
|
||||
processText("", false, true)
|
||||
if eventThinkingOpen {
|
||||
sendChunk("", 3)
|
||||
eventThinkingOpen = false
|
||||
}
|
||||
|
||||
inputTokens = estimatedInputTokens
|
||||
outputContent, extractedReasoning := extractThinkingFromContent(rawContentBuilder.String())
|
||||
reasoningOutput := rawReasoningBuilder.String()
|
||||
if thinking && reasoningOutput == "" && extractedReasoning != "" {
|
||||
reasoningOutput = extractedReasoning
|
||||
}
|
||||
if !thinking {
|
||||
reasoningOutput = ""
|
||||
}
|
||||
outputTokens = estimateApproxTokens(outputContent) + estimateApproxTokens(reasoningOutput)
|
||||
for _, tc := range toolCalls {
|
||||
outputTokens += estimateApproxTokens(tc.Function.Name)
|
||||
outputTokens += estimateApproxTokens(tc.Function.Arguments)
|
||||
}
|
||||
|
||||
h.recordSuccess(inputTokens, outputTokens, credits)
|
||||
h.pool.RecordSuccess(account.ID)
|
||||
@@ -1221,7 +1428,7 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
|
||||
}
|
||||
|
||||
// handleOpenAINonStream OpenAI 非流式响应
|
||||
func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string) {
|
||||
func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int) {
|
||||
var content string
|
||||
var reasoningContent string
|
||||
var toolUses []KiroToolUse
|
||||
@@ -1250,16 +1457,21 @@ func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.A
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 content 中的 <thinking> 标签
|
||||
finalContent, extractedReasoning := extractThinkingFromContent(content)
|
||||
if thinking && reasoningContent == "" && extractedReasoning != "" {
|
||||
reasoningContent = extractedReasoning
|
||||
} else if !thinking {
|
||||
reasoningContent = ""
|
||||
}
|
||||
|
||||
inputTokens = estimatedInputTokens
|
||||
outputTokens = estimateOpenAIOutputTokens(finalContent, reasoningContent, toolUses)
|
||||
|
||||
h.recordSuccess(inputTokens, outputTokens, credits)
|
||||
h.pool.RecordSuccess(account.ID)
|
||||
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
|
||||
|
||||
// 解析 content 中的 <thinking> 标签
|
||||
finalContent, extractedReasoning := extractThinkingFromContent(content)
|
||||
if extractedReasoning != "" {
|
||||
reasoningContent = extractedReasoning + reasoningContent
|
||||
}
|
||||
|
||||
thinkingFormat := config.GetThinkingConfig().OpenAIFormat
|
||||
resp := KiroToOpenAIResponseWithReasoning(finalContent, reasoningContent, toolUses, inputTokens, outputTokens, model, thinkingFormat)
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
|
||||
50
proxy/handler_test.go
Normal file
50
proxy/handler_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package proxy
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestThinkingSourceReasoningFirst(t *testing.T) {
|
||||
var source thinkingStreamSource
|
||||
|
||||
if !allowReasoningSource(&source) {
|
||||
t.Fatalf("expected reasoning source to be accepted first")
|
||||
}
|
||||
if source != thinkingSourceReasoningEvent {
|
||||
t.Fatalf("expected source to be reasoning, got %v", source)
|
||||
}
|
||||
if allowTagSource(&source) {
|
||||
t.Fatalf("expected tag source to be rejected after reasoning source selected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkingSourceTagFirst(t *testing.T) {
|
||||
var source thinkingStreamSource
|
||||
|
||||
if !allowTagSource(&source) {
|
||||
t.Fatalf("expected tag source to be accepted first")
|
||||
}
|
||||
if source != thinkingSourceTagBlock {
|
||||
t.Fatalf("expected source to be tag, got %v", source)
|
||||
}
|
||||
if allowReasoningSource(&source) {
|
||||
t.Fatalf("expected reasoning source to be rejected after tag source selected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkingSourceSameSourceRemainsAllowed(t *testing.T) {
|
||||
var source thinkingStreamSource
|
||||
|
||||
if !allowTagSource(&source) {
|
||||
t.Fatalf("expected initial tag source selection to succeed")
|
||||
}
|
||||
if !allowTagSource(&source) {
|
||||
t.Fatalf("expected repeated tag source selection to stay allowed")
|
||||
}
|
||||
|
||||
source = thinkingSourceUnknown
|
||||
if !allowReasoningSource(&source) {
|
||||
t.Fatalf("expected initial reasoning source selection to succeed")
|
||||
}
|
||||
if !allowReasoningSource(&source) {
|
||||
t.Fatalf("expected repeated reasoning source selection to stay allowed")
|
||||
}
|
||||
}
|
||||
193
proxy/kiro.go
193
proxy/kiro.go
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"kiro-api-proxy/config"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -159,14 +160,10 @@ func getSortedEndpoints(preferred string) []kiroEndpoint {
|
||||
|
||||
// CallKiroAPI 调用 Kiro API(流式),双端点自动 fallback
|
||||
func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroStreamCallback) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
if _, err := json.Marshal(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 预估输入 token(约 3 字符 = 1 token)
|
||||
estimatedInputTokens := max(1, len(body)/3)
|
||||
|
||||
// User-Agent
|
||||
machineId := account.MachineId
|
||||
var userAgent, amzUserAgent string
|
||||
@@ -230,7 +227,7 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt
|
||||
continue
|
||||
}
|
||||
|
||||
err = parseEventStream(resp.Body, callback, estimatedInputTokens)
|
||||
err = parseEventStream(resp.Body, callback)
|
||||
resp.Body.Close()
|
||||
return err
|
||||
}
|
||||
@@ -244,12 +241,13 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt
|
||||
// ==================== Event Stream 解析 ====================
|
||||
|
||||
// parseEventStream 解析 AWS Event Stream 二进制格式
|
||||
func parseEventStream(body io.Reader, callback *KiroStreamCallback, estimatedInputTokens int) error {
|
||||
func parseEventStream(body io.Reader, callback *KiroStreamCallback) error {
|
||||
// 不使用 bufio,直接读取避免缓冲延迟
|
||||
var inputTokens, outputTokens int
|
||||
var totalOutputChars int
|
||||
var totalCredits float64
|
||||
var currentToolUse *toolUseState
|
||||
var lastAssistantContent string
|
||||
var lastReasoningContent string
|
||||
|
||||
for {
|
||||
// Prelude: 12 bytes (total_len + headers_len + crc)
|
||||
@@ -292,30 +290,26 @@ func parseEventStream(body io.Reader, callback *KiroStreamCallback, estimatedInp
|
||||
continue
|
||||
}
|
||||
|
||||
inputTokens, outputTokens = updateTokensFromEvent(event, inputTokens, outputTokens)
|
||||
|
||||
// 处理事件
|
||||
switch eventType {
|
||||
case "assistantResponseEvent":
|
||||
if content, ok := event["content"].(string); ok && content != "" {
|
||||
callback.OnText(content, false)
|
||||
totalOutputChars += len(content)
|
||||
normalized := normalizeChunk(content, &lastAssistantContent)
|
||||
if normalized != "" {
|
||||
callback.OnText(normalized, false)
|
||||
}
|
||||
}
|
||||
case "reasoningContentEvent":
|
||||
if text, ok := event["text"].(string); ok && text != "" {
|
||||
callback.OnText(text, true)
|
||||
totalOutputChars += len(text)
|
||||
normalized := normalizeChunk(text, &lastReasoningContent)
|
||||
if normalized != "" {
|
||||
callback.OnText(normalized, true)
|
||||
}
|
||||
}
|
||||
case "toolUseEvent":
|
||||
currentToolUse = handleToolUseEvent(event, currentToolUse, callback)
|
||||
case "messageMetadataEvent", "metadataEvent":
|
||||
if tokenUsage, ok := event["tokenUsage"].(map[string]interface{}); ok {
|
||||
if v, ok := tokenUsage["outputTokens"].(float64); ok {
|
||||
outputTokens = int(v)
|
||||
}
|
||||
uncached, _ := tokenUsage["uncachedInputTokens"].(float64)
|
||||
cacheRead, _ := tokenUsage["cacheReadInputTokens"].(float64)
|
||||
cacheWrite, _ := tokenUsage["cacheWriteInputTokens"].(float64)
|
||||
inputTokens = int(uncached + cacheRead + cacheWrite)
|
||||
}
|
||||
case "meteringEvent":
|
||||
if usage, ok := event["usage"].(float64); ok {
|
||||
totalCredits += usage
|
||||
@@ -323,15 +317,6 @@ func parseEventStream(body io.Reader, callback *KiroStreamCallback, estimatedInp
|
||||
}
|
||||
}
|
||||
|
||||
// 估算 token(约 3 字符 = 1 token)
|
||||
if outputTokens == 0 && totalOutputChars > 0 {
|
||||
outputTokens = max(1, totalOutputChars/3)
|
||||
}
|
||||
// 如果 Kiro 没返回 inputTokens,使用预估值
|
||||
if inputTokens == 0 {
|
||||
inputTokens = estimatedInputTokens
|
||||
}
|
||||
|
||||
if callback.OnCredits != nil && totalCredits > 0 {
|
||||
callback.OnCredits(totalCredits)
|
||||
}
|
||||
@@ -340,6 +325,152 @@ func parseEventStream(body io.Reader, callback *KiroStreamCallback, estimatedInp
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateTokensFromEvent(event map[string]interface{}, currentInputTokens, currentOutputTokens int) (int, int) {
|
||||
candidates := []map[string]interface{}{event}
|
||||
collectUsageMaps(event, &candidates)
|
||||
|
||||
inputTokens := currentInputTokens
|
||||
outputTokens := currentOutputTokens
|
||||
|
||||
for _, usage := range candidates {
|
||||
if usage == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if v, ok := readTokenNumber(usage,
|
||||
"outputTokens", "completionTokens", "totalOutputTokens",
|
||||
"output_tokens", "completion_tokens", "total_output_tokens",
|
||||
); ok {
|
||||
outputTokens = v
|
||||
}
|
||||
|
||||
if v, ok := readTokenNumber(usage,
|
||||
"inputTokens", "promptTokens", "totalInputTokens",
|
||||
"input_tokens", "prompt_tokens", "total_input_tokens",
|
||||
); ok {
|
||||
inputTokens = v
|
||||
continue
|
||||
}
|
||||
|
||||
uncached, _ := readTokenNumber(usage, "uncachedInputTokens", "uncached_input_tokens")
|
||||
cacheRead, _ := readTokenNumber(usage, "cacheReadInputTokens", "cache_read_input_tokens")
|
||||
cacheWrite, _ := readTokenNumber(usage, "cacheWriteInputTokens", "cache_write_input_tokens", "cacheCreationInputTokens", "cache_creation_input_tokens")
|
||||
if uncached+cacheRead+cacheWrite > 0 {
|
||||
inputTokens = uncached + cacheRead + cacheWrite
|
||||
continue
|
||||
}
|
||||
|
||||
total, ok := readTokenNumber(usage, "totalTokens", "total_tokens")
|
||||
if ok && total > 0 {
|
||||
candidateOutput := outputTokens
|
||||
if v, vok := readTokenNumber(usage,
|
||||
"outputTokens", "completionTokens", "totalOutputTokens",
|
||||
"output_tokens", "completion_tokens", "total_output_tokens",
|
||||
); vok {
|
||||
candidateOutput = v
|
||||
}
|
||||
if total-candidateOutput > 0 {
|
||||
inputTokens = total - candidateOutput
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return inputTokens, outputTokens
|
||||
}
|
||||
|
||||
func collectUsageMaps(v interface{}, out *[]map[string]interface{}) {
|
||||
switch t := v.(type) {
|
||||
case map[string]interface{}:
|
||||
for k, child := range t {
|
||||
lk := strings.ToLower(k)
|
||||
if lk == "usage" || lk == "tokenusage" || lk == "token_usage" {
|
||||
if m, ok := child.(map[string]interface{}); ok {
|
||||
*out = append(*out, m)
|
||||
}
|
||||
}
|
||||
collectUsageMaps(child, out)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, child := range t {
|
||||
collectUsageMaps(child, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeChunk(chunk string, previous *string) string {
|
||||
if chunk == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
prev := *previous
|
||||
if prev == "" {
|
||||
*previous = chunk
|
||||
return chunk
|
||||
}
|
||||
|
||||
if chunk == prev {
|
||||
return ""
|
||||
}
|
||||
|
||||
if strings.HasPrefix(chunk, prev) {
|
||||
delta := chunk[len(prev):]
|
||||
*previous = chunk
|
||||
return delta
|
||||
}
|
||||
|
||||
if strings.HasPrefix(prev, chunk) {
|
||||
return ""
|
||||
}
|
||||
|
||||
maxOverlap := 0
|
||||
maxLen := len(prev)
|
||||
if len(chunk) < maxLen {
|
||||
maxLen = len(chunk)
|
||||
}
|
||||
for i := maxLen; i > 0; i-- {
|
||||
if strings.HasSuffix(prev, chunk[:i]) {
|
||||
maxOverlap = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
*previous = chunk
|
||||
if maxOverlap > 0 {
|
||||
return chunk[maxOverlap:]
|
||||
}
|
||||
|
||||
return chunk
|
||||
}
|
||||
|
||||
func readTokenNumber(m map[string]interface{}, keys ...string) (int, bool) {
|
||||
for _, k := range keys {
|
||||
v, ok := m[k]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int(n), true
|
||||
case int:
|
||||
return n, true
|
||||
case int64:
|
||||
return int(n), true
|
||||
case json.Number:
|
||||
if parsed, err := n.Int64(); err == nil {
|
||||
return int(parsed), true
|
||||
}
|
||||
case string:
|
||||
if parsed, err := strconv.Atoi(n); err == nil {
|
||||
return parsed, true
|
||||
}
|
||||
if parsed, err := strconv.ParseFloat(n, 64); err == nil {
|
||||
return int(parsed), true
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// ==================== Tool Use 处理 ====================
|
||||
|
||||
type toolUseState struct {
|
||||
|
||||
37
proxy/kiro_test.go
Normal file
37
proxy/kiro_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package proxy
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeChunkBasicProgression(t *testing.T) {
|
||||
prev := ""
|
||||
|
||||
if got := normalizeChunk("abc", &prev); got != "abc" {
|
||||
t.Fatalf("expected first chunk to pass through, got %q", got)
|
||||
}
|
||||
if got := normalizeChunk("abcde", &prev); got != "de" {
|
||||
t.Fatalf("expected appended delta, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeChunkPrefixRewindDoesNotReplay(t *testing.T) {
|
||||
prev := ""
|
||||
|
||||
_ = normalizeChunk("abcde", &prev)
|
||||
if got := normalizeChunk("abc", &prev); got != "" {
|
||||
t.Fatalf("expected rewind chunk to be ignored, got %q", got)
|
||||
}
|
||||
if prev != "abcde" {
|
||||
t.Fatalf("expected previous snapshot to remain longest version, got %q", prev)
|
||||
}
|
||||
if got := normalizeChunk("abcdef", &prev); got != "f" {
|
||||
t.Fatalf("expected only unseen suffix after rewind, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeChunkOverlapDelta(t *testing.T) {
|
||||
prev := "hello world"
|
||||
|
||||
if got := normalizeChunk("world!!!", &prev); got != "!!!" {
|
||||
t.Fatalf("expected overlap suffix delta, got %q", got)
|
||||
}
|
||||
}
|
||||
196
proxy/token_estimator.go
Normal file
196
proxy/token_estimator.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
)
|
||||
|
||||
func estimateApproxTokens(text string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
runes := []rune(text)
|
||||
length := len(runes)
|
||||
if length == 0 {
|
||||
return 0
|
||||
}
|
||||
if length < 5 {
|
||||
return max(1, int(math.Ceil(float64(length)/3.0)))
|
||||
}
|
||||
|
||||
var regularAscii, digits, symbols, nonASCII int
|
||||
for _, r := range runes {
|
||||
switch {
|
||||
case r >= 0x80:
|
||||
nonASCII++
|
||||
case r >= '0' && r <= '9':
|
||||
digits++
|
||||
case (r >= '!' && r <= '/') || (r >= ':' && r <= '@') || (r >= '[' && r <= '`') || (r >= '{' && r <= '~'):
|
||||
symbols++
|
||||
default:
|
||||
regularAscii++
|
||||
}
|
||||
}
|
||||
|
||||
estimated := int(math.Ceil(
|
||||
float64(regularAscii)/4.5 +
|
||||
float64(digits)/2.0 +
|
||||
float64(symbols)/1.5 +
|
||||
float64(nonASCII)/1.5,
|
||||
))
|
||||
|
||||
if estimated < 1 {
|
||||
return 1
|
||||
}
|
||||
return estimated
|
||||
}
|
||||
|
||||
func estimateClaudeRequestInputTokens(req *ClaudeRequest) int {
|
||||
if req == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
total := estimateClaudeValueTokens(req.System)
|
||||
|
||||
for _, msg := range req.Messages {
|
||||
total += estimateClaudeValueTokens(msg.Content)
|
||||
}
|
||||
|
||||
for _, tool := range req.Tools {
|
||||
total += estimateApproxTokens(tool.Name)
|
||||
total += estimateApproxTokens(tool.Description)
|
||||
total += estimateJSONTokens(tool.InputSchema)
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
func estimateClaudeOutputTokens(content, thinkingContent string, toolUses []KiroToolUse) int {
|
||||
total := estimateApproxTokens(content)
|
||||
total += estimateApproxTokens(thinkingContent)
|
||||
|
||||
for _, tu := range toolUses {
|
||||
total += estimateApproxTokens(tu.Name)
|
||||
total += estimateJSONTokens(tu.Input)
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
func estimateClaudeValueTokens(v interface{}) int {
|
||||
switch value := v.(type) {
|
||||
case nil:
|
||||
return 0
|
||||
case string:
|
||||
return estimateApproxTokens(value)
|
||||
case []interface{}:
|
||||
total := 0
|
||||
for _, part := range value {
|
||||
total += estimateClaudeValueTokens(part)
|
||||
}
|
||||
return total
|
||||
case map[string]interface{}:
|
||||
typeName, _ := value["type"].(string)
|
||||
switch typeName {
|
||||
case "text":
|
||||
if text, ok := value["text"].(string); ok {
|
||||
return estimateApproxTokens(text)
|
||||
}
|
||||
case "thinking":
|
||||
if thinking, ok := value["thinking"].(string); ok {
|
||||
return estimateApproxTokens(thinking)
|
||||
}
|
||||
case "tool_use":
|
||||
total := 0
|
||||
if name, ok := value["name"].(string); ok {
|
||||
total += estimateApproxTokens(name)
|
||||
}
|
||||
if input, ok := value["input"]; ok {
|
||||
total += estimateJSONTokens(input)
|
||||
}
|
||||
if total > 0 {
|
||||
return total
|
||||
}
|
||||
case "tool_result":
|
||||
if content, ok := value["content"]; ok {
|
||||
return estimateClaudeValueTokens(content)
|
||||
}
|
||||
}
|
||||
|
||||
total := 0
|
||||
if text, ok := value["text"].(string); ok {
|
||||
total += estimateApproxTokens(text)
|
||||
}
|
||||
if thinking, ok := value["thinking"].(string); ok {
|
||||
total += estimateApproxTokens(thinking)
|
||||
}
|
||||
if content, ok := value["content"]; ok {
|
||||
total += estimateClaudeValueTokens(content)
|
||||
}
|
||||
if total > 0 {
|
||||
return total
|
||||
}
|
||||
|
||||
return estimateJSONTokens(value)
|
||||
default:
|
||||
return estimateJSONTokens(value)
|
||||
}
|
||||
}
|
||||
|
||||
func estimateJSONTokens(v interface{}) int {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return estimateApproxTokens(string(b))
|
||||
}
|
||||
|
||||
func estimateOpenAIRequestInputTokens(req *OpenAIRequest) int {
|
||||
if req == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
total := 0
|
||||
|
||||
for _, msg := range req.Messages {
|
||||
total += estimateOpenAIContentTokens(msg.Content)
|
||||
total += estimateApproxTokens(msg.ToolCallID)
|
||||
for _, tc := range msg.ToolCalls {
|
||||
total += estimateApproxTokens(tc.Function.Name)
|
||||
total += estimateApproxTokens(tc.Function.Arguments)
|
||||
}
|
||||
}
|
||||
|
||||
for _, tool := range req.Tools {
|
||||
total += estimateApproxTokens(tool.Function.Name)
|
||||
total += estimateApproxTokens(tool.Function.Description)
|
||||
total += estimateJSONTokens(tool.Function.Parameters)
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
func estimateOpenAIContentTokens(content interface{}) int {
|
||||
switch value := content.(type) {
|
||||
case nil:
|
||||
return 0
|
||||
case string:
|
||||
return estimateApproxTokens(value)
|
||||
default:
|
||||
text := extractOpenAIMessageText(value)
|
||||
if text != "" {
|
||||
return estimateApproxTokens(text)
|
||||
}
|
||||
return estimateJSONTokens(value)
|
||||
}
|
||||
}
|
||||
|
||||
func estimateOpenAIOutputTokens(content, reasoningContent string, toolUses []KiroToolUse) int {
|
||||
return estimateClaudeOutputTokens(content, reasoningContent, toolUses)
|
||||
}
|
||||
@@ -10,34 +10,40 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// 模型映射
|
||||
var modelMap = map[string]string{
|
||||
"claude-sonnet-4-5": "claude-sonnet-4.5",
|
||||
"claude-sonnet-4.5": "claude-sonnet-4.5",
|
||||
"claude-haiku-4-5": "claude-haiku-4.5",
|
||||
"claude-haiku-4.5": "claude-haiku-4.5",
|
||||
"claude-sonnet-4-6": "claude-sonnet-4.6",
|
||||
"claude-sonnet-4.6": "claude-sonnet-4.6",
|
||||
"claude-opus-4-6": "claude-opus-4.6",
|
||||
"claude-opus-4.6": "claude-opus-4.6",
|
||||
"claude-opus-4-5": "claude-opus-4.5",
|
||||
"claude-opus-4.5": "claude-opus-4.5",
|
||||
"claude-sonnet-4": "claude-sonnet-4",
|
||||
"claude-sonnet-4-20250514": "claude-sonnet-4",
|
||||
"claude-3-5-sonnet": "claude-sonnet-4.5",
|
||||
"claude-3-opus": "claude-sonnet-4.5",
|
||||
"claude-3-sonnet": "claude-sonnet-4",
|
||||
"claude-3-haiku": "claude-haiku-4.5",
|
||||
"gpt-4": "claude-sonnet-4.5",
|
||||
"gpt-4o": "claude-sonnet-4.5",
|
||||
"gpt-4-turbo": "claude-sonnet-4.5",
|
||||
"gpt-3.5-turbo": "claude-sonnet-4.5",
|
||||
type modelRule struct {
|
||||
pattern string
|
||||
target string
|
||||
}
|
||||
|
||||
var modelRules = []modelRule{
|
||||
{pattern: "claude-sonnet-4-20250514", target: "claude-sonnet-4"},
|
||||
{pattern: "claude-sonnet-4-6", target: "claude-sonnet-4.6"},
|
||||
{pattern: "claude-sonnet-4.6", target: "claude-sonnet-4.6"},
|
||||
{pattern: "claude-sonnet-4-5", target: "claude-sonnet-4.5"},
|
||||
{pattern: "claude-sonnet-4.5", target: "claude-sonnet-4.5"},
|
||||
{pattern: "claude-haiku-4-5", target: "claude-haiku-4.5"},
|
||||
{pattern: "claude-haiku-4.5", target: "claude-haiku-4.5"},
|
||||
{pattern: "claude-opus-4-6", target: "claude-opus-4.6"},
|
||||
{pattern: "claude-opus-4.6", target: "claude-opus-4.6"},
|
||||
{pattern: "claude-opus-4-5", target: "claude-opus-4.5"},
|
||||
{pattern: "claude-opus-4.5", target: "claude-opus-4.5"},
|
||||
{pattern: "claude-3-5-sonnet", target: "claude-sonnet-4.5"},
|
||||
{pattern: "claude-3-opus", target: "claude-sonnet-4.5"},
|
||||
{pattern: "claude-3-sonnet", target: "claude-sonnet-4"},
|
||||
{pattern: "claude-3-haiku", target: "claude-haiku-4.5"},
|
||||
{pattern: "gpt-4o", target: "claude-sonnet-4.5"},
|
||||
{pattern: "gpt-4-turbo", target: "claude-sonnet-4.5"},
|
||||
{pattern: "gpt-3.5-turbo", target: "claude-sonnet-4.5"},
|
||||
{pattern: "gpt-4", target: "claude-sonnet-4.5"},
|
||||
{pattern: "claude-sonnet-4", target: "claude-sonnet-4"},
|
||||
}
|
||||
|
||||
// Thinking 模式提示
|
||||
const ThinkingModePrompt = `<thinking_mode>enabled</thinking_mode>
|
||||
<max_thinking_length>200000</max_thinking_length>`
|
||||
|
||||
const minimalFallbackUserContent = "."
|
||||
|
||||
// ParseModelAndThinking 解析模型名称,返回实际模型和是否启用 thinking
|
||||
func ParseModelAndThinking(model string, thinkingSuffix string) (string, bool) {
|
||||
lower := strings.ToLower(model)
|
||||
@@ -51,10 +57,9 @@ func ParseModelAndThinking(model string, thinkingSuffix string) (string, bool) {
|
||||
lower = strings.ToLower(model)
|
||||
}
|
||||
|
||||
// 映射模型
|
||||
for k, v := range modelMap {
|
||||
if strings.Contains(lower, k) {
|
||||
return v, thinking
|
||||
for _, rule := range modelRules {
|
||||
if strings.Contains(lower, rule.pattern) {
|
||||
return rule.target, thinking
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,6 +98,7 @@ type ClaudeMessage struct {
|
||||
type ClaudeContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input interface{} `json:"input,omitempty"`
|
||||
@@ -145,24 +151,6 @@ func ClaudeToKiro(req *ClaudeRequest, thinking bool) *KiroPayload {
|
||||
systemPrompt = ThinkingModePrompt + "\n\n" + systemPrompt
|
||||
}
|
||||
|
||||
// 注入时间戳
|
||||
timestamp := time.Now().Format(time.RFC3339)
|
||||
systemPrompt = "[Context: Current time is " + timestamp + "]\n\n" + systemPrompt
|
||||
|
||||
// 注入执行导向指令(防止 AI 在探索过程中丢失目标)
|
||||
executionDirective := `
|
||||
<execution_discipline>
|
||||
当用户要求执行特定任务时,你必须遵循以下纪律:
|
||||
1. **目标锁定**:在整个会话中始终牢记用户的原始目标,不要在代码探索过程中迷失方向
|
||||
2. **行动优先**:优先执行任务而非仅分析或总结,除非用户明确只要求分析
|
||||
3. **计划执行**:为任务创建明确的步骤计划,逐步执行并标记完成状态
|
||||
4. **禁止确认性收尾**:在任务未完成前,禁止输出"需要我继续吗?"、"需要深入分析吗?"等确认性问题
|
||||
5. **持续推进**:如果发现部分任务已完成,立即继续执行剩余未完成的任务
|
||||
6. **完整交付**:直到所有任务步骤都执行完毕才算完成
|
||||
</execution_discipline>
|
||||
`
|
||||
systemPrompt = systemPrompt + "\n\n" + executionDirective
|
||||
|
||||
// 构建历史消息
|
||||
history := make([]KiroHistoryMessage, 0)
|
||||
var currentContent string
|
||||
@@ -174,6 +162,7 @@ func ClaudeToKiro(req *ClaudeRequest, thinking bool) *KiroPayload {
|
||||
|
||||
if msg.Role == "user" {
|
||||
content, images, toolResults := extractClaudeUserContent(msg.Content)
|
||||
content = normalizeUserContent(content, len(images) > 0)
|
||||
|
||||
if isLast {
|
||||
currentContent = content
|
||||
@@ -226,10 +215,12 @@ func ClaudeToKiro(req *ClaudeRequest, thinking bool) *KiroPayload {
|
||||
}
|
||||
if currentContent != "" {
|
||||
finalContent += currentContent
|
||||
} else if len(currentImages) > 0 {
|
||||
finalContent += normalizeUserContent("", true)
|
||||
} else if len(currentToolResults) > 0 {
|
||||
finalContent += "Tool results provided."
|
||||
finalContent += buildToolResultsContinuation(currentToolResults)
|
||||
} else {
|
||||
finalContent += "Continue"
|
||||
finalContent += minimalFallbackUserContent
|
||||
}
|
||||
|
||||
// 转换工具
|
||||
@@ -238,7 +229,7 @@ func ClaudeToKiro(req *ClaudeRequest, thinking bool) *KiroPayload {
|
||||
// 构建 payload
|
||||
payload := &KiroPayload{}
|
||||
payload.ConversationState.ChatTriggerType = "MANUAL"
|
||||
payload.ConversationState.ConversationID = uuid.New().String()
|
||||
payload.ConversationState.ConversationID = buildConversationID(modelID, systemPrompt, firstClaudeConversationAnchor(req.Messages))
|
||||
payload.ConversationState.CurrentMessage.UserInputMessage = KiroUserInputMessage{
|
||||
Content: finalContent,
|
||||
ModelID: modelID,
|
||||
@@ -307,24 +298,13 @@ func extractClaudeUserContent(content interface{}) (string, []KiroImage, []KiroT
|
||||
|
||||
blockType, _ := block["type"].(string)
|
||||
switch blockType {
|
||||
case "text":
|
||||
case "text", "input_text":
|
||||
if t, ok := block["text"].(string); ok {
|
||||
text += t
|
||||
}
|
||||
case "image":
|
||||
if source, ok := block["source"].(map[string]interface{}); ok {
|
||||
mediaType, _ := source["media_type"].(string)
|
||||
data, _ := source["data"].(string)
|
||||
format := strings.TrimPrefix(mediaType, "image/")
|
||||
if format == "jpg" {
|
||||
format = "jpeg"
|
||||
}
|
||||
images = append(images, KiroImage{
|
||||
Format: format,
|
||||
Source: struct {
|
||||
Bytes string `json:"bytes"`
|
||||
}{Bytes: data},
|
||||
})
|
||||
case "image", "image_url", "input_image":
|
||||
if img := extractImageFromClaudeBlock(block); img != nil {
|
||||
images = append(images, *img)
|
||||
}
|
||||
case "tool_result":
|
||||
toolUseID, _ := block["tool_use_id"].(string)
|
||||
@@ -341,6 +321,44 @@ func extractClaudeUserContent(content interface{}) (string, []KiroImage, []KiroT
|
||||
return text, images, toolResults
|
||||
}
|
||||
|
||||
func extractImageFromClaudeBlock(block map[string]interface{}) *KiroImage {
|
||||
if source, ok := block["source"].(map[string]interface{}); ok {
|
||||
if data, ok := source["data"].(string); ok {
|
||||
if img := parseDataURL(data); img != nil {
|
||||
return img
|
||||
}
|
||||
mediaType, _ := source["media_type"].(string)
|
||||
if mediaType == "" {
|
||||
mediaType, _ = source["mediaType"].(string)
|
||||
}
|
||||
if mediaType == "" {
|
||||
mediaType, _ = source["mime_type"].(string)
|
||||
}
|
||||
format := strings.TrimPrefix(strings.ToLower(mediaType), "image/")
|
||||
if img := parseBase64Image(data, format); img != nil {
|
||||
return img
|
||||
}
|
||||
}
|
||||
if url, ok := source["url"].(string); ok {
|
||||
if img := parseDataURL(url); img != nil {
|
||||
return img
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if img := extractImageFromOpenAIPart(block); img != nil {
|
||||
return img
|
||||
}
|
||||
|
||||
if data, ok := block["data"].(string); ok {
|
||||
if img := parseDataURL(data); img != nil {
|
||||
return img
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractToolResultContent(content interface{}) string {
|
||||
if s, ok := content.(string); ok {
|
||||
return s
|
||||
@@ -396,10 +414,6 @@ func extractClaudeAssistantContent(content interface{}) (string, []KiroToolUse)
|
||||
}
|
||||
}
|
||||
|
||||
if text == "" && len(toolUses) > 0 {
|
||||
text = "Using tools."
|
||||
}
|
||||
|
||||
return text, toolUses
|
||||
}
|
||||
|
||||
@@ -441,9 +455,16 @@ func shortenToolName(name string) string {
|
||||
|
||||
// ==================== Kiro -> Claude 转换 ====================
|
||||
|
||||
func KiroToClaudeResponse(content string, toolUses []KiroToolUse, inputTokens, outputTokens int, model string) *ClaudeResponse {
|
||||
func KiroToClaudeResponse(content, thinkingContent string, toolUses []KiroToolUse, inputTokens, outputTokens int, model string) *ClaudeResponse {
|
||||
blocks := make([]ClaudeContentBlock, 0)
|
||||
|
||||
if thinkingContent != "" {
|
||||
blocks = append(blocks, ClaudeContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: thinkingContent,
|
||||
})
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
blocks = append(blocks, ClaudeContentBlock{
|
||||
Type: "text",
|
||||
@@ -549,7 +570,7 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload {
|
||||
|
||||
for _, msg := range req.Messages {
|
||||
if msg.Role == "system" {
|
||||
if s, ok := msg.Content.(string); ok {
|
||||
if s := extractOpenAIMessageText(msg.Content); s != "" {
|
||||
systemPrompt += s + "\n"
|
||||
}
|
||||
} else {
|
||||
@@ -562,24 +583,6 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload {
|
||||
systemPrompt = ThinkingModePrompt + "\n\n" + systemPrompt
|
||||
}
|
||||
|
||||
// 注入时间戳
|
||||
timestamp := time.Now().Format(time.RFC3339)
|
||||
systemPrompt = "[Context: Current time is " + timestamp + "]\n\n" + systemPrompt
|
||||
|
||||
// 注入执行导向指令(防止 AI 在探索过程中丢失目标)
|
||||
executionDirective := `
|
||||
<execution_discipline>
|
||||
当用户要求执行特定任务时,你必须遵循以下纪律:
|
||||
1. **目标锁定**:在整个会话中始终牢记用户的原始目标,不要在代码探索过程中迷失方向
|
||||
2. **行动优先**:优先执行任务而非仅分析或总结,除非用户明确只要求分析
|
||||
3. **计划执行**:为任务创建明确的步骤计划,逐步执行并标记完成状态
|
||||
4. **禁止确认性收尾**:在任务未完成前,禁止输出"需要我继续吗?"、"需要深入分析吗?"等确认性问题
|
||||
5. **持续推进**:如果发现部分任务已完成,立即继续执行剩余未完成的任务
|
||||
6. **完整交付**:直到所有任务步骤都执行完毕才算完成
|
||||
</execution_discipline>
|
||||
`
|
||||
systemPrompt = systemPrompt + "\n\n" + executionDirective
|
||||
|
||||
// 构建历史消息
|
||||
history := make([]KiroHistoryMessage, 0)
|
||||
var currentContent string
|
||||
@@ -593,6 +596,7 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload {
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
content, images := extractOpenAIUserContent(msg.Content)
|
||||
content = normalizeUserContent(content, len(images) > 0)
|
||||
|
||||
// 第一条 user 消息合并 system prompt
|
||||
if !systemMerged && systemPrompt != "" {
|
||||
@@ -615,10 +619,7 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload {
|
||||
}
|
||||
|
||||
case "assistant":
|
||||
content, _ := msg.Content.(string)
|
||||
if content == "" && len(msg.ToolCalls) > 0 {
|
||||
content = "Using tools."
|
||||
}
|
||||
content := extractOpenAIMessageText(msg.Content)
|
||||
|
||||
var toolUses []KiroToolUse
|
||||
for _, tc := range msg.ToolCalls {
|
||||
@@ -642,7 +643,7 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload {
|
||||
})
|
||||
|
||||
case "tool":
|
||||
content, _ := msg.Content.(string)
|
||||
content := extractOpenAIMessageText(msg.Content)
|
||||
currentToolResults = append(currentToolResults, KiroToolResult{
|
||||
ToolUseID: msg.ToolCallID,
|
||||
Content: []KiroResultContent{{Text: content}},
|
||||
@@ -655,7 +656,7 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload {
|
||||
if !isLast {
|
||||
history = append(history, KiroHistoryMessage{
|
||||
UserInputMessage: &KiroUserInputMessage{
|
||||
Content: "Tool results provided.",
|
||||
Content: buildToolResultsContinuation(currentToolResults),
|
||||
ModelID: modelID,
|
||||
Origin: origin,
|
||||
UserInputMessageContext: &UserInputMessageContext{
|
||||
@@ -672,10 +673,12 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload {
|
||||
// 构建最终内容
|
||||
finalContent := currentContent
|
||||
if finalContent == "" {
|
||||
if len(currentToolResults) > 0 {
|
||||
finalContent = "Tool results provided."
|
||||
if len(currentImages) > 0 {
|
||||
finalContent = normalizeUserContent("", true)
|
||||
} else if len(currentToolResults) > 0 {
|
||||
finalContent = buildToolResultsContinuation(currentToolResults)
|
||||
} else {
|
||||
finalContent = "Continue"
|
||||
finalContent = minimalFallbackUserContent
|
||||
}
|
||||
}
|
||||
if !systemMerged && systemPrompt != "" {
|
||||
@@ -688,7 +691,7 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload {
|
||||
// 构建 payload
|
||||
payload := &KiroPayload{}
|
||||
payload.ConversationState.ChatTriggerType = "MANUAL"
|
||||
payload.ConversationState.ConversationID = uuid.New().String()
|
||||
payload.ConversationState.ConversationID = buildConversationID(modelID, systemPrompt, firstOpenAIConversationAnchor(nonSystemMessages))
|
||||
payload.ConversationState.CurrentMessage.UserInputMessage = KiroUserInputMessage{
|
||||
Content: finalContent,
|
||||
ModelID: modelID,
|
||||
@@ -726,6 +729,15 @@ func extractOpenAIUserContent(content interface{}) (string, []KiroImage) {
|
||||
var text string
|
||||
var images []KiroImage
|
||||
|
||||
if part, ok := content.(map[string]interface{}); ok {
|
||||
if t, ok := extractOpenAITextPart(part); ok {
|
||||
text += t
|
||||
}
|
||||
if img := extractImageFromOpenAIPart(part); img != nil {
|
||||
images = append(images, *img)
|
||||
}
|
||||
}
|
||||
|
||||
if parts, ok := content.([]interface{}); ok {
|
||||
for _, p := range parts {
|
||||
part, ok := p.(map[string]interface{})
|
||||
@@ -733,50 +745,301 @@ func extractOpenAIUserContent(content interface{}) (string, []KiroImage) {
|
||||
continue
|
||||
}
|
||||
|
||||
partType, _ := part["type"].(string)
|
||||
switch partType {
|
||||
case "text":
|
||||
if t, ok := part["text"].(string); ok {
|
||||
text += t
|
||||
}
|
||||
case "image_url":
|
||||
if imgUrl, ok := part["image_url"].(map[string]interface{}); ok {
|
||||
if url, ok := imgUrl["url"].(string); ok {
|
||||
if img := parseDataURL(url); img != nil {
|
||||
images = append(images, *img)
|
||||
}
|
||||
}
|
||||
}
|
||||
if t, ok := extractOpenAITextPart(part); ok {
|
||||
text += t
|
||||
}
|
||||
if img := extractImageFromOpenAIPart(part); img != nil {
|
||||
images = append(images, *img)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(images) > 0 {
|
||||
text = sanitizeImagePlaceholders(text)
|
||||
}
|
||||
|
||||
return text, images
|
||||
}
|
||||
|
||||
func extractOpenAIMessageText(content interface{}) string {
|
||||
if content == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if s, ok := content.(string); ok {
|
||||
return s
|
||||
}
|
||||
|
||||
if text, _ := extractOpenAIUserContent(content); strings.TrimSpace(text) != "" {
|
||||
return text
|
||||
}
|
||||
|
||||
switch v := content.(type) {
|
||||
case map[string]interface{}:
|
||||
if nested, ok := v["content"]; ok {
|
||||
if nestedText := extractOpenAIMessageText(nested); strings.TrimSpace(nestedText) != "" {
|
||||
return nestedText
|
||||
}
|
||||
}
|
||||
if raw, err := json.Marshal(v); err == nil {
|
||||
return string(raw)
|
||||
}
|
||||
case []interface{}:
|
||||
parts := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
partText := extractOpenAIMessageText(item)
|
||||
if strings.TrimSpace(partText) != "" {
|
||||
parts = append(parts, partText)
|
||||
}
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
return strings.Join(parts, "")
|
||||
}
|
||||
if raw, err := json.Marshal(v); err == nil {
|
||||
return string(raw)
|
||||
}
|
||||
default:
|
||||
if raw, err := json.Marshal(v); err == nil {
|
||||
return string(raw)
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildToolResultsContinuation(toolResults []KiroToolResult) string {
|
||||
if len(toolResults) == 0 {
|
||||
return minimalFallbackUserContent
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(toolResults))
|
||||
for _, tr := range toolResults {
|
||||
if len(tr.Content) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, c := range tr.Content {
|
||||
if strings.TrimSpace(c.Text) != "" {
|
||||
parts = append(parts, c.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return minimalFallbackUserContent
|
||||
}
|
||||
|
||||
joined := strings.Join(parts, "\n\n")
|
||||
if len(joined) > 4000 {
|
||||
return joined[:4000]
|
||||
}
|
||||
return joined
|
||||
}
|
||||
|
||||
func firstClaudeConversationAnchor(messages []ClaudeMessage) string {
|
||||
for _, msg := range messages {
|
||||
if msg.Role != "user" {
|
||||
continue
|
||||
}
|
||||
text, _, toolResults := extractClaudeUserContent(msg.Content)
|
||||
if strings.TrimSpace(text) != "" {
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
if len(toolResults) > 0 {
|
||||
return buildToolResultsContinuation(toolResults)
|
||||
}
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
if strings.TrimSpace(msg.Role) != "" {
|
||||
if text := extractOpenAIMessageText(msg.Content); strings.TrimSpace(text) != "" {
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func firstOpenAIConversationAnchor(messages []OpenAIMessage) string {
|
||||
for _, msg := range messages {
|
||||
if msg.Role != "user" {
|
||||
continue
|
||||
}
|
||||
text := extractOpenAIMessageText(msg.Content)
|
||||
if strings.TrimSpace(text) != "" {
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
text := extractOpenAIMessageText(msg.Content)
|
||||
if strings.TrimSpace(text) != "" {
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildConversationID(modelID, systemPrompt, anchor string) string {
|
||||
anchor = strings.TrimSpace(anchor)
|
||||
if anchor == "" {
|
||||
return uuid.New().String()
|
||||
}
|
||||
seed := strings.Join([]string{modelID, strings.TrimSpace(systemPrompt), anchor}, "\n")
|
||||
return uuid.NewSHA1(uuid.NameSpaceURL, []byte(seed)).String()
|
||||
}
|
||||
|
||||
func extractOpenAITextPart(part map[string]interface{}) (string, bool) {
|
||||
partType, _ := part["type"].(string)
|
||||
switch partType {
|
||||
case "text", "input_text":
|
||||
if t, ok := part["text"].(string); ok {
|
||||
return t, true
|
||||
}
|
||||
}
|
||||
|
||||
if t, ok := part["text"].(string); ok {
|
||||
return t, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func extractImageFromOpenAIPart(part map[string]interface{}) *KiroImage {
|
||||
partType, _ := part["type"].(string)
|
||||
if partType != "" {
|
||||
switch partType {
|
||||
case "image", "image_url", "input_image", "file", "input_file":
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if fileObj, ok := part["file"].(map[string]interface{}); ok {
|
||||
if img := extractImageFromOpenAIPart(fileObj); img != nil {
|
||||
return img
|
||||
}
|
||||
}
|
||||
|
||||
if sourceObj, ok := part["source"].(map[string]interface{}); ok {
|
||||
if img := extractImageFromOpenAIPart(sourceObj); img != nil {
|
||||
return img
|
||||
}
|
||||
}
|
||||
|
||||
if raw, ok := part["mime"].(string); ok && !strings.HasPrefix(strings.ToLower(raw), "image/") {
|
||||
return nil
|
||||
}
|
||||
if raw, ok := part["media_type"].(string); ok && !strings.HasPrefix(strings.ToLower(raw), "image/") {
|
||||
return nil
|
||||
}
|
||||
if raw, ok := part["mime_type"].(string); ok && !strings.HasPrefix(strings.ToLower(raw), "image/") {
|
||||
return nil
|
||||
}
|
||||
|
||||
if raw, ok := part["url"].(string); ok {
|
||||
if img := parseDataURL(raw); img != nil {
|
||||
return img
|
||||
}
|
||||
}
|
||||
|
||||
if raw, ok := part["b64_json"].(string); ok {
|
||||
if img := parseBase64Image(raw, "png"); img != nil {
|
||||
return img
|
||||
}
|
||||
}
|
||||
|
||||
if raw, ok := part["image_url"]; ok {
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
if img := parseDataURL(v); img != nil {
|
||||
return img
|
||||
}
|
||||
case map[string]interface{}:
|
||||
if u, ok := v["url"].(string); ok {
|
||||
if img := parseDataURL(u); img != nil {
|
||||
return img
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if raw, ok := part["image_base64"].(string); ok {
|
||||
if img := parseBase64Image(raw, "png"); img != nil {
|
||||
return img
|
||||
}
|
||||
}
|
||||
if raw, ok := part["data"].(string); ok {
|
||||
if img := parseDataURL(raw); img != nil {
|
||||
return img
|
||||
}
|
||||
if img := parseBase64Image(raw, "png"); img != nil {
|
||||
return img
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func sanitizeImagePlaceholders(text string) string {
|
||||
re := regexp.MustCompile(`\[Image\s+\d+\]`)
|
||||
cleaned := re.ReplaceAllString(text, "")
|
||||
cleaned = strings.Join(strings.Fields(cleaned), " ")
|
||||
return strings.TrimSpace(cleaned)
|
||||
}
|
||||
|
||||
func normalizeUserContent(text string, hasImages bool) string {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if trimmed == "" && hasImages {
|
||||
return "Please analyze the attached image."
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func parseDataURL(url string) *KiroImage {
|
||||
// data:image/png;base64,xxxxx
|
||||
re := regexp.MustCompile(`^data:image/(\w+);base64,(.+)$`)
|
||||
matches := re.FindStringSubmatch(url)
|
||||
cleaned := strings.TrimSpace(strings.ReplaceAll(strings.ReplaceAll(url, "\n", ""), "\r", ""))
|
||||
if strings.Contains(cleaned, "[Image") {
|
||||
return nil
|
||||
}
|
||||
re := regexp.MustCompile(`^data:image/([a-zA-Z0-9+.-]+)(;[a-zA-Z0-9=._:+-]+)*;base64,(.+)$`)
|
||||
matches := re.FindStringSubmatch(cleaned)
|
||||
if len(matches) == 4 {
|
||||
return parseBase64Image(matches[3], matches[1])
|
||||
}
|
||||
if len(matches) != 3 {
|
||||
return nil
|
||||
}
|
||||
|
||||
format := matches[1]
|
||||
return parseBase64Image(matches[2], matches[1])
|
||||
}
|
||||
|
||||
func parseBase64Image(data, format string) *KiroImage {
|
||||
format = strings.ToLower(format)
|
||||
if format == "jpg" {
|
||||
format = "jpeg"
|
||||
}
|
||||
|
||||
// 验证 base64
|
||||
if _, err := base64.StdEncoding.DecodeString(matches[2]); err != nil {
|
||||
return nil
|
||||
if _, err := base64.StdEncoding.DecodeString(data); err != nil {
|
||||
if _, errRaw := base64.RawStdEncoding.DecodeString(data); errRaw != nil {
|
||||
if _, errURL := base64.URLEncoding.DecodeString(data); errURL != nil {
|
||||
if _, errRawURL := base64.RawURLEncoding.DecodeString(data); errRawURL != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if format == "" {
|
||||
format = "png"
|
||||
}
|
||||
|
||||
return &KiroImage{
|
||||
Format: format,
|
||||
Source: struct {
|
||||
Bytes string `json:"bytes"`
|
||||
}{Bytes: matches[2]},
|
||||
}{Bytes: data},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
198
proxy/translator_test.go
Normal file
198
proxy/translator_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractOpenAIMessageTextStructured(t *testing.T) {
|
||||
content := []interface{}{
|
||||
map[string]interface{}{"type": "text", "text": "alpha"},
|
||||
map[string]interface{}{"type": "input_text", "text": "beta"},
|
||||
}
|
||||
|
||||
if got := extractOpenAIMessageText(content); got != "alphabeta" {
|
||||
t.Fatalf("expected concatenated structured text, got %q", got)
|
||||
}
|
||||
|
||||
nested := map[string]interface{}{
|
||||
"content": []interface{}{map[string]interface{}{"type": "text", "text": "nested"}},
|
||||
}
|
||||
if got := extractOpenAIMessageText(nested); got != "nested" {
|
||||
t.Fatalf("expected nested content extraction, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIToKiroPreservesStructuredAssistantAndToolContent(t *testing.T) {
|
||||
req := &OpenAIRequest{
|
||||
Model: "claude-sonnet-4.5",
|
||||
Messages: []OpenAIMessage{
|
||||
{
|
||||
Role: "system",
|
||||
Content: []interface{}{
|
||||
map[string]interface{}{"type": "text", "text": "system-a"},
|
||||
map[string]interface{}{"type": "text", "text": "system-b"},
|
||||
},
|
||||
},
|
||||
{Role: "user", Content: "first-question"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []interface{}{
|
||||
map[string]interface{}{"type": "text", "text": "assistant-structured"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
ToolCallID: "call_1",
|
||||
Content: []interface{}{
|
||||
map[string]interface{}{"type": "text", "text": "tool-result-structured"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
payload := OpenAIToKiro(req, false)
|
||||
|
||||
if len(payload.ConversationState.History) != 2 {
|
||||
t.Fatalf("expected 2 history items, got %d", len(payload.ConversationState.History))
|
||||
}
|
||||
|
||||
firstHistoryUser := payload.ConversationState.History[0].UserInputMessage
|
||||
if firstHistoryUser == nil {
|
||||
t.Fatalf("expected first history item to be user message")
|
||||
}
|
||||
if !strings.Contains(firstHistoryUser.Content, "system-a") ||
|
||||
!strings.Contains(firstHistoryUser.Content, "system-b") ||
|
||||
!strings.Contains(firstHistoryUser.Content, "first-question") {
|
||||
t.Fatalf("expected merged system+user content, got %q", firstHistoryUser.Content)
|
||||
}
|
||||
|
||||
historyAssistant := payload.ConversationState.History[1].AssistantResponseMessage
|
||||
if historyAssistant == nil {
|
||||
t.Fatalf("expected second history item to be assistant message")
|
||||
}
|
||||
if historyAssistant.Content != "assistant-structured" {
|
||||
t.Fatalf("expected assistant structured content to be preserved, got %q", historyAssistant.Content)
|
||||
}
|
||||
|
||||
cur := payload.ConversationState.CurrentMessage.UserInputMessage
|
||||
if cur.Content != "tool-result-structured" {
|
||||
t.Fatalf("expected tool-result continuation content, got %q", cur.Content)
|
||||
}
|
||||
if cur.UserInputMessageContext == nil || len(cur.UserInputMessageContext.ToolResults) != 1 {
|
||||
t.Fatalf("expected one tool result in current context")
|
||||
}
|
||||
gotToolText := cur.UserInputMessageContext.ToolResults[0].Content[0].Text
|
||||
if gotToolText != "tool-result-structured" {
|
||||
t.Fatalf("expected structured tool result text, got %q", gotToolText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIToKiroAssistantMapContentInHistory(t *testing.T) {
|
||||
req := &OpenAIRequest{
|
||||
Model: "claude-sonnet-4.5",
|
||||
Messages: []OpenAIMessage{
|
||||
{Role: "user", Content: "u1"},
|
||||
{Role: "assistant", Content: map[string]interface{}{"type": "text", "text": "assistant-map"}},
|
||||
{Role: "user", Content: "u2"},
|
||||
},
|
||||
}
|
||||
|
||||
payload := OpenAIToKiro(req, false)
|
||||
|
||||
if len(payload.ConversationState.History) != 2 {
|
||||
t.Fatalf("expected 2 history entries, got %d", len(payload.ConversationState.History))
|
||||
}
|
||||
assistant := payload.ConversationState.History[1].AssistantResponseMessage
|
||||
if assistant == nil {
|
||||
t.Fatalf("expected second history entry to be assistant")
|
||||
}
|
||||
if assistant.Content != "assistant-map" {
|
||||
t.Fatalf("expected assistant map content preserved, got %q", assistant.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIToKiroAssistantToolCallsDoNotInjectPlaceholder(t *testing.T) {
|
||||
req := &OpenAIRequest{
|
||||
Model: "claude-sonnet-4.5",
|
||||
Messages: []OpenAIMessage{
|
||||
{Role: "user", Content: "find weather"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: nil,
|
||||
ToolCalls: []ToolCall{{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}{Name: "get_weather", Arguments: "{}"},
|
||||
}},
|
||||
},
|
||||
{Role: "user", Content: "continue"},
|
||||
},
|
||||
}
|
||||
|
||||
payload := OpenAIToKiro(req, false)
|
||||
if len(payload.ConversationState.History) < 2 {
|
||||
t.Fatalf("expected history with assistant tool call")
|
||||
}
|
||||
assistant := payload.ConversationState.History[1].AssistantResponseMessage
|
||||
if assistant == nil {
|
||||
t.Fatalf("expected assistant history entry")
|
||||
}
|
||||
if assistant.Content != "" {
|
||||
t.Fatalf("expected empty assistant content for tool-call-only turn, got %q", assistant.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIConversationIDStableFromAnchor(t *testing.T) {
|
||||
baseMessages := []OpenAIMessage{
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "Build calculator"},
|
||||
{Role: "assistant", Content: "Sure"},
|
||||
{Role: "user", Content: "Continue"},
|
||||
}
|
||||
|
||||
reqA := &OpenAIRequest{Model: "claude-sonnet-4.5", Messages: baseMessages}
|
||||
reqB := &OpenAIRequest{Model: "claude-sonnet-4.5", Messages: append(baseMessages, OpenAIMessage{Role: "assistant", Content: "Next step"})}
|
||||
|
||||
payloadA := OpenAIToKiro(reqA, false)
|
||||
payloadB := OpenAIToKiro(reqB, false)
|
||||
|
||||
if payloadA.ConversationState.ConversationID == "" || payloadB.ConversationState.ConversationID == "" {
|
||||
t.Fatalf("expected non-empty conversation IDs")
|
||||
}
|
||||
if payloadA.ConversationState.ConversationID != payloadB.ConversationState.ConversationID {
|
||||
t.Fatalf("expected stable conversation ID across turns, got %q vs %q", payloadA.ConversationState.ConversationID, payloadB.ConversationState.ConversationID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeConversationIDStableFromAnchor(t *testing.T) {
|
||||
reqA := &ClaudeRequest{
|
||||
Model: "claude-sonnet-4.5",
|
||||
System: "sys",
|
||||
Messages: []ClaudeMessage{
|
||||
{Role: "user", Content: "hello"},
|
||||
},
|
||||
}
|
||||
reqB := &ClaudeRequest{
|
||||
Model: "claude-sonnet-4.5",
|
||||
System: "sys",
|
||||
Messages: []ClaudeMessage{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "ok"},
|
||||
{Role: "user", Content: "next"},
|
||||
},
|
||||
}
|
||||
|
||||
payloadA := ClaudeToKiro(reqA, false)
|
||||
payloadB := ClaudeToKiro(reqB, false)
|
||||
|
||||
if payloadA.ConversationState.ConversationID == "" || payloadB.ConversationState.ConversationID == "" {
|
||||
t.Fatalf("expected non-empty conversation IDs")
|
||||
}
|
||||
if payloadA.ConversationState.ConversationID != payloadB.ConversationState.ConversationID {
|
||||
t.Fatalf("expected stable conversation ID across turns, got %q vs %q", payloadA.ConversationState.ConversationID, payloadB.ConversationState.ConversationID)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user