Merge pull request #2006 from gaoren002/pr/openai-images-explicit-session
fix(openai): avoid implicit image sticky sessions
This commit is contained in:
@@ -117,12 +117,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := ""
|
||||
if parsed.Multipart {
|
||||
sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed())
|
||||
} else {
|
||||
sessionHash = h.gatewayService.GenerateSessionHash(c, body)
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateExplicitSessionHash(c, body)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
|
||||
@@ -1125,6 +1125,35 @@ func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) str
|
||||
return sessionID
|
||||
}
|
||||
|
||||
func explicitOpenAISessionID(c *gin.Context, body []byte) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
|
||||
if sessionID == "" {
|
||||
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
||||
}
|
||||
if sessionID == "" && len(body) > 0 {
|
||||
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
||||
}
|
||||
return sessionID
|
||||
}
|
||||
|
||||
// GenerateExplicitSessionHash generates a sticky-session hash only from explicit
|
||||
// client session signals. It intentionally skips content-derived fallback and is
|
||||
// used by stateless endpoints such as /v1/images.
|
||||
func (s *OpenAIGatewayService) GenerateExplicitSessionHash(c *gin.Context, body []byte) string {
|
||||
sessionID := explicitOpenAISessionID(c, body)
|
||||
if sessionID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
currentHash, legacyHash := deriveOpenAISessionHashes(sessionID)
|
||||
attachOpenAILegacySessionHashToGin(c, legacyHash)
|
||||
return currentHash
|
||||
}
|
||||
|
||||
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
|
||||
//
|
||||
// Priority:
|
||||
@@ -1137,13 +1166,7 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte)
|
||||
return ""
|
||||
}
|
||||
|
||||
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
|
||||
if sessionID == "" {
|
||||
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
||||
}
|
||||
if sessionID == "" && len(body) > 0 {
|
||||
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
||||
}
|
||||
sessionID := explicitOpenAISessionID(c, body)
|
||||
if sessionID == "" && len(body) > 0 {
|
||||
sessionID = deriveOpenAIContentSessionSeed(body)
|
||||
}
|
||||
|
||||
@@ -227,6 +227,41 @@ func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t
|
||||
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_GenerateExplicitSessionHash_SkipsContentFallback(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
svc := &OpenAIGatewayService{}
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat"}`)
|
||||
|
||||
t.Run("stateless image body stays unstuck", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||
|
||||
require.Empty(t, svc.GenerateExplicitSessionHash(c, body))
|
||||
require.Empty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("prompt_cache_key is explicit", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||
|
||||
got := svc.GenerateExplicitSessionHash(c, []byte(`{"model":"gpt-image-2","prompt_cache_key":"image-session"}`))
|
||||
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("image-session")), got)
|
||||
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("header overrides body", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
|
||||
c.Request.Header.Set("session_id", "header-session")
|
||||
|
||||
got := svc.GenerateExplicitSessionHash(c, []byte(`{"prompt_cache_key":"body-session"}`))
|
||||
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("header-session")), got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
Reference in New Issue
Block a user