diff --git a/README_CN.md b/README_CN.md index b8a818b3..41d399d5 100644 --- a/README_CN.md +++ b/README_CN.md @@ -57,6 +57,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅( --- +## OpenAI Responses 兼容注意事项 + +- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。 +- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。 + +--- + ## 部署方式 ### 方式一:脚本安装(推荐) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 3011b97d..c4cfabc3 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -114,6 +114,26 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) + // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 + // 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call, + // 或带 id 且与 call_id 匹配的 item_reference。 + if service.HasFunctionCallOutput(reqBody) { + previousResponseID, _ := reqBody["previous_response_id"].(string) + if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) { + if service.HasFunctionCallOutputMissingCallID(reqBody) { + log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") + return + } + callIDs := service.FunctionCallOutputCallIDs(reqBody) + if !service.HasItemReferenceForCallIDs(reqBody, callIDs) { + log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") + return + } + } + } + // Track if we've started streaming (for error handling) streamStarted := false diff --git a/backend/internal/middleware/rate_limiter.go b/backend/internal/middleware/rate_limiter.go index 13b71683..819d74c2 100644 --- a/backend/internal/middleware/rate_limiter.go +++ b/backend/internal/middleware/rate_limiter.go @@ -2,7 +2,10 @@ package middleware import ( "context" + "fmt" + "log" "net/http" + "strconv" "time" "github.com/gin-gonic/gin" @@ -25,15 +28,34 @@ type RateLimitOptions struct { var rateLimitScript = redis.NewScript(` local current = redis.call('INCR', KEYS[1]) local ttl = redis.call('PTTL', KEYS[1]) -if current == 1 or ttl == -1 then +local repaired = 0 +if current == 1 then redis.call('PEXPIRE', KEYS[1], ARGV[1]) +elseif ttl == -1 then + redis.call('PEXPIRE', KEYS[1], ARGV[1]) + repaired = 1 end -return current +return {current, repaired} `) // rateLimitRun 允许测试覆写脚本执行逻辑 -var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) { - return rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Int64() +var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) { + values, err := rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Slice() + if err != nil { + return 0, false, err + } + if len(values) < 2 { + return 0, false, fmt.Errorf("rate limit script returned %d values", len(values)) + } + count, err := parseInt64(values[0]) + if err != nil { + return 0, false, err + } + repaired, err := parseInt64(values[1]) + if err != nil { + return 0, false, err + } + return count, repaired == 1, nil } // RateLimiter Redis 速率限制器 @@ -74,8 +96,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati windowMillis := windowTTLMillis(window) // 使用 Lua 脚本原子操作增加计数并设置过期 - count, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis) + count, repaired, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis) if err != nil { + log.Printf("[RateLimit] redis error: key=%s mode=%s err=%v", redisKey, failureModeLabel(failureMode), err) if failureMode == RateLimitFailClose { abortRateLimit(c) return @@ -84,6 +107,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati c.Next() return } + if repaired { + log.Printf("[RateLimit] ttl repaired: key=%s window_ms=%d", redisKey, windowMillis) + } // 超过限制 if count > int64(limit) { @@ -109,3 +135,27 @@ func abortRateLimit(c *gin.Context) { "message": "Too many requests, please try again later", }) } + +func failureModeLabel(mode RateLimitFailureMode) string { + if mode == RateLimitFailClose { + return "fail-close" + } + return "fail-open" +} + +func parseInt64(value any) (int64, error) { + switch v := value.(type) { + case int64: + return v, nil + case int: + return int64(v), nil + case string: + parsed, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, err + } + return parsed, nil + default: + return 0, fmt.Errorf("unexpected value type %T", value) + } +} diff --git a/backend/internal/middleware/rate_limiter_test.go b/backend/internal/middleware/rate_limiter_test.go index 7c72e5be..0c379c0f 100644 --- a/backend/internal/middleware/rate_limiter_test.go +++ b/backend/internal/middleware/rate_limiter_test.go @@ -66,13 +66,13 @@ func TestRateLimiterSuccessAndLimit(t *testing.T) { originalRun := rateLimitRun counts := []int64{1, 2} callIndex := 0 - rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) { + rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) { if callIndex >= len(counts) { - return counts[len(counts)-1], nil + return counts[len(counts)-1], false, nil } value := counts[callIndex] callIndex++ - return value, nil + return value, false, nil } t.Cleanup(func() { rateLimitRun = originalRun diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 97b405b2..dcc30fc1 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -74,6 +74,8 @@ type opencodeCacheMetadata struct { func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { result := codexTransformResult{} + // 工具续链需求会影响存储策略与 input 过滤逻辑。 + needsToolContinuation := NeedsToolContinuation(reqBody) model := "" if v, ok := reqBody["model"].(string); ok { @@ -88,9 +90,17 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { result.NormalizedModel = normalizedModel } - if v, ok := reqBody["store"].(bool); !ok || v { - reqBody["store"] = false - result.Modified = true + // 续链场景强制启用 store;非续链仍按原策略强制关闭存储。 + if needsToolContinuation { + if v, ok := reqBody["store"].(bool); !ok || !v { + reqBody["store"] = true + result.Modified = true + } + } else { + if v, ok := reqBody["store"].(bool); !ok || v { + reqBody["store"] = false + result.Modified = true + } } if v, ok := reqBody["stream"].(bool); !ok || !v { reqBody["stream"] = true @@ -124,7 +134,7 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { result.Modified = true } } else if existingInstructions == "" { - // If no opencode instructions available, try codex CLI instructions + // 未获取到 opencode 指令时,回退使用 Codex CLI 指令。 codexInstructions := strings.TrimSpace(getCodexCLIInstructions()) if codexInstructions != "" { reqBody["instructions"] = codexInstructions @@ -132,8 +142,9 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { } } + // 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。 if input, ok := reqBody["input"].([]any); ok { - input = filterCodexInput(input) + input = filterCodexInput(input, needsToolContinuation) reqBody["input"] = input result.Modified = true } @@ -246,15 +257,15 @@ func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string { } func getOpenCodeCodexHeader() string { - // Try to get from opencode repository first + // 优先从 opencode 仓库缓存获取指令。 opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json") - // If opencode instructions are available, return them + // 若 opencode 指令可用,直接返回。 if opencodeInstructions != "" { return opencodeInstructions } - // Fallback to local codex CLI instructions + // 否则回退使用本地 Codex CLI 指令。 return getCodexCLIInstructions() } @@ -266,10 +277,12 @@ func GetOpenCodeInstructions() string { return getOpenCodeCodexHeader() } +// GetCodexCLIInstructions 返回内置的 Codex CLI 指令内容。 func GetCodexCLIInstructions() string { return getCodexCLIInstructions() } +// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。 func ReplaceWithCodexInstructions(reqBody map[string]any) bool { codexInstructions := strings.TrimSpace(getCodexCLIInstructions()) if codexInstructions == "" { @@ -285,6 +298,7 @@ func ReplaceWithCodexInstructions(reqBody map[string]any) bool { return false } +// IsInstructionError 判断错误信息是否与指令格式/系统提示相关。 func IsInstructionError(errorMessage string) bool { if errorMessage == "" { return false @@ -309,7 +323,9 @@ func IsInstructionError(errorMessage string) bool { return false } -func filterCodexInput(input []any) []any { +// filterCodexInput 按需过滤 item_reference 与 id。 +// preserveReferences 为 true 时保持引用与 id,以满足续链请求对上下文的依赖。 +func filterCodexInput(input []any, preserveReferences bool) []any { filtered := make([]any, 0, len(input)) for _, item := range input { m, ok := item.(map[string]any) @@ -319,23 +335,49 @@ func filterCodexInput(input []any) []any { } typ, _ := m["type"].(string) if typ == "item_reference" { - filtered = append(filtered, m) + if !preserveReferences { + continue + } + newItem := make(map[string]any, len(m)) + for key, value := range m { + newItem[key] = value + } + filtered = append(filtered, newItem) continue } - // Strip per-item ids; keep call_id only for tool call items so outputs can match. + + newItem := m + copied := false + // 仅在需要修改字段时创建副本,避免直接改写原始输入。 + ensureCopy := func() { + if copied { + return + } + newItem = make(map[string]any, len(m)) + for key, value := range m { + newItem[key] = value + } + copied = true + } + if isCodexToolCallItemType(typ) { - callID, _ := m["call_id"].(string) - if strings.TrimSpace(callID) == "" { + if callID, ok := m["call_id"].(string); !ok || strings.TrimSpace(callID) == "" { if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" { - m["call_id"] = id + ensureCopy() + newItem["call_id"] = id } } } - delete(m, "id") - if !isCodexToolCallItemType(typ) { - delete(m, "call_id") + + if !preserveReferences { + ensureCopy() + delete(newItem, "id") + if !isCodexToolCallItemType(typ) { + delete(newItem, "call_id") + } } - filtered = append(filtered, m) + + filtered = append(filtered, newItem) } return filtered } diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go new file mode 100644 index 00000000..9663229f --- /dev/null +++ b/backend/internal/service/openai_codex_transform_test.go @@ -0,0 +1,147 @@ +package service + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { + // 续链场景:保留 item_reference 与 id,并启用 store。 + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.2", + "input": []any{ + map[string]any{"type": "item_reference", "id": "ref1", "text": "x"}, + map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok", "id": "o1"}, + }, + "tool_choice": "auto", + } + + applyCodexOAuthTransform(reqBody) + + store, ok := reqBody["store"].(bool) + require.True(t, ok) + require.True(t, store) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 2) + + // 校验 input[0] 为 map,避免断言失败导致测试中断。 + first, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "item_reference", first["type"]) + require.Equal(t, "ref1", first["id"]) + + // 校验 input[1] 为 map,确保后续字段断言安全。 + second, ok := input[1].(map[string]any) + require.True(t, ok) + require.Equal(t, "o1", second["id"]) +} + +func TestApplyCodexOAuthTransform_ToolContinuationForcesStoreTrue(t *testing.T) { + // 续链场景:显式 store=false 也会被强制为 true。 + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.1", + "store": false, + "input": []any{ + map[string]any{"type": "function_call_output", "call_id": "call_1"}, + }, + "tool_choice": "auto", + } + + applyCodexOAuthTransform(reqBody) + + store, ok := reqBody["store"].(bool) + require.True(t, ok) + require.True(t, store) +} + +func TestApplyCodexOAuthTransform_NonContinuationForcesStoreFalseAndStripsIDs(t *testing.T) { + // 非续链场景:强制 store=false,并移除 input 中的 id。 + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.1", + "store": true, + "input": []any{ + map[string]any{"type": "text", "id": "t1", "text": "hi"}, + }, + } + + applyCodexOAuthTransform(reqBody) + + store, ok := reqBody["store"].(bool) + require.True(t, ok) + require.False(t, store) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) + // 校验 input[0] 为 map,避免类型不匹配触发 errcheck。 + item, ok := input[0].(map[string]any) + require.True(t, ok) + _, hasID := item["id"] + require.False(t, hasID) +} + +func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) { + input := []any{ + map[string]any{"type": "item_reference", "id": "ref1"}, + map[string]any{"type": "text", "id": "t1", "text": "hi"}, + } + + filtered := filterCodexInput(input, false) + require.Len(t, filtered, 1) + // 校验 filtered[0] 为 map,确保字段检查可靠。 + item, ok := filtered[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "text", item["type"]) + _, hasID := item["id"] + require.False(t, hasID) +} + +func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { + // 空 input 应保持为空且不触发异常。 + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.1", + "input": []any{}, + } + + applyCodexOAuthTransform(reqBody) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 0) +} + +func setupCodexCache(t *testing.T) { + t.Helper() + + // 使用临时 HOME 避免触发网络拉取 header。 + tempDir := t.TempDir() + t.Setenv("HOME", tempDir) + + cacheDir := filepath.Join(tempDir, ".opencode", "cache") + require.NoError(t, os.MkdirAll(cacheDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644)) + + meta := map[string]any{ + "etag": "", + "lastFetch": time.Now().UTC().Format(time.RFC3339), + "lastChecked": time.Now().UnixMilli(), + } + data, err := json.Marshal(meta) + require.NoError(t, err) + require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644)) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index ffd42d2f..bac117b8 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -546,7 +546,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) - // Apply model mapping for all requests (including Codex CLI) + // 对所有请求执行模型映射(包含 Codex CLI)。 mappedModel := account.GetMappedModel(reqModel) if mappedModel != reqModel { log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI) @@ -554,7 +554,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco bodyModified = true } - // Apply Codex model normalization for all OpenAI accounts + // 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。 if model, ok := reqBody["model"].(string); ok { normalizedModel := normalizeCodexModel(model) if normalizedModel != "" && normalizedModel != model { @@ -566,7 +566,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } } - // Normalize reasoning.effort parameter (minimal -> none) + // 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。 if reasoning, ok := reqBody["reasoning"].(map[string]any); ok { if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" { reasoning["effort"] = "none" diff --git a/backend/internal/service/openai_tool_continuation.go b/backend/internal/service/openai_tool_continuation.go new file mode 100644 index 00000000..e59082b2 --- /dev/null +++ b/backend/internal/service/openai_tool_continuation.go @@ -0,0 +1,213 @@ +package service + +import "strings" + +// NeedsToolContinuation 判定请求是否需要工具调用续链处理。 +// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、 +// 或显式声明 tools/tool_choice。 +func NeedsToolContinuation(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + if hasNonEmptyString(reqBody["previous_response_id"]) { + return true + } + if hasToolsSignal(reqBody) { + return true + } + if hasToolChoiceSignal(reqBody) { + return true + } + if inputHasType(reqBody, "function_call_output") { + return true + } + if inputHasType(reqBody, "item_reference") { + return true + } + return false +} + +// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。 +func HasFunctionCallOutput(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + return inputHasType(reqBody, "function_call_output") +} + +// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call, +// 用于判断 function_call_output 是否具备可关联的上下文。 +func HasToolCallContext(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "tool_call" && itemType != "function_call" { + continue + } + if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { + return true + } + } + return false +} + +// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。 +// 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。 +func FunctionCallOutputCallIDs(reqBody map[string]any) []string { + if reqBody == nil { + return nil + } + input, ok := reqBody["input"].([]any) + if !ok { + return nil + } + ids := make(map[string]struct{}) + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "function_call_output" { + continue + } + if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { + ids[callID] = struct{}{} + } + } + if len(ids) == 0 { + return nil + } + result := make([]string, 0, len(ids)) + for id := range ids { + result = append(result, id) + } + return result +} + +// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。 +func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "function_call_output" { + continue + } + callID, _ := itemMap["call_id"].(string) + if strings.TrimSpace(callID) == "" { + return true + } + } + return false +} + +// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。 +// 用于仅依赖引用项完成续链场景的校验。 +func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool { + if reqBody == nil || len(callIDs) == 0 { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + referenceIDs := make(map[string]struct{}) + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "item_reference" { + continue + } + idValue, _ := itemMap["id"].(string) + idValue = strings.TrimSpace(idValue) + if idValue == "" { + continue + } + referenceIDs[idValue] = struct{}{} + } + if len(referenceIDs) == 0 { + return false + } + for _, callID := range callIDs { + if _, ok := referenceIDs[callID]; !ok { + return false + } + } + return true +} + +// inputHasType 判断 input 中是否存在指定类型的 item。 +func inputHasType(reqBody map[string]any, want string) bool { + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType == want { + return true + } + } + return false +} + +// hasNonEmptyString 判断字段是否为非空字符串。 +func hasNonEmptyString(value any) bool { + stringValue, ok := value.(string) + return ok && strings.TrimSpace(stringValue) != "" +} + +// hasToolsSignal 判断 tools 字段是否显式声明(存在且不为空)。 +func hasToolsSignal(reqBody map[string]any) bool { + raw, exists := reqBody["tools"] + if !exists || raw == nil { + return false + } + if tools, ok := raw.([]any); ok { + return len(tools) > 0 + } + return false +} + +// hasToolChoiceSignal 判断 tool_choice 是否显式声明(非空或非 nil)。 +func hasToolChoiceSignal(reqBody map[string]any) bool { + raw, exists := reqBody["tool_choice"] + if !exists || raw == nil { + return false + } + switch value := raw.(type) { + case string: + return strings.TrimSpace(value) != "" + case map[string]any: + return len(value) > 0 + default: + return false + } +} diff --git a/backend/internal/service/openai_tool_continuation_test.go b/backend/internal/service/openai_tool_continuation_test.go new file mode 100644 index 00000000..fe737ad6 --- /dev/null +++ b/backend/internal/service/openai_tool_continuation_test.go @@ -0,0 +1,98 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNeedsToolContinuationSignals(t *testing.T) { + // 覆盖所有触发续链的信号来源,确保判定逻辑完整。 + cases := []struct { + name string + body map[string]any + want bool + }{ + {name: "nil", body: nil, want: false}, + {name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true}, + {name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false}, + {name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true}, + {name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true}, + {name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true}, + {name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false}, + {name: "tools_invalid", body: map[string]any{"tools": "bad"}, want: false}, + {name: "tool_choice", body: map[string]any{"tool_choice": "auto"}, want: true}, + {name: "tool_choice_object", body: map[string]any{"tool_choice": map[string]any{"type": "function"}}, want: true}, + {name: "tool_choice_empty_object", body: map[string]any{"tool_choice": map[string]any{}}, want: false}, + {name: "none", body: map[string]any{"input": []any{map[string]any{"type": "text", "text": "hi"}}}, want: false}, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, NeedsToolContinuation(tt.body)) + }) + } +} + +func TestHasFunctionCallOutput(t *testing.T) { + // 仅当 input 中存在 function_call_output 才视为续链输出。 + require.False(t, HasFunctionCallOutput(nil)) + require.True(t, HasFunctionCallOutput(map[string]any{ + "input": []any{map[string]any{"type": "function_call_output"}}, + })) + require.False(t, HasFunctionCallOutput(map[string]any{ + "input": "text", + })) +} + +func TestHasToolCallContext(t *testing.T) { + // tool_call/function_call 必须包含 call_id,才能作为可关联上下文。 + require.False(t, HasToolCallContext(nil)) + require.True(t, HasToolCallContext(map[string]any{ + "input": []any{map[string]any{"type": "tool_call", "call_id": "call_1"}}, + })) + require.True(t, HasToolCallContext(map[string]any{ + "input": []any{map[string]any{"type": "function_call", "call_id": "call_2"}}, + })) + require.False(t, HasToolCallContext(map[string]any{ + "input": []any{map[string]any{"type": "tool_call"}}, + })) +} + +func TestFunctionCallOutputCallIDs(t *testing.T) { + // 仅提取非空 call_id,去重后返回。 + require.Empty(t, FunctionCallOutputCallIDs(nil)) + callIDs := FunctionCallOutputCallIDs(map[string]any{ + "input": []any{ + map[string]any{"type": "function_call_output", "call_id": "call_1"}, + map[string]any{"type": "function_call_output", "call_id": ""}, + map[string]any{"type": "function_call_output", "call_id": "call_1"}, + }, + }) + require.ElementsMatch(t, []string{"call_1"}, callIDs) +} + +func TestHasFunctionCallOutputMissingCallID(t *testing.T) { + require.False(t, HasFunctionCallOutputMissingCallID(nil)) + require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{ + "input": []any{map[string]any{"type": "function_call_output"}}, + })) + require.False(t, HasFunctionCallOutputMissingCallID(map[string]any{ + "input": []any{map[string]any{"type": "function_call_output", "call_id": "call_1"}}, + })) +} + +func TestHasItemReferenceForCallIDs(t *testing.T) { + // item_reference 需要覆盖所有 call_id 才视为可关联上下文。 + require.False(t, HasItemReferenceForCallIDs(nil, []string{"call_1"})) + require.False(t, HasItemReferenceForCallIDs(map[string]any{}, []string{"call_1"})) + req := map[string]any{ + "input": []any{ + map[string]any{"type": "item_reference", "id": "call_1"}, + map[string]any{"type": "item_reference", "id": "call_2"}, + }, + } + require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1"})) + require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_2"})) + require.False(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_3"})) +}