diff --git a/README_CN.md b/README_CN.md index b8a818b3..41d399d5 100644 --- a/README_CN.md +++ b/README_CN.md @@ -57,6 +57,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅( --- +## OpenAI Responses 兼容注意事项 + +- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。 +- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。 + +--- + ## 部署方式 ### 方式一:脚本安装(推荐) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 3011b97d..c4cfabc3 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -114,6 +114,26 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) + // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 + // 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call, + // 或带 id 且与 call_id 匹配的 item_reference。 + if service.HasFunctionCallOutput(reqBody) { + previousResponseID, _ := reqBody["previous_response_id"].(string) + if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) { + if service.HasFunctionCallOutputMissingCallID(reqBody) { + log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") + return + } + callIDs := service.FunctionCallOutputCallIDs(reqBody) + if !service.HasItemReferenceForCallIDs(reqBody, callIDs) { + log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") + return + } + } + } + // Track if we've started streaming (for error handling) streamStarted := false diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 94e74f22..fb5cf58b 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -70,6 +70,8 @@ type opencodeCacheMetadata struct { func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { result := codexTransformResult{} + // 工具续链需求会影响存储策略与 input 过滤逻辑。 + needsToolContinuation := NeedsToolContinuation(reqBody) model := "" if v, ok := reqBody["model"].(string); ok { @@ -84,9 +86,17 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { result.NormalizedModel = normalizedModel } - if v, ok := reqBody["store"].(bool); !ok || v { - reqBody["store"] = false - result.Modified = true + // 续链场景强制启用 store;非续链仍按原策略强制关闭存储。 + if needsToolContinuation { + if v, ok := reqBody["store"].(bool); !ok || !v { + reqBody["store"] = true + result.Modified = true + } + } else { + if v, ok := reqBody["store"].(bool); !ok || v { + reqBody["store"] = false + result.Modified = true + } } if v, ok := reqBody["stream"].(bool); !ok || !v { reqBody["stream"] = true @@ -121,8 +131,9 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { } } + // 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。 if input, ok := reqBody["input"].([]any); ok { - input = filterCodexInput(input) + input = filterCodexInput(input, needsToolContinuation) reqBody["input"] = input result.Modified = true } @@ -242,7 +253,9 @@ func GetOpenCodeInstructions() string { return getOpenCodeCodexHeader() } -func filterCodexInput(input []any) []any { +// filterCodexInput 按需过滤 item_reference 与 id。 +// preserveReferences 为 true 时保持引用与 id,以满足续链请求对上下文的依赖。 +func filterCodexInput(input []any, preserveReferences bool) []any { filtered := make([]any, 0, len(input)) for _, item := range input { m, ok := item.(map[string]any) @@ -251,10 +264,19 @@ func filterCodexInput(input []any) []any { continue } if typ, ok := m["type"].(string); ok && typ == "item_reference" { - continue + if !preserveReferences { + continue + } } - delete(m, "id") - filtered = append(filtered, m) + newItem := m + if !preserveReferences { + newItem = make(map[string]any, len(m)) + for key, value := range m { + newItem[key] = value + } + delete(newItem, "id") + } + filtered = append(filtered, newItem) } return filtered } diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go new file mode 100644 index 00000000..40e080bd --- /dev/null +++ b/backend/internal/service/openai_codex_transform_test.go @@ -0,0 +1,139 @@ +package service + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { + // 续链场景:保留 item_reference 与 id,并启用 store。 + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.2", + "input": []any{ + map[string]any{"type": "item_reference", "id": "ref1", "text": "x"}, + map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok", "id": "o1"}, + }, + "tool_choice": "auto", + } + + applyCodexOAuthTransform(reqBody) + + store, ok := reqBody["store"].(bool) + require.True(t, ok) + require.True(t, store) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 2) + + first := input[0].(map[string]any) + require.Equal(t, "item_reference", first["type"]) + require.Equal(t, "ref1", first["id"]) + + second := input[1].(map[string]any) + require.Equal(t, "o1", second["id"]) +} + +func TestApplyCodexOAuthTransform_ToolContinuationForcesStoreTrue(t *testing.T) { + // 续链场景:显式 store=false 也会被强制为 true。 + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.1", + "store": false, + "input": []any{ + map[string]any{"type": "function_call_output", "call_id": "call_1"}, + }, + "tool_choice": "auto", + } + + applyCodexOAuthTransform(reqBody) + + store, ok := reqBody["store"].(bool) + require.True(t, ok) + require.True(t, store) +} + +func TestApplyCodexOAuthTransform_NonContinuationForcesStoreFalseAndStripsIDs(t *testing.T) { + // 非续链场景:强制 store=false,并移除 input 中的 id。 + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.1", + "store": true, + "input": []any{ + map[string]any{"type": "text", "id": "t1", "text": "hi"}, + }, + } + + applyCodexOAuthTransform(reqBody) + + store, ok := reqBody["store"].(bool) + require.True(t, ok) + require.False(t, store) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) + item := input[0].(map[string]any) + _, hasID := item["id"] + require.False(t, hasID) +} + +func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) { + input := []any{ + map[string]any{"type": "item_reference", "id": "ref1"}, + map[string]any{"type": "text", "id": "t1", "text": "hi"}, + } + + filtered := filterCodexInput(input, false) + require.Len(t, filtered, 1) + item := filtered[0].(map[string]any) + require.Equal(t, "text", item["type"]) + _, hasID := item["id"] + require.False(t, hasID) +} + +func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { + // 空 input 应保持为空且不触发异常。 + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.1", + "input": []any{}, + } + + applyCodexOAuthTransform(reqBody) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 0) +} + +func setupCodexCache(t *testing.T) { + t.Helper() + + // 使用临时 HOME 避免触发网络拉取 header。 + tempDir := t.TempDir() + t.Setenv("HOME", tempDir) + + cacheDir := filepath.Join(tempDir, ".opencode", "cache") + require.NoError(t, os.MkdirAll(cacheDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644)) + + meta := map[string]any{ + "etag": "", + "lastFetch": time.Now().UTC().Format(time.RFC3339), + "lastChecked": time.Now().UnixMilli(), + } + data, err := json.Marshal(meta) + require.NoError(t, err) + require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644)) +} diff --git a/backend/internal/service/openai_tool_continuation.go b/backend/internal/service/openai_tool_continuation.go new file mode 100644 index 00000000..e59082b2 --- /dev/null +++ b/backend/internal/service/openai_tool_continuation.go @@ -0,0 +1,213 @@ +package service + +import "strings" + +// NeedsToolContinuation 判定请求是否需要工具调用续链处理。 +// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、 +// 或显式声明 tools/tool_choice。 +func NeedsToolContinuation(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + if hasNonEmptyString(reqBody["previous_response_id"]) { + return true + } + if hasToolsSignal(reqBody) { + return true + } + if hasToolChoiceSignal(reqBody) { + return true + } + if inputHasType(reqBody, "function_call_output") { + return true + } + if inputHasType(reqBody, "item_reference") { + return true + } + return false +} + +// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。 +func HasFunctionCallOutput(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + return inputHasType(reqBody, "function_call_output") +} + +// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call, +// 用于判断 function_call_output 是否具备可关联的上下文。 +func HasToolCallContext(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "tool_call" && itemType != "function_call" { + continue + } + if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { + return true + } + } + return false +} + +// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。 +// 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。 +func FunctionCallOutputCallIDs(reqBody map[string]any) []string { + if reqBody == nil { + return nil + } + input, ok := reqBody["input"].([]any) + if !ok { + return nil + } + ids := make(map[string]struct{}) + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "function_call_output" { + continue + } + if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { + ids[callID] = struct{}{} + } + } + if len(ids) == 0 { + return nil + } + result := make([]string, 0, len(ids)) + for id := range ids { + result = append(result, id) + } + return result +} + +// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。 +func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "function_call_output" { + continue + } + callID, _ := itemMap["call_id"].(string) + if strings.TrimSpace(callID) == "" { + return true + } + } + return false +} + +// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。 +// 用于仅依赖引用项完成续链场景的校验。 +func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool { + if reqBody == nil || len(callIDs) == 0 { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + referenceIDs := make(map[string]struct{}) + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "item_reference" { + continue + } + idValue, _ := itemMap["id"].(string) + idValue = strings.TrimSpace(idValue) + if idValue == "" { + continue + } + referenceIDs[idValue] = struct{}{} + } + if len(referenceIDs) == 0 { + return false + } + for _, callID := range callIDs { + if _, ok := referenceIDs[callID]; !ok { + return false + } + } + return true +} + +// inputHasType 判断 input 中是否存在指定类型的 item。 +func inputHasType(reqBody map[string]any, want string) bool { + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType == want { + return true + } + } + return false +} + +// hasNonEmptyString 判断字段是否为非空字符串。 +func hasNonEmptyString(value any) bool { + stringValue, ok := value.(string) + return ok && strings.TrimSpace(stringValue) != "" +} + +// hasToolsSignal 判断 tools 字段是否显式声明(存在且不为空)。 +func hasToolsSignal(reqBody map[string]any) bool { + raw, exists := reqBody["tools"] + if !exists || raw == nil { + return false + } + if tools, ok := raw.([]any); ok { + return len(tools) > 0 + } + return false +} + +// hasToolChoiceSignal 判断 tool_choice 是否显式声明(非空或非 nil)。 +func hasToolChoiceSignal(reqBody map[string]any) bool { + raw, exists := reqBody["tool_choice"] + if !exists || raw == nil { + return false + } + switch value := raw.(type) { + case string: + return strings.TrimSpace(value) != "" + case map[string]any: + return len(value) > 0 + default: + return false + } +} diff --git a/backend/internal/service/openai_tool_continuation_test.go b/backend/internal/service/openai_tool_continuation_test.go new file mode 100644 index 00000000..fe737ad6 --- /dev/null +++ b/backend/internal/service/openai_tool_continuation_test.go @@ -0,0 +1,98 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNeedsToolContinuationSignals(t *testing.T) { + // 覆盖所有触发续链的信号来源,确保判定逻辑完整。 + cases := []struct { + name string + body map[string]any + want bool + }{ + {name: "nil", body: nil, want: false}, + {name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true}, + {name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false}, + {name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true}, + {name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true}, + {name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true}, + {name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false}, + {name: "tools_invalid", body: map[string]any{"tools": "bad"}, want: false}, + {name: "tool_choice", body: map[string]any{"tool_choice": "auto"}, want: true}, + {name: "tool_choice_object", body: map[string]any{"tool_choice": map[string]any{"type": "function"}}, want: true}, + {name: "tool_choice_empty_object", body: map[string]any{"tool_choice": map[string]any{}}, want: false}, + {name: "none", body: map[string]any{"input": []any{map[string]any{"type": "text", "text": "hi"}}}, want: false}, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, NeedsToolContinuation(tt.body)) + }) + } +} + +func TestHasFunctionCallOutput(t *testing.T) { + // 仅当 input 中存在 function_call_output 才视为续链输出。 + require.False(t, HasFunctionCallOutput(nil)) + require.True(t, HasFunctionCallOutput(map[string]any{ + "input": []any{map[string]any{"type": "function_call_output"}}, + })) + require.False(t, HasFunctionCallOutput(map[string]any{ + "input": "text", + })) +} + +func TestHasToolCallContext(t *testing.T) { + // tool_call/function_call 必须包含 call_id,才能作为可关联上下文。 + require.False(t, HasToolCallContext(nil)) + require.True(t, HasToolCallContext(map[string]any{ + "input": []any{map[string]any{"type": "tool_call", "call_id": "call_1"}}, + })) + require.True(t, HasToolCallContext(map[string]any{ + "input": []any{map[string]any{"type": "function_call", "call_id": "call_2"}}, + })) + require.False(t, HasToolCallContext(map[string]any{ + "input": []any{map[string]any{"type": "tool_call"}}, + })) +} + +func TestFunctionCallOutputCallIDs(t *testing.T) { + // 仅提取非空 call_id,去重后返回。 + require.Empty(t, FunctionCallOutputCallIDs(nil)) + callIDs := FunctionCallOutputCallIDs(map[string]any{ + "input": []any{ + map[string]any{"type": "function_call_output", "call_id": "call_1"}, + map[string]any{"type": "function_call_output", "call_id": ""}, + map[string]any{"type": "function_call_output", "call_id": "call_1"}, + }, + }) + require.ElementsMatch(t, []string{"call_1"}, callIDs) +} + +func TestHasFunctionCallOutputMissingCallID(t *testing.T) { + require.False(t, HasFunctionCallOutputMissingCallID(nil)) + require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{ + "input": []any{map[string]any{"type": "function_call_output"}}, + })) + require.False(t, HasFunctionCallOutputMissingCallID(map[string]any{ + "input": []any{map[string]any{"type": "function_call_output", "call_id": "call_1"}}, + })) +} + +func TestHasItemReferenceForCallIDs(t *testing.T) { + // item_reference 需要覆盖所有 call_id 才视为可关联上下文。 + require.False(t, HasItemReferenceForCallIDs(nil, []string{"call_1"})) + require.False(t, HasItemReferenceForCallIDs(map[string]any{}, []string{"call_1"})) + req := map[string]any{ + "input": []any{ + map[string]any{"type": "item_reference", "id": "call_1"}, + map[string]any{"type": "item_reference", "id": "call_2"}, + }, + } + require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1"})) + require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_2"})) + require.False(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_3"})) +}