Files
sub2api-ht/backend/internal/service/openai_messages_continuation.go

278 lines
8.1 KiB
Go

package service
import (
"context"
"encoding/json"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type openAICompatSessionResponseBinding struct {
ResponseID string
TurnState string
ContinuationDisabled bool
ExpiresAt time.Time
}
func openAICompatContinuationEnabled(account *Account, model string) bool {
if account == nil || account.Type != AccountTypeAPIKey {
return false
}
return shouldAutoInjectPromptCacheKeyForCompat(model)
}
func trimAnthropicCompatResponsesInputToLatestTurn(req *apicompat.ResponsesRequest) {
if req == nil || len(req.Input) == 0 {
return
}
var items []apicompat.ResponsesInputItem
if err := json.Unmarshal(req.Input, &items); err != nil || len(items) == 0 {
return
}
start := len(items) - 1
for start > 0 && items[start].Type == "function_call_output" {
start--
}
trimmed := append([]apicompat.ResponsesInputItem(nil), items[start:]...)
if len(trimmed) == len(items) {
return
}
if input, err := json.Marshal(trimmed); err == nil {
req.Input = input
}
}
func isOpenAICompatPreviousResponseNotFound(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
if statusCode != http.StatusBadRequest && statusCode != http.StatusNotFound {
return false
}
check := func(s string) bool {
lower := strings.ToLower(strings.TrimSpace(s))
return strings.Contains(lower, "previous_response_not_found") ||
(strings.Contains(lower, "previous response") && strings.Contains(lower, "not found")) ||
(strings.Contains(lower, "unsupported parameter") && strings.Contains(lower, "previous_response_id"))
}
if check(upstreamMsg) || check(string(upstreamBody)) {
return true
}
return check(gjson.GetBytes(upstreamBody, "error.code").String()) ||
check(gjson.GetBytes(upstreamBody, "error.message").String())
}
func isOpenAICompatPreviousResponseUnsupported(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
if statusCode != http.StatusBadRequest {
return false
}
check := func(s string) bool {
lower := strings.ToLower(strings.TrimSpace(s))
if !strings.Contains(lower, "previous_response_id") {
return false
}
return strings.Contains(lower, "unsupported parameter") ||
strings.Contains(lower, "only supported on responses websocket") ||
strings.Contains(lower, "not supported")
}
if check(upstreamMsg) || check(string(upstreamBody)) {
return true
}
return check(gjson.GetBytes(upstreamBody, "error.code").String()) ||
check(gjson.GetBytes(upstreamBody, "error.message").String())
}
func openAICompatSessionResponseKey(c *gin.Context, account *Account, promptCacheKey string) string {
key := strings.TrimSpace(promptCacheKey)
if account == nil || key == "" {
return ""
}
apiKeyID := int64(0)
if c != nil {
apiKeyID = getAPIKeyIDFromContext(c)
}
return strings.Join([]string{
strconv.FormatInt(account.ID, 10),
strconv.FormatInt(apiKeyID, 10),
key,
}, "\x00")
}
func (s *OpenAIGatewayService) getOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) string {
if s == nil {
return ""
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
if key == "" {
return ""
}
raw, ok := s.openaiCompatSessionResponses.Load(key)
if !ok {
return ""
}
binding, ok := raw.(openAICompatSessionResponseBinding)
if !ok {
s.openaiCompatSessionResponses.Delete(key)
return ""
}
if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) {
s.openaiCompatSessionResponses.Delete(key)
return ""
}
if binding.ContinuationDisabled {
return ""
}
if strings.TrimSpace(binding.ResponseID) == "" {
s.openaiCompatSessionResponses.Delete(key)
return ""
}
return strings.TrimSpace(binding.ResponseID)
}
func (s *OpenAIGatewayService) bindOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey, responseID string) {
if s == nil {
return
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
id := strings.TrimSpace(responseID)
if key == "" || id == "" {
return
}
binding := openAICompatSessionResponseBinding{
ResponseID: id,
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
}
if raw, ok := s.openaiCompatSessionResponses.Load(key); ok {
if existing, ok := raw.(openAICompatSessionResponseBinding); ok {
if existing.ContinuationDisabled {
existing.ResponseID = ""
existing.ExpiresAt = time.Now().Add(s.openAIWSResponseStickyTTL())
s.openaiCompatSessionResponses.Store(key, existing)
return
}
binding.TurnState = existing.TurnState
}
}
s.openaiCompatSessionResponses.Store(key, binding)
}
func (s *OpenAIGatewayService) deleteOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) {
if s == nil {
return
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
if key == "" {
return
}
raw, ok := s.openaiCompatSessionResponses.Load(key)
if !ok {
return
}
binding, ok := raw.(openAICompatSessionResponseBinding)
if !ok {
s.openaiCompatSessionResponses.Delete(key)
return
}
binding.ResponseID = ""
if strings.TrimSpace(binding.TurnState) == "" && !binding.ContinuationDisabled {
s.openaiCompatSessionResponses.Delete(key)
return
}
binding.ExpiresAt = time.Now().Add(s.openAIWSResponseStickyTTL())
s.openaiCompatSessionResponses.Store(key, binding)
}
func (s *OpenAIGatewayService) disableOpenAICompatSessionContinuation(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) {
if s == nil {
return
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
if key == "" {
return
}
binding := openAICompatSessionResponseBinding{
ContinuationDisabled: true,
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
}
if raw, ok := s.openaiCompatSessionResponses.Load(key); ok {
if existing, ok := raw.(openAICompatSessionResponseBinding); ok {
binding.TurnState = existing.TurnState
}
}
s.openaiCompatSessionResponses.Store(key, binding)
}
func (s *OpenAIGatewayService) isOpenAICompatSessionContinuationDisabled(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) bool {
if s == nil {
return false
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
if key == "" {
return false
}
raw, ok := s.openaiCompatSessionResponses.Load(key)
if !ok {
return false
}
binding, ok := raw.(openAICompatSessionResponseBinding)
if !ok {
s.openaiCompatSessionResponses.Delete(key)
return false
}
if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) {
s.openaiCompatSessionResponses.Delete(key)
return false
}
return binding.ContinuationDisabled
}
func (s *OpenAIGatewayService) getOpenAICompatSessionTurnState(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) string {
if s == nil {
return ""
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
if key == "" {
return ""
}
raw, ok := s.openaiCompatSessionResponses.Load(key)
if !ok {
return ""
}
binding, ok := raw.(openAICompatSessionResponseBinding)
if !ok || strings.TrimSpace(binding.TurnState) == "" {
return ""
}
if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) {
s.openaiCompatSessionResponses.Delete(key)
return ""
}
return strings.TrimSpace(binding.TurnState)
}
func (s *OpenAIGatewayService) bindOpenAICompatSessionTurnState(_ context.Context, c *gin.Context, account *Account, promptCacheKey, turnState string) {
if s == nil {
return
}
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
state := strings.TrimSpace(turnState)
if key == "" || state == "" {
return
}
binding := openAICompatSessionResponseBinding{
TurnState: state,
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
}
if raw, ok := s.openaiCompatSessionResponses.Load(key); ok {
if existing, ok := raw.(openAICompatSessionResponseBinding); ok {
binding.ResponseID = existing.ResponseID
binding.ContinuationDisabled = existing.ContinuationDisabled
}
}
s.openaiCompatSessionResponses.Store(key, binding)
}