diff --git a/proxy/handler.go b/proxy/handler.go
index 7a70b08..2559406 100644
--- a/proxy/handler.go
+++ b/proxy/handler.go
@@ -35,6 +35,32 @@ type Handler struct {
modelsCacheTime int64
}
+type thinkingStreamSource int
+
+const (
+ thinkingSourceUnknown thinkingStreamSource = iota
+ thinkingSourceReasoningEvent
+ thinkingSourceTagBlock
+)
+
+func allowReasoningSource(source *thinkingStreamSource) bool {
+ if *source == thinkingSourceTagBlock {
+ return false
+ }
+ *source = thinkingSourceReasoningEvent
+ return true
+}
+
+func allowTagSource(source *thinkingStreamSource) bool {
+ if *source == thinkingSourceReasoningEvent {
+ return false
+ }
+ if *source == thinkingSourceUnknown {
+ *source = thinkingSourceTagBlock
+ }
+ return *source == thinkingSourceTagBlock
+}
+
func NewHandler() *Handler {
totalReq, successReq, failedReq, totalTokens, totalCredits := config.GetStats()
h := &Handler{
@@ -248,36 +274,33 @@ func (h *Handler) handleModels(w http.ResponseWriter, r *http.Request) {
var models []map[string]interface{}
if len(cached) > 0 {
for _, m := range cached {
- models = append(models, map[string]interface{}{
- "id": m.ModelId, "object": "model", "owned_by": "anthropic",
- })
+ supportsImage := modelSupportsImage(m.InputTypes)
+ models = append(models, buildModelInfo(m.ModelId, "anthropic", supportsImage))
// 自动生成 thinking 变体
- models = append(models, map[string]interface{}{
- "id": m.ModelId + thinkingSuffix, "object": "model", "owned_by": "anthropic",
- })
+ models = append(models, buildModelInfo(m.ModelId+thinkingSuffix, "anthropic", supportsImage))
}
} else {
// fallback 静态列表
models = []map[string]interface{}{
- {"id": "claude-sonnet-4.6", "object": "model", "owned_by": "anthropic"},
- {"id": "claude-sonnet-4.6" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
- {"id": "claude-opus-4.6", "object": "model", "owned_by": "anthropic"},
- {"id": "claude-opus-4.6" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
- {"id": "claude-sonnet-4.5", "object": "model", "owned_by": "anthropic"},
- {"id": "claude-sonnet-4.5" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
- {"id": "claude-sonnet-4", "object": "model", "owned_by": "anthropic"},
- {"id": "claude-sonnet-4" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
- {"id": "claude-haiku-4.5", "object": "model", "owned_by": "anthropic"},
- {"id": "claude-haiku-4.5" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
- {"id": "claude-opus-4.5", "object": "model", "owned_by": "anthropic"},
- {"id": "claude-opus-4.5" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
+ buildModelInfo("claude-sonnet-4.6", "anthropic", true),
+ buildModelInfo("claude-sonnet-4.6"+thinkingSuffix, "anthropic", true),
+ buildModelInfo("claude-opus-4.6", "anthropic", true),
+ buildModelInfo("claude-opus-4.6"+thinkingSuffix, "anthropic", true),
+ buildModelInfo("claude-sonnet-4.5", "anthropic", true),
+ buildModelInfo("claude-sonnet-4.5"+thinkingSuffix, "anthropic", true),
+ buildModelInfo("claude-sonnet-4", "anthropic", true),
+ buildModelInfo("claude-sonnet-4"+thinkingSuffix, "anthropic", true),
+ buildModelInfo("claude-haiku-4.5", "anthropic", true),
+ buildModelInfo("claude-haiku-4.5"+thinkingSuffix, "anthropic", true),
+ buildModelInfo("claude-opus-4.5", "anthropic", true),
+ buildModelInfo("claude-opus-4.5"+thinkingSuffix, "anthropic", true),
}
}
// 添加别名模型
models = append(models,
- map[string]interface{}{"id": "auto", "object": "model", "owned_by": "kiro-proxy"},
- map[string]interface{}{"id": "gpt-4o", "object": "model", "owned_by": "kiro-proxy"},
- map[string]interface{}{"id": "gpt-4", "object": "model", "owned_by": "kiro-proxy"},
+ buildModelInfo("auto", "kiro-proxy", true),
+ buildModelInfo("gpt-4o", "kiro-proxy", true),
+ buildModelInfo("gpt-4", "kiro-proxy", true),
)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
@@ -287,6 +310,49 @@ func (h *Handler) handleModels(w http.ResponseWriter, r *http.Request) {
})
}
+func modelSupportsImage(inputTypes []string) bool {
+ for _, t := range inputTypes {
+ lt := strings.ToLower(t)
+ if strings.Contains(lt, "image") || strings.Contains(lt, "vision") {
+ return true
+ }
+ }
+ return false
+}
+
+func buildModelInfo(id, ownedBy string, supportsImage bool) map[string]interface{} {
+ modalities := []string{"text"}
+ if supportsImage {
+ modalities = append(modalities, "image")
+ }
+ modalitiesMap := map[string][]string{
+ "input": modalities,
+ "output": []string{"text"},
+ }
+
+ return map[string]interface{}{
+ "id": id,
+ "object": "model",
+ "owned_by": ownedBy,
+ "supports_image": supportsImage,
+ "input_modalities": modalities,
+ "modalities": modalitiesMap,
+ "capabilities": map[string]bool{
+ "vision": supportsImage,
+ "image": supportsImage,
+ "image_vision": supportsImage,
+ },
+ "info": map[string]interface{}{
+ "meta": map[string]interface{}{
+ "capabilities": map[string]bool{
+ "vision": supportsImage,
+ "image_vision": supportsImage,
+ },
+ },
+ },
+ }
+}
+
// refreshModelsCache 从 Kiro API 拉取模型列表并缓存
func (h *Handler) refreshModelsCache() {
account := h.pool.GetNext()
@@ -327,50 +393,13 @@ func (h *Handler) handleCountTokens(w http.ResponseWriter, r *http.Request) {
return
}
- var req struct {
- Messages []struct {
- Role string `json:"role"`
- Content interface{} `json:"content"`
- } `json:"messages"`
- System interface{} `json:"system"`
- }
+ var req ClaudeRequest
if err := json.Unmarshal(body, &req); err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Invalid JSON")
return
}
- // 简单估算 token 数量(每 4 个字符约 1 个 token)
- var totalChars int
- for _, msg := range req.Messages {
- switch content := msg.Content.(type) {
- case string:
- totalChars += len(content)
- case []interface{}:
- for _, part := range content {
- if p, ok := part.(map[string]interface{}); ok {
- if text, ok := p["text"].(string); ok {
- totalChars += len(text)
- }
- }
- }
- }
- }
-
- // 系统提示
- switch system := req.System.(type) {
- case string:
- totalChars += len(system)
- case []interface{}:
- for _, part := range system {
- if p, ok := part.(map[string]interface{}); ok {
- if text, ok := p["text"].(string); ok {
- totalChars += len(text)
- }
- }
- }
- }
-
- estimatedTokens := (totalChars + 3) / 4 // 向上取整
+ estimatedTokens := estimateClaudeRequestInputTokens(&req)
if estimatedTokens < 1 {
estimatedTokens = 1
}
@@ -381,6 +410,10 @@ func (h *Handler) handleCountTokens(w http.ResponseWriter, r *http.Request) {
// handleClaudeMessages Claude API 处理
func (h *Handler) handleClaudeMessages(w http.ResponseWriter, r *http.Request) {
+ h.handleClaudeMessagesInternal(w, r)
+}
+
+func (h *Handler) handleClaudeMessagesInternal(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method Not Allowed", 405)
return
@@ -416,20 +449,21 @@ func (h *Handler) handleClaudeMessages(w http.ResponseWriter, r *http.Request) {
thinkingCfg := config.GetThinkingConfig()
actualModel, thinking := ParseModelAndThinking(req.Model, thinkingCfg.Suffix)
req.Model = actualModel
+ estimatedInputTokens := estimateClaudeRequestInputTokens(&req)
// 转换请求
kiroPayload := ClaudeToKiro(&req, thinking)
// 流式或非流式
if req.Stream {
- h.handleClaudeStream(w, account, kiroPayload, req.Model)
+ h.handleClaudeStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens)
} else {
- h.handleClaudeNonStream(w, account, kiroPayload, req.Model)
+ h.handleClaudeNonStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens)
}
}
// handleClaudeStream Claude 流式响应
-func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string) {
+func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int) {
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
@@ -444,91 +478,169 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
thinkingFormat := config.GetThinkingConfig().ClaudeFormat
msgID := "msg_" + uuid.New().String()
- var contentStarted bool
- var toolUseIndex int
var inputTokens, outputTokens int
var credits float64
var toolUses []KiroToolUse
+ var nextContentIndex int
+ var rawContentBuilder strings.Builder
+ var rawThinkingBuilder strings.Builder
+ activeBlockIndex := -1
+ activeBlockType := ""
+ startInputTokens := estimatedInputTokens
+
+ closeActiveBlock := func() {
+ if activeBlockIndex < 0 {
+ return
+ }
+ h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
+ "type": "content_block_stop",
+ "index": activeBlockIndex,
+ })
+ activeBlockIndex = -1
+ activeBlockType = ""
+ }
+
+ startContentBlock := func(blockType string) {
+ if activeBlockType == blockType {
+ return
+ }
+ closeActiveBlock()
+
+ idx := nextContentIndex
+ nextContentIndex++
+
+ if blockType == "thinking" {
+ h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
+ "type": "content_block_start",
+ "index": idx,
+ "content_block": map[string]string{
+ "type": "thinking",
+ "thinking": "",
+ },
+ })
+ } else {
+ h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
+ "type": "content_block_start",
+ "index": idx,
+ "content_block": map[string]string{
+ "type": "text",
+ "text": "",
+ },
+ })
+ }
+
+ activeBlockIndex = idx
+ activeBlockType = blockType
+ }
// Thinking 标签解析状态
var textBuffer string
var inThinkingBlock bool
+ var dropTagThinking bool
+ var thinkingSource thinkingStreamSource
// 发送文本的辅助函数
// thinkingState: 0=普通内容, 1=thinking开始, 2=thinking中间, 3=thinking结束
sendText := func(text string, thinkingState int) {
- // 确保 content_block 已开始
- if !contentStarted {
- h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
- "type": "content_block_start",
- "index": 0,
- "content_block": map[string]string{"type": "text", "text": ""},
- })
- contentStarted = true
- }
-
if thinkingState == 0 {
// 普通内容
if text == "" {
return
}
+ startContentBlock("text")
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta",
- "index": 0,
+ "index": activeBlockIndex,
"delta": map[string]string{"type": "text_delta", "text": text},
})
- } else {
- // thinking 内容
+ return
+ }
+
+ if !thinking {
+ return
+ }
+
+ switch thinkingFormat {
+ case "think":
var outputText string
- switch thinkingFormat {
- case "think":
- switch thinkingState {
- case 1:
- outputText = "" + text
- case 2:
- outputText = text
- case 3:
- outputText = text + ""
- }
- case "reasoning_content":
- // Claude 格式不支持 reasoning_content,直接输出内容
+ switch thinkingState {
+ case 1:
+ outputText = "" + text
+ case 2:
outputText = text
- default: // "thinking"
- switch thinkingState {
- case 1:
- outputText = "" + text
- case 2:
- outputText = text
- case 3:
- outputText = text + ""
- }
+ case 3:
+ outputText = text + ""
}
if outputText == "" {
return
}
+ startContentBlock("text")
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta",
- "index": 0,
+ "index": activeBlockIndex,
"delta": map[string]string{"type": "text_delta", "text": outputText},
})
+ case "reasoning_content":
+ if text == "" {
+ return
+ }
+ startContentBlock("text")
+ h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
+ "type": "content_block_delta",
+ "index": activeBlockIndex,
+ "delta": map[string]string{"type": "text_delta", "text": text},
+ })
+ default:
+ if thinkingState == 3 && text == "" {
+ if activeBlockType == "thinking" {
+ closeActiveBlock()
+ }
+ return
+ }
+ if text != "" {
+ startContentBlock("thinking")
+ h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
+ "type": "content_block_delta",
+ "index": activeBlockIndex,
+ "delta": map[string]string{"type": "thinking_delta", "thinking": text},
+ })
+ }
+ if thinkingState == 3 && activeBlockType == "thinking" {
+ closeActiveBlock()
+ }
}
}
// 处理文本,解析 标签
var thinkingStarted bool
+ var eventThinkingOpen bool
processClaudeText := func(text string, isThinking bool, forceFlush bool) {
+ if isThinking && !thinking {
+ return
+ }
+
// 如果是 reasoningContentEvent,直接输出
if isThinking {
+ if !allowReasoningSource(&thinkingSource) {
+ return
+ }
if !thinkingStarted {
sendText(text, 1)
thinkingStarted = true
+ eventThinkingOpen = true
} else {
sendText(text, 2)
}
return
}
+ if eventThinkingOpen {
+ sendText("", 3)
+ eventThinkingOpen = false
+ thinkingStarted = false
+ }
+
textBuffer += text
for {
@@ -540,6 +652,7 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
}
textBuffer = textBuffer[thinkingStart+10:]
inThinkingBlock = true
+ dropTagThinking = !allowTagSource(&thinkingSource)
thinkingStarted = false
} else if forceFlush || len([]rune(textBuffer)) > 50 {
// 使用 rune 切片来正确处理 Unicode 字符
@@ -560,25 +673,33 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
thinkingEnd := strings.Index(textBuffer, "")
if thinkingEnd != -1 {
content := textBuffer[:thinkingEnd]
- if !thinkingStarted {
- sendText(content, 1)
- sendText("", 3)
- } else {
- sendText(content, 3)
+ if !dropTagThinking {
+ if !thinkingStarted {
+ sendText(content, 1)
+ sendText("", 3)
+ } else {
+ sendText(content, 3)
+ }
}
textBuffer = textBuffer[thinkingEnd+11:]
inThinkingBlock = false
+ dropTagThinking = false
thinkingStarted = false
} else if forceFlush {
if textBuffer != "" {
- if !thinkingStarted {
- sendText(textBuffer, 1)
- sendText("", 3)
- } else {
- sendText(textBuffer, 3)
+ if !dropTagThinking {
+ if !thinkingStarted {
+ sendText(textBuffer, 1)
+ sendText("", 3)
+ } else {
+ sendText(textBuffer, 3)
+ }
}
textBuffer = ""
}
+ inThinkingBlock = false
+ dropTagThinking = false
+ thinkingStarted = false
break
} else {
// 流式输出 thinking 块内的内容
@@ -586,11 +707,13 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
if len(runes) > 20 {
safeLen := len(runes) - 15
if safeLen > 0 {
- if !thinkingStarted {
- sendText(string(runes[:safeLen]), 1)
- thinkingStarted = true
- } else {
- sendText(string(runes[:safeLen]), 2)
+ if !dropTagThinking {
+ if !thinkingStarted {
+ sendText(string(runes[:safeLen]), 1)
+ thinkingStarted = true
+ } else {
+ sendText(string(runes[:safeLen]), 2)
+ }
}
textBuffer = string(runes[safeLen:])
}
@@ -605,11 +728,17 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
h.sendSSE(w, flusher, "message_start", map[string]interface{}{
"type": "message_start",
"message": map[string]interface{}{
- "id": msgID,
- "type": "message",
- "role": "assistant",
- "content": []interface{}{},
- "model": model,
+ "id": msgID,
+ "type": "message",
+ "role": "assistant",
+ "content": []interface{}{},
+ "model": model,
+ "stop_reason": nil,
+ "stop_sequence": nil,
+ "usage": map[string]int{
+ "input_tokens": startInputTokens,
+ "output_tokens": 0,
+ },
},
})
@@ -618,27 +747,26 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
if text == "" {
return
}
+ if isThinking {
+ rawThinkingBuilder.WriteString(text)
+ } else {
+ rawContentBuilder.WriteString(text)
+ }
processClaudeText(text, isThinking, false)
},
OnToolUse: func(tu KiroToolUse) {
// 先刷新缓冲区
processClaudeText("", false, true)
+ rawContentBuilder.WriteString(tu.Name)
+ if b, err := json.Marshal(tu.Input); err == nil {
+ rawContentBuilder.Write(b)
+ }
toolUses = append(toolUses, tu)
+ closeActiveBlock()
- // 关闭文本块
- if contentStarted && toolUseIndex == 0 {
- h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
- "type": "content_block_stop",
- "index": 0,
- })
- }
-
- idx := toolUseIndex
- if contentStarted {
- idx = toolUseIndex + 1
- }
- toolUseIndex++
+ idx := nextContentIndex
+ nextContentIndex++
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start",
@@ -691,19 +819,27 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
// 刷新剩余缓冲区
processClaudeText("", false, true)
+ if eventThinkingOpen {
+ sendText("", 3)
+ eventThinkingOpen = false
+ }
+ closeActiveBlock()
+
+ inputTokens = estimatedInputTokens
+ outputContent, extractedReasoning := extractThinkingFromContent(rawContentBuilder.String())
+ thinkingOutput := rawThinkingBuilder.String()
+ if thinking && thinkingOutput == "" && extractedReasoning != "" {
+ thinkingOutput = extractedReasoning
+ }
+ if !thinking {
+ thinkingOutput = ""
+ }
+ outputTokens = estimateClaudeOutputTokens(outputContent, thinkingOutput, toolUses)
h.recordSuccess(inputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
- // 关闭最后的内容块
- if contentStarted && toolUseIndex == 0 {
- h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
- "type": "content_block_stop",
- "index": 0,
- })
- }
-
// 发送 message_delta
stopReason := "end_turn"
if len(toolUses) > 0 {
@@ -787,7 +923,7 @@ func (h *Handler) recordFailure() {
}
// handleClaudeNonStream Claude 非流式响应
-func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string) {
+func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int) {
var content string
var thinkingContent string
var toolUses []KiroToolUse
@@ -825,25 +961,36 @@ func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.A
return
}
+ // 合并 thinking 内容(如果有 reasoningContentEvent 的内容)
+ thinkingFormat := config.GetThinkingConfig().ClaudeFormat
+ finalContent, extractedReasoning := extractThinkingFromContent(content)
+ if thinking && thinkingContent == "" && extractedReasoning != "" {
+ thinkingContent = extractedReasoning
+ }
+ if !thinking {
+ thinkingContent = ""
+ }
+
+ inputTokens = estimatedInputTokens
+ outputTokens = estimateClaudeOutputTokens(finalContent, thinkingContent, toolUses)
+
h.recordSuccess(inputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
- // 合并 thinking 内容(如果有 reasoningContentEvent 的内容)
- thinkingFormat := config.GetThinkingConfig().ClaudeFormat
- finalContent := content
- if thinkingContent != "" {
+ if thinking && thinkingContent != "" {
switch thinkingFormat {
case "think":
- finalContent = "" + thinkingContent + "" + content
+ finalContent = "" + thinkingContent + "" + finalContent
+ thinkingContent = ""
case "reasoning_content":
- finalContent = thinkingContent + content // Claude 格式不支持 reasoning_content,直接拼接
- default: // "thinking"
- finalContent = "" + thinkingContent + "" + content
+ finalContent = thinkingContent + finalContent // Claude 格式不支持 reasoning_content,直接拼接
+ thinkingContent = ""
+ default:
}
}
- resp := KiroToClaudeResponse(finalContent, toolUses, inputTokens, outputTokens, model)
+ resp := KiroToClaudeResponse(finalContent, thinkingContent, toolUses, inputTokens, outputTokens, model)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(resp)
}
@@ -894,18 +1041,19 @@ func (h *Handler) handleOpenAIChat(w http.ResponseWriter, r *http.Request) {
thinkingCfg := config.GetThinkingConfig()
actualModel, thinking := ParseModelAndThinking(req.Model, thinkingCfg.Suffix)
req.Model = actualModel
+ estimatedInputTokens := estimateOpenAIRequestInputTokens(&req)
kiroPayload := OpenAIToKiro(&req, thinking)
if req.Stream {
- h.handleOpenAIStream(w, account, kiroPayload, req.Model)
+ h.handleOpenAIStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens)
} else {
- h.handleOpenAINonStream(w, account, kiroPayload, req.Model)
+ h.handleOpenAINonStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens)
}
}
// handleOpenAIStream OpenAI 流式响应
-func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string) {
+func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int) {
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
@@ -924,10 +1072,14 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
var toolCallIndex int
var inputTokens, outputTokens int
var credits float64
+ var rawContentBuilder strings.Builder
+ var rawReasoningBuilder strings.Builder
// Thinking 标签解析状态
var textBuffer string
var inThinkingBlock bool
+ var dropTagThinking bool
+ var thinkingSource thinkingStreamSource
// 发送 chunk 的辅助函数
// thinkingState: 0=普通内容, 1=thinking开始, 2=thinking中间, 3=thinking结束
@@ -939,6 +1091,9 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
var chunk map[string]interface{}
if thinkingState > 0 {
+ if !thinking {
+ return
+ }
// thinking 内容
switch thinkingFormat {
case "thinking":
@@ -1031,19 +1186,34 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
// 处理文本,解析 标签
// thinkingStarted 用于跟踪是否已发送开始标签
var thinkingStarted bool
+ var eventThinkingOpen bool
processText := func(text string, isThinking bool, forceFlush bool) {
+ if isThinking && !thinking {
+ return
+ }
+
// 如果是 reasoningContentEvent,直接输出
if isThinking {
+ if !allowReasoningSource(&thinkingSource) {
+ return
+ }
if !thinkingStarted {
sendChunk(text, 1) // 开始
thinkingStarted = true
+ eventThinkingOpen = true
} else {
sendChunk(text, 2) // 中间
}
return
}
+ if eventThinkingOpen {
+ sendChunk("", 3)
+ eventThinkingOpen = false
+ thinkingStarted = false
+ }
+
textBuffer += text
for {
@@ -1057,6 +1227,7 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
}
textBuffer = textBuffer[thinkingStart+10:] // 移除
inThinkingBlock = true
+ dropTagThinking = !allowTagSource(&thinkingSource)
thinkingStarted = false // 重置,准备发送新的开始标签
} else if forceFlush || len([]rune(textBuffer)) > 50 {
// 没有找到标签,安全输出(保留可能的部分标签)
@@ -1079,28 +1250,36 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
if thinkingEnd != -1 {
// 输出 thinking 内容
content := textBuffer[:thinkingEnd]
- if !thinkingStarted {
- // 一次性输出完整内容(开始+内容+结束)
- sendChunk(content, 1) // 开始
- sendChunk("", 3) // 结束(空内容,只发结束标签)
- } else {
- // 已经开始了,发送剩余内容和结束
- sendChunk(content, 3) // 结束
+ if !dropTagThinking {
+ if !thinkingStarted {
+ // 一次性输出完整内容(开始+内容+结束)
+ sendChunk(content, 1) // 开始
+ sendChunk("", 3) // 结束(空内容,只发结束标签)
+ } else {
+ // 已经开始了,发送剩余内容和结束
+ sendChunk(content, 3) // 结束
+ }
}
textBuffer = textBuffer[thinkingEnd+11:] // 移除
inThinkingBlock = false
+ dropTagThinking = false
thinkingStarted = false
} else if forceFlush {
// 强制刷新:输出剩余内容
if textBuffer != "" {
- if !thinkingStarted {
- sendChunk(textBuffer, 1) // 开始
- sendChunk("", 3) // 结束
- } else {
- sendChunk(textBuffer, 3) // 结束
+ if !dropTagThinking {
+ if !thinkingStarted {
+ sendChunk(textBuffer, 1) // 开始
+ sendChunk("", 3) // 结束
+ } else {
+ sendChunk(textBuffer, 3) // 结束
+ }
}
textBuffer = ""
}
+ inThinkingBlock = false
+ dropTagThinking = false
+ thinkingStarted = false
break
} else {
// 流式输出 thinking 块内的内容
@@ -1108,11 +1287,13 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
if len(runes) > 20 {
safeLen := len(runes) - 15 // 保留可能的 部分
if safeLen > 0 {
- if !thinkingStarted {
- sendChunk(string(runes[:safeLen]), 1) // 开始
- thinkingStarted = true
- } else {
- sendChunk(string(runes[:safeLen]), 2) // 中间
+ if !dropTagThinking {
+ if !thinkingStarted {
+ sendChunk(string(runes[:safeLen]), 1) // 开始
+ thinkingStarted = true
+ } else {
+ sendChunk(string(runes[:safeLen]), 2) // 中间
+ }
}
textBuffer = string(runes[safeLen:])
}
@@ -1128,6 +1309,11 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
if text == "" {
return
}
+ if isThinking {
+ rawReasoningBuilder.WriteString(text)
+ } else {
+ rawContentBuilder.WriteString(text)
+ }
processText(text, isThinking, false)
},
OnToolUse: func(tu KiroToolUse) {
@@ -1135,6 +1321,8 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
processText("", false, true)
args, _ := json.Marshal(tu.Input)
+ rawContentBuilder.WriteString(tu.Name)
+ rawContentBuilder.Write(args)
tc := ToolCall{ID: tu.ToolUseID, Type: "function"}
tc.Function.Name = tu.Name
tc.Function.Arguments = string(args)
@@ -1187,6 +1375,25 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
// 刷新剩余缓冲区
processText("", false, true)
+ if eventThinkingOpen {
+ sendChunk("", 3)
+ eventThinkingOpen = false
+ }
+
+ inputTokens = estimatedInputTokens
+ outputContent, extractedReasoning := extractThinkingFromContent(rawContentBuilder.String())
+ reasoningOutput := rawReasoningBuilder.String()
+ if thinking && reasoningOutput == "" && extractedReasoning != "" {
+ reasoningOutput = extractedReasoning
+ }
+ if !thinking {
+ reasoningOutput = ""
+ }
+ outputTokens = estimateApproxTokens(outputContent) + estimateApproxTokens(reasoningOutput)
+ for _, tc := range toolCalls {
+ outputTokens += estimateApproxTokens(tc.Function.Name)
+ outputTokens += estimateApproxTokens(tc.Function.Arguments)
+ }
h.recordSuccess(inputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
@@ -1221,7 +1428,7 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
}
// handleOpenAINonStream OpenAI 非流式响应
-func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string) {
+func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int) {
var content string
var reasoningContent string
var toolUses []KiroToolUse
@@ -1250,16 +1457,21 @@ func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.A
return
}
+ // 解析 content 中的 标签
+ finalContent, extractedReasoning := extractThinkingFromContent(content)
+ if thinking && reasoningContent == "" && extractedReasoning != "" {
+ reasoningContent = extractedReasoning
+ } else if !thinking {
+ reasoningContent = ""
+ }
+
+ inputTokens = estimatedInputTokens
+ outputTokens = estimateOpenAIOutputTokens(finalContent, reasoningContent, toolUses)
+
h.recordSuccess(inputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits)
- // 解析 content 中的 标签
- finalContent, extractedReasoning := extractThinkingFromContent(content)
- if extractedReasoning != "" {
- reasoningContent = extractedReasoning + reasoningContent
- }
-
thinkingFormat := config.GetThinkingConfig().OpenAIFormat
resp := KiroToOpenAIResponseWithReasoning(finalContent, reasoningContent, toolUses, inputTokens, outputTokens, model, thinkingFormat)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
diff --git a/proxy/handler_test.go b/proxy/handler_test.go
new file mode 100644
index 0000000..e45b8dd
--- /dev/null
+++ b/proxy/handler_test.go
@@ -0,0 +1,50 @@
+package proxy
+
+import "testing"
+
+func TestThinkingSourceReasoningFirst(t *testing.T) {
+ var source thinkingStreamSource
+
+ if !allowReasoningSource(&source) {
+ t.Fatalf("expected reasoning source to be accepted first")
+ }
+ if source != thinkingSourceReasoningEvent {
+ t.Fatalf("expected source to be reasoning, got %v", source)
+ }
+ if allowTagSource(&source) {
+ t.Fatalf("expected tag source to be rejected after reasoning source selected")
+ }
+}
+
+func TestThinkingSourceTagFirst(t *testing.T) {
+ var source thinkingStreamSource
+
+ if !allowTagSource(&source) {
+ t.Fatalf("expected tag source to be accepted first")
+ }
+ if source != thinkingSourceTagBlock {
+ t.Fatalf("expected source to be tag, got %v", source)
+ }
+ if allowReasoningSource(&source) {
+ t.Fatalf("expected reasoning source to be rejected after tag source selected")
+ }
+}
+
+func TestThinkingSourceSameSourceRemainsAllowed(t *testing.T) {
+ var source thinkingStreamSource
+
+ if !allowTagSource(&source) {
+ t.Fatalf("expected initial tag source selection to succeed")
+ }
+ if !allowTagSource(&source) {
+ t.Fatalf("expected repeated tag source selection to stay allowed")
+ }
+
+ source = thinkingSourceUnknown
+ if !allowReasoningSource(&source) {
+ t.Fatalf("expected initial reasoning source selection to succeed")
+ }
+ if !allowReasoningSource(&source) {
+ t.Fatalf("expected repeated reasoning source selection to stay allowed")
+ }
+}
diff --git a/proxy/kiro.go b/proxy/kiro.go
index 84412f2..1a6f53a 100644
--- a/proxy/kiro.go
+++ b/proxy/kiro.go
@@ -9,6 +9,7 @@ import (
"io"
"kiro-api-proxy/config"
"net/http"
+ "strconv"
"strings"
"time"
@@ -159,14 +160,10 @@ func getSortedEndpoints(preferred string) []kiroEndpoint {
// CallKiroAPI 调用 Kiro API(流式),双端点自动 fallback
func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroStreamCallback) error {
- body, err := json.Marshal(payload)
- if err != nil {
+ if _, err := json.Marshal(payload); err != nil {
return err
}
- // 预估输入 token(约 3 字符 = 1 token)
- estimatedInputTokens := max(1, len(body)/3)
-
// User-Agent
machineId := account.MachineId
var userAgent, amzUserAgent string
@@ -230,7 +227,7 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt
continue
}
- err = parseEventStream(resp.Body, callback, estimatedInputTokens)
+ err = parseEventStream(resp.Body, callback)
resp.Body.Close()
return err
}
@@ -244,12 +241,13 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt
// ==================== Event Stream 解析 ====================
// parseEventStream 解析 AWS Event Stream 二进制格式
-func parseEventStream(body io.Reader, callback *KiroStreamCallback, estimatedInputTokens int) error {
+func parseEventStream(body io.Reader, callback *KiroStreamCallback) error {
// 不使用 bufio,直接读取避免缓冲延迟
var inputTokens, outputTokens int
- var totalOutputChars int
var totalCredits float64
var currentToolUse *toolUseState
+ var lastAssistantContent string
+ var lastReasoningContent string
for {
// Prelude: 12 bytes (total_len + headers_len + crc)
@@ -292,30 +290,26 @@ func parseEventStream(body io.Reader, callback *KiroStreamCallback, estimatedInp
continue
}
+ inputTokens, outputTokens = updateTokensFromEvent(event, inputTokens, outputTokens)
+
// 处理事件
switch eventType {
case "assistantResponseEvent":
if content, ok := event["content"].(string); ok && content != "" {
- callback.OnText(content, false)
- totalOutputChars += len(content)
+ normalized := normalizeChunk(content, &lastAssistantContent)
+ if normalized != "" {
+ callback.OnText(normalized, false)
+ }
}
case "reasoningContentEvent":
if text, ok := event["text"].(string); ok && text != "" {
- callback.OnText(text, true)
- totalOutputChars += len(text)
+ normalized := normalizeChunk(text, &lastReasoningContent)
+ if normalized != "" {
+ callback.OnText(normalized, true)
+ }
}
case "toolUseEvent":
currentToolUse = handleToolUseEvent(event, currentToolUse, callback)
- case "messageMetadataEvent", "metadataEvent":
- if tokenUsage, ok := event["tokenUsage"].(map[string]interface{}); ok {
- if v, ok := tokenUsage["outputTokens"].(float64); ok {
- outputTokens = int(v)
- }
- uncached, _ := tokenUsage["uncachedInputTokens"].(float64)
- cacheRead, _ := tokenUsage["cacheReadInputTokens"].(float64)
- cacheWrite, _ := tokenUsage["cacheWriteInputTokens"].(float64)
- inputTokens = int(uncached + cacheRead + cacheWrite)
- }
case "meteringEvent":
if usage, ok := event["usage"].(float64); ok {
totalCredits += usage
@@ -323,15 +317,6 @@ func parseEventStream(body io.Reader, callback *KiroStreamCallback, estimatedInp
}
}
- // 估算 token(约 3 字符 = 1 token)
- if outputTokens == 0 && totalOutputChars > 0 {
- outputTokens = max(1, totalOutputChars/3)
- }
- // 如果 Kiro 没返回 inputTokens,使用预估值
- if inputTokens == 0 {
- inputTokens = estimatedInputTokens
- }
-
if callback.OnCredits != nil && totalCredits > 0 {
callback.OnCredits(totalCredits)
}
@@ -340,6 +325,152 @@ func parseEventStream(body io.Reader, callback *KiroStreamCallback, estimatedInp
return nil
}
+func updateTokensFromEvent(event map[string]interface{}, currentInputTokens, currentOutputTokens int) (int, int) {
+ candidates := []map[string]interface{}{event}
+ collectUsageMaps(event, &candidates)
+
+ inputTokens := currentInputTokens
+ outputTokens := currentOutputTokens
+
+ for _, usage := range candidates {
+ if usage == nil {
+ continue
+ }
+
+ if v, ok := readTokenNumber(usage,
+ "outputTokens", "completionTokens", "totalOutputTokens",
+ "output_tokens", "completion_tokens", "total_output_tokens",
+ ); ok {
+ outputTokens = v
+ }
+
+ if v, ok := readTokenNumber(usage,
+ "inputTokens", "promptTokens", "totalInputTokens",
+ "input_tokens", "prompt_tokens", "total_input_tokens",
+ ); ok {
+ inputTokens = v
+ continue
+ }
+
+ uncached, _ := readTokenNumber(usage, "uncachedInputTokens", "uncached_input_tokens")
+ cacheRead, _ := readTokenNumber(usage, "cacheReadInputTokens", "cache_read_input_tokens")
+ cacheWrite, _ := readTokenNumber(usage, "cacheWriteInputTokens", "cache_write_input_tokens", "cacheCreationInputTokens", "cache_creation_input_tokens")
+ if uncached+cacheRead+cacheWrite > 0 {
+ inputTokens = uncached + cacheRead + cacheWrite
+ continue
+ }
+
+ total, ok := readTokenNumber(usage, "totalTokens", "total_tokens")
+ if ok && total > 0 {
+ candidateOutput := outputTokens
+ if v, vok := readTokenNumber(usage,
+ "outputTokens", "completionTokens", "totalOutputTokens",
+ "output_tokens", "completion_tokens", "total_output_tokens",
+ ); vok {
+ candidateOutput = v
+ }
+ if total-candidateOutput > 0 {
+ inputTokens = total - candidateOutput
+ }
+ }
+ }
+
+ return inputTokens, outputTokens
+}
+
+func collectUsageMaps(v interface{}, out *[]map[string]interface{}) {
+ switch t := v.(type) {
+ case map[string]interface{}:
+ for k, child := range t {
+ lk := strings.ToLower(k)
+ if lk == "usage" || lk == "tokenusage" || lk == "token_usage" {
+ if m, ok := child.(map[string]interface{}); ok {
+ *out = append(*out, m)
+ }
+ }
+ collectUsageMaps(child, out)
+ }
+ case []interface{}:
+ for _, child := range t {
+ collectUsageMaps(child, out)
+ }
+ }
+}
+
+func normalizeChunk(chunk string, previous *string) string {
+ if chunk == "" {
+ return ""
+ }
+
+ prev := *previous
+ if prev == "" {
+ *previous = chunk
+ return chunk
+ }
+
+ if chunk == prev {
+ return ""
+ }
+
+ if strings.HasPrefix(chunk, prev) {
+ delta := chunk[len(prev):]
+ *previous = chunk
+ return delta
+ }
+
+ if strings.HasPrefix(prev, chunk) {
+ return ""
+ }
+
+ maxOverlap := 0
+ maxLen := len(prev)
+ if len(chunk) < maxLen {
+ maxLen = len(chunk)
+ }
+ for i := maxLen; i > 0; i-- {
+ if strings.HasSuffix(prev, chunk[:i]) {
+ maxOverlap = i
+ break
+ }
+ }
+
+ *previous = chunk
+ if maxOverlap > 0 {
+ return chunk[maxOverlap:]
+ }
+
+ return chunk
+}
+
+func readTokenNumber(m map[string]interface{}, keys ...string) (int, bool) {
+ for _, k := range keys {
+ v, ok := m[k]
+ if !ok {
+ continue
+ }
+ switch n := v.(type) {
+ case float64:
+ return int(n), true
+ case int:
+ return n, true
+ case int64:
+ return int(n), true
+ case json.Number:
+ if parsed, err := n.Int64(); err == nil {
+ return int(parsed), true
+ }
+ case string:
+ if parsed, err := strconv.Atoi(n); err == nil {
+ return parsed, true
+ }
+ if parsed, err := strconv.ParseFloat(n, 64); err == nil {
+ return int(parsed), true
+ }
+ }
+ }
+ return 0, false
+}
+
// ==================== Tool Use 处理 ====================
type toolUseState struct {
diff --git a/proxy/kiro_test.go b/proxy/kiro_test.go
new file mode 100644
index 0000000..f32190b
--- /dev/null
+++ b/proxy/kiro_test.go
@@ -0,0 +1,37 @@
+package proxy
+
+import "testing"
+
+func TestNormalizeChunkBasicProgression(t *testing.T) {
+ prev := ""
+
+ if got := normalizeChunk("abc", &prev); got != "abc" {
+ t.Fatalf("expected first chunk to pass through, got %q", got)
+ }
+ if got := normalizeChunk("abcde", &prev); got != "de" {
+ t.Fatalf("expected appended delta, got %q", got)
+ }
+}
+
+func TestNormalizeChunkPrefixRewindDoesNotReplay(t *testing.T) {
+ prev := ""
+
+ _ = normalizeChunk("abcde", &prev)
+ if got := normalizeChunk("abc", &prev); got != "" {
+ t.Fatalf("expected rewind chunk to be ignored, got %q", got)
+ }
+ if prev != "abcde" {
+ t.Fatalf("expected previous snapshot to remain longest version, got %q", prev)
+ }
+ if got := normalizeChunk("abcdef", &prev); got != "f" {
+ t.Fatalf("expected only unseen suffix after rewind, got %q", got)
+ }
+}
+
+func TestNormalizeChunkOverlapDelta(t *testing.T) {
+ prev := "hello world"
+
+ if got := normalizeChunk("world!!!", &prev); got != "!!!" {
+ t.Fatalf("expected overlap suffix delta, got %q", got)
+ }
+}
diff --git a/proxy/token_estimator.go b/proxy/token_estimator.go
new file mode 100644
index 0000000..717543e
--- /dev/null
+++ b/proxy/token_estimator.go
@@ -0,0 +1,196 @@
+package proxy
+
+import (
+ "encoding/json"
+ "math"
+)
+
+func estimateApproxTokens(text string) int {
+ if text == "" {
+ return 0
+ }
+
+ runes := []rune(text)
+ length := len(runes)
+ if length == 0 {
+ return 0
+ }
+ if length < 5 {
+ return max(1, int(math.Ceil(float64(length)/3.0)))
+ }
+
+ var regularAscii, digits, symbols, nonASCII int
+ for _, r := range runes {
+ switch {
+ case r >= 0x80:
+ nonASCII++
+ case r >= '0' && r <= '9':
+ digits++
+ case (r >= '!' && r <= '/') || (r >= ':' && r <= '@') || (r >= '[' && r <= '`') || (r >= '{' && r <= '~'):
+ symbols++
+ default:
+ regularAscii++
+ }
+ }
+
+ estimated := int(math.Ceil(
+ float64(regularAscii)/4.5 +
+ float64(digits)/2.0 +
+ float64(symbols)/1.5 +
+ float64(nonASCII)/1.5,
+ ))
+
+ if estimated < 1 {
+ return 1
+ }
+ return estimated
+}
+
+func estimateClaudeRequestInputTokens(req *ClaudeRequest) int {
+ if req == nil {
+ return 0
+ }
+
+ total := estimateClaudeValueTokens(req.System)
+
+ for _, msg := range req.Messages {
+ total += estimateClaudeValueTokens(msg.Content)
+ }
+
+ for _, tool := range req.Tools {
+ total += estimateApproxTokens(tool.Name)
+ total += estimateApproxTokens(tool.Description)
+ total += estimateJSONTokens(tool.InputSchema)
+ }
+
+ return total
+}
+
+func estimateClaudeOutputTokens(content, thinkingContent string, toolUses []KiroToolUse) int {
+ total := estimateApproxTokens(content)
+ total += estimateApproxTokens(thinkingContent)
+
+ for _, tu := range toolUses {
+ total += estimateApproxTokens(tu.Name)
+ total += estimateJSONTokens(tu.Input)
+ }
+
+ return total
+}
+
+func estimateClaudeValueTokens(v interface{}) int {
+ switch value := v.(type) {
+ case nil:
+ return 0
+ case string:
+ return estimateApproxTokens(value)
+ case []interface{}:
+ total := 0
+ for _, part := range value {
+ total += estimateClaudeValueTokens(part)
+ }
+ return total
+ case map[string]interface{}:
+ typeName, _ := value["type"].(string)
+ switch typeName {
+ case "text":
+ if text, ok := value["text"].(string); ok {
+ return estimateApproxTokens(text)
+ }
+ case "thinking":
+ if thinking, ok := value["thinking"].(string); ok {
+ return estimateApproxTokens(thinking)
+ }
+ case "tool_use":
+ total := 0
+ if name, ok := value["name"].(string); ok {
+ total += estimateApproxTokens(name)
+ }
+ if input, ok := value["input"]; ok {
+ total += estimateJSONTokens(input)
+ }
+ if total > 0 {
+ return total
+ }
+ case "tool_result":
+ if content, ok := value["content"]; ok {
+ return estimateClaudeValueTokens(content)
+ }
+ }
+
+ total := 0
+ if text, ok := value["text"].(string); ok {
+ total += estimateApproxTokens(text)
+ }
+ if thinking, ok := value["thinking"].(string); ok {
+ total += estimateApproxTokens(thinking)
+ }
+ if content, ok := value["content"]; ok {
+ total += estimateClaudeValueTokens(content)
+ }
+ if total > 0 {
+ return total
+ }
+
+ return estimateJSONTokens(value)
+ default:
+ return estimateJSONTokens(value)
+ }
+}
+
+func estimateJSONTokens(v interface{}) int {
+ if v == nil {
+ return 0
+ }
+
+ b, err := json.Marshal(v)
+ if err != nil {
+ return 0
+ }
+
+ return estimateApproxTokens(string(b))
+}
+
+func estimateOpenAIRequestInputTokens(req *OpenAIRequest) int {
+ if req == nil {
+ return 0
+ }
+
+ total := 0
+
+ for _, msg := range req.Messages {
+ total += estimateOpenAIContentTokens(msg.Content)
+ total += estimateApproxTokens(msg.ToolCallID)
+ for _, tc := range msg.ToolCalls {
+ total += estimateApproxTokens(tc.Function.Name)
+ total += estimateApproxTokens(tc.Function.Arguments)
+ }
+ }
+
+ for _, tool := range req.Tools {
+ total += estimateApproxTokens(tool.Function.Name)
+ total += estimateApproxTokens(tool.Function.Description)
+ total += estimateJSONTokens(tool.Function.Parameters)
+ }
+
+ return total
+}
+
+func estimateOpenAIContentTokens(content interface{}) int {
+ switch value := content.(type) {
+ case nil:
+ return 0
+ case string:
+ return estimateApproxTokens(value)
+ default:
+ text := extractOpenAIMessageText(value)
+ if text != "" {
+ return estimateApproxTokens(text)
+ }
+ return estimateJSONTokens(value)
+ }
+}
+
+func estimateOpenAIOutputTokens(content, reasoningContent string, toolUses []KiroToolUse) int {
+ return estimateClaudeOutputTokens(content, reasoningContent, toolUses)
+}
diff --git a/proxy/translator.go b/proxy/translator.go
index 250f23c..2971fbb 100644
--- a/proxy/translator.go
+++ b/proxy/translator.go
@@ -10,34 +10,40 @@ import (
"github.com/google/uuid"
)
-// 模型映射
-var modelMap = map[string]string{
- "claude-sonnet-4-5": "claude-sonnet-4.5",
- "claude-sonnet-4.5": "claude-sonnet-4.5",
- "claude-haiku-4-5": "claude-haiku-4.5",
- "claude-haiku-4.5": "claude-haiku-4.5",
- "claude-sonnet-4-6": "claude-sonnet-4.6",
- "claude-sonnet-4.6": "claude-sonnet-4.6",
- "claude-opus-4-6": "claude-opus-4.6",
- "claude-opus-4.6": "claude-opus-4.6",
- "claude-opus-4-5": "claude-opus-4.5",
- "claude-opus-4.5": "claude-opus-4.5",
- "claude-sonnet-4": "claude-sonnet-4",
- "claude-sonnet-4-20250514": "claude-sonnet-4",
- "claude-3-5-sonnet": "claude-sonnet-4.5",
- "claude-3-opus": "claude-sonnet-4.5",
- "claude-3-sonnet": "claude-sonnet-4",
- "claude-3-haiku": "claude-haiku-4.5",
- "gpt-4": "claude-sonnet-4.5",
- "gpt-4o": "claude-sonnet-4.5",
- "gpt-4-turbo": "claude-sonnet-4.5",
- "gpt-3.5-turbo": "claude-sonnet-4.5",
+type modelRule struct {
+ pattern string
+ target string
+}
+
+var modelRules = []modelRule{
+ {pattern: "claude-sonnet-4-20250514", target: "claude-sonnet-4"},
+ {pattern: "claude-sonnet-4-6", target: "claude-sonnet-4.6"},
+ {pattern: "claude-sonnet-4.6", target: "claude-sonnet-4.6"},
+ {pattern: "claude-sonnet-4-5", target: "claude-sonnet-4.5"},
+ {pattern: "claude-sonnet-4.5", target: "claude-sonnet-4.5"},
+ {pattern: "claude-haiku-4-5", target: "claude-haiku-4.5"},
+ {pattern: "claude-haiku-4.5", target: "claude-haiku-4.5"},
+ {pattern: "claude-opus-4-6", target: "claude-opus-4.6"},
+ {pattern: "claude-opus-4.6", target: "claude-opus-4.6"},
+ {pattern: "claude-opus-4-5", target: "claude-opus-4.5"},
+ {pattern: "claude-opus-4.5", target: "claude-opus-4.5"},
+ {pattern: "claude-3-5-sonnet", target: "claude-sonnet-4.5"},
+ {pattern: "claude-3-opus", target: "claude-sonnet-4.5"},
+ {pattern: "claude-3-sonnet", target: "claude-sonnet-4"},
+ {pattern: "claude-3-haiku", target: "claude-haiku-4.5"},
+ {pattern: "gpt-4o", target: "claude-sonnet-4.5"},
+ {pattern: "gpt-4-turbo", target: "claude-sonnet-4.5"},
+ {pattern: "gpt-3.5-turbo", target: "claude-sonnet-4.5"},
+ {pattern: "gpt-4", target: "claude-sonnet-4.5"},
+ {pattern: "claude-sonnet-4", target: "claude-sonnet-4"},
}
// Thinking 模式提示
const ThinkingModePrompt = `enabled
200000`
+const minimalFallbackUserContent = "."
+
// ParseModelAndThinking 解析模型名称,返回实际模型和是否启用 thinking
func ParseModelAndThinking(model string, thinkingSuffix string) (string, bool) {
lower := strings.ToLower(model)
@@ -51,10 +57,9 @@ func ParseModelAndThinking(model string, thinkingSuffix string) (string, bool) {
lower = strings.ToLower(model)
}
- // 映射模型
- for k, v := range modelMap {
- if strings.Contains(lower, k) {
- return v, thinking
+ for _, rule := range modelRules {
+ if strings.Contains(lower, rule.pattern) {
+ return rule.target, thinking
}
}
@@ -93,6 +98,7 @@ type ClaudeMessage struct {
type ClaudeContentBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
+ Thinking string `json:"thinking,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input interface{} `json:"input,omitempty"`
@@ -145,24 +151,6 @@ func ClaudeToKiro(req *ClaudeRequest, thinking bool) *KiroPayload {
systemPrompt = ThinkingModePrompt + "\n\n" + systemPrompt
}
- // 注入时间戳
- timestamp := time.Now().Format(time.RFC3339)
- systemPrompt = "[Context: Current time is " + timestamp + "]\n\n" + systemPrompt
-
- // 注入执行导向指令(防止 AI 在探索过程中丢失目标)
- executionDirective := `
-
-当用户要求执行特定任务时,你必须遵循以下纪律:
-1. **目标锁定**:在整个会话中始终牢记用户的原始目标,不要在代码探索过程中迷失方向
-2. **行动优先**:优先执行任务而非仅分析或总结,除非用户明确只要求分析
-3. **计划执行**:为任务创建明确的步骤计划,逐步执行并标记完成状态
-4. **禁止确认性收尾**:在任务未完成前,禁止输出"需要我继续吗?"、"需要深入分析吗?"等确认性问题
-5. **持续推进**:如果发现部分任务已完成,立即继续执行剩余未完成的任务
-6. **完整交付**:直到所有任务步骤都执行完毕才算完成
-
-`
- systemPrompt = systemPrompt + "\n\n" + executionDirective
-
// 构建历史消息
history := make([]KiroHistoryMessage, 0)
var currentContent string
@@ -174,6 +162,7 @@ func ClaudeToKiro(req *ClaudeRequest, thinking bool) *KiroPayload {
if msg.Role == "user" {
content, images, toolResults := extractClaudeUserContent(msg.Content)
+ content = normalizeUserContent(content, len(images) > 0)
if isLast {
currentContent = content
@@ -226,10 +215,12 @@ func ClaudeToKiro(req *ClaudeRequest, thinking bool) *KiroPayload {
}
if currentContent != "" {
finalContent += currentContent
+ } else if len(currentImages) > 0 {
+ finalContent += normalizeUserContent("", true)
} else if len(currentToolResults) > 0 {
- finalContent += "Tool results provided."
+ finalContent += buildToolResultsContinuation(currentToolResults)
} else {
- finalContent += "Continue"
+ finalContent += minimalFallbackUserContent
}
// 转换工具
@@ -238,7 +229,7 @@ func ClaudeToKiro(req *ClaudeRequest, thinking bool) *KiroPayload {
// 构建 payload
payload := &KiroPayload{}
payload.ConversationState.ChatTriggerType = "MANUAL"
- payload.ConversationState.ConversationID = uuid.New().String()
+ payload.ConversationState.ConversationID = buildConversationID(modelID, systemPrompt, firstClaudeConversationAnchor(req.Messages))
payload.ConversationState.CurrentMessage.UserInputMessage = KiroUserInputMessage{
Content: finalContent,
ModelID: modelID,
@@ -307,24 +298,13 @@ func extractClaudeUserContent(content interface{}) (string, []KiroImage, []KiroT
blockType, _ := block["type"].(string)
switch blockType {
- case "text":
+ case "text", "input_text":
if t, ok := block["text"].(string); ok {
text += t
}
- case "image":
- if source, ok := block["source"].(map[string]interface{}); ok {
- mediaType, _ := source["media_type"].(string)
- data, _ := source["data"].(string)
- format := strings.TrimPrefix(mediaType, "image/")
- if format == "jpg" {
- format = "jpeg"
- }
- images = append(images, KiroImage{
- Format: format,
- Source: struct {
- Bytes string `json:"bytes"`
- }{Bytes: data},
- })
+ case "image", "image_url", "input_image":
+ if img := extractImageFromClaudeBlock(block); img != nil {
+ images = append(images, *img)
}
case "tool_result":
toolUseID, _ := block["tool_use_id"].(string)
@@ -341,6 +321,44 @@ func extractClaudeUserContent(content interface{}) (string, []KiroImage, []KiroT
return text, images, toolResults
}
+func extractImageFromClaudeBlock(block map[string]interface{}) *KiroImage {
+ if source, ok := block["source"].(map[string]interface{}); ok {
+ if data, ok := source["data"].(string); ok {
+ if img := parseDataURL(data); img != nil {
+ return img
+ }
+ mediaType, _ := source["media_type"].(string)
+ if mediaType == "" {
+ mediaType, _ = source["mediaType"].(string)
+ }
+ if mediaType == "" {
+ mediaType, _ = source["mime_type"].(string)
+ }
+ format := strings.TrimPrefix(strings.ToLower(mediaType), "image/")
+ if img := parseBase64Image(data, format); img != nil {
+ return img
+ }
+ }
+ if url, ok := source["url"].(string); ok {
+ if img := parseDataURL(url); img != nil {
+ return img
+ }
+ }
+ }
+
+ if img := extractImageFromOpenAIPart(block); img != nil {
+ return img
+ }
+
+ if data, ok := block["data"].(string); ok {
+ if img := parseDataURL(data); img != nil {
+ return img
+ }
+ }
+
+ return nil
+}
+
func extractToolResultContent(content interface{}) string {
if s, ok := content.(string); ok {
return s
@@ -396,10 +414,6 @@ func extractClaudeAssistantContent(content interface{}) (string, []KiroToolUse)
}
}
- if text == "" && len(toolUses) > 0 {
- text = "Using tools."
- }
-
return text, toolUses
}
@@ -441,9 +455,16 @@ func shortenToolName(name string) string {
// ==================== Kiro -> Claude 转换 ====================
-func KiroToClaudeResponse(content string, toolUses []KiroToolUse, inputTokens, outputTokens int, model string) *ClaudeResponse {
+func KiroToClaudeResponse(content, thinkingContent string, toolUses []KiroToolUse, inputTokens, outputTokens int, model string) *ClaudeResponse {
blocks := make([]ClaudeContentBlock, 0)
+ if thinkingContent != "" {
+ blocks = append(blocks, ClaudeContentBlock{
+ Type: "thinking",
+ Thinking: thinkingContent,
+ })
+ }
+
if content != "" {
blocks = append(blocks, ClaudeContentBlock{
Type: "text",
@@ -549,7 +570,7 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload {
for _, msg := range req.Messages {
if msg.Role == "system" {
- if s, ok := msg.Content.(string); ok {
+ if s := extractOpenAIMessageText(msg.Content); s != "" {
systemPrompt += s + "\n"
}
} else {
@@ -562,24 +583,6 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload {
systemPrompt = ThinkingModePrompt + "\n\n" + systemPrompt
}
- // 注入时间戳
- timestamp := time.Now().Format(time.RFC3339)
- systemPrompt = "[Context: Current time is " + timestamp + "]\n\n" + systemPrompt
-
- // 注入执行导向指令(防止 AI 在探索过程中丢失目标)
- executionDirective := `
-
-当用户要求执行特定任务时,你必须遵循以下纪律:
-1. **目标锁定**:在整个会话中始终牢记用户的原始目标,不要在代码探索过程中迷失方向
-2. **行动优先**:优先执行任务而非仅分析或总结,除非用户明确只要求分析
-3. **计划执行**:为任务创建明确的步骤计划,逐步执行并标记完成状态
-4. **禁止确认性收尾**:在任务未完成前,禁止输出"需要我继续吗?"、"需要深入分析吗?"等确认性问题
-5. **持续推进**:如果发现部分任务已完成,立即继续执行剩余未完成的任务
-6. **完整交付**:直到所有任务步骤都执行完毕才算完成
-
-`
- systemPrompt = systemPrompt + "\n\n" + executionDirective
-
// 构建历史消息
history := make([]KiroHistoryMessage, 0)
var currentContent string
@@ -593,6 +596,7 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload {
switch msg.Role {
case "user":
content, images := extractOpenAIUserContent(msg.Content)
+ content = normalizeUserContent(content, len(images) > 0)
// 第一条 user 消息合并 system prompt
if !systemMerged && systemPrompt != "" {
@@ -615,10 +619,7 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload {
}
case "assistant":
- content, _ := msg.Content.(string)
- if content == "" && len(msg.ToolCalls) > 0 {
- content = "Using tools."
- }
+ content := extractOpenAIMessageText(msg.Content)
var toolUses []KiroToolUse
for _, tc := range msg.ToolCalls {
@@ -642,7 +643,7 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload {
})
case "tool":
- content, _ := msg.Content.(string)
+ content := extractOpenAIMessageText(msg.Content)
currentToolResults = append(currentToolResults, KiroToolResult{
ToolUseID: msg.ToolCallID,
Content: []KiroResultContent{{Text: content}},
@@ -655,7 +656,7 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload {
if !isLast {
history = append(history, KiroHistoryMessage{
UserInputMessage: &KiroUserInputMessage{
- Content: "Tool results provided.",
+ Content: buildToolResultsContinuation(currentToolResults),
ModelID: modelID,
Origin: origin,
UserInputMessageContext: &UserInputMessageContext{
@@ -672,10 +673,12 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload {
// 构建最终内容
finalContent := currentContent
if finalContent == "" {
- if len(currentToolResults) > 0 {
- finalContent = "Tool results provided."
+ if len(currentImages) > 0 {
+ finalContent = normalizeUserContent("", true)
+ } else if len(currentToolResults) > 0 {
+ finalContent = buildToolResultsContinuation(currentToolResults)
} else {
- finalContent = "Continue"
+ finalContent = minimalFallbackUserContent
}
}
if !systemMerged && systemPrompt != "" {
@@ -688,7 +691,7 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload {
// 构建 payload
payload := &KiroPayload{}
payload.ConversationState.ChatTriggerType = "MANUAL"
- payload.ConversationState.ConversationID = uuid.New().String()
+ payload.ConversationState.ConversationID = buildConversationID(modelID, systemPrompt, firstOpenAIConversationAnchor(nonSystemMessages))
payload.ConversationState.CurrentMessage.UserInputMessage = KiroUserInputMessage{
Content: finalContent,
ModelID: modelID,
@@ -726,6 +729,15 @@ func extractOpenAIUserContent(content interface{}) (string, []KiroImage) {
var text string
var images []KiroImage
+ if part, ok := content.(map[string]interface{}); ok {
+ if t, ok := extractOpenAITextPart(part); ok {
+ text += t
+ }
+ if img := extractImageFromOpenAIPart(part); img != nil {
+ images = append(images, *img)
+ }
+ }
+
if parts, ok := content.([]interface{}); ok {
for _, p := range parts {
part, ok := p.(map[string]interface{})
@@ -733,50 +745,301 @@ func extractOpenAIUserContent(content interface{}) (string, []KiroImage) {
continue
}
- partType, _ := part["type"].(string)
- switch partType {
- case "text":
- if t, ok := part["text"].(string); ok {
- text += t
- }
- case "image_url":
- if imgUrl, ok := part["image_url"].(map[string]interface{}); ok {
- if url, ok := imgUrl["url"].(string); ok {
- if img := parseDataURL(url); img != nil {
- images = append(images, *img)
- }
- }
- }
+ if t, ok := extractOpenAITextPart(part); ok {
+ text += t
+ }
+ if img := extractImageFromOpenAIPart(part); img != nil {
+ images = append(images, *img)
}
}
}
+ if len(images) > 0 {
+ text = sanitizeImagePlaceholders(text)
+ }
+
return text, images
}
+func extractOpenAIMessageText(content interface{}) string {
+ if content == nil {
+ return ""
+ }
+
+ if s, ok := content.(string); ok {
+ return s
+ }
+
+ if text, _ := extractOpenAIUserContent(content); strings.TrimSpace(text) != "" {
+ return text
+ }
+
+ switch v := content.(type) {
+ case map[string]interface{}:
+ if nested, ok := v["content"]; ok {
+ if nestedText := extractOpenAIMessageText(nested); strings.TrimSpace(nestedText) != "" {
+ return nestedText
+ }
+ }
+ if raw, err := json.Marshal(v); err == nil {
+ return string(raw)
+ }
+ case []interface{}:
+ parts := make([]string, 0, len(v))
+ for _, item := range v {
+ partText := extractOpenAIMessageText(item)
+ if strings.TrimSpace(partText) != "" {
+ parts = append(parts, partText)
+ }
+ }
+ if len(parts) > 0 {
+ return strings.Join(parts, "")
+ }
+ if raw, err := json.Marshal(v); err == nil {
+ return string(raw)
+ }
+ default:
+ if raw, err := json.Marshal(v); err == nil {
+ return string(raw)
+ }
+ }
+
+ return ""
+}
+
+func buildToolResultsContinuation(toolResults []KiroToolResult) string {
+ if len(toolResults) == 0 {
+ return minimalFallbackUserContent
+ }
+
+ parts := make([]string, 0, len(toolResults))
+ for _, tr := range toolResults {
+ if len(tr.Content) == 0 {
+ continue
+ }
+ for _, c := range tr.Content {
+ if strings.TrimSpace(c.Text) != "" {
+ parts = append(parts, c.Text)
+ }
+ }
+ }
+
+ if len(parts) == 0 {
+ return minimalFallbackUserContent
+ }
+
+ joined := strings.Join(parts, "\n\n")
+ if len(joined) > 4000 {
+ return joined[:4000]
+ }
+ return joined
+}
+
+func firstClaudeConversationAnchor(messages []ClaudeMessage) string {
+ for _, msg := range messages {
+ if msg.Role != "user" {
+ continue
+ }
+ text, _, toolResults := extractClaudeUserContent(msg.Content)
+ if strings.TrimSpace(text) != "" {
+ return strings.TrimSpace(text)
+ }
+ if len(toolResults) > 0 {
+ return buildToolResultsContinuation(toolResults)
+ }
+ }
+
+ for _, msg := range messages {
+ if strings.TrimSpace(msg.Role) != "" {
+ if text := extractOpenAIMessageText(msg.Content); strings.TrimSpace(text) != "" {
+ return strings.TrimSpace(text)
+ }
+ }
+ }
+
+ return ""
+}
+
+func firstOpenAIConversationAnchor(messages []OpenAIMessage) string {
+ for _, msg := range messages {
+ if msg.Role != "user" {
+ continue
+ }
+ text := extractOpenAIMessageText(msg.Content)
+ if strings.TrimSpace(text) != "" {
+ return strings.TrimSpace(text)
+ }
+ }
+
+ for _, msg := range messages {
+ text := extractOpenAIMessageText(msg.Content)
+ if strings.TrimSpace(text) != "" {
+ return strings.TrimSpace(text)
+ }
+ }
+
+ return ""
+}
+
+func buildConversationID(modelID, systemPrompt, anchor string) string {
+ anchor = strings.TrimSpace(anchor)
+ if anchor == "" {
+ return uuid.New().String()
+ }
+ seed := strings.Join([]string{modelID, strings.TrimSpace(systemPrompt), anchor}, "\n")
+ return uuid.NewSHA1(uuid.NameSpaceURL, []byte(seed)).String()
+}
+
+func extractOpenAITextPart(part map[string]interface{}) (string, bool) {
+ partType, _ := part["type"].(string)
+ switch partType {
+ case "text", "input_text":
+ if t, ok := part["text"].(string); ok {
+ return t, true
+ }
+ }
+
+ if t, ok := part["text"].(string); ok {
+ return t, true
+ }
+
+ return "", false
+}
+
+func extractImageFromOpenAIPart(part map[string]interface{}) *KiroImage {
+ partType, _ := part["type"].(string)
+ if partType != "" {
+ switch partType {
+ case "image", "image_url", "input_image", "file", "input_file":
+ default:
+ return nil
+ }
+ }
+
+ if fileObj, ok := part["file"].(map[string]interface{}); ok {
+ if img := extractImageFromOpenAIPart(fileObj); img != nil {
+ return img
+ }
+ }
+
+ if sourceObj, ok := part["source"].(map[string]interface{}); ok {
+ if img := extractImageFromOpenAIPart(sourceObj); img != nil {
+ return img
+ }
+ }
+
+ if raw, ok := part["mime"].(string); ok && !strings.HasPrefix(strings.ToLower(raw), "image/") {
+ return nil
+ }
+ if raw, ok := part["media_type"].(string); ok && !strings.HasPrefix(strings.ToLower(raw), "image/") {
+ return nil
+ }
+ if raw, ok := part["mime_type"].(string); ok && !strings.HasPrefix(strings.ToLower(raw), "image/") {
+ return nil
+ }
+
+ if raw, ok := part["url"].(string); ok {
+ if img := parseDataURL(raw); img != nil {
+ return img
+ }
+ }
+
+ if raw, ok := part["b64_json"].(string); ok {
+ if img := parseBase64Image(raw, "png"); img != nil {
+ return img
+ }
+ }
+
+ if raw, ok := part["image_url"]; ok {
+ switch v := raw.(type) {
+ case string:
+ if img := parseDataURL(v); img != nil {
+ return img
+ }
+ case map[string]interface{}:
+ if u, ok := v["url"].(string); ok {
+ if img := parseDataURL(u); img != nil {
+ return img
+ }
+ }
+ }
+ }
+
+ if raw, ok := part["image_base64"].(string); ok {
+ if img := parseBase64Image(raw, "png"); img != nil {
+ return img
+ }
+ }
+ if raw, ok := part["data"].(string); ok {
+ if img := parseDataURL(raw); img != nil {
+ return img
+ }
+ if img := parseBase64Image(raw, "png"); img != nil {
+ return img
+ }
+ }
+
+ return nil
+}
+
+func sanitizeImagePlaceholders(text string) string {
+ re := regexp.MustCompile(`\[Image\s+\d+\]`)
+ cleaned := re.ReplaceAllString(text, "")
+ cleaned = strings.Join(strings.Fields(cleaned), " ")
+ return strings.TrimSpace(cleaned)
+}
+
+func normalizeUserContent(text string, hasImages bool) string {
+ trimmed := strings.TrimSpace(text)
+ if trimmed == "" && hasImages {
+ return "Please analyze the attached image."
+ }
+ return trimmed
+}
+
func parseDataURL(url string) *KiroImage {
- // data:image/png;base64,xxxxx
- re := regexp.MustCompile(`^data:image/(\w+);base64,(.+)$`)
- matches := re.FindStringSubmatch(url)
+ cleaned := strings.TrimSpace(strings.ReplaceAll(strings.ReplaceAll(url, "\n", ""), "\r", ""))
+ if strings.Contains(cleaned, "[Image") {
+ return nil
+ }
+ re := regexp.MustCompile(`^data:image/([a-zA-Z0-9+.-]+)(;[a-zA-Z0-9=._:+-]+)*;base64,(.+)$`)
+ matches := re.FindStringSubmatch(cleaned)
+ if len(matches) == 4 {
+ return parseBase64Image(matches[3], matches[1])
+ }
if len(matches) != 3 {
return nil
}
- format := matches[1]
+ return parseBase64Image(matches[2], matches[1])
+}
+
+func parseBase64Image(data, format string) *KiroImage {
+ format = strings.ToLower(format)
if format == "jpg" {
format = "jpeg"
}
// 验证 base64
- if _, err := base64.StdEncoding.DecodeString(matches[2]); err != nil {
- return nil
+ if _, err := base64.StdEncoding.DecodeString(data); err != nil {
+ if _, errRaw := base64.RawStdEncoding.DecodeString(data); errRaw != nil {
+ if _, errURL := base64.URLEncoding.DecodeString(data); errURL != nil {
+ if _, errRawURL := base64.RawURLEncoding.DecodeString(data); errRawURL != nil {
+ return nil
+ }
+ }
+ }
+ }
+
+ if format == "" {
+ format = "png"
}
return &KiroImage{
Format: format,
Source: struct {
Bytes string `json:"bytes"`
- }{Bytes: matches[2]},
+ }{Bytes: data},
}
}
diff --git a/proxy/translator_test.go b/proxy/translator_test.go
new file mode 100644
index 0000000..c650081
--- /dev/null
+++ b/proxy/translator_test.go
@@ -0,0 +1,198 @@
+package proxy
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestExtractOpenAIMessageTextStructured(t *testing.T) {
+ content := []interface{}{
+ map[string]interface{}{"type": "text", "text": "alpha"},
+ map[string]interface{}{"type": "input_text", "text": "beta"},
+ }
+
+ if got := extractOpenAIMessageText(content); got != "alphabeta" {
+ t.Fatalf("expected concatenated structured text, got %q", got)
+ }
+
+ nested := map[string]interface{}{
+ "content": []interface{}{map[string]interface{}{"type": "text", "text": "nested"}},
+ }
+ if got := extractOpenAIMessageText(nested); got != "nested" {
+ t.Fatalf("expected nested content extraction, got %q", got)
+ }
+}
+
+func TestOpenAIToKiroPreservesStructuredAssistantAndToolContent(t *testing.T) {
+ req := &OpenAIRequest{
+ Model: "claude-sonnet-4.5",
+ Messages: []OpenAIMessage{
+ {
+ Role: "system",
+ Content: []interface{}{
+ map[string]interface{}{"type": "text", "text": "system-a"},
+ map[string]interface{}{"type": "text", "text": "system-b"},
+ },
+ },
+ {Role: "user", Content: "first-question"},
+ {
+ Role: "assistant",
+ Content: []interface{}{
+ map[string]interface{}{"type": "text", "text": "assistant-structured"},
+ },
+ },
+ {
+ Role: "tool",
+ ToolCallID: "call_1",
+ Content: []interface{}{
+ map[string]interface{}{"type": "text", "text": "tool-result-structured"},
+ },
+ },
+ },
+ }
+
+ payload := OpenAIToKiro(req, false)
+
+ if len(payload.ConversationState.History) != 2 {
+ t.Fatalf("expected 2 history items, got %d", len(payload.ConversationState.History))
+ }
+
+ firstHistoryUser := payload.ConversationState.History[0].UserInputMessage
+ if firstHistoryUser == nil {
+ t.Fatalf("expected first history item to be user message")
+ }
+ if !strings.Contains(firstHistoryUser.Content, "system-a") ||
+ !strings.Contains(firstHistoryUser.Content, "system-b") ||
+ !strings.Contains(firstHistoryUser.Content, "first-question") {
+ t.Fatalf("expected merged system+user content, got %q", firstHistoryUser.Content)
+ }
+
+ historyAssistant := payload.ConversationState.History[1].AssistantResponseMessage
+ if historyAssistant == nil {
+ t.Fatalf("expected second history item to be assistant message")
+ }
+ if historyAssistant.Content != "assistant-structured" {
+ t.Fatalf("expected assistant structured content to be preserved, got %q", historyAssistant.Content)
+ }
+
+ cur := payload.ConversationState.CurrentMessage.UserInputMessage
+ if cur.Content != "tool-result-structured" {
+ t.Fatalf("expected tool-result continuation content, got %q", cur.Content)
+ }
+ if cur.UserInputMessageContext == nil || len(cur.UserInputMessageContext.ToolResults) != 1 {
+ t.Fatalf("expected one tool result in current context")
+ }
+ gotToolText := cur.UserInputMessageContext.ToolResults[0].Content[0].Text
+ if gotToolText != "tool-result-structured" {
+ t.Fatalf("expected structured tool result text, got %q", gotToolText)
+ }
+}
+
+func TestOpenAIToKiroAssistantMapContentInHistory(t *testing.T) {
+ req := &OpenAIRequest{
+ Model: "claude-sonnet-4.5",
+ Messages: []OpenAIMessage{
+ {Role: "user", Content: "u1"},
+ {Role: "assistant", Content: map[string]interface{}{"type": "text", "text": "assistant-map"}},
+ {Role: "user", Content: "u2"},
+ },
+ }
+
+ payload := OpenAIToKiro(req, false)
+
+ if len(payload.ConversationState.History) != 2 {
+ t.Fatalf("expected 2 history entries, got %d", len(payload.ConversationState.History))
+ }
+ assistant := payload.ConversationState.History[1].AssistantResponseMessage
+ if assistant == nil {
+ t.Fatalf("expected second history entry to be assistant")
+ }
+ if assistant.Content != "assistant-map" {
+ t.Fatalf("expected assistant map content preserved, got %q", assistant.Content)
+ }
+}
+
+func TestOpenAIToKiroAssistantToolCallsDoNotInjectPlaceholder(t *testing.T) {
+ req := &OpenAIRequest{
+ Model: "claude-sonnet-4.5",
+ Messages: []OpenAIMessage{
+ {Role: "user", Content: "find weather"},
+ {
+ Role: "assistant",
+ Content: nil,
+ ToolCalls: []ToolCall{{
+ ID: "call_1",
+ Type: "function",
+ Function: struct {
+ Name string `json:"name"`
+ Arguments string `json:"arguments"`
+ }{Name: "get_weather", Arguments: "{}"},
+ }},
+ },
+ {Role: "user", Content: "continue"},
+ },
+ }
+
+ payload := OpenAIToKiro(req, false)
+ if len(payload.ConversationState.History) < 2 {
+ t.Fatalf("expected history with assistant tool call")
+ }
+ assistant := payload.ConversationState.History[1].AssistantResponseMessage
+ if assistant == nil {
+ t.Fatalf("expected assistant history entry")
+ }
+ if assistant.Content != "" {
+ t.Fatalf("expected empty assistant content for tool-call-only turn, got %q", assistant.Content)
+ }
+}
+
+func TestOpenAIConversationIDStableFromAnchor(t *testing.T) {
+ baseMessages := []OpenAIMessage{
+ {Role: "system", Content: "You are helpful"},
+ {Role: "user", Content: "Build calculator"},
+ {Role: "assistant", Content: "Sure"},
+ {Role: "user", Content: "Continue"},
+ }
+
+ reqA := &OpenAIRequest{Model: "claude-sonnet-4.5", Messages: baseMessages}
+ reqB := &OpenAIRequest{Model: "claude-sonnet-4.5", Messages: append(baseMessages, OpenAIMessage{Role: "assistant", Content: "Next step"})}
+
+ payloadA := OpenAIToKiro(reqA, false)
+ payloadB := OpenAIToKiro(reqB, false)
+
+ if payloadA.ConversationState.ConversationID == "" || payloadB.ConversationState.ConversationID == "" {
+ t.Fatalf("expected non-empty conversation IDs")
+ }
+ if payloadA.ConversationState.ConversationID != payloadB.ConversationState.ConversationID {
+ t.Fatalf("expected stable conversation ID across turns, got %q vs %q", payloadA.ConversationState.ConversationID, payloadB.ConversationState.ConversationID)
+ }
+}
+
+func TestClaudeConversationIDStableFromAnchor(t *testing.T) {
+ reqA := &ClaudeRequest{
+ Model: "claude-sonnet-4.5",
+ System: "sys",
+ Messages: []ClaudeMessage{
+ {Role: "user", Content: "hello"},
+ },
+ }
+ reqB := &ClaudeRequest{
+ Model: "claude-sonnet-4.5",
+ System: "sys",
+ Messages: []ClaudeMessage{
+ {Role: "user", Content: "hello"},
+ {Role: "assistant", Content: "ok"},
+ {Role: "user", Content: "next"},
+ },
+ }
+
+ payloadA := ClaudeToKiro(reqA, false)
+ payloadB := ClaudeToKiro(reqB, false)
+
+ if payloadA.ConversationState.ConversationID == "" || payloadB.ConversationState.ConversationID == "" {
+ t.Fatalf("expected non-empty conversation IDs")
+ }
+ if payloadA.ConversationState.ConversationID != payloadB.ConversationState.ConversationID {
+ t.Fatalf("expected stable conversation ID across turns, got %q vs %q", payloadA.ConversationState.ConversationID, payloadB.ConversationState.ConversationID)
+ }
+}