Merge pull request #269 from mt21625457/main

fix: 修复opencode 适配openai 套餐的错误,通过sub2api完美转发 opencode
This commit is contained in:
Wesley Liddick
2026-01-13 17:33:07 +08:00
committed by GitHub
9 changed files with 606 additions and 29 deletions

View File

@@ -57,6 +57,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
---
## OpenAI Responses 兼容注意事项
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id``tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
---
## 部署方式
### 方式一:脚本安装(推荐)

View File

@@ -114,6 +114,26 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
// 要求 previous_response_id或 input 内存在带 call_id 的 tool_call/function_call
// 或带 id 且与 call_id 匹配的 item_reference。
if service.HasFunctionCallOutput(reqBody) {
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
if service.HasFunctionCallOutputMissingCallID(reqBody) {
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
return
}
callIDs := service.FunctionCallOutputCallIDs(reqBody)
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
return
}
}
}
// Track if we've started streaming (for error handling)
streamStarted := false

View File

@@ -2,7 +2,10 @@ package middleware
import (
"context"
"fmt"
"log"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
@@ -25,15 +28,34 @@ type RateLimitOptions struct {
var rateLimitScript = redis.NewScript(`
local current = redis.call('INCR', KEYS[1])
local ttl = redis.call('PTTL', KEYS[1])
if current == 1 or ttl == -1 then
local repaired = 0
if current == 1 then
redis.call('PEXPIRE', KEYS[1], ARGV[1])
elseif ttl == -1 then
redis.call('PEXPIRE', KEYS[1], ARGV[1])
repaired = 1
end
return current
return {current, repaired}
`)
// rateLimitRun 允许测试覆写脚本执行逻辑
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
return rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Int64()
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
values, err := rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Slice()
if err != nil {
return 0, false, err
}
if len(values) < 2 {
return 0, false, fmt.Errorf("rate limit script returned %d values", len(values))
}
count, err := parseInt64(values[0])
if err != nil {
return 0, false, err
}
repaired, err := parseInt64(values[1])
if err != nil {
return 0, false, err
}
return count, repaired == 1, nil
}
// RateLimiter Redis 速率限制器
@@ -74,8 +96,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati
windowMillis := windowTTLMillis(window)
// 使用 Lua 脚本原子操作增加计数并设置过期
count, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
count, repaired, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
if err != nil {
log.Printf("[RateLimit] redis error: key=%s mode=%s err=%v", redisKey, failureModeLabel(failureMode), err)
if failureMode == RateLimitFailClose {
abortRateLimit(c)
return
@@ -84,6 +107,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati
c.Next()
return
}
if repaired {
log.Printf("[RateLimit] ttl repaired: key=%s window_ms=%d", redisKey, windowMillis)
}
// 超过限制
if count > int64(limit) {
@@ -109,3 +135,27 @@ func abortRateLimit(c *gin.Context) {
"message": "Too many requests, please try again later",
})
}
func failureModeLabel(mode RateLimitFailureMode) string {
if mode == RateLimitFailClose {
return "fail-close"
}
return "fail-open"
}
func parseInt64(value any) (int64, error) {
switch v := value.(type) {
case int64:
return v, nil
case int:
return int64(v), nil
case string:
parsed, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return 0, err
}
return parsed, nil
default:
return 0, fmt.Errorf("unexpected value type %T", value)
}
}

View File

@@ -66,13 +66,13 @@ func TestRateLimiterSuccessAndLimit(t *testing.T) {
originalRun := rateLimitRun
counts := []int64{1, 2}
callIndex := 0
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
if callIndex >= len(counts) {
return counts[len(counts)-1], nil
return counts[len(counts)-1], false, nil
}
value := counts[callIndex]
callIndex++
return value, nil
return value, false, nil
}
t.Cleanup(func() {
rateLimitRun = originalRun

View File

@@ -74,6 +74,8 @@ type opencodeCacheMetadata struct {
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
result := codexTransformResult{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。
needsToolContinuation := NeedsToolContinuation(reqBody)
model := ""
if v, ok := reqBody["model"].(string); ok {
@@ -88,9 +90,17 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
result.NormalizedModel = normalizedModel
}
if v, ok := reqBody["store"].(bool); !ok || v {
reqBody["store"] = false
result.Modified = true
// 续链场景强制启用 store非续链仍按原策略强制关闭存储。
if needsToolContinuation {
if v, ok := reqBody["store"].(bool); !ok || !v {
reqBody["store"] = true
result.Modified = true
}
} else {
if v, ok := reqBody["store"].(bool); !ok || v {
reqBody["store"] = false
result.Modified = true
}
}
if v, ok := reqBody["stream"].(bool); !ok || !v {
reqBody["stream"] = true
@@ -124,7 +134,7 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
result.Modified = true
}
} else if existingInstructions == "" {
// If no opencode instructions available, try codex CLI instructions
// 未获取到 opencode 指令时,回退使用 Codex CLI 指令。
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
if codexInstructions != "" {
reqBody["instructions"] = codexInstructions
@@ -132,8 +142,9 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
}
}
// 续链场景保留 item_reference 与 id避免 call_id 上下文丢失。
if input, ok := reqBody["input"].([]any); ok {
input = filterCodexInput(input)
input = filterCodexInput(input, needsToolContinuation)
reqBody["input"] = input
result.Modified = true
}
@@ -246,15 +257,15 @@ func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
}
func getOpenCodeCodexHeader() string {
// Try to get from opencode repository first
// 优先从 opencode 仓库缓存获取指令。
opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
// If opencode instructions are available, return them
// opencode 指令可用,直接返回。
if opencodeInstructions != "" {
return opencodeInstructions
}
// Fallback to local codex CLI instructions
// 否则回退使用本地 Codex CLI 指令。
return getCodexCLIInstructions()
}
@@ -266,10 +277,12 @@ func GetOpenCodeInstructions() string {
return getOpenCodeCodexHeader()
}
// GetCodexCLIInstructions 返回内置的 Codex CLI 指令内容。
func GetCodexCLIInstructions() string {
return getCodexCLIInstructions()
}
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
if codexInstructions == "" {
@@ -285,6 +298,7 @@ func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
return false
}
// IsInstructionError 判断错误信息是否与指令格式/系统提示相关。
func IsInstructionError(errorMessage string) bool {
if errorMessage == "" {
return false
@@ -309,7 +323,9 @@ func IsInstructionError(errorMessage string) bool {
return false
}
func filterCodexInput(input []any) []any {
// filterCodexInput 按需过滤 item_reference 与 id。
// preserveReferences 为 true 时保持引用与 id以满足续链请求对上下文的依赖。
func filterCodexInput(input []any, preserveReferences bool) []any {
filtered := make([]any, 0, len(input))
for _, item := range input {
m, ok := item.(map[string]any)
@@ -319,23 +335,49 @@ func filterCodexInput(input []any) []any {
}
typ, _ := m["type"].(string)
if typ == "item_reference" {
filtered = append(filtered, m)
if !preserveReferences {
continue
}
newItem := make(map[string]any, len(m))
for key, value := range m {
newItem[key] = value
}
filtered = append(filtered, newItem)
continue
}
// Strip per-item ids; keep call_id only for tool call items so outputs can match.
newItem := m
copied := false
// 仅在需要修改字段时创建副本,避免直接改写原始输入。
ensureCopy := func() {
if copied {
return
}
newItem = make(map[string]any, len(m))
for key, value := range m {
newItem[key] = value
}
copied = true
}
if isCodexToolCallItemType(typ) {
callID, _ := m["call_id"].(string)
if strings.TrimSpace(callID) == "" {
if callID, ok := m["call_id"].(string); !ok || strings.TrimSpace(callID) == "" {
if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" {
m["call_id"] = id
ensureCopy()
newItem["call_id"] = id
}
}
}
delete(m, "id")
if !isCodexToolCallItemType(typ) {
delete(m, "call_id")
if !preserveReferences {
ensureCopy()
delete(newItem, "id")
if !isCodexToolCallItemType(typ) {
delete(newItem, "call_id")
}
}
filtered = append(filtered, m)
filtered = append(filtered, newItem)
}
return filtered
}

View File

@@ -0,0 +1,147 @@
package service
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
// 续链场景:保留 item_reference 与 id并启用 store。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "item_reference", "id": "ref1", "text": "x"},
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok", "id": "o1"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.True(t, store)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
// 校验 input[0] 为 map避免断言失败导致测试中断。
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "item_reference", first["type"])
require.Equal(t, "ref1", first["id"])
// 校验 input[1] 为 map确保后续字段断言安全。
second, ok := input[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "o1", second["id"])
}
func TestApplyCodexOAuthTransform_ToolContinuationForcesStoreTrue(t *testing.T) {
// 续链场景:显式 store=false 也会被强制为 true。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"store": false,
"input": []any{
map[string]any{"type": "function_call_output", "call_id": "call_1"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.True(t, store)
}
func TestApplyCodexOAuthTransform_NonContinuationForcesStoreFalseAndStripsIDs(t *testing.T) {
// 非续链场景:强制 store=false并移除 input 中的 id。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"store": true,
"input": []any{
map[string]any{"type": "text", "id": "t1", "text": "hi"},
},
}
applyCodexOAuthTransform(reqBody)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.False(t, store)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
// 校验 input[0] 为 map避免类型不匹配触发 errcheck。
item, ok := input[0].(map[string]any)
require.True(t, ok)
_, hasID := item["id"]
require.False(t, hasID)
}
func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
input := []any{
map[string]any{"type": "item_reference", "id": "ref1"},
map[string]any{"type": "text", "id": "t1", "text": "hi"},
}
filtered := filterCodexInput(input, false)
require.Len(t, filtered, 1)
// 校验 filtered[0] 为 map确保字段检查可靠。
item, ok := filtered[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", item["type"])
_, hasID := item["id"]
require.False(t, hasID)
}
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
// 空 input 应保持为空且不触发异常。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"input": []any{},
}
applyCodexOAuthTransform(reqBody)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 0)
}
func setupCodexCache(t *testing.T) {
t.Helper()
// 使用临时 HOME 避免触发网络拉取 header。
tempDir := t.TempDir()
t.Setenv("HOME", tempDir)
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644))
meta := map[string]any{
"etag": "",
"lastFetch": time.Now().UTC().Format(time.RFC3339),
"lastChecked": time.Now().UnixMilli(),
}
data, err := json.Marshal(meta)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
}

View File

@@ -546,7 +546,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
// Apply model mapping for all requests (including Codex CLI)
// 对所有请求执行模型映射(包含 Codex CLI)。
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
@@ -554,7 +554,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
bodyModified = true
}
// Apply Codex model normalization for all OpenAI accounts
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
if model, ok := reqBody["model"].(string); ok {
normalizedModel := normalizeCodexModel(model)
if normalizedModel != "" && normalizedModel != model {
@@ -566,7 +566,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
}
// Normalize reasoning.effort parameter (minimal -> none)
// 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。
if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" {
reasoning["effort"] = "none"

View File

@@ -0,0 +1,213 @@
package service
import "strings"
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
// 满足以下任一信号即视为续链previous_response_id、input 内包含 function_call_output/item_reference、
// 或显式声明 tools/tool_choice。
func NeedsToolContinuation(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
if hasNonEmptyString(reqBody["previous_response_id"]) {
return true
}
if hasToolsSignal(reqBody) {
return true
}
if hasToolChoiceSignal(reqBody) {
return true
}
if inputHasType(reqBody, "function_call_output") {
return true
}
if inputHasType(reqBody, "item_reference") {
return true
}
return false
}
// HasFunctionCallOutput 判断 input 是否包含 function_call_output用于触发续链校验。
func HasFunctionCallOutput(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
return inputHasType(reqBody, "function_call_output")
}
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call
// 用于判断 function_call_output 是否具备可关联的上下文。
func HasToolCallContext(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "tool_call" && itemType != "function_call" {
continue
}
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
return true
}
}
return false
}
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
// 仅返回非空 call_id用于与 item_reference.id 做匹配校验。
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
if reqBody == nil {
return nil
}
input, ok := reqBody["input"].([]any)
if !ok {
return nil
}
ids := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "function_call_output" {
continue
}
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
ids[callID] = struct{}{}
}
}
if len(ids) == 0 {
return nil
}
result := make([]string, 0, len(ids))
for id := range ids {
result = append(result, id)
}
return result
}
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "function_call_output" {
continue
}
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) == "" {
return true
}
}
return false
}
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
// 用于仅依赖引用项完成续链场景的校验。
func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
if reqBody == nil || len(callIDs) == 0 {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
referenceIDs := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "item_reference" {
continue
}
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
continue
}
referenceIDs[idValue] = struct{}{}
}
if len(referenceIDs) == 0 {
return false
}
for _, callID := range callIDs {
if _, ok := referenceIDs[callID]; !ok {
return false
}
}
return true
}
// inputHasType 判断 input 中是否存在指定类型的 item。
func inputHasType(reqBody map[string]any, want string) bool {
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType == want {
return true
}
}
return false
}
// hasNonEmptyString 判断字段是否为非空字符串。
func hasNonEmptyString(value any) bool {
stringValue, ok := value.(string)
return ok && strings.TrimSpace(stringValue) != ""
}
// hasToolsSignal 判断 tools 字段是否显式声明(存在且不为空)。
func hasToolsSignal(reqBody map[string]any) bool {
raw, exists := reqBody["tools"]
if !exists || raw == nil {
return false
}
if tools, ok := raw.([]any); ok {
return len(tools) > 0
}
return false
}
// hasToolChoiceSignal 判断 tool_choice 是否显式声明(非空或非 nil
func hasToolChoiceSignal(reqBody map[string]any) bool {
raw, exists := reqBody["tool_choice"]
if !exists || raw == nil {
return false
}
switch value := raw.(type) {
case string:
return strings.TrimSpace(value) != ""
case map[string]any:
return len(value) > 0
default:
return false
}
}

View File

@@ -0,0 +1,98 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNeedsToolContinuationSignals(t *testing.T) {
// 覆盖所有触发续链的信号来源,确保判定逻辑完整。
cases := []struct {
name string
body map[string]any
want bool
}{
{name: "nil", body: nil, want: false},
{name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true},
{name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false},
{name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true},
{name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true},
{name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true},
{name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false},
{name: "tools_invalid", body: map[string]any{"tools": "bad"}, want: false},
{name: "tool_choice", body: map[string]any{"tool_choice": "auto"}, want: true},
{name: "tool_choice_object", body: map[string]any{"tool_choice": map[string]any{"type": "function"}}, want: true},
{name: "tool_choice_empty_object", body: map[string]any{"tool_choice": map[string]any{}}, want: false},
{name: "none", body: map[string]any{"input": []any{map[string]any{"type": "text", "text": "hi"}}}, want: false},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, NeedsToolContinuation(tt.body))
})
}
}
func TestHasFunctionCallOutput(t *testing.T) {
// 仅当 input 中存在 function_call_output 才视为续链输出。
require.False(t, HasFunctionCallOutput(nil))
require.True(t, HasFunctionCallOutput(map[string]any{
"input": []any{map[string]any{"type": "function_call_output"}},
}))
require.False(t, HasFunctionCallOutput(map[string]any{
"input": "text",
}))
}
func TestHasToolCallContext(t *testing.T) {
// tool_call/function_call 必须包含 call_id才能作为可关联上下文。
require.False(t, HasToolCallContext(nil))
require.True(t, HasToolCallContext(map[string]any{
"input": []any{map[string]any{"type": "tool_call", "call_id": "call_1"}},
}))
require.True(t, HasToolCallContext(map[string]any{
"input": []any{map[string]any{"type": "function_call", "call_id": "call_2"}},
}))
require.False(t, HasToolCallContext(map[string]any{
"input": []any{map[string]any{"type": "tool_call"}},
}))
}
func TestFunctionCallOutputCallIDs(t *testing.T) {
// 仅提取非空 call_id去重后返回。
require.Empty(t, FunctionCallOutputCallIDs(nil))
callIDs := FunctionCallOutputCallIDs(map[string]any{
"input": []any{
map[string]any{"type": "function_call_output", "call_id": "call_1"},
map[string]any{"type": "function_call_output", "call_id": ""},
map[string]any{"type": "function_call_output", "call_id": "call_1"},
},
})
require.ElementsMatch(t, []string{"call_1"}, callIDs)
}
func TestHasFunctionCallOutputMissingCallID(t *testing.T) {
require.False(t, HasFunctionCallOutputMissingCallID(nil))
require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{
"input": []any{map[string]any{"type": "function_call_output"}},
}))
require.False(t, HasFunctionCallOutputMissingCallID(map[string]any{
"input": []any{map[string]any{"type": "function_call_output", "call_id": "call_1"}},
}))
}
func TestHasItemReferenceForCallIDs(t *testing.T) {
// item_reference 需要覆盖所有 call_id 才视为可关联上下文。
require.False(t, HasItemReferenceForCallIDs(nil, []string{"call_1"}))
require.False(t, HasItemReferenceForCallIDs(map[string]any{}, []string{"call_1"}))
req := map[string]any{
"input": []any{
map[string]any{"type": "item_reference", "id": "call_1"},
map[string]any{"type": "item_reference", "id": "call_2"},
},
}
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1"}))
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_2"}))
require.False(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_3"}))
}