fix(网关): 修复工具续链校验与存储策略

完善 function_call_output 续链校验与引用匹配
续链场景强制 store=true,过滤 input 时避免副作用
补充续链判断与过滤相关单元测试

测试: go test ./...
This commit is contained in:
yangjianbo
2026-01-13 16:47:35 +08:00
parent 0e44829720
commit 70eaa450db
6 changed files with 507 additions and 8 deletions

View File

@@ -57,6 +57,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
---
## OpenAI Responses 兼容注意事项
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id``tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
---
## 部署方式
### 方式一:脚本安装(推荐)

View File

@@ -114,6 +114,26 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
// 要求 previous_response_id或 input 内存在带 call_id 的 tool_call/function_call
// 或带 id 且与 call_id 匹配的 item_reference。
if service.HasFunctionCallOutput(reqBody) {
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
if service.HasFunctionCallOutputMissingCallID(reqBody) {
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
return
}
callIDs := service.FunctionCallOutputCallIDs(reqBody)
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
return
}
}
}
// Track if we've started streaming (for error handling)
streamStarted := false

View File

@@ -70,6 +70,8 @@ type opencodeCacheMetadata struct {
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
result := codexTransformResult{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。
needsToolContinuation := NeedsToolContinuation(reqBody)
model := ""
if v, ok := reqBody["model"].(string); ok {
@@ -84,10 +86,18 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
result.NormalizedModel = normalizedModel
}
// 续链场景强制启用 store非续链仍按原策略强制关闭存储。
if needsToolContinuation {
if v, ok := reqBody["store"].(bool); !ok || !v {
reqBody["store"] = true
result.Modified = true
}
} else {
if v, ok := reqBody["store"].(bool); !ok || v {
reqBody["store"] = false
result.Modified = true
}
}
if v, ok := reqBody["stream"].(bool); !ok || !v {
reqBody["stream"] = true
result.Modified = true
@@ -121,8 +131,9 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
}
}
// 续链场景保留 item_reference 与 id避免 call_id 上下文丢失。
if input, ok := reqBody["input"].([]any); ok {
input = filterCodexInput(input)
input = filterCodexInput(input, needsToolContinuation)
reqBody["input"] = input
result.Modified = true
}
@@ -242,7 +253,9 @@ func GetOpenCodeInstructions() string {
return getOpenCodeCodexHeader()
}
func filterCodexInput(input []any) []any {
// filterCodexInput 按需过滤 item_reference 与 id。
// preserveReferences 为 true 时保持引用与 id以满足续链请求对上下文的依赖。
func filterCodexInput(input []any, preserveReferences bool) []any {
filtered := make([]any, 0, len(input))
for _, item := range input {
m, ok := item.(map[string]any)
@@ -251,10 +264,19 @@ func filterCodexInput(input []any) []any {
continue
}
if typ, ok := m["type"].(string); ok && typ == "item_reference" {
if !preserveReferences {
continue
}
delete(m, "id")
filtered = append(filtered, m)
}
newItem := m
if !preserveReferences {
newItem = make(map[string]any, len(m))
for key, value := range m {
newItem[key] = value
}
delete(newItem, "id")
}
filtered = append(filtered, newItem)
}
return filtered
}

View File

@@ -0,0 +1,139 @@
package service
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
// 续链场景:保留 item_reference 与 id并启用 store。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "item_reference", "id": "ref1", "text": "x"},
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok", "id": "o1"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.True(t, store)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
first := input[0].(map[string]any)
require.Equal(t, "item_reference", first["type"])
require.Equal(t, "ref1", first["id"])
second := input[1].(map[string]any)
require.Equal(t, "o1", second["id"])
}
func TestApplyCodexOAuthTransform_ToolContinuationForcesStoreTrue(t *testing.T) {
// 续链场景:显式 store=false 也会被强制为 true。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"store": false,
"input": []any{
map[string]any{"type": "function_call_output", "call_id": "call_1"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.True(t, store)
}
func TestApplyCodexOAuthTransform_NonContinuationForcesStoreFalseAndStripsIDs(t *testing.T) {
// 非续链场景:强制 store=false并移除 input 中的 id。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"store": true,
"input": []any{
map[string]any{"type": "text", "id": "t1", "text": "hi"},
},
}
applyCodexOAuthTransform(reqBody)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.False(t, store)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
item := input[0].(map[string]any)
_, hasID := item["id"]
require.False(t, hasID)
}
func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
input := []any{
map[string]any{"type": "item_reference", "id": "ref1"},
map[string]any{"type": "text", "id": "t1", "text": "hi"},
}
filtered := filterCodexInput(input, false)
require.Len(t, filtered, 1)
item := filtered[0].(map[string]any)
require.Equal(t, "text", item["type"])
_, hasID := item["id"]
require.False(t, hasID)
}
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
// 空 input 应保持为空且不触发异常。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"input": []any{},
}
applyCodexOAuthTransform(reqBody)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 0)
}
func setupCodexCache(t *testing.T) {
t.Helper()
// 使用临时 HOME 避免触发网络拉取 header。
tempDir := t.TempDir()
t.Setenv("HOME", tempDir)
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644))
meta := map[string]any{
"etag": "",
"lastFetch": time.Now().UTC().Format(time.RFC3339),
"lastChecked": time.Now().UnixMilli(),
}
data, err := json.Marshal(meta)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
}

View File

@@ -0,0 +1,213 @@
package service
import "strings"
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
// 满足以下任一信号即视为续链previous_response_id、input 内包含 function_call_output/item_reference、
// 或显式声明 tools/tool_choice。
func NeedsToolContinuation(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
if hasNonEmptyString(reqBody["previous_response_id"]) {
return true
}
if hasToolsSignal(reqBody) {
return true
}
if hasToolChoiceSignal(reqBody) {
return true
}
if inputHasType(reqBody, "function_call_output") {
return true
}
if inputHasType(reqBody, "item_reference") {
return true
}
return false
}
// HasFunctionCallOutput 判断 input 是否包含 function_call_output用于触发续链校验。
func HasFunctionCallOutput(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
return inputHasType(reqBody, "function_call_output")
}
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call
// 用于判断 function_call_output 是否具备可关联的上下文。
func HasToolCallContext(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "tool_call" && itemType != "function_call" {
continue
}
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
return true
}
}
return false
}
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
// 仅返回非空 call_id用于与 item_reference.id 做匹配校验。
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
if reqBody == nil {
return nil
}
input, ok := reqBody["input"].([]any)
if !ok {
return nil
}
ids := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "function_call_output" {
continue
}
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
ids[callID] = struct{}{}
}
}
if len(ids) == 0 {
return nil
}
result := make([]string, 0, len(ids))
for id := range ids {
result = append(result, id)
}
return result
}
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "function_call_output" {
continue
}
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) == "" {
return true
}
}
return false
}
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
// 用于仅依赖引用项完成续链场景的校验。
func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
if reqBody == nil || len(callIDs) == 0 {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
referenceIDs := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "item_reference" {
continue
}
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
continue
}
referenceIDs[idValue] = struct{}{}
}
if len(referenceIDs) == 0 {
return false
}
for _, callID := range callIDs {
if _, ok := referenceIDs[callID]; !ok {
return false
}
}
return true
}
// inputHasType 判断 input 中是否存在指定类型的 item。
func inputHasType(reqBody map[string]any, want string) bool {
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType == want {
return true
}
}
return false
}
// hasNonEmptyString 判断字段是否为非空字符串。
func hasNonEmptyString(value any) bool {
stringValue, ok := value.(string)
return ok && strings.TrimSpace(stringValue) != ""
}
// hasToolsSignal 判断 tools 字段是否显式声明(存在且不为空)。
func hasToolsSignal(reqBody map[string]any) bool {
raw, exists := reqBody["tools"]
if !exists || raw == nil {
return false
}
if tools, ok := raw.([]any); ok {
return len(tools) > 0
}
return false
}
// hasToolChoiceSignal 判断 tool_choice 是否显式声明(非空或非 nil
func hasToolChoiceSignal(reqBody map[string]any) bool {
raw, exists := reqBody["tool_choice"]
if !exists || raw == nil {
return false
}
switch value := raw.(type) {
case string:
return strings.TrimSpace(value) != ""
case map[string]any:
return len(value) > 0
default:
return false
}
}

View File

@@ -0,0 +1,98 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNeedsToolContinuationSignals(t *testing.T) {
// 覆盖所有触发续链的信号来源,确保判定逻辑完整。
cases := []struct {
name string
body map[string]any
want bool
}{
{name: "nil", body: nil, want: false},
{name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true},
{name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false},
{name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true},
{name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true},
{name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true},
{name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false},
{name: "tools_invalid", body: map[string]any{"tools": "bad"}, want: false},
{name: "tool_choice", body: map[string]any{"tool_choice": "auto"}, want: true},
{name: "tool_choice_object", body: map[string]any{"tool_choice": map[string]any{"type": "function"}}, want: true},
{name: "tool_choice_empty_object", body: map[string]any{"tool_choice": map[string]any{}}, want: false},
{name: "none", body: map[string]any{"input": []any{map[string]any{"type": "text", "text": "hi"}}}, want: false},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, NeedsToolContinuation(tt.body))
})
}
}
func TestHasFunctionCallOutput(t *testing.T) {
// 仅当 input 中存在 function_call_output 才视为续链输出。
require.False(t, HasFunctionCallOutput(nil))
require.True(t, HasFunctionCallOutput(map[string]any{
"input": []any{map[string]any{"type": "function_call_output"}},
}))
require.False(t, HasFunctionCallOutput(map[string]any{
"input": "text",
}))
}
func TestHasToolCallContext(t *testing.T) {
// tool_call/function_call 必须包含 call_id才能作为可关联上下文。
require.False(t, HasToolCallContext(nil))
require.True(t, HasToolCallContext(map[string]any{
"input": []any{map[string]any{"type": "tool_call", "call_id": "call_1"}},
}))
require.True(t, HasToolCallContext(map[string]any{
"input": []any{map[string]any{"type": "function_call", "call_id": "call_2"}},
}))
require.False(t, HasToolCallContext(map[string]any{
"input": []any{map[string]any{"type": "tool_call"}},
}))
}
func TestFunctionCallOutputCallIDs(t *testing.T) {
// 仅提取非空 call_id去重后返回。
require.Empty(t, FunctionCallOutputCallIDs(nil))
callIDs := FunctionCallOutputCallIDs(map[string]any{
"input": []any{
map[string]any{"type": "function_call_output", "call_id": "call_1"},
map[string]any{"type": "function_call_output", "call_id": ""},
map[string]any{"type": "function_call_output", "call_id": "call_1"},
},
})
require.ElementsMatch(t, []string{"call_1"}, callIDs)
}
func TestHasFunctionCallOutputMissingCallID(t *testing.T) {
require.False(t, HasFunctionCallOutputMissingCallID(nil))
require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{
"input": []any{map[string]any{"type": "function_call_output"}},
}))
require.False(t, HasFunctionCallOutputMissingCallID(map[string]any{
"input": []any{map[string]any{"type": "function_call_output", "call_id": "call_1"}},
}))
}
func TestHasItemReferenceForCallIDs(t *testing.T) {
// item_reference 需要覆盖所有 call_id 才视为可关联上下文。
require.False(t, HasItemReferenceForCallIDs(nil, []string{"call_1"}))
require.False(t, HasItemReferenceForCallIDs(map[string]any{}, []string{"call_1"}))
req := map[string]any{
"input": []any{
map[string]any{"type": "item_reference", "id": "call_1"},
map[string]any{"type": "item_reference", "id": "call_2"},
},
}
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1"}))
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_2"}))
require.False(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_3"}))
}