278 lines
8.1 KiB
Go
278 lines
8.1 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
type openAICompatSessionResponseBinding struct {
|
|
ResponseID string
|
|
TurnState string
|
|
ContinuationDisabled bool
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
func openAICompatContinuationEnabled(account *Account, model string) bool {
|
|
if account == nil || account.Type != AccountTypeAPIKey {
|
|
return false
|
|
}
|
|
return shouldAutoInjectPromptCacheKeyForCompat(model)
|
|
}
|
|
|
|
func trimAnthropicCompatResponsesInputToLatestTurn(req *apicompat.ResponsesRequest) {
|
|
if req == nil || len(req.Input) == 0 {
|
|
return
|
|
}
|
|
|
|
var items []apicompat.ResponsesInputItem
|
|
if err := json.Unmarshal(req.Input, &items); err != nil || len(items) == 0 {
|
|
return
|
|
}
|
|
|
|
start := len(items) - 1
|
|
for start > 0 && items[start].Type == "function_call_output" {
|
|
start--
|
|
}
|
|
trimmed := append([]apicompat.ResponsesInputItem(nil), items[start:]...)
|
|
if len(trimmed) == len(items) {
|
|
return
|
|
}
|
|
if input, err := json.Marshal(trimmed); err == nil {
|
|
req.Input = input
|
|
}
|
|
}
|
|
|
|
func isOpenAICompatPreviousResponseNotFound(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
|
|
if statusCode != http.StatusBadRequest && statusCode != http.StatusNotFound {
|
|
return false
|
|
}
|
|
check := func(s string) bool {
|
|
lower := strings.ToLower(strings.TrimSpace(s))
|
|
return strings.Contains(lower, "previous_response_not_found") ||
|
|
(strings.Contains(lower, "previous response") && strings.Contains(lower, "not found")) ||
|
|
(strings.Contains(lower, "unsupported parameter") && strings.Contains(lower, "previous_response_id"))
|
|
}
|
|
if check(upstreamMsg) || check(string(upstreamBody)) {
|
|
return true
|
|
}
|
|
return check(gjson.GetBytes(upstreamBody, "error.code").String()) ||
|
|
check(gjson.GetBytes(upstreamBody, "error.message").String())
|
|
}
|
|
|
|
func isOpenAICompatPreviousResponseUnsupported(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
|
|
if statusCode != http.StatusBadRequest {
|
|
return false
|
|
}
|
|
check := func(s string) bool {
|
|
lower := strings.ToLower(strings.TrimSpace(s))
|
|
if !strings.Contains(lower, "previous_response_id") {
|
|
return false
|
|
}
|
|
return strings.Contains(lower, "unsupported parameter") ||
|
|
strings.Contains(lower, "only supported on responses websocket") ||
|
|
strings.Contains(lower, "not supported")
|
|
}
|
|
if check(upstreamMsg) || check(string(upstreamBody)) {
|
|
return true
|
|
}
|
|
return check(gjson.GetBytes(upstreamBody, "error.code").String()) ||
|
|
check(gjson.GetBytes(upstreamBody, "error.message").String())
|
|
}
|
|
|
|
func openAICompatSessionResponseKey(c *gin.Context, account *Account, promptCacheKey string) string {
|
|
key := strings.TrimSpace(promptCacheKey)
|
|
if account == nil || key == "" {
|
|
return ""
|
|
}
|
|
apiKeyID := int64(0)
|
|
if c != nil {
|
|
apiKeyID = getAPIKeyIDFromContext(c)
|
|
}
|
|
return strings.Join([]string{
|
|
strconv.FormatInt(account.ID, 10),
|
|
strconv.FormatInt(apiKeyID, 10),
|
|
key,
|
|
}, "\x00")
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) getOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) string {
|
|
if s == nil {
|
|
return ""
|
|
}
|
|
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
|
|
if key == "" {
|
|
return ""
|
|
}
|
|
raw, ok := s.openaiCompatSessionResponses.Load(key)
|
|
if !ok {
|
|
return ""
|
|
}
|
|
binding, ok := raw.(openAICompatSessionResponseBinding)
|
|
if !ok {
|
|
s.openaiCompatSessionResponses.Delete(key)
|
|
return ""
|
|
}
|
|
if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) {
|
|
s.openaiCompatSessionResponses.Delete(key)
|
|
return ""
|
|
}
|
|
if binding.ContinuationDisabled {
|
|
return ""
|
|
}
|
|
if strings.TrimSpace(binding.ResponseID) == "" {
|
|
s.openaiCompatSessionResponses.Delete(key)
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(binding.ResponseID)
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) bindOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey, responseID string) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
|
|
id := strings.TrimSpace(responseID)
|
|
if key == "" || id == "" {
|
|
return
|
|
}
|
|
binding := openAICompatSessionResponseBinding{
|
|
ResponseID: id,
|
|
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
|
|
}
|
|
if raw, ok := s.openaiCompatSessionResponses.Load(key); ok {
|
|
if existing, ok := raw.(openAICompatSessionResponseBinding); ok {
|
|
if existing.ContinuationDisabled {
|
|
existing.ResponseID = ""
|
|
existing.ExpiresAt = time.Now().Add(s.openAIWSResponseStickyTTL())
|
|
s.openaiCompatSessionResponses.Store(key, existing)
|
|
return
|
|
}
|
|
binding.TurnState = existing.TurnState
|
|
}
|
|
}
|
|
s.openaiCompatSessionResponses.Store(key, binding)
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) deleteOpenAICompatSessionResponseID(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
|
|
if key == "" {
|
|
return
|
|
}
|
|
raw, ok := s.openaiCompatSessionResponses.Load(key)
|
|
if !ok {
|
|
return
|
|
}
|
|
binding, ok := raw.(openAICompatSessionResponseBinding)
|
|
if !ok {
|
|
s.openaiCompatSessionResponses.Delete(key)
|
|
return
|
|
}
|
|
binding.ResponseID = ""
|
|
if strings.TrimSpace(binding.TurnState) == "" && !binding.ContinuationDisabled {
|
|
s.openaiCompatSessionResponses.Delete(key)
|
|
return
|
|
}
|
|
binding.ExpiresAt = time.Now().Add(s.openAIWSResponseStickyTTL())
|
|
s.openaiCompatSessionResponses.Store(key, binding)
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) disableOpenAICompatSessionContinuation(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
|
|
if key == "" {
|
|
return
|
|
}
|
|
binding := openAICompatSessionResponseBinding{
|
|
ContinuationDisabled: true,
|
|
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
|
|
}
|
|
if raw, ok := s.openaiCompatSessionResponses.Load(key); ok {
|
|
if existing, ok := raw.(openAICompatSessionResponseBinding); ok {
|
|
binding.TurnState = existing.TurnState
|
|
}
|
|
}
|
|
s.openaiCompatSessionResponses.Store(key, binding)
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) isOpenAICompatSessionContinuationDisabled(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) bool {
|
|
if s == nil {
|
|
return false
|
|
}
|
|
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
|
|
if key == "" {
|
|
return false
|
|
}
|
|
raw, ok := s.openaiCompatSessionResponses.Load(key)
|
|
if !ok {
|
|
return false
|
|
}
|
|
binding, ok := raw.(openAICompatSessionResponseBinding)
|
|
if !ok {
|
|
s.openaiCompatSessionResponses.Delete(key)
|
|
return false
|
|
}
|
|
if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) {
|
|
s.openaiCompatSessionResponses.Delete(key)
|
|
return false
|
|
}
|
|
return binding.ContinuationDisabled
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) getOpenAICompatSessionTurnState(_ context.Context, c *gin.Context, account *Account, promptCacheKey string) string {
|
|
if s == nil {
|
|
return ""
|
|
}
|
|
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
|
|
if key == "" {
|
|
return ""
|
|
}
|
|
raw, ok := s.openaiCompatSessionResponses.Load(key)
|
|
if !ok {
|
|
return ""
|
|
}
|
|
binding, ok := raw.(openAICompatSessionResponseBinding)
|
|
if !ok || strings.TrimSpace(binding.TurnState) == "" {
|
|
return ""
|
|
}
|
|
if !binding.ExpiresAt.IsZero() && time.Now().After(binding.ExpiresAt) {
|
|
s.openaiCompatSessionResponses.Delete(key)
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(binding.TurnState)
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) bindOpenAICompatSessionTurnState(_ context.Context, c *gin.Context, account *Account, promptCacheKey, turnState string) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
key := openAICompatSessionResponseKey(c, account, promptCacheKey)
|
|
state := strings.TrimSpace(turnState)
|
|
if key == "" || state == "" {
|
|
return
|
|
}
|
|
binding := openAICompatSessionResponseBinding{
|
|
TurnState: state,
|
|
ExpiresAt: time.Now().Add(s.openAIWSResponseStickyTTL()),
|
|
}
|
|
if raw, ok := s.openaiCompatSessionResponses.Load(key); ok {
|
|
if existing, ok := raw.(openAICompatSessionResponseBinding); ok {
|
|
binding.ResponseID = existing.ResponseID
|
|
binding.ContinuationDisabled = existing.ContinuationDisabled
|
|
}
|
|
}
|
|
s.openaiCompatSessionResponses.Store(key, binding)
|
|
}
|