fix(openai): avoid implicit image sticky sessions

This commit is contained in:
gaoren002
2026-04-26 17:05:19 +00:00
parent c056db740d
commit 615557ec20
3 changed files with 66 additions and 13 deletions

View File

@@ -117,12 +117,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
return
}
sessionHash := ""
if parsed.Multipart {
sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed())
} else {
sessionHash = h.gatewayService.GenerateSessionHash(c, body)
}
sessionHash := h.gatewayService.GenerateExplicitSessionHash(c, body)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0

View File

@@ -1125,6 +1125,35 @@ func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) str
return sessionID
}
func explicitOpenAISessionID(c *gin.Context, body []byte) string {
if c == nil {
return ""
}
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
return sessionID
}
// GenerateExplicitSessionHash generates a sticky-session hash only from explicit
// client session signals. It intentionally skips content-derived fallback and is
// used by stateless endpoints such as /v1/images.
func (s *OpenAIGatewayService) GenerateExplicitSessionHash(c *gin.Context, body []byte) string {
sessionID := explicitOpenAISessionID(c, body)
if sessionID == "" {
return ""
}
currentHash, legacyHash := deriveOpenAISessionHashes(sessionID)
attachOpenAILegacySessionHashToGin(c, legacyHash)
return currentHash
}
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
//
// Priority:
@@ -1137,13 +1166,7 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte)
return ""
}
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
sessionID := explicitOpenAISessionID(c, body)
if sessionID == "" && len(body) > 0 {
sessionID = deriveOpenAIContentSessionSeed(body)
}

View File

@@ -227,6 +227,41 @@ func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
}
func TestOpenAIGatewayService_GenerateExplicitSessionHash_SkipsContentFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := &OpenAIGatewayService{}
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat"}`)
t.Run("stateless image body stays unstuck", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
require.Empty(t, svc.GenerateExplicitSessionHash(c, body))
require.Empty(t, openAILegacySessionHashFromContext(c.Request.Context()))
})
t.Run("prompt_cache_key is explicit", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
got := svc.GenerateExplicitSessionHash(c, []byte(`{"model":"gpt-image-2","prompt_cache_key":"image-session"}`))
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("image-session")), got)
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
})
t.Run("header overrides body", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
c.Request.Header.Set("session_id", "header-session")
got := svc.GenerateExplicitSessionHash(c, []byte(`{"prompt_cache_key":"body-session"}`))
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("header-session")), got)
})
}
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()