Merge pull request #226 from xilu0/main
feat(gateway): 优化 Antigravity/Gemini 思考块处理 此提交解决了思考块 (thinking blocks) 在转发过程中的兼容性问题
This commit is contained in:
@@ -275,12 +275,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
|
||||||
switchCount++
|
switchCount++
|
||||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
continue
|
continue
|
||||||
@@ -409,12 +408,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
|
||||||
switchCount++
|
switchCount++
|
||||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -523,6 +523,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
proxyURL = account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sanitize thinking blocks (clean cache_control and flatten history thinking)
|
||||||
|
sanitizeThinkingBlocks(&claudeReq)
|
||||||
|
|
||||||
// 获取转换选项
|
// 获取转换选项
|
||||||
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
|
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
|
||||||
transformOpts := s.getClaudeTransformOptions(ctx)
|
transformOpts := s.getClaudeTransformOptions(ctx)
|
||||||
@@ -534,6 +537,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
return nil, fmt.Errorf("transform request: %w", err)
|
return nil, fmt.Errorf("transform request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Safety net: ensure no cache_control leaked into Gemini request
|
||||||
|
geminiBody = cleanCacheControlFromGeminiJSON(geminiBody)
|
||||||
|
|
||||||
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
|
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
|
||||||
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
|
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
|
||||||
action := "streamGenerateContent"
|
action := "streamGenerateContent"
|
||||||
@@ -903,6 +909,143 @@ func extractAntigravityErrorMessage(body []byte) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cleanCacheControlFromGeminiJSON removes cache_control from Gemini JSON (emergency fix)
|
||||||
|
// This should not be needed if transformation is correct, but serves as a safety net
|
||||||
|
func cleanCacheControlFromGeminiJSON(body []byte) []byte {
|
||||||
|
// Try a more robust approach: parse and clean
|
||||||
|
var data map[string]any
|
||||||
|
if err := json.Unmarshal(body, &data); err != nil {
|
||||||
|
log.Printf("[Antigravity] Failed to parse Gemini JSON for cache_control cleaning: %v", err)
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
cleaned := removeCacheControlFromAny(data)
|
||||||
|
if !cleaned {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
if result, err := json.Marshal(data); err == nil {
|
||||||
|
log.Printf("[Antigravity] Successfully cleaned cache_control from Gemini JSON")
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeCacheControlFromAny recursively removes cache_control fields
|
||||||
|
func removeCacheControlFromAny(v any) bool {
|
||||||
|
cleaned := false
|
||||||
|
|
||||||
|
switch val := v.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
for k, child := range val {
|
||||||
|
if k == "cache_control" {
|
||||||
|
delete(val, k)
|
||||||
|
cleaned = true
|
||||||
|
} else if removeCacheControlFromAny(child) {
|
||||||
|
cleaned = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, item := range val {
|
||||||
|
if removeCacheControlFromAny(item) {
|
||||||
|
cleaned = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeThinkingBlocks cleans cache_control and flattens history thinking blocks
|
||||||
|
// Thinking blocks do NOT support cache_control field (Anthropic API/Vertex AI requirement)
|
||||||
|
// Additionally, history thinking blocks are flattened to text to avoid upstream validation errors
|
||||||
|
func sanitizeThinkingBlocks(req *antigravity.ClaudeRequest) {
|
||||||
|
if req == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[Antigravity] sanitizeThinkingBlocks: processing request with %d messages", len(req.Messages))
|
||||||
|
|
||||||
|
// Clean system blocks
|
||||||
|
if len(req.System) > 0 {
|
||||||
|
var systemBlocks []map[string]any
|
||||||
|
if err := json.Unmarshal(req.System, &systemBlocks); err == nil {
|
||||||
|
for i := range systemBlocks {
|
||||||
|
if blockType, _ := systemBlocks[i]["type"].(string); blockType == "thinking" || systemBlocks[i]["thinking"] != nil {
|
||||||
|
if removeCacheControlFromAny(systemBlocks[i]) {
|
||||||
|
log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in system[%d]", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Marshal back
|
||||||
|
if cleaned, err := json.Marshal(systemBlocks); err == nil {
|
||||||
|
req.System = cleaned
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean message content blocks and flatten history
|
||||||
|
lastMsgIdx := len(req.Messages) - 1
|
||||||
|
for msgIdx := range req.Messages {
|
||||||
|
raw := req.Messages[msgIdx].Content
|
||||||
|
if len(raw) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse as blocks array
|
||||||
|
var blocks []map[string]any
|
||||||
|
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
cleaned := false
|
||||||
|
for blockIdx := range blocks {
|
||||||
|
blockType, _ := blocks[blockIdx]["type"].(string)
|
||||||
|
|
||||||
|
// Check for thinking blocks (typed or untyped)
|
||||||
|
if blockType == "thinking" || blocks[blockIdx]["thinking"] != nil {
|
||||||
|
// 1. Clean cache_control
|
||||||
|
if removeCacheControlFromAny(blocks[blockIdx]) {
|
||||||
|
log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in messages[%d].content[%d]", msgIdx, blockIdx)
|
||||||
|
cleaned = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Flatten to text if it's a history message (not the last one)
|
||||||
|
if msgIdx < lastMsgIdx {
|
||||||
|
log.Printf("[Antigravity] Flattening history thinking block to text at messages[%d].content[%d]", msgIdx, blockIdx)
|
||||||
|
|
||||||
|
// Extract thinking content
|
||||||
|
var textContent string
|
||||||
|
if t, ok := blocks[blockIdx]["thinking"].(string); ok {
|
||||||
|
textContent = t
|
||||||
|
} else {
|
||||||
|
// Fallback for non-string content (marshal it)
|
||||||
|
if b, err := json.Marshal(blocks[blockIdx]["thinking"]); err == nil {
|
||||||
|
textContent = string(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to text block
|
||||||
|
blocks[blockIdx]["type"] = "text"
|
||||||
|
blocks[blockIdx]["text"] = textContent
|
||||||
|
delete(blocks[blockIdx], "thinking")
|
||||||
|
delete(blocks[blockIdx], "signature")
|
||||||
|
delete(blocks[blockIdx], "cache_control") // Ensure it's gone
|
||||||
|
cleaned = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal back if modified
|
||||||
|
if cleaned {
|
||||||
|
if marshaled, err := json.Marshal(blocks); err == nil {
|
||||||
|
req.Messages[msgIdx].Content = marshaled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
|
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
|
||||||
// This preserves the thinking content while avoiding signature validation errors.
|
// This preserves the thinking content while avoiding signature validation errors.
|
||||||
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
|
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
|
||||||
|
|||||||
@@ -1227,6 +1227,9 @@ func enforceCacheControlLimit(body []byte) []byte {
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 清理 thinking 块中的非法 cache_control(thinking 块不支持该字段)
|
||||||
|
removeCacheControlFromThinkingBlocks(data)
|
||||||
|
|
||||||
// 计算当前 cache_control 块数量
|
// 计算当前 cache_control 块数量
|
||||||
count := countCacheControlBlocks(data)
|
count := countCacheControlBlocks(data)
|
||||||
if count <= maxCacheControlBlocks {
|
if count <= maxCacheControlBlocks {
|
||||||
@@ -1254,6 +1257,7 @@ func enforceCacheControlLimit(body []byte) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量
|
// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量
|
||||||
|
// 注意:thinking 块不支持 cache_control,统计时跳过
|
||||||
func countCacheControlBlocks(data map[string]any) int {
|
func countCacheControlBlocks(data map[string]any) int {
|
||||||
count := 0
|
count := 0
|
||||||
|
|
||||||
@@ -1261,6 +1265,10 @@ func countCacheControlBlocks(data map[string]any) int {
|
|||||||
if system, ok := data["system"].([]any); ok {
|
if system, ok := data["system"].([]any); ok {
|
||||||
for _, item := range system {
|
for _, item := range system {
|
||||||
if m, ok := item.(map[string]any); ok {
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
// thinking 块不支持 cache_control,跳过
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if _, has := m["cache_control"]; has {
|
if _, has := m["cache_control"]; has {
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
@@ -1275,6 +1283,10 @@ func countCacheControlBlocks(data map[string]any) int {
|
|||||||
if content, ok := msgMap["content"].([]any); ok {
|
if content, ok := msgMap["content"].([]any); ok {
|
||||||
for _, item := range content {
|
for _, item := range content {
|
||||||
if m, ok := item.(map[string]any); ok {
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
// thinking 块不支持 cache_control,跳过
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if _, has := m["cache_control"]; has {
|
if _, has := m["cache_control"]; has {
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
@@ -1290,6 +1302,7 @@ func countCacheControlBlocks(data map[string]any) int {
|
|||||||
|
|
||||||
// removeCacheControlFromMessages 从 messages 中移除一个 cache_control(从头开始)
|
// removeCacheControlFromMessages 从 messages 中移除一个 cache_control(从头开始)
|
||||||
// 返回 true 表示成功移除,false 表示没有可移除的
|
// 返回 true 表示成功移除,false 表示没有可移除的
|
||||||
|
// 注意:跳过 thinking 块(它不支持 cache_control)
|
||||||
func removeCacheControlFromMessages(data map[string]any) bool {
|
func removeCacheControlFromMessages(data map[string]any) bool {
|
||||||
messages, ok := data["messages"].([]any)
|
messages, ok := data["messages"].([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -1307,6 +1320,10 @@ func removeCacheControlFromMessages(data map[string]any) bool {
|
|||||||
}
|
}
|
||||||
for _, item := range content {
|
for _, item := range content {
|
||||||
if m, ok := item.(map[string]any); ok {
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
// thinking 块不支持 cache_control,跳过
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if _, has := m["cache_control"]; has {
|
if _, has := m["cache_control"]; has {
|
||||||
delete(m, "cache_control")
|
delete(m, "cache_control")
|
||||||
return true
|
return true
|
||||||
@@ -1319,6 +1336,7 @@ func removeCacheControlFromMessages(data map[string]any) bool {
|
|||||||
|
|
||||||
// removeCacheControlFromSystem 从 system 中移除一个 cache_control(从尾部开始,保护注入的 prompt)
|
// removeCacheControlFromSystem 从 system 中移除一个 cache_control(从尾部开始,保护注入的 prompt)
|
||||||
// 返回 true 表示成功移除,false 表示没有可移除的
|
// 返回 true 表示成功移除,false 表示没有可移除的
|
||||||
|
// 注意:跳过 thinking 块(它不支持 cache_control)
|
||||||
func removeCacheControlFromSystem(data map[string]any) bool {
|
func removeCacheControlFromSystem(data map[string]any) bool {
|
||||||
system, ok := data["system"].([]any)
|
system, ok := data["system"].([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -1328,6 +1346,10 @@ func removeCacheControlFromSystem(data map[string]any) bool {
|
|||||||
// 从尾部开始移除,保护开头注入的 Claude Code prompt
|
// 从尾部开始移除,保护开头注入的 Claude Code prompt
|
||||||
for i := len(system) - 1; i >= 0; i-- {
|
for i := len(system) - 1; i >= 0; i-- {
|
||||||
if m, ok := system[i].(map[string]any); ok {
|
if m, ok := system[i].(map[string]any); ok {
|
||||||
|
// thinking 块不支持 cache_control,跳过
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if _, has := m["cache_control"]; has {
|
if _, has := m["cache_control"]; has {
|
||||||
delete(m, "cache_control")
|
delete(m, "cache_control")
|
||||||
return true
|
return true
|
||||||
@@ -1337,6 +1359,44 @@ func removeCacheControlFromSystem(data map[string]any) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeCacheControlFromThinkingBlocks 强制清理所有 thinking 块中的非法 cache_control
|
||||||
|
// thinking 块不支持 cache_control 字段,这个函数确保所有 thinking 块都不含该字段
|
||||||
|
func removeCacheControlFromThinkingBlocks(data map[string]any) {
|
||||||
|
// 清理 system 中的 thinking 块
|
||||||
|
if system, ok := data["system"].([]any); ok {
|
||||||
|
for _, item := range system {
|
||||||
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
if _, has := m["cache_control"]; has {
|
||||||
|
delete(m, "cache_control")
|
||||||
|
log.Printf("[Warning] Removed illegal cache_control from thinking block in system")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理 messages 中的 thinking 块
|
||||||
|
if messages, ok := data["messages"].([]any); ok {
|
||||||
|
for msgIdx, msg := range messages {
|
||||||
|
if msgMap, ok := msg.(map[string]any); ok {
|
||||||
|
if content, ok := msgMap["content"].([]any); ok {
|
||||||
|
for contentIdx, item := range content {
|
||||||
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
if _, has := m["cache_control"]; has {
|
||||||
|
delete(m, "cache_control")
|
||||||
|
log.Printf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Forward 转发请求到Claude API
|
// Forward 转发请求到Claude API
|
||||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|||||||
Reference in New Issue
Block a user