Merge pull request #269 from mt21625457/main
fix: 修复opencode 适配openai 套餐的错误,通过sub2api完美转发 opencode
This commit is contained in:
@@ -57,6 +57,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## OpenAI Responses 兼容注意事项
|
||||||
|
|
||||||
|
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
|
||||||
|
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 部署方式
|
## 部署方式
|
||||||
|
|
||||||
### 方式一:脚本安装(推荐)
|
### 方式一:脚本安装(推荐)
|
||||||
|
|||||||
@@ -114,6 +114,26 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
|
|
||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
|
|
||||||
|
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||||
|
// 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call,
|
||||||
|
// 或带 id 且与 call_id 匹配的 item_reference。
|
||||||
|
if service.HasFunctionCallOutput(reqBody) {
|
||||||
|
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||||
|
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
|
||||||
|
if service.HasFunctionCallOutputMissingCallID(reqBody) {
|
||||||
|
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
callIDs := service.FunctionCallOutputCallIDs(reqBody)
|
||||||
|
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
|
||||||
|
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Track if we've started streaming (for error handling)
|
// Track if we've started streaming (for error handling)
|
||||||
streamStarted := false
|
streamStarted := false
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,10 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -25,15 +28,34 @@ type RateLimitOptions struct {
|
|||||||
var rateLimitScript = redis.NewScript(`
|
var rateLimitScript = redis.NewScript(`
|
||||||
local current = redis.call('INCR', KEYS[1])
|
local current = redis.call('INCR', KEYS[1])
|
||||||
local ttl = redis.call('PTTL', KEYS[1])
|
local ttl = redis.call('PTTL', KEYS[1])
|
||||||
if current == 1 or ttl == -1 then
|
local repaired = 0
|
||||||
|
if current == 1 then
|
||||||
redis.call('PEXPIRE', KEYS[1], ARGV[1])
|
redis.call('PEXPIRE', KEYS[1], ARGV[1])
|
||||||
|
elseif ttl == -1 then
|
||||||
|
redis.call('PEXPIRE', KEYS[1], ARGV[1])
|
||||||
|
repaired = 1
|
||||||
end
|
end
|
||||||
return current
|
return {current, repaired}
|
||||||
`)
|
`)
|
||||||
|
|
||||||
// rateLimitRun 允许测试覆写脚本执行逻辑
|
// rateLimitRun 允许测试覆写脚本执行逻辑
|
||||||
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
|
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
|
||||||
return rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Int64()
|
values, err := rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Slice()
|
||||||
|
if err != nil {
|
||||||
|
return 0, false, err
|
||||||
|
}
|
||||||
|
if len(values) < 2 {
|
||||||
|
return 0, false, fmt.Errorf("rate limit script returned %d values", len(values))
|
||||||
|
}
|
||||||
|
count, err := parseInt64(values[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, false, err
|
||||||
|
}
|
||||||
|
repaired, err := parseInt64(values[1])
|
||||||
|
if err != nil {
|
||||||
|
return 0, false, err
|
||||||
|
}
|
||||||
|
return count, repaired == 1, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RateLimiter Redis 速率限制器
|
// RateLimiter Redis 速率限制器
|
||||||
@@ -74,8 +96,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati
|
|||||||
windowMillis := windowTTLMillis(window)
|
windowMillis := windowTTLMillis(window)
|
||||||
|
|
||||||
// 使用 Lua 脚本原子操作增加计数并设置过期
|
// 使用 Lua 脚本原子操作增加计数并设置过期
|
||||||
count, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
|
count, repaired, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("[RateLimit] redis error: key=%s mode=%s err=%v", redisKey, failureModeLabel(failureMode), err)
|
||||||
if failureMode == RateLimitFailClose {
|
if failureMode == RateLimitFailClose {
|
||||||
abortRateLimit(c)
|
abortRateLimit(c)
|
||||||
return
|
return
|
||||||
@@ -84,6 +107,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati
|
|||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if repaired {
|
||||||
|
log.Printf("[RateLimit] ttl repaired: key=%s window_ms=%d", redisKey, windowMillis)
|
||||||
|
}
|
||||||
|
|
||||||
// 超过限制
|
// 超过限制
|
||||||
if count > int64(limit) {
|
if count > int64(limit) {
|
||||||
@@ -109,3 +135,27 @@ func abortRateLimit(c *gin.Context) {
|
|||||||
"message": "Too many requests, please try again later",
|
"message": "Too many requests, please try again later",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func failureModeLabel(mode RateLimitFailureMode) string {
|
||||||
|
if mode == RateLimitFailClose {
|
||||||
|
return "fail-close"
|
||||||
|
}
|
||||||
|
return "fail-open"
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseInt64(value any) (int64, error) {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int64:
|
||||||
|
return v, nil
|
||||||
|
case int:
|
||||||
|
return int64(v), nil
|
||||||
|
case string:
|
||||||
|
parsed, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return parsed, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unexpected value type %T", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -66,13 +66,13 @@ func TestRateLimiterSuccessAndLimit(t *testing.T) {
|
|||||||
originalRun := rateLimitRun
|
originalRun := rateLimitRun
|
||||||
counts := []int64{1, 2}
|
counts := []int64{1, 2}
|
||||||
callIndex := 0
|
callIndex := 0
|
||||||
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
|
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
|
||||||
if callIndex >= len(counts) {
|
if callIndex >= len(counts) {
|
||||||
return counts[len(counts)-1], nil
|
return counts[len(counts)-1], false, nil
|
||||||
}
|
}
|
||||||
value := counts[callIndex]
|
value := counts[callIndex]
|
||||||
callIndex++
|
callIndex++
|
||||||
return value, nil
|
return value, false, nil
|
||||||
}
|
}
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
rateLimitRun = originalRun
|
rateLimitRun = originalRun
|
||||||
|
|||||||
@@ -74,6 +74,8 @@ type opencodeCacheMetadata struct {
|
|||||||
|
|
||||||
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
||||||
result := codexTransformResult{}
|
result := codexTransformResult{}
|
||||||
|
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
||||||
|
needsToolContinuation := NeedsToolContinuation(reqBody)
|
||||||
|
|
||||||
model := ""
|
model := ""
|
||||||
if v, ok := reqBody["model"].(string); ok {
|
if v, ok := reqBody["model"].(string); ok {
|
||||||
@@ -88,9 +90,17 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
|||||||
result.NormalizedModel = normalizedModel
|
result.NormalizedModel = normalizedModel
|
||||||
}
|
}
|
||||||
|
|
||||||
if v, ok := reqBody["store"].(bool); !ok || v {
|
// 续链场景强制启用 store;非续链仍按原策略强制关闭存储。
|
||||||
reqBody["store"] = false
|
if needsToolContinuation {
|
||||||
result.Modified = true
|
if v, ok := reqBody["store"].(bool); !ok || !v {
|
||||||
|
reqBody["store"] = true
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if v, ok := reqBody["store"].(bool); !ok || v {
|
||||||
|
reqBody["store"] = false
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if v, ok := reqBody["stream"].(bool); !ok || !v {
|
if v, ok := reqBody["stream"].(bool); !ok || !v {
|
||||||
reqBody["stream"] = true
|
reqBody["stream"] = true
|
||||||
@@ -124,7 +134,7 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
|||||||
result.Modified = true
|
result.Modified = true
|
||||||
}
|
}
|
||||||
} else if existingInstructions == "" {
|
} else if existingInstructions == "" {
|
||||||
// If no opencode instructions available, try codex CLI instructions
|
// 未获取到 opencode 指令时,回退使用 Codex CLI 指令。
|
||||||
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||||
if codexInstructions != "" {
|
if codexInstructions != "" {
|
||||||
reqBody["instructions"] = codexInstructions
|
reqBody["instructions"] = codexInstructions
|
||||||
@@ -132,8 +142,9 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
|
||||||
if input, ok := reqBody["input"].([]any); ok {
|
if input, ok := reqBody["input"].([]any); ok {
|
||||||
input = filterCodexInput(input)
|
input = filterCodexInput(input, needsToolContinuation)
|
||||||
reqBody["input"] = input
|
reqBody["input"] = input
|
||||||
result.Modified = true
|
result.Modified = true
|
||||||
}
|
}
|
||||||
@@ -246,15 +257,15 @@ func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getOpenCodeCodexHeader() string {
|
func getOpenCodeCodexHeader() string {
|
||||||
// Try to get from opencode repository first
|
// 优先从 opencode 仓库缓存获取指令。
|
||||||
opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
|
opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
|
||||||
|
|
||||||
// If opencode instructions are available, return them
|
// 若 opencode 指令可用,直接返回。
|
||||||
if opencodeInstructions != "" {
|
if opencodeInstructions != "" {
|
||||||
return opencodeInstructions
|
return opencodeInstructions
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to local codex CLI instructions
|
// 否则回退使用本地 Codex CLI 指令。
|
||||||
return getCodexCLIInstructions()
|
return getCodexCLIInstructions()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -266,10 +277,12 @@ func GetOpenCodeInstructions() string {
|
|||||||
return getOpenCodeCodexHeader()
|
return getOpenCodeCodexHeader()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCodexCLIInstructions 返回内置的 Codex CLI 指令内容。
|
||||||
func GetCodexCLIInstructions() string {
|
func GetCodexCLIInstructions() string {
|
||||||
return getCodexCLIInstructions()
|
return getCodexCLIInstructions()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
|
||||||
func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
|
func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
|
||||||
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||||
if codexInstructions == "" {
|
if codexInstructions == "" {
|
||||||
@@ -285,6 +298,7 @@ func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsInstructionError 判断错误信息是否与指令格式/系统提示相关。
|
||||||
func IsInstructionError(errorMessage string) bool {
|
func IsInstructionError(errorMessage string) bool {
|
||||||
if errorMessage == "" {
|
if errorMessage == "" {
|
||||||
return false
|
return false
|
||||||
@@ -309,7 +323,9 @@ func IsInstructionError(errorMessage string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterCodexInput(input []any) []any {
|
// filterCodexInput 按需过滤 item_reference 与 id。
|
||||||
|
// preserveReferences 为 true 时保持引用与 id,以满足续链请求对上下文的依赖。
|
||||||
|
func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||||
filtered := make([]any, 0, len(input))
|
filtered := make([]any, 0, len(input))
|
||||||
for _, item := range input {
|
for _, item := range input {
|
||||||
m, ok := item.(map[string]any)
|
m, ok := item.(map[string]any)
|
||||||
@@ -319,23 +335,49 @@ func filterCodexInput(input []any) []any {
|
|||||||
}
|
}
|
||||||
typ, _ := m["type"].(string)
|
typ, _ := m["type"].(string)
|
||||||
if typ == "item_reference" {
|
if typ == "item_reference" {
|
||||||
filtered = append(filtered, m)
|
if !preserveReferences {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newItem := make(map[string]any, len(m))
|
||||||
|
for key, value := range m {
|
||||||
|
newItem[key] = value
|
||||||
|
}
|
||||||
|
filtered = append(filtered, newItem)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Strip per-item ids; keep call_id only for tool call items so outputs can match.
|
|
||||||
|
newItem := m
|
||||||
|
copied := false
|
||||||
|
// 仅在需要修改字段时创建副本,避免直接改写原始输入。
|
||||||
|
ensureCopy := func() {
|
||||||
|
if copied {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
newItem = make(map[string]any, len(m))
|
||||||
|
for key, value := range m {
|
||||||
|
newItem[key] = value
|
||||||
|
}
|
||||||
|
copied = true
|
||||||
|
}
|
||||||
|
|
||||||
if isCodexToolCallItemType(typ) {
|
if isCodexToolCallItemType(typ) {
|
||||||
callID, _ := m["call_id"].(string)
|
if callID, ok := m["call_id"].(string); !ok || strings.TrimSpace(callID) == "" {
|
||||||
if strings.TrimSpace(callID) == "" {
|
|
||||||
if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" {
|
if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" {
|
||||||
m["call_id"] = id
|
ensureCopy()
|
||||||
|
newItem["call_id"] = id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete(m, "id")
|
|
||||||
if !isCodexToolCallItemType(typ) {
|
if !preserveReferences {
|
||||||
delete(m, "call_id")
|
ensureCopy()
|
||||||
|
delete(newItem, "id")
|
||||||
|
if !isCodexToolCallItemType(typ) {
|
||||||
|
delete(newItem, "call_id")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
filtered = append(filtered, m)
|
|
||||||
|
filtered = append(filtered, newItem)
|
||||||
}
|
}
|
||||||
return filtered
|
return filtered
|
||||||
}
|
}
|
||||||
|
|||||||
147
backend/internal/service/openai_codex_transform_test.go
Normal file
147
backend/internal/service/openai_codex_transform_test.go
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||||
|
// 续链场景:保留 item_reference 与 id,并启用 store。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.2",
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "item_reference", "id": "ref1", "text": "x"},
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok", "id": "o1"},
|
||||||
|
},
|
||||||
|
"tool_choice": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
store, ok := reqBody["store"].(bool)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.True(t, store)
|
||||||
|
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 2)
|
||||||
|
|
||||||
|
// 校验 input[0] 为 map,避免断言失败导致测试中断。
|
||||||
|
first, ok := input[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "item_reference", first["type"])
|
||||||
|
require.Equal(t, "ref1", first["id"])
|
||||||
|
|
||||||
|
// 校验 input[1] 为 map,确保后续字段断言安全。
|
||||||
|
second, ok := input[1].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "o1", second["id"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_ToolContinuationForcesStoreTrue(t *testing.T) {
|
||||||
|
// 续链场景:显式 store=false 也会被强制为 true。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
"store": false,
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||||
|
},
|
||||||
|
"tool_choice": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
store, ok := reqBody["store"].(bool)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.True(t, store)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_NonContinuationForcesStoreFalseAndStripsIDs(t *testing.T) {
|
||||||
|
// 非续链场景:强制 store=false,并移除 input 中的 id。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
"store": true,
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "text", "id": "t1", "text": "hi"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
store, ok := reqBody["store"].(bool)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, store)
|
||||||
|
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 1)
|
||||||
|
// 校验 input[0] 为 map,避免类型不匹配触发 errcheck。
|
||||||
|
item, ok := input[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
_, hasID := item["id"]
|
||||||
|
require.False(t, hasID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
|
||||||
|
input := []any{
|
||||||
|
map[string]any{"type": "item_reference", "id": "ref1"},
|
||||||
|
map[string]any{"type": "text", "id": "t1", "text": "hi"},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := filterCodexInput(input, false)
|
||||||
|
require.Len(t, filtered, 1)
|
||||||
|
// 校验 filtered[0] 为 map,确保字段检查可靠。
|
||||||
|
item, ok := filtered[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "text", item["type"])
|
||||||
|
_, hasID := item["id"]
|
||||||
|
require.False(t, hasID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||||
|
// 空 input 应保持为空且不触发异常。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
"input": []any{},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupCodexCache(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// 使用临时 HOME 避免触发网络拉取 header。
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
t.Setenv("HOME", tempDir)
|
||||||
|
|
||||||
|
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
|
||||||
|
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644))
|
||||||
|
|
||||||
|
meta := map[string]any{
|
||||||
|
"etag": "",
|
||||||
|
"lastFetch": time.Now().UTC().Format(time.RFC3339),
|
||||||
|
"lastChecked": time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(meta)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
|
||||||
|
}
|
||||||
@@ -546,7 +546,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
|
|
||||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
||||||
|
|
||||||
// Apply model mapping for all requests (including Codex CLI)
|
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||||
mappedModel := account.GetMappedModel(reqModel)
|
mappedModel := account.GetMappedModel(reqModel)
|
||||||
if mappedModel != reqModel {
|
if mappedModel != reqModel {
|
||||||
log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
|
log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
|
||||||
@@ -554,7 +554,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
bodyModified = true
|
bodyModified = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply Codex model normalization for all OpenAI accounts
|
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
|
||||||
if model, ok := reqBody["model"].(string); ok {
|
if model, ok := reqBody["model"].(string); ok {
|
||||||
normalizedModel := normalizeCodexModel(model)
|
normalizedModel := normalizeCodexModel(model)
|
||||||
if normalizedModel != "" && normalizedModel != model {
|
if normalizedModel != "" && normalizedModel != model {
|
||||||
@@ -566,7 +566,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normalize reasoning.effort parameter (minimal -> none)
|
// 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。
|
||||||
if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
|
if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
|
||||||
if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" {
|
if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" {
|
||||||
reasoning["effort"] = "none"
|
reasoning["effort"] = "none"
|
||||||
|
|||||||
213
backend/internal/service/openai_tool_continuation.go
Normal file
213
backend/internal/service/openai_tool_continuation.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
|
||||||
|
// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、
|
||||||
|
// 或显式声明 tools/tool_choice。
|
||||||
|
func NeedsToolContinuation(reqBody map[string]any) bool {
|
||||||
|
if reqBody == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if hasNonEmptyString(reqBody["previous_response_id"]) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if hasToolsSignal(reqBody) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if hasToolChoiceSignal(reqBody) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if inputHasType(reqBody, "function_call_output") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if inputHasType(reqBody, "item_reference") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。
|
||||||
|
func HasFunctionCallOutput(reqBody map[string]any) bool {
|
||||||
|
if reqBody == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return inputHasType(reqBody, "function_call_output")
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call,
|
||||||
|
// 用于判断 function_call_output 是否具备可关联的上下文。
|
||||||
|
func HasToolCallContext(reqBody map[string]any) bool {
|
||||||
|
if reqBody == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "tool_call" && itemType != "function_call" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
|
||||||
|
// 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。
|
||||||
|
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
|
||||||
|
if reqBody == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ids := make(map[string]struct{})
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "function_call_output" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||||
|
ids[callID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make([]string, 0, len(ids))
|
||||||
|
for id := range ids {
|
||||||
|
result = append(result, id)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
|
||||||
|
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
|
||||||
|
if reqBody == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "function_call_output" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
callID, _ := itemMap["call_id"].(string)
|
||||||
|
if strings.TrimSpace(callID) == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
|
||||||
|
// 用于仅依赖引用项完成续链场景的校验。
|
||||||
|
func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
|
||||||
|
if reqBody == nil || len(callIDs) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
referenceIDs := make(map[string]struct{})
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "item_reference" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idValue, _ := itemMap["id"].(string)
|
||||||
|
idValue = strings.TrimSpace(idValue)
|
||||||
|
if idValue == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
referenceIDs[idValue] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(referenceIDs) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, callID := range callIDs {
|
||||||
|
if _, ok := referenceIDs[callID]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// inputHasType 判断 input 中是否存在指定类型的 item。
|
||||||
|
func inputHasType(reqBody map[string]any, want string) bool {
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType == want {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasNonEmptyString 判断字段是否为非空字符串。
|
||||||
|
func hasNonEmptyString(value any) bool {
|
||||||
|
stringValue, ok := value.(string)
|
||||||
|
return ok && strings.TrimSpace(stringValue) != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasToolsSignal 判断 tools 字段是否显式声明(存在且不为空)。
|
||||||
|
func hasToolsSignal(reqBody map[string]any) bool {
|
||||||
|
raw, exists := reqBody["tools"]
|
||||||
|
if !exists || raw == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if tools, ok := raw.([]any); ok {
|
||||||
|
return len(tools) > 0
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasToolChoiceSignal 判断 tool_choice 是否显式声明(非空或非 nil)。
|
||||||
|
func hasToolChoiceSignal(reqBody map[string]any) bool {
|
||||||
|
raw, exists := reqBody["tool_choice"]
|
||||||
|
if !exists || raw == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch value := raw.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(value) != ""
|
||||||
|
case map[string]any:
|
||||||
|
return len(value) > 0
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
98
backend/internal/service/openai_tool_continuation_test.go
Normal file
98
backend/internal/service/openai_tool_continuation_test.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNeedsToolContinuationSignals(t *testing.T) {
|
||||||
|
// 覆盖所有触发续链的信号来源,确保判定逻辑完整。
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
body map[string]any
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "nil", body: nil, want: false},
|
||||||
|
{name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true},
|
||||||
|
{name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false},
|
||||||
|
{name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true},
|
||||||
|
{name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true},
|
||||||
|
{name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true},
|
||||||
|
{name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false},
|
||||||
|
{name: "tools_invalid", body: map[string]any{"tools": "bad"}, want: false},
|
||||||
|
{name: "tool_choice", body: map[string]any{"tool_choice": "auto"}, want: true},
|
||||||
|
{name: "tool_choice_object", body: map[string]any{"tool_choice": map[string]any{"type": "function"}}, want: true},
|
||||||
|
{name: "tool_choice_empty_object", body: map[string]any{"tool_choice": map[string]any{}}, want: false},
|
||||||
|
{name: "none", body: map[string]any{"input": []any{map[string]any{"type": "text", "text": "hi"}}}, want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
require.Equal(t, tt.want, NeedsToolContinuation(tt.body))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasFunctionCallOutput(t *testing.T) {
|
||||||
|
// 仅当 input 中存在 function_call_output 才视为续链输出。
|
||||||
|
require.False(t, HasFunctionCallOutput(nil))
|
||||||
|
require.True(t, HasFunctionCallOutput(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "function_call_output"}},
|
||||||
|
}))
|
||||||
|
require.False(t, HasFunctionCallOutput(map[string]any{
|
||||||
|
"input": "text",
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasToolCallContext(t *testing.T) {
|
||||||
|
// tool_call/function_call 必须包含 call_id,才能作为可关联上下文。
|
||||||
|
require.False(t, HasToolCallContext(nil))
|
||||||
|
require.True(t, HasToolCallContext(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "tool_call", "call_id": "call_1"}},
|
||||||
|
}))
|
||||||
|
require.True(t, HasToolCallContext(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "function_call", "call_id": "call_2"}},
|
||||||
|
}))
|
||||||
|
require.False(t, HasToolCallContext(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "tool_call"}},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFunctionCallOutputCallIDs(t *testing.T) {
|
||||||
|
// 仅提取非空 call_id,去重后返回。
|
||||||
|
require.Empty(t, FunctionCallOutputCallIDs(nil))
|
||||||
|
callIDs := FunctionCallOutputCallIDs(map[string]any{
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": ""},
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.ElementsMatch(t, []string{"call_1"}, callIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasFunctionCallOutputMissingCallID(t *testing.T) {
|
||||||
|
require.False(t, HasFunctionCallOutputMissingCallID(nil))
|
||||||
|
require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "function_call_output"}},
|
||||||
|
}))
|
||||||
|
require.False(t, HasFunctionCallOutputMissingCallID(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "function_call_output", "call_id": "call_1"}},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasItemReferenceForCallIDs(t *testing.T) {
|
||||||
|
// item_reference 需要覆盖所有 call_id 才视为可关联上下文。
|
||||||
|
require.False(t, HasItemReferenceForCallIDs(nil, []string{"call_1"}))
|
||||||
|
require.False(t, HasItemReferenceForCallIDs(map[string]any{}, []string{"call_1"}))
|
||||||
|
req := map[string]any{
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "item_reference", "id": "call_1"},
|
||||||
|
map[string]any{"type": "item_reference", "id": "call_2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1"}))
|
||||||
|
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_2"}))
|
||||||
|
require.False(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_3"}))
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user