Files
sub2api/backend/internal/service/openai_tool_continuation.go
2026-02-28 15:01:20 +08:00

297 lines
8.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import "strings"
// ToolContinuationSignals 聚合工具续链相关信号,避免重复遍历 input。
type ToolContinuationSignals struct {
HasFunctionCallOutput bool
HasFunctionCallOutputMissingCallID bool
HasToolCallContext bool
HasItemReference bool
HasItemReferenceForAllCallIDs bool
FunctionCallOutputCallIDs []string
}
// FunctionCallOutputValidation 汇总 function_call_output 关联性校验结果。
type FunctionCallOutputValidation struct {
HasFunctionCallOutput bool
HasToolCallContext bool
HasFunctionCallOutputMissingCallID bool
HasItemReferenceForAllCallIDs bool
}
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
// 满足以下任一信号即视为续链previous_response_id、input 内包含 function_call_output/item_reference、
// 或显式声明 tools/tool_choice。
func NeedsToolContinuation(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
if hasNonEmptyString(reqBody["previous_response_id"]) {
return true
}
if hasToolsSignal(reqBody) {
return true
}
if hasToolChoiceSignal(reqBody) {
return true
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType == "function_call_output" || itemType == "item_reference" {
return true
}
}
return false
}
// AnalyzeToolContinuationSignals 单次遍历 input提取 function_call_output/tool_call/item_reference 相关信号。
func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSignals {
signals := ToolContinuationSignals{}
if reqBody == nil {
return signals
}
input, ok := reqBody["input"].([]any)
if !ok {
return signals
}
var callIDs map[string]struct{}
var referenceIDs map[string]struct{}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
switch itemType {
case "tool_call", "function_call":
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) != "" {
signals.HasToolCallContext = true
}
case "function_call_output":
signals.HasFunctionCallOutput = true
callID, _ := itemMap["call_id"].(string)
callID = strings.TrimSpace(callID)
if callID == "" {
signals.HasFunctionCallOutputMissingCallID = true
continue
}
if callIDs == nil {
callIDs = make(map[string]struct{})
}
callIDs[callID] = struct{}{}
case "item_reference":
signals.HasItemReference = true
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
continue
}
if referenceIDs == nil {
referenceIDs = make(map[string]struct{})
}
referenceIDs[idValue] = struct{}{}
}
}
if len(callIDs) == 0 {
return signals
}
signals.FunctionCallOutputCallIDs = make([]string, 0, len(callIDs))
allReferenced := len(referenceIDs) > 0
for callID := range callIDs {
signals.FunctionCallOutputCallIDs = append(signals.FunctionCallOutputCallIDs, callID)
if allReferenced {
if _, ok := referenceIDs[callID]; !ok {
allReferenced = false
}
}
}
signals.HasItemReferenceForAllCallIDs = allReferenced
return signals
}
// ValidateFunctionCallOutputContext 为 handler 提供低开销校验结果:
// 1) 无 function_call_output 直接返回
// 2) 若已存在 tool_call/function_call 上下文则提前返回
// 3) 仅在无工具上下文时才构建 call_id / item_reference 集合
func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutputValidation {
result := FunctionCallOutputValidation{}
if reqBody == nil {
return result
}
input, ok := reqBody["input"].([]any)
if !ok {
return result
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
switch itemType {
case "function_call_output":
result.HasFunctionCallOutput = true
case "tool_call", "function_call":
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) != "" {
result.HasToolCallContext = true
}
}
if result.HasFunctionCallOutput && result.HasToolCallContext {
return result
}
}
if !result.HasFunctionCallOutput || result.HasToolCallContext {
return result
}
callIDs := make(map[string]struct{})
referenceIDs := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
switch itemType {
case "function_call_output":
callID, _ := itemMap["call_id"].(string)
callID = strings.TrimSpace(callID)
if callID == "" {
result.HasFunctionCallOutputMissingCallID = true
continue
}
callIDs[callID] = struct{}{}
case "item_reference":
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
continue
}
referenceIDs[idValue] = struct{}{}
}
}
if len(callIDs) == 0 || len(referenceIDs) == 0 {
return result
}
allReferenced := true
for callID := range callIDs {
if _, ok := referenceIDs[callID]; !ok {
allReferenced = false
break
}
}
result.HasItemReferenceForAllCallIDs = allReferenced
return result
}
// HasFunctionCallOutput 判断 input 是否包含 function_call_output用于触发续链校验。
func HasFunctionCallOutput(reqBody map[string]any) bool {
return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutput
}
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call
// 用于判断 function_call_output 是否具备可关联的上下文。
func HasToolCallContext(reqBody map[string]any) bool {
return AnalyzeToolContinuationSignals(reqBody).HasToolCallContext
}
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
// 仅返回非空 call_id用于与 item_reference.id 做匹配校验。
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
return AnalyzeToolContinuationSignals(reqBody).FunctionCallOutputCallIDs
}
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutputMissingCallID
}
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
// 用于仅依赖引用项完成续链场景的校验。
func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
if reqBody == nil || len(callIDs) == 0 {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
referenceIDs := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "item_reference" {
continue
}
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
continue
}
referenceIDs[idValue] = struct{}{}
}
if len(referenceIDs) == 0 {
return false
}
for _, callID := range callIDs {
if _, ok := referenceIDs[strings.TrimSpace(callID)]; !ok {
return false
}
}
return true
}
// hasNonEmptyString 判断字段是否为非空字符串。
func hasNonEmptyString(value any) bool {
stringValue, ok := value.(string)
return ok && strings.TrimSpace(stringValue) != ""
}
// hasToolsSignal 判断 tools 字段是否显式声明(存在且不为空)。
func hasToolsSignal(reqBody map[string]any) bool {
raw, exists := reqBody["tools"]
if !exists || raw == nil {
return false
}
if tools, ok := raw.([]any); ok {
return len(tools) > 0
}
return false
}
// hasToolChoiceSignal 判断 tool_choice 是否显式声明(非空或非 nil
func hasToolChoiceSignal(reqBody map[string]any) bool {
raw, exists := reqBody["tool_choice"]
if !exists || raw == nil {
return false
}
switch value := raw.(type) {
case string:
return strings.TrimSpace(value) != ""
case map[string]any:
return len(value) > 0
default:
return false
}
}