215 lines
5.8 KiB
Go
215 lines
5.8 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/cespare/xxhash/v2"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type openAILegacySessionHashContextKey struct{}
|
|
|
|
var openAILegacySessionHashKey = openAILegacySessionHashContextKey{}
|
|
|
|
var (
|
|
openAIStickyLegacyReadFallbackTotal atomic.Int64
|
|
openAIStickyLegacyReadFallbackHit atomic.Int64
|
|
openAIStickyLegacyDualWriteTotal atomic.Int64
|
|
)
|
|
|
|
func openAIStickyCompatStats() (legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal int64) {
|
|
return openAIStickyLegacyReadFallbackTotal.Load(),
|
|
openAIStickyLegacyReadFallbackHit.Load(),
|
|
openAIStickyLegacyDualWriteTotal.Load()
|
|
}
|
|
|
|
func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash string) {
|
|
normalized := strings.TrimSpace(sessionID)
|
|
if normalized == "" {
|
|
return "", ""
|
|
}
|
|
|
|
currentHash = fmt.Sprintf("%016x", xxhash.Sum64String(normalized))
|
|
sum := sha256.Sum256([]byte(normalized))
|
|
legacyHash = hex.EncodeToString(sum[:])
|
|
return currentHash, legacyHash
|
|
}
|
|
|
|
func withOpenAILegacySessionHash(ctx context.Context, legacyHash string) context.Context {
|
|
if ctx == nil {
|
|
return nil
|
|
}
|
|
trimmed := strings.TrimSpace(legacyHash)
|
|
if trimmed == "" {
|
|
return ctx
|
|
}
|
|
return context.WithValue(ctx, openAILegacySessionHashKey, trimmed)
|
|
}
|
|
|
|
func openAILegacySessionHashFromContext(ctx context.Context) string {
|
|
if ctx == nil {
|
|
return ""
|
|
}
|
|
value, _ := ctx.Value(openAILegacySessionHashKey).(string)
|
|
return strings.TrimSpace(value)
|
|
}
|
|
|
|
func attachOpenAILegacySessionHashToGin(c *gin.Context, legacyHash string) {
|
|
if c == nil || c.Request == nil {
|
|
return
|
|
}
|
|
c.Request = c.Request.WithContext(withOpenAILegacySessionHash(c.Request.Context(), legacyHash))
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) openAISessionHashReadOldFallbackEnabled() bool {
|
|
if s == nil || s.cfg == nil {
|
|
return true
|
|
}
|
|
return s.cfg.Gateway.OpenAIWS.SessionHashReadOldFallback
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) openAISessionHashDualWriteOldEnabled() bool {
|
|
if s == nil || s.cfg == nil {
|
|
return true
|
|
}
|
|
return s.cfg.Gateway.OpenAIWS.SessionHashDualWriteOld
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) openAISessionCacheKey(sessionHash string) string {
|
|
normalized := strings.TrimSpace(sessionHash)
|
|
if normalized == "" {
|
|
return ""
|
|
}
|
|
return "openai:" + normalized
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) openAILegacySessionCacheKey(ctx context.Context, sessionHash string) string {
|
|
legacyHash := openAILegacySessionHashFromContext(ctx)
|
|
if legacyHash == "" {
|
|
return ""
|
|
}
|
|
legacyKey := "openai:" + legacyHash
|
|
if legacyKey == s.openAISessionCacheKey(sessionHash) {
|
|
return ""
|
|
}
|
|
return legacyKey
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) openAIStickyLegacyTTL(ttl time.Duration) time.Duration {
|
|
legacyTTL := ttl
|
|
if legacyTTL <= 0 {
|
|
legacyTTL = openaiStickySessionTTL
|
|
}
|
|
if legacyTTL > 10*time.Minute {
|
|
return 10 * time.Minute
|
|
}
|
|
return legacyTTL
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) getStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) {
|
|
if s == nil || s.cache == nil {
|
|
return 0, nil
|
|
}
|
|
|
|
primaryKey := s.openAISessionCacheKey(sessionHash)
|
|
if primaryKey == "" {
|
|
return 0, nil
|
|
}
|
|
|
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), primaryKey)
|
|
if err == nil && accountID > 0 {
|
|
return accountID, nil
|
|
}
|
|
if !s.openAISessionHashReadOldFallbackEnabled() {
|
|
return accountID, err
|
|
}
|
|
|
|
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
|
|
if legacyKey == "" {
|
|
return accountID, err
|
|
}
|
|
|
|
openAIStickyLegacyReadFallbackTotal.Add(1)
|
|
legacyAccountID, legacyErr := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), legacyKey)
|
|
if legacyErr == nil && legacyAccountID > 0 {
|
|
openAIStickyLegacyReadFallbackHit.Add(1)
|
|
return legacyAccountID, nil
|
|
}
|
|
return accountID, err
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) setStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
|
if s == nil || s.cache == nil || accountID <= 0 {
|
|
return nil
|
|
}
|
|
primaryKey := s.openAISessionCacheKey(sessionHash)
|
|
if primaryKey == "" {
|
|
return nil
|
|
}
|
|
|
|
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), primaryKey, accountID, ttl); err != nil {
|
|
return err
|
|
}
|
|
|
|
if !s.openAISessionHashDualWriteOldEnabled() {
|
|
return nil
|
|
}
|
|
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
|
|
if legacyKey == "" {
|
|
return nil
|
|
}
|
|
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), legacyKey, accountID, s.openAIStickyLegacyTTL(ttl)); err != nil {
|
|
return err
|
|
}
|
|
openAIStickyLegacyDualWriteTotal.Add(1)
|
|
return nil
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) refreshStickySessionTTL(ctx context.Context, groupID *int64, sessionHash string, ttl time.Duration) error {
|
|
if s == nil || s.cache == nil {
|
|
return nil
|
|
}
|
|
primaryKey := s.openAISessionCacheKey(sessionHash)
|
|
if primaryKey == "" {
|
|
return nil
|
|
}
|
|
|
|
err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), primaryKey, ttl)
|
|
if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() {
|
|
return err
|
|
}
|
|
|
|
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
|
|
if legacyKey != "" {
|
|
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), legacyKey, s.openAIStickyLegacyTTL(ttl))
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) deleteStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) error {
|
|
if s == nil || s.cache == nil {
|
|
return nil
|
|
}
|
|
primaryKey := s.openAISessionCacheKey(sessionHash)
|
|
if primaryKey == "" {
|
|
return nil
|
|
}
|
|
|
|
err := s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), primaryKey)
|
|
if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() {
|
|
return err
|
|
}
|
|
|
|
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
|
|
if legacyKey != "" {
|
|
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), legacyKey)
|
|
}
|
|
return err
|
|
}
|