From c6a456c7c7d06fa2197c31ebf76aae669fc7b780 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Fri, 6 Feb 2026 08:42:55 +0800 Subject: [PATCH] =?UTF-8?q?fix(=E5=85=BC=E5=AE=B9):=20=E5=B0=86=20Kimi=20c?= =?UTF-8?q?ached=5Ftokens=20=E6=98=A0=E5=B0=84=E5=88=B0=20Claude=20?= =?UTF-8?q?=E6=A0=87=E5=87=86=20cache=5Fread=5Finput=5Ftokens?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Kimi 等 Claude 兼容 API 返回缓存信息使用 OpenAI 风格的 cached_tokens 字段, 而非 Claude 标准的 cache_read_input_tokens,导致客户端收不到缓存命中信息且 内部计费缓存折扣为 0。 新增 reconcileCachedTokens 辅助函数,在 cache_read_input_tokens == 0 且 cached_tokens > 0 时自动填充,覆盖流式(message_start/message_delta)和 非流式两种响应路径。对 Claude 原生上游无影响。 Co-Authored-By: Claude Opus 4.6 --- .../service/gateway_cached_tokens_test.go | 300 ++++++++++++++++++ backend/internal/service/gateway_service.go | 43 +++ 2 files changed, 343 insertions(+) create mode 100644 backend/internal/service/gateway_cached_tokens_test.go diff --git a/backend/internal/service/gateway_cached_tokens_test.go b/backend/internal/service/gateway_cached_tokens_test.go new file mode 100644 index 00000000..a51e928c --- /dev/null +++ b/backend/internal/service/gateway_cached_tokens_test.go @@ -0,0 +1,300 @@ +package service + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// ---------- reconcileCachedTokens 单元测试 ---------- + +func TestReconcileCachedTokens_NilUsage(t *testing.T) { + assert.False(t, reconcileCachedTokens(nil)) +} + +func TestReconcileCachedTokens_AlreadyHasCacheRead(t *testing.T) { + // 已有标准字段,不应覆盖 + usage := map[string]any{ + "cache_read_input_tokens": float64(100), + "cached_tokens": float64(50), + } + assert.False(t, reconcileCachedTokens(usage)) + assert.Equal(t, float64(100), usage["cache_read_input_tokens"]) +} + +func TestReconcileCachedTokens_KimiStyle(t *testing.T) { + // Kimi 风格:cache_read_input_tokens=0,cached_tokens>0 + usage := map[string]any{ + "input_tokens": float64(23), + "cache_creation_input_tokens": float64(0), + "cache_read_input_tokens": float64(0), + "cached_tokens": float64(23), + } + assert.True(t, reconcileCachedTokens(usage)) + assert.Equal(t, float64(23), usage["cache_read_input_tokens"]) +} + +func TestReconcileCachedTokens_NoCachedTokens(t *testing.T) { + // 无 cached_tokens 字段(原生 Claude) + usage := map[string]any{ + "input_tokens": float64(100), + "cache_read_input_tokens": float64(0), + "cache_creation_input_tokens": float64(0), + } + assert.False(t, reconcileCachedTokens(usage)) + assert.Equal(t, float64(0), usage["cache_read_input_tokens"]) +} + +func TestReconcileCachedTokens_CachedTokensZero(t *testing.T) { + // cached_tokens 为 0,不应覆盖 + usage := map[string]any{ + "cache_read_input_tokens": float64(0), + "cached_tokens": float64(0), + } + assert.False(t, reconcileCachedTokens(usage)) + assert.Equal(t, float64(0), usage["cache_read_input_tokens"]) +} + +func TestReconcileCachedTokens_MissingCacheReadField(t *testing.T) { + // cache_read_input_tokens 字段完全不存在,cached_tokens > 0 + usage := map[string]any{ + "cached_tokens": float64(42), + } + assert.True(t, reconcileCachedTokens(usage)) + assert.Equal(t, float64(42), usage["cache_read_input_tokens"]) +} + +// ---------- 流式 message_start 事件 reconcile 测试 ---------- + +func TestStreamingReconcile_MessageStart(t *testing.T) { + // 模拟 Kimi 返回的 message_start SSE 事件 + eventJSON := `{ + "type": "message_start", + "message": { + "id": "msg_123", + "type": "message", + "role": "assistant", + "model": "kimi", + "usage": { + "input_tokens": 23, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "cached_tokens": 23 + } + } + }` + + var event map[string]any + require.NoError(t, json.Unmarshal([]byte(eventJSON), &event)) + + eventType, _ := event["type"].(string) + require.Equal(t, "message_start", eventType) + + // 模拟 processSSEEvent 中的 reconcile 逻辑 + if msg, ok := event["message"].(map[string]any); ok { + if u, ok := msg["usage"].(map[string]any); ok { + reconcileCachedTokens(u) + } + } + + // 验证 cache_read_input_tokens 已被填充 + msg := event["message"].(map[string]any) + usage := msg["usage"].(map[string]any) + assert.Equal(t, float64(23), usage["cache_read_input_tokens"]) + + // 验证重新序列化后 JSON 也包含正确值 + data, err := json.Marshal(event) + require.NoError(t, err) + assert.Equal(t, int64(23), gjson.GetBytes(data, "message.usage.cache_read_input_tokens").Int()) +} + +func TestStreamingReconcile_MessageStart_NativeClaude(t *testing.T) { + // 原生 Claude 不返回 cached_tokens,reconcile 不应改变任何值 + eventJSON := `{ + "type": "message_start", + "message": { + "usage": { + "input_tokens": 100, + "cache_creation_input_tokens": 50, + "cache_read_input_tokens": 30 + } + } + }` + + var event map[string]any + require.NoError(t, json.Unmarshal([]byte(eventJSON), &event)) + + if msg, ok := event["message"].(map[string]any); ok { + if u, ok := msg["usage"].(map[string]any); ok { + reconcileCachedTokens(u) + } + } + + msg := event["message"].(map[string]any) + usage := msg["usage"].(map[string]any) + assert.Equal(t, float64(30), usage["cache_read_input_tokens"]) +} + +// ---------- 流式 message_delta 事件 reconcile 测试 ---------- + +func TestStreamingReconcile_MessageDelta(t *testing.T) { + // 模拟 Kimi 返回的 message_delta SSE 事件 + eventJSON := `{ + "type": "message_delta", + "usage": { + "output_tokens": 7, + "cache_read_input_tokens": 0, + "cached_tokens": 15 + } + }` + + var event map[string]any + require.NoError(t, json.Unmarshal([]byte(eventJSON), &event)) + + eventType, _ := event["type"].(string) + require.Equal(t, "message_delta", eventType) + + // 模拟 processSSEEvent 中的 reconcile 逻辑 + if u, ok := event["usage"].(map[string]any); ok { + reconcileCachedTokens(u) + } + + usage := event["usage"].(map[string]any) + assert.Equal(t, float64(15), usage["cache_read_input_tokens"]) +} + +func TestStreamingReconcile_MessageDelta_NativeClaude(t *testing.T) { + // 原生 Claude 的 message_delta 通常没有 cached_tokens + eventJSON := `{ + "type": "message_delta", + "usage": { + "output_tokens": 50 + } + }` + + var event map[string]any + require.NoError(t, json.Unmarshal([]byte(eventJSON), &event)) + + if u, ok := event["usage"].(map[string]any); ok { + reconcileCachedTokens(u) + } + + usage := event["usage"].(map[string]any) + _, hasCacheRead := usage["cache_read_input_tokens"] + assert.False(t, hasCacheRead, "不应为原生 Claude 响应注入 cache_read_input_tokens") +} + +// ---------- 非流式响应 reconcile 测试 ---------- + +func TestNonStreamingReconcile_KimiResponse(t *testing.T) { + // 模拟 Kimi 非流式响应 + body := []byte(`{ + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "hello"}], + "model": "kimi", + "usage": { + "input_tokens": 23, + "output_tokens": 7, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "cached_tokens": 23, + "prompt_tokens": 23, + "completion_tokens": 7 + } + }`) + + // 模拟 handleNonStreamingResponse 中的逻辑 + var response struct { + Usage ClaudeUsage `json:"usage"` + } + require.NoError(t, json.Unmarshal(body, &response)) + + // reconcile + if response.Usage.CacheReadInputTokens == 0 { + cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() + if cachedTokens > 0 { + response.Usage.CacheReadInputTokens = int(cachedTokens) + if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil { + body = newBody + } + } + } + + // 验证内部 usage(计费用) + assert.Equal(t, 23, response.Usage.CacheReadInputTokens) + assert.Equal(t, 23, response.Usage.InputTokens) + assert.Equal(t, 7, response.Usage.OutputTokens) + + // 验证返回给客户端的 JSON body + assert.Equal(t, int64(23), gjson.GetBytes(body, "usage.cache_read_input_tokens").Int()) +} + +func TestNonStreamingReconcile_NativeClaude(t *testing.T) { + // 原生 Claude 响应:cache_read_input_tokens 已有值 + body := []byte(`{ + "usage": { + "input_tokens": 100, + "output_tokens": 50, + "cache_creation_input_tokens": 20, + "cache_read_input_tokens": 30 + } + }`) + + var response struct { + Usage ClaudeUsage `json:"usage"` + } + require.NoError(t, json.Unmarshal(body, &response)) + + originalBody := make([]byte, len(body)) + copy(originalBody, body) + + if response.Usage.CacheReadInputTokens == 0 { + cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() + if cachedTokens > 0 { + response.Usage.CacheReadInputTokens = int(cachedTokens) + if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil { + body = newBody + } + } + } + + // 不应修改 + assert.Equal(t, 30, response.Usage.CacheReadInputTokens) +} + +func TestNonStreamingReconcile_NoCachedTokens(t *testing.T) { + // 没有 cached_tokens 字段 + body := []byte(`{ + "usage": { + "input_tokens": 100, + "output_tokens": 50, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0 + } + }`) + + var response struct { + Usage ClaudeUsage `json:"usage"` + } + require.NoError(t, json.Unmarshal(body, &response)) + + if response.Usage.CacheReadInputTokens == 0 { + cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() + if cachedTokens > 0 { + response.Usage.CacheReadInputTokens = int(cachedTokens) + if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil { + body = newBody + } + } + } + + // cache_read_input_tokens 应保持为 0 + assert.Equal(t, 0, response.Usage.CacheReadInputTokens) + assert.Equal(t, int64(0), gjson.GetBytes(body, "usage.cache_read_input_tokens").Int()) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 9aecce22..bbfb1723 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -4176,6 +4176,20 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http eventName = eventType } + // 兼容 Kimi cached_tokens → cache_read_input_tokens + if eventType == "message_start" { + if msg, ok := event["message"].(map[string]any); ok { + if u, ok := msg["usage"].(map[string]any); ok { + reconcileCachedTokens(u) + } + } + } + if eventType == "message_delta" { + if u, ok := event["usage"].(map[string]any); ok { + reconcileCachedTokens(u) + } + } + if needModelReplace { if msg, ok := event["message"].(map[string]any); ok { if model, ok := msg["model"].(string); ok && model == mappedModel { @@ -4526,6 +4540,17 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h return nil, fmt.Errorf("parse response: %w", err) } + // 兼容 Kimi cached_tokens → cache_read_input_tokens + if response.Usage.CacheReadInputTokens == 0 { + cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() + if cachedTokens > 0 { + response.Usage.CacheReadInputTokens = int(cachedTokens) + if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil { + body = newBody + } + } + } + // 如果有模型映射,替换响应中的model字段 if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) @@ -5311,3 +5336,21 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, return models } + +// reconcileCachedTokens 兼容 Kimi 等上游: +// 将 OpenAI 风格的 cached_tokens 映射到 Claude 标准的 cache_read_input_tokens +func reconcileCachedTokens(usage map[string]any) bool { + if usage == nil { + return false + } + cacheRead, _ := usage["cache_read_input_tokens"].(float64) + if cacheRead > 0 { + return false // 已有标准字段,无需处理 + } + cached, _ := usage["cached_tokens"].(float64) + if cached <= 0 { + return false + } + usage["cache_read_input_tokens"] = cached + return true +}