fix(openai): support remote compact task

This commit is contained in:
神乐
2026-03-06 18:50:28 +08:00
parent 005d0c5f53
commit 3403909354
10 changed files with 446 additions and 57 deletions

View File

@@ -25,6 +25,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.uber.org/zap"
@@ -49,6 +50,8 @@ const (
openAIWSRetryBackoffInitialDefault = 120 * time.Millisecond
openAIWSRetryBackoffMaxDefault = 2 * time.Second
openAIWSRetryJitterRatioDefault = 0.2
openAICompactSessionSeedKey = "openai_compact_session_seed"
codexCLIVersion = "0.104.0"
)
// OpenAI allowed headers whitelist (for non-passthrough).
@@ -1614,7 +1617,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
if account.Type == AccountTypeOAuth {
codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI)
codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI, isOpenAIResponsesCompactPath(c))
if codexResult.Modified {
bodyModified = true
disablePatch()
@@ -2046,14 +2049,14 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
return nil, fmt.Errorf("openai passthrough rejected before upstream: %s", rejectReason)
}
normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body)
normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body, isOpenAIResponsesCompactPath(c))
if err != nil {
return nil, err
}
if normalized {
body = normalizedBody
reqStream = true
}
reqStream = gjson.GetBytes(body, "stream").Bool()
}
logger.LegacyPrintf("service.openai_gateway",
@@ -2218,6 +2221,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
targetURL = buildOpenAIResponsesURL(validatedURL)
}
}
targetURL = appendOpenAIResponsesRequestPathSuffix(targetURL, openAIResponsesRequestPathSuffix(c))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
if err != nil {
@@ -2251,7 +2255,15 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}
if req.Header.Get("accept") == "" {
if isOpenAIResponsesCompactPath(c) {
req.Header.Set("accept", "application/json")
if req.Header.Get("version") == "" {
req.Header.Set("version", codexCLIVersion)
}
if req.Header.Get("session_id") == "" {
req.Header.Set("session_id", resolveOpenAICompactSessionID(c))
}
} else if req.Header.Get("accept") == "" {
req.Header.Set("accept", "text/event-stream")
}
if req.Header.Get("OpenAI-Beta") == "" {
@@ -2598,6 +2610,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
default:
targetURL = openaiPlatformAPIURL
}
targetURL = appendOpenAIResponsesRequestPathSuffix(targetURL, openAIResponsesRequestPathSuffix(c))
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
@@ -2634,7 +2647,17 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
} else {
req.Header.Set("originator", "opencode")
}
req.Header.Set("accept", "text/event-stream")
if isOpenAIResponsesCompactPath(c) {
req.Header.Set("accept", "application/json")
if req.Header.Get("version") == "" {
req.Header.Set("version", codexCLIVersion)
}
if req.Header.Get("session_id") == "" {
req.Header.Set("session_id", resolveOpenAICompactSessionID(c))
}
} else {
req.Header.Set("accept", "text/event-stream")
}
if promptCacheKey != "" {
req.Header.Set("conversation_id", promptCacheKey)
req.Header.Set("session_id", promptCacheKey)
@@ -3425,6 +3448,95 @@ func buildOpenAIResponsesURL(base string) string {
return normalized + "/v1/responses"
}
func IsOpenAIResponsesCompactPathForTest(c *gin.Context) bool {
return isOpenAIResponsesCompactPath(c)
}
func OpenAICompactSessionSeedKeyForTest() string {
return openAICompactSessionSeedKey
}
func NormalizeOpenAICompactRequestBodyForTest(body []byte) ([]byte, bool, error) {
return normalizeOpenAICompactRequestBody(body)
}
func isOpenAIResponsesCompactPath(c *gin.Context) bool {
suffix := strings.TrimSpace(openAIResponsesRequestPathSuffix(c))
return suffix == "/compact" || strings.HasPrefix(suffix, "/compact/")
}
func normalizeOpenAICompactRequestBody(body []byte) ([]byte, bool, error) {
if len(body) == 0 {
return body, false, nil
}
normalized := []byte(`{}`)
for _, field := range []string{"model", "input", "instructions", "previous_response_id"} {
value := gjson.GetBytes(body, field)
if !value.Exists() {
continue
}
next, err := sjson.SetRawBytes(normalized, field, []byte(value.Raw))
if err != nil {
return body, false, fmt.Errorf("normalize compact body %s: %w", field, err)
}
normalized = next
}
if bytes.Equal(bytes.TrimSpace(body), bytes.TrimSpace(normalized)) {
return body, false, nil
}
return normalized, true, nil
}
func resolveOpenAICompactSessionID(c *gin.Context) string {
if c != nil {
if sessionID := strings.TrimSpace(c.GetHeader("session_id")); sessionID != "" {
return sessionID
}
if conversationID := strings.TrimSpace(c.GetHeader("conversation_id")); conversationID != "" {
return conversationID
}
if seed, ok := c.Get(openAICompactSessionSeedKey); ok {
if seedStr, ok := seed.(string); ok && strings.TrimSpace(seedStr) != "" {
return strings.TrimSpace(seedStr)
}
}
}
return uuid.NewString()
}
func openAIResponsesRequestPathSuffix(c *gin.Context) string {
if c == nil || c.Request == nil || c.Request.URL == nil {
return ""
}
normalizedPath := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
if normalizedPath == "" {
return ""
}
idx := strings.LastIndex(normalizedPath, "/responses")
if idx < 0 {
return ""
}
suffix := normalizedPath[idx+len("/responses"):]
if suffix == "" || suffix == "/" {
return ""
}
if !strings.HasPrefix(suffix, "/") {
return ""
}
return suffix
}
func appendOpenAIResponsesRequestPathSuffix(baseURL, suffix string) string {
trimmedBase := strings.TrimRight(strings.TrimSpace(baseURL), "/")
trimmedSuffix := strings.TrimSpace(suffix)
if trimmedBase == "" || trimmedSuffix == "" {
return trimmedBase
}
return trimmedBase + trimmedSuffix
}
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
// 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化
if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
@@ -3805,8 +3917,8 @@ func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, p
}
// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为:
// 1) store=false 2) stream=true
func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) {
// 1) store=false 2) 非 compact 保持 stream=truecompact 强制 stream=false
func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, bool, error) {
if len(body) == 0 {
return body, false, nil
}
@@ -3814,22 +3926,40 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) {
normalized := body
changed := false
if store := gjson.GetBytes(normalized, "store"); !store.Exists() || store.Type != gjson.False {
next, err := sjson.SetBytes(normalized, "store", false)
if err != nil {
return body, false, fmt.Errorf("normalize passthrough body store=false: %w", err)
if compact {
if store := gjson.GetBytes(normalized, "store"); store.Exists() {
next, err := sjson.DeleteBytes(normalized, "store")
if err != nil {
return body, false, fmt.Errorf("normalize passthrough body delete store: %w", err)
}
normalized = next
changed = true
}
normalized = next
changed = true
}
if stream := gjson.GetBytes(normalized, "stream"); !stream.Exists() || stream.Type != gjson.True {
next, err := sjson.SetBytes(normalized, "stream", true)
if err != nil {
return body, false, fmt.Errorf("normalize passthrough body stream=true: %w", err)
if stream := gjson.GetBytes(normalized, "stream"); stream.Exists() {
next, err := sjson.DeleteBytes(normalized, "stream")
if err != nil {
return body, false, fmt.Errorf("normalize passthrough body delete stream: %w", err)
}
normalized = next
changed = true
}
} else {
if store := gjson.GetBytes(normalized, "store"); !store.Exists() || store.Type != gjson.False {
next, err := sjson.SetBytes(normalized, "store", false)
if err != nil {
return body, false, fmt.Errorf("normalize passthrough body store=false: %w", err)
}
normalized = next
changed = true
}
if stream := gjson.GetBytes(normalized, "stream"); !stream.Exists() || stream.Type != gjson.True {
next, err := sjson.SetBytes(normalized, "stream", true)
if err != nil {
return body, false, fmt.Errorf("normalize passthrough body stream=true: %w", err)
}
normalized = next
changed = true
}
normalized = next
changed = true
}
return normalized, changed, nil