fix(网关): 修复工具续链校验与存储策略
完善 function_call_output 续链校验与引用匹配 续链场景强制 store=true,过滤 input 时避免副作用 补充续链判断与过滤相关单元测试 测试: go test ./...
This commit is contained in:
@@ -57,6 +57,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## OpenAI Responses 兼容注意事项
|
||||||
|
|
||||||
|
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
|
||||||
|
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 部署方式
|
## 部署方式
|
||||||
|
|
||||||
### 方式一:脚本安装(推荐)
|
### 方式一:脚本安装(推荐)
|
||||||
|
|||||||
@@ -114,6 +114,26 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
|
|
||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
|
|
||||||
|
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||||
|
// 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call,
|
||||||
|
// 或带 id 且与 call_id 匹配的 item_reference。
|
||||||
|
if service.HasFunctionCallOutput(reqBody) {
|
||||||
|
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||||
|
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
|
||||||
|
if service.HasFunctionCallOutputMissingCallID(reqBody) {
|
||||||
|
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
callIDs := service.FunctionCallOutputCallIDs(reqBody)
|
||||||
|
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
|
||||||
|
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Track if we've started streaming (for error handling)
|
// Track if we've started streaming (for error handling)
|
||||||
streamStarted := false
|
streamStarted := false
|
||||||
|
|
||||||
|
|||||||
@@ -70,6 +70,8 @@ type opencodeCacheMetadata struct {
|
|||||||
|
|
||||||
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
||||||
result := codexTransformResult{}
|
result := codexTransformResult{}
|
||||||
|
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
||||||
|
needsToolContinuation := NeedsToolContinuation(reqBody)
|
||||||
|
|
||||||
model := ""
|
model := ""
|
||||||
if v, ok := reqBody["model"].(string); ok {
|
if v, ok := reqBody["model"].(string); ok {
|
||||||
@@ -84,9 +86,17 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
|||||||
result.NormalizedModel = normalizedModel
|
result.NormalizedModel = normalizedModel
|
||||||
}
|
}
|
||||||
|
|
||||||
if v, ok := reqBody["store"].(bool); !ok || v {
|
// 续链场景强制启用 store;非续链仍按原策略强制关闭存储。
|
||||||
reqBody["store"] = false
|
if needsToolContinuation {
|
||||||
result.Modified = true
|
if v, ok := reqBody["store"].(bool); !ok || !v {
|
||||||
|
reqBody["store"] = true
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if v, ok := reqBody["store"].(bool); !ok || v {
|
||||||
|
reqBody["store"] = false
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if v, ok := reqBody["stream"].(bool); !ok || !v {
|
if v, ok := reqBody["stream"].(bool); !ok || !v {
|
||||||
reqBody["stream"] = true
|
reqBody["stream"] = true
|
||||||
@@ -121,8 +131,9 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
|
||||||
if input, ok := reqBody["input"].([]any); ok {
|
if input, ok := reqBody["input"].([]any); ok {
|
||||||
input = filterCodexInput(input)
|
input = filterCodexInput(input, needsToolContinuation)
|
||||||
reqBody["input"] = input
|
reqBody["input"] = input
|
||||||
result.Modified = true
|
result.Modified = true
|
||||||
}
|
}
|
||||||
@@ -242,7 +253,9 @@ func GetOpenCodeInstructions() string {
|
|||||||
return getOpenCodeCodexHeader()
|
return getOpenCodeCodexHeader()
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterCodexInput(input []any) []any {
|
// filterCodexInput 按需过滤 item_reference 与 id。
|
||||||
|
// preserveReferences 为 true 时保持引用与 id,以满足续链请求对上下文的依赖。
|
||||||
|
func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||||
filtered := make([]any, 0, len(input))
|
filtered := make([]any, 0, len(input))
|
||||||
for _, item := range input {
|
for _, item := range input {
|
||||||
m, ok := item.(map[string]any)
|
m, ok := item.(map[string]any)
|
||||||
@@ -251,10 +264,19 @@ func filterCodexInput(input []any) []any {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if typ, ok := m["type"].(string); ok && typ == "item_reference" {
|
if typ, ok := m["type"].(string); ok && typ == "item_reference" {
|
||||||
continue
|
if !preserveReferences {
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
delete(m, "id")
|
newItem := m
|
||||||
filtered = append(filtered, m)
|
if !preserveReferences {
|
||||||
|
newItem = make(map[string]any, len(m))
|
||||||
|
for key, value := range m {
|
||||||
|
newItem[key] = value
|
||||||
|
}
|
||||||
|
delete(newItem, "id")
|
||||||
|
}
|
||||||
|
filtered = append(filtered, newItem)
|
||||||
}
|
}
|
||||||
return filtered
|
return filtered
|
||||||
}
|
}
|
||||||
|
|||||||
139
backend/internal/service/openai_codex_transform_test.go
Normal file
139
backend/internal/service/openai_codex_transform_test.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||||
|
// 续链场景:保留 item_reference 与 id,并启用 store。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.2",
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "item_reference", "id": "ref1", "text": "x"},
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok", "id": "o1"},
|
||||||
|
},
|
||||||
|
"tool_choice": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
store, ok := reqBody["store"].(bool)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.True(t, store)
|
||||||
|
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 2)
|
||||||
|
|
||||||
|
first := input[0].(map[string]any)
|
||||||
|
require.Equal(t, "item_reference", first["type"])
|
||||||
|
require.Equal(t, "ref1", first["id"])
|
||||||
|
|
||||||
|
second := input[1].(map[string]any)
|
||||||
|
require.Equal(t, "o1", second["id"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_ToolContinuationForcesStoreTrue(t *testing.T) {
|
||||||
|
// 续链场景:显式 store=false 也会被强制为 true。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
"store": false,
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||||
|
},
|
||||||
|
"tool_choice": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
store, ok := reqBody["store"].(bool)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.True(t, store)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_NonContinuationForcesStoreFalseAndStripsIDs(t *testing.T) {
|
||||||
|
// 非续链场景:强制 store=false,并移除 input 中的 id。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
"store": true,
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "text", "id": "t1", "text": "hi"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
store, ok := reqBody["store"].(bool)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, store)
|
||||||
|
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 1)
|
||||||
|
item := input[0].(map[string]any)
|
||||||
|
_, hasID := item["id"]
|
||||||
|
require.False(t, hasID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
|
||||||
|
input := []any{
|
||||||
|
map[string]any{"type": "item_reference", "id": "ref1"},
|
||||||
|
map[string]any{"type": "text", "id": "t1", "text": "hi"},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := filterCodexInput(input, false)
|
||||||
|
require.Len(t, filtered, 1)
|
||||||
|
item := filtered[0].(map[string]any)
|
||||||
|
require.Equal(t, "text", item["type"])
|
||||||
|
_, hasID := item["id"]
|
||||||
|
require.False(t, hasID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||||
|
// 空 input 应保持为空且不触发异常。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
"input": []any{},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupCodexCache(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// 使用临时 HOME 避免触发网络拉取 header。
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
t.Setenv("HOME", tempDir)
|
||||||
|
|
||||||
|
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
|
||||||
|
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644))
|
||||||
|
|
||||||
|
meta := map[string]any{
|
||||||
|
"etag": "",
|
||||||
|
"lastFetch": time.Now().UTC().Format(time.RFC3339),
|
||||||
|
"lastChecked": time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(meta)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
|
||||||
|
}
|
||||||
213
backend/internal/service/openai_tool_continuation.go
Normal file
213
backend/internal/service/openai_tool_continuation.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
|
||||||
|
// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、
|
||||||
|
// 或显式声明 tools/tool_choice。
|
||||||
|
func NeedsToolContinuation(reqBody map[string]any) bool {
|
||||||
|
if reqBody == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if hasNonEmptyString(reqBody["previous_response_id"]) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if hasToolsSignal(reqBody) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if hasToolChoiceSignal(reqBody) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if inputHasType(reqBody, "function_call_output") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if inputHasType(reqBody, "item_reference") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。
|
||||||
|
func HasFunctionCallOutput(reqBody map[string]any) bool {
|
||||||
|
if reqBody == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return inputHasType(reqBody, "function_call_output")
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call,
|
||||||
|
// 用于判断 function_call_output 是否具备可关联的上下文。
|
||||||
|
func HasToolCallContext(reqBody map[string]any) bool {
|
||||||
|
if reqBody == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "tool_call" && itemType != "function_call" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
|
||||||
|
// 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。
|
||||||
|
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
|
||||||
|
if reqBody == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ids := make(map[string]struct{})
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "function_call_output" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||||
|
ids[callID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make([]string, 0, len(ids))
|
||||||
|
for id := range ids {
|
||||||
|
result = append(result, id)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
|
||||||
|
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
|
||||||
|
if reqBody == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "function_call_output" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
callID, _ := itemMap["call_id"].(string)
|
||||||
|
if strings.TrimSpace(callID) == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
|
||||||
|
// 用于仅依赖引用项完成续链场景的校验。
|
||||||
|
func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
|
||||||
|
if reqBody == nil || len(callIDs) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
referenceIDs := make(map[string]struct{})
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "item_reference" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idValue, _ := itemMap["id"].(string)
|
||||||
|
idValue = strings.TrimSpace(idValue)
|
||||||
|
if idValue == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
referenceIDs[idValue] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(referenceIDs) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, callID := range callIDs {
|
||||||
|
if _, ok := referenceIDs[callID]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// inputHasType 判断 input 中是否存在指定类型的 item。
|
||||||
|
func inputHasType(reqBody map[string]any, want string) bool {
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType == want {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasNonEmptyString 判断字段是否为非空字符串。
|
||||||
|
func hasNonEmptyString(value any) bool {
|
||||||
|
stringValue, ok := value.(string)
|
||||||
|
return ok && strings.TrimSpace(stringValue) != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasToolsSignal 判断 tools 字段是否显式声明(存在且不为空)。
|
||||||
|
func hasToolsSignal(reqBody map[string]any) bool {
|
||||||
|
raw, exists := reqBody["tools"]
|
||||||
|
if !exists || raw == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if tools, ok := raw.([]any); ok {
|
||||||
|
return len(tools) > 0
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasToolChoiceSignal 判断 tool_choice 是否显式声明(非空或非 nil)。
|
||||||
|
func hasToolChoiceSignal(reqBody map[string]any) bool {
|
||||||
|
raw, exists := reqBody["tool_choice"]
|
||||||
|
if !exists || raw == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch value := raw.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(value) != ""
|
||||||
|
case map[string]any:
|
||||||
|
return len(value) > 0
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
98
backend/internal/service/openai_tool_continuation_test.go
Normal file
98
backend/internal/service/openai_tool_continuation_test.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNeedsToolContinuationSignals(t *testing.T) {
|
||||||
|
// 覆盖所有触发续链的信号来源,确保判定逻辑完整。
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
body map[string]any
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "nil", body: nil, want: false},
|
||||||
|
{name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true},
|
||||||
|
{name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false},
|
||||||
|
{name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true},
|
||||||
|
{name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true},
|
||||||
|
{name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true},
|
||||||
|
{name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false},
|
||||||
|
{name: "tools_invalid", body: map[string]any{"tools": "bad"}, want: false},
|
||||||
|
{name: "tool_choice", body: map[string]any{"tool_choice": "auto"}, want: true},
|
||||||
|
{name: "tool_choice_object", body: map[string]any{"tool_choice": map[string]any{"type": "function"}}, want: true},
|
||||||
|
{name: "tool_choice_empty_object", body: map[string]any{"tool_choice": map[string]any{}}, want: false},
|
||||||
|
{name: "none", body: map[string]any{"input": []any{map[string]any{"type": "text", "text": "hi"}}}, want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
require.Equal(t, tt.want, NeedsToolContinuation(tt.body))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasFunctionCallOutput(t *testing.T) {
|
||||||
|
// 仅当 input 中存在 function_call_output 才视为续链输出。
|
||||||
|
require.False(t, HasFunctionCallOutput(nil))
|
||||||
|
require.True(t, HasFunctionCallOutput(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "function_call_output"}},
|
||||||
|
}))
|
||||||
|
require.False(t, HasFunctionCallOutput(map[string]any{
|
||||||
|
"input": "text",
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasToolCallContext(t *testing.T) {
|
||||||
|
// tool_call/function_call 必须包含 call_id,才能作为可关联上下文。
|
||||||
|
require.False(t, HasToolCallContext(nil))
|
||||||
|
require.True(t, HasToolCallContext(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "tool_call", "call_id": "call_1"}},
|
||||||
|
}))
|
||||||
|
require.True(t, HasToolCallContext(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "function_call", "call_id": "call_2"}},
|
||||||
|
}))
|
||||||
|
require.False(t, HasToolCallContext(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "tool_call"}},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFunctionCallOutputCallIDs(t *testing.T) {
|
||||||
|
// 仅提取非空 call_id,去重后返回。
|
||||||
|
require.Empty(t, FunctionCallOutputCallIDs(nil))
|
||||||
|
callIDs := FunctionCallOutputCallIDs(map[string]any{
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": ""},
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.ElementsMatch(t, []string{"call_1"}, callIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasFunctionCallOutputMissingCallID(t *testing.T) {
|
||||||
|
require.False(t, HasFunctionCallOutputMissingCallID(nil))
|
||||||
|
require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "function_call_output"}},
|
||||||
|
}))
|
||||||
|
require.False(t, HasFunctionCallOutputMissingCallID(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "function_call_output", "call_id": "call_1"}},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasItemReferenceForCallIDs(t *testing.T) {
|
||||||
|
// item_reference 需要覆盖所有 call_id 才视为可关联上下文。
|
||||||
|
require.False(t, HasItemReferenceForCallIDs(nil, []string{"call_1"}))
|
||||||
|
require.False(t, HasItemReferenceForCallIDs(map[string]any{}, []string{"call_1"}))
|
||||||
|
req := map[string]any{
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "item_reference", "id": "call_1"},
|
||||||
|
map[string]any{"type": "item_reference", "id": "call_2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1"}))
|
||||||
|
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_2"}))
|
||||||
|
require.False(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_3"}))
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user