Files
sub2api-ht/backend/internal/service/openai_messages_continuation.go
Jack 87d73236f2 Preserve multi-tool context in OpenAI messages continuation
Claude Code can send one assistant turn with multiple tool_use blocks followed by a user turn containing matching tool_result blocks. The OpenAI /v1/messages compatibility path trimmed continuation input to the last user turn plus adjacent tool outputs, which could leave a function_call_output without its earlier function_call when previous_response_id was attached.

This keeps all function_call items needed by retained function_call_output entries so the upstream Responses API can resolve every call_id.

Constraint: Applies only to the OpenAI /v1/messages -> Responses compatibility continuation path.

Rejected: Disable previous_response_id for all tool outputs | loses continuation and cache benefits for valid turns.

Confidence: high

Scope-risk: narrow

Directive: Do not trim function_call_output entries without preserving their matching function_call call_id context.

Tested: go test ./internal/service -run 'TestForwardAsAnthropic_(PreviousResponseIDKeepsMultiToolCallContext|AttachesPreviousResponseIDForCompatContinuation|OAuthPreservesClaudeCodeToolCallID)' -count=1

Tested: go test ./internal/service -run 'TestForwardAsAnthropic|TestApplyAnthropicCompatFullReplayGuard|TestOpenAICompat|Test.*ToolContinuation' -count=1

Tested: go test ./internal/pkg/apicompat -count=1

Related: #2337
2026-05-11 12:03:17 +08:00

332 lines
9.3 KiB
Go

package service
import (
"context"
"encoding/json"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type openAICompatSessionResponseBinding struct {
ResponseID string
TurnState string
ContinuationDisabled bool
ExpiresAt time.Time
}
func openAICompatContinuationEnabled(account *Account, model string) bool {
if account == nil || account.Type != AccountTypeAPIKey {
return false
}
return shouldAutoInjectPromptCacheKeyForCompat(model)
}
func trimAnthropicCompatResponsesInputToLatestTurn(req *apicompat.ResponsesRequest) {
if req == nil || len(req.Input) == 0 {
return
}
var items []apicompat.ResponsesInputItem
if err := json.Unmarshal(req.Input, &items); err != nil || len(items) == 0 {
return
}
start := latestAnthropicCompatResponsesInputTurnStart(items)
trimmed := append([]apicompat.ResponsesInputItem(nil), items[start:]...)
if len(trimmed) == len(items) {
return
}
if input, err := json.Marshal(trimmed); err == nil {
req.Input = input
}
}
func latestAnthropicCompatResponsesInputTurnStart(items []apicompat.ResponsesInputItem) int {
if len(items) == 0 {
return 0
}
start := len(items) - 1
last := items[start]
switch {
case last.Type == "function_call_output":
for start > 0 && items[start-1].Type == "function_call_output" {
start--
}
case last.Type == "message" && last.Role == "user":
for start > 0 && items[start-1].Type == "function_call_output" {
start--
}
default:
return start
}
return expandAnthropicCompatResponsesInputToolCallStart(items, start)
}
func expandAnthropicCompatResponsesInputToolCallStart(items []apicompat.ResponsesInputItem, start int) int {
if start < 0 || start >= len(items) {
return start
}
needed := make(map[string]struct{})
for i := start; i < len(items); i++ {
if items[i].Type != "function_call_output" {
continue
}
callID := strings.TrimSpace(items[i].CallID)
if callID != "" {
needed[callID] = struct{}{}
}
}
if len(needed) == 0 {
return start
}
expandedStart := start
for i := start - 1; i >= 0 && len(needed) > 0; i-- {
if items[i].Type != "function_call" {
continue
}
callID := strings.TrimSpace(items[i].CallID)
if _, ok := needed[callID]; !ok {
continue
}
delete(needed, callID)
expandedStart = i
}
return expandedStart
}
func isOpenAICompatPreviousResponseNotFound(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
if statusCode != http.StatusBadRequest && statusCode != http.StatusNotFound {
return false
}
check := func(s string) bool {
lower := strings.ToLower(strings.TrimSpace(s))
return strings.Contains(lower, "previous_response_not_found") ||
(strings.Contains(lower, "previous response") && strings.Contains(lower, "not found")) ||
(strings.Contains(lower, "unsupported parameter") && strings.Contains(lower, "previous_response_id"))
}
if check(upstreamMsg) || check(string(upstreamBody)) {
return true
}
return check(gjson.GetBytes(upstreamBody, "error.code").String()) ||
check(gjson.GetBytes(upstreamBody, "error.message").String())
}
func isOpenAICompatPreviousResponseUnsupported(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
if statusCode != http.StatusBadRequest {
return false
}
check := func(s string) bool {
lower := strings.ToLower(strings.TrimSpace(s))
if !strings.Contains(lower, "previous_response_id") {
return false
}
return strings.Contains(lower, "unsupported parameter") ||
strings.Contains(lower, "only supported on responses websocket") ||
strings.Contains(lower, "not supported")
}
if check(upstreamMsg) || check(string(upstreamBody)) {
return true
}
return check(gjson.GetBytes(upstreamBody, "error.code").String()) ||
check(gjson.GetBytes(upstreamBody, "error.message").String())
}
func openAICompatSessionResponseKey(c *gin.Context, account *Account, promptCacheKey string) string {
key := strings.TrimSpace(promptCacheKey)
if account == nil || key == "" {
return ""
}
apiKeyID := int64(0)
if c != nil {
apiKeyID = getAPIKeyIDFromContext(c)
}
return strings.Join([]string{
strconv.FormatInt(account.ID, 10),
strconv.FormatInt(apiKeyID, 10),
key,
}, "\x00")
}
func (s *OpenAIGatewayService) getOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) string {
if s == nil {
return ""
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
if key == "" {
return ""
}
raw, ok := s.openaiCompatSessionResponses.Load(key)
if !ok {
return ""
}
binding, ok := raw.(openAICompatSessionResponseBinding)
if !ok {
s.openaiCompatSessionResponses.Delete(key)
return ""
}
if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) {
s.openaiCompatSessionResponses.Delete(key)
return ""
}
if binding.ContinuationDisabled {
return ""
}
if strings.TrimSpace(binding.ResponseID) == "" {
s.openaiCompatSessionResponses.Delete(key)
return ""
}
return strings.TrimSpace(binding.ResponseID)
}
func (s *OpenAIGatewayService) bindOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey, responseID string) {
if s == nil {
return
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
id := strings.TrimSpace(responseID)
if key == "" || id == "" {
return
}
binding := openAICompatSessionResponseBinding{
ResponseID: id,
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
}
if raw, ok := s.openaiCompatSessionResponses.Load(key); ok {
if existing, ok := raw.(openAICompatSessionResponseBinding); ok {
if existing.ContinuationDisabled {
existing.ResponseID = ""
existing.ExpiresAt = time.Now().Add(s.openAIWSResponseStickyTTL())
s.openaiCompatSessionResponses.Store(key, existing)
return
}
binding.TurnState = existing.TurnState
}
}
s.openaiCompatSessionResponses.Store(key, binding)
}
func (s *OpenAIGatewayService) deleteOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) {
if s == nil {
return
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
if key == "" {
return
}
raw, ok := s.openaiCompatSessionResponses.Load(key)
if !ok {
return
}
binding, ok := raw.(openAICompatSessionResponseBinding)
if !ok {
s.openaiCompatSessionResponses.Delete(key)
return
}
binding.ResponseID = ""
if strings.TrimSpace(binding.TurnState) == "" && !binding.ContinuationDisabled {
s.openaiCompatSessionResponses.Delete(key)
return
}
binding.ExpiresAt = time.Now().Add(s.openAIWSResponseStickyTTL())
s.openaiCompatSessionResponses.Store(key, binding)
}
func (s *OpenAIGatewayService) disableOpenAICompatSessionContinuation(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) {
if s == nil {
return
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
if key == "" {
return
}
binding := openAICompatSessionResponseBinding{
ContinuationDisabled: true,
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
}
if raw, ok := s.openaiCompatSessionResponses.Load(key); ok {
if existing, ok := raw.(openAICompatSessionResponseBinding); ok {
binding.TurnState = existing.TurnState
}
}
s.openaiCompatSessionResponses.Store(key, binding)
}
func (s *OpenAIGatewayService) isOpenAICompatSessionContinuationDisabled(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) bool {
if s == nil {
return false
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
if key == "" {
return false
}
raw, ok := s.openaiCompatSessionResponses.Load(key)
if !ok {
return false
}
binding, ok := raw.(openAICompatSessionResponseBinding)
if !ok {
s.openaiCompatSessionResponses.Delete(key)
return false
}
if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) {
s.openaiCompatSessionResponses.Delete(key)
return false
}
return binding.ContinuationDisabled
}
func (s *OpenAIGatewayService) getOpenAICompatSessionTurnState(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) string {
if s == nil {
return ""
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
if key == "" {
return ""
}
raw, ok := s.openaiCompatSessionResponses.Load(key)
if !ok {
return ""
}
binding, ok := raw.(openAICompatSessionResponseBinding)
if !ok || strings.TrimSpace(binding.TurnState) == "" {
return ""
}
if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) {
s.openaiCompatSessionResponses.Delete(key)
return ""
}
return strings.TrimSpace(binding.TurnState)
}
func (s *OpenAIGatewayService) bindOpenAICompatSessionTurnState(_ context.Context, c *gin.Context, account *Account, promptCacheKey, turnState string) {
if s == nil {
return
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
state := strings.TrimSpace(turnState)
if key == "" || state == "" {
return
}
binding := openAICompatSessionResponseBinding{
TurnState: state,
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
}
if raw, ok := s.openaiCompatSessionResponses.Load(key); ok {
if existing, ok := raw.(openAICompatSessionResponseBinding); ok {
binding.ResponseID = existing.ResponseID
binding.ContinuationDisabled = existing.ContinuationDisabled
}
}
s.openaiCompatSessionResponses.Store(key, binding)
}