162 lines
4.9 KiB
Go
162 lines
4.9 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||
"github.com/tidwall/gjson"
|
||
"github.com/tidwall/sjson"
|
||
)
|
||
|
||
// stripMessageCacheControl 移除 $.messages[*].content[*].cache_control。
|
||
// 与 Parrot _strip_message_cache_control 语义一致。
|
||
//
|
||
// 旧策略为什么整体清空:客户端(特别是 Claude Code)经常把 cache_control 打在
|
||
// "当前最后一条 user message" 上;下一轮对话 messages 追加后,原本的最后一条
|
||
// 变成中间某条,cache_control 还挂着就导致"前缀签名变化",破坏缓存命中。
|
||
// 统一由代理重新打断点(addMessageCacheBreakpoints)才能在多轮间稳定。
|
||
func stripMessageCacheControl(body []byte) []byte {
|
||
messages := gjson.GetBytes(body, "messages")
|
||
if !messages.IsArray() {
|
||
return body
|
||
}
|
||
msgIdx := -1
|
||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||
msgIdx++
|
||
content := msg.Get("content")
|
||
if !content.IsArray() {
|
||
return true
|
||
}
|
||
blockIdx := -1
|
||
content.ForEach(func(_, block gjson.Result) bool {
|
||
blockIdx++
|
||
if !block.Get("cache_control").Exists() {
|
||
return true
|
||
}
|
||
path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIdx, blockIdx)
|
||
if next, err := sjson.DeleteBytes(body, path); err == nil {
|
||
body = next
|
||
}
|
||
return true
|
||
})
|
||
return true
|
||
})
|
||
return body
|
||
}
|
||
|
||
// addMessageCacheBreakpoints 在 messages 上注入两个稳定的 cache 断点:
|
||
// 1. 最后一条 message
|
||
// 2. 当 messages 数量 ≥ 4 时,倒数第二个 role=user 的 message
|
||
//
|
||
// 与 Parrot add_cache_breakpoints 一致。两个断点 + system prompt block 的断点
|
||
// + tools[-1] 的断点共同构成最多 4 个断点(Anthropic 上限)。
|
||
//
|
||
// cache_control ttl 策略:
|
||
// - 若目标 block 已有 cache_control.ttl → 不覆盖
|
||
// - 否则写入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL}
|
||
//
|
||
// 调用前应先 stripMessageCacheControl 以保证幂等和稳定。
|
||
func addMessageCacheBreakpoints(body []byte) []byte {
|
||
messages := gjson.GetBytes(body, "messages")
|
||
if !messages.IsArray() {
|
||
return body
|
||
}
|
||
arr := messages.Array()
|
||
if len(arr) == 0 {
|
||
return body
|
||
}
|
||
|
||
body = injectCacheControlOnLastContentBlock(body, len(arr)-1, &arr[len(arr)-1])
|
||
|
||
if len(arr) >= 4 {
|
||
userCount := 0
|
||
for i := len(arr) - 1; i >= 0; i-- {
|
||
if arr[i].Get("role").String() != "user" {
|
||
continue
|
||
}
|
||
userCount++
|
||
if userCount == 2 {
|
||
body = injectCacheControlOnLastContentBlock(body, i, &arr[i])
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
return body
|
||
}
|
||
|
||
// rewriteMessageCacheControlIfEnabled 按系统设置决定是否执行旧版 messages 缓存断点改写。
|
||
func (s *GatewayService) rewriteMessageCacheControlIfEnabled(ctx context.Context, body []byte) []byte {
|
||
if s == nil || !s.isRewriteMessageCacheControlEnabled(ctx) {
|
||
return body
|
||
}
|
||
body = stripMessageCacheControl(body)
|
||
return addMessageCacheBreakpoints(body)
|
||
}
|
||
|
||
func (s *GatewayService) isRewriteMessageCacheControlEnabled(ctx context.Context) bool {
|
||
if s == nil {
|
||
return false
|
||
}
|
||
if s.settingService != nil {
|
||
return s.settingService.IsRewriteMessageCacheControlEnabled(ctx)
|
||
}
|
||
return false
|
||
}
|
||
|
||
// injectCacheControlOnLastContentBlock 把 cache_control 断点打在 messages[idx]
|
||
// 的最后一个 content block 上。若 content 是 string,先升级成单块 text 数组
|
||
// (对齐 Parrot _inject_cache_on_msg 的行为)。
|
||
//
|
||
// msg 是调用方已持有的 gjson.Result 快照,用于省一次 GetBytes。
|
||
func injectCacheControlOnLastContentBlock(body []byte, idx int, msg *gjson.Result) []byte {
|
||
content := msg.Get("content")
|
||
|
||
if content.Type == gjson.String {
|
||
text := content.String()
|
||
blockRaw := fmt.Sprintf(
|
||
`[{"type":"text","text":%s,"cache_control":{"type":"ephemeral","ttl":%q}}]`,
|
||
mustJSONString(text), claude.DefaultCacheControlTTL,
|
||
)
|
||
if next, err := sjson.SetRawBytes(body, fmt.Sprintf("messages.%d.content", idx), []byte(blockRaw)); err == nil {
|
||
body = next
|
||
}
|
||
return body
|
||
}
|
||
|
||
if !content.IsArray() {
|
||
return body
|
||
}
|
||
contentArr := content.Array()
|
||
if len(contentArr) == 0 {
|
||
return body
|
||
}
|
||
lastBlockIdx := len(contentArr) - 1
|
||
lastBlock := contentArr[lastBlockIdx]
|
||
|
||
if cc := lastBlock.Get("cache_control"); cc.Exists() && cc.Get("ttl").String() != "" {
|
||
return body
|
||
}
|
||
|
||
pathPrefix := fmt.Sprintf("messages.%d.content.%d.cache_control", idx, lastBlockIdx)
|
||
existingCC := lastBlock.Get("cache_control")
|
||
if existingCC.Exists() {
|
||
if next, err := sjson.SetBytes(body, pathPrefix+".ttl", claude.DefaultCacheControlTTL); err == nil {
|
||
body = next
|
||
}
|
||
return body
|
||
}
|
||
raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL)
|
||
if next, err := sjson.SetRawBytes(body, pathPrefix, []byte(raw)); err == nil {
|
||
body = next
|
||
}
|
||
return body
|
||
}
|
||
|
||
// mustJSONString 把一个 Go string 序列化为合法 JSON string(含引号),
|
||
// 用于 sjson.SetRawBytes 场景下手工拼 JSON。
|
||
func mustJSONString(s string) string {
|
||
return fmt.Sprintf("%q", s)
|
||
}
|