441 lines
12 KiB
Go
441 lines
12 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"fmt"
|
||
"strings"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
)
|
||
|
||
const (
|
||
openAIWSResponseAccountCachePrefix = "openai:response:"
|
||
openAIWSStateStoreCleanupInterval = time.Minute
|
||
openAIWSStateStoreCleanupMaxPerMap = 512
|
||
openAIWSStateStoreMaxEntriesPerMap = 65536
|
||
openAIWSStateStoreRedisTimeout = 3 * time.Second
|
||
)
|
||
|
||
type openAIWSAccountBinding struct {
|
||
accountID int64
|
||
expiresAt time.Time
|
||
}
|
||
|
||
type openAIWSConnBinding struct {
|
||
connID string
|
||
expiresAt time.Time
|
||
}
|
||
|
||
type openAIWSTurnStateBinding struct {
|
||
turnState string
|
||
expiresAt time.Time
|
||
}
|
||
|
||
type openAIWSSessionConnBinding struct {
|
||
connID string
|
||
expiresAt time.Time
|
||
}
|
||
|
||
// OpenAIWSStateStore 管理 WSv2 的粘连状态。
|
||
// - response_id -> account_id 用于续链路由
|
||
// - response_id -> conn_id 用于连接内上下文复用
|
||
//
|
||
// response_id -> account_id 优先走 GatewayCache(Redis),同时维护本地热缓存。
|
||
// response_id -> conn_id 仅在本进程内有效。
|
||
type OpenAIWSStateStore interface {
|
||
BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error
|
||
GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error)
|
||
DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error
|
||
|
||
BindResponseConn(responseID, connID string, ttl time.Duration)
|
||
GetResponseConn(responseID string) (string, bool)
|
||
DeleteResponseConn(responseID string)
|
||
|
||
BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration)
|
||
GetSessionTurnState(groupID int64, sessionHash string) (string, bool)
|
||
DeleteSessionTurnState(groupID int64, sessionHash string)
|
||
|
||
BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration)
|
||
GetSessionConn(groupID int64, sessionHash string) (string, bool)
|
||
DeleteSessionConn(groupID int64, sessionHash string)
|
||
}
|
||
|
||
type defaultOpenAIWSStateStore struct {
|
||
cache GatewayCache
|
||
|
||
responseToAccountMu sync.RWMutex
|
||
responseToAccount map[string]openAIWSAccountBinding
|
||
responseToConnMu sync.RWMutex
|
||
responseToConn map[string]openAIWSConnBinding
|
||
sessionToTurnStateMu sync.RWMutex
|
||
sessionToTurnState map[string]openAIWSTurnStateBinding
|
||
sessionToConnMu sync.RWMutex
|
||
sessionToConn map[string]openAIWSSessionConnBinding
|
||
|
||
lastCleanupUnixNano atomic.Int64
|
||
}
|
||
|
||
// NewOpenAIWSStateStore 创建默认 WS 状态存储。
|
||
func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore {
|
||
store := &defaultOpenAIWSStateStore{
|
||
cache: cache,
|
||
responseToAccount: make(map[string]openAIWSAccountBinding, 256),
|
||
responseToConn: make(map[string]openAIWSConnBinding, 256),
|
||
sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256),
|
||
sessionToConn: make(map[string]openAIWSSessionConnBinding, 256),
|
||
}
|
||
store.lastCleanupUnixNano.Store(time.Now().UnixNano())
|
||
return store
|
||
}
|
||
|
||
func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error {
|
||
id := normalizeOpenAIWSResponseID(responseID)
|
||
if id == "" || accountID <= 0 {
|
||
return nil
|
||
}
|
||
ttl = normalizeOpenAIWSTTL(ttl)
|
||
s.maybeCleanup()
|
||
|
||
expiresAt := time.Now().Add(ttl)
|
||
s.responseToAccountMu.Lock()
|
||
ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap)
|
||
s.responseToAccount[id] = openAIWSAccountBinding{accountID: accountID, expiresAt: expiresAt}
|
||
s.responseToAccountMu.Unlock()
|
||
|
||
if s.cache == nil {
|
||
return nil
|
||
}
|
||
cacheKey := openAIWSResponseAccountCacheKey(id)
|
||
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
|
||
defer cancel()
|
||
return s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl)
|
||
}
|
||
|
||
func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) {
|
||
id := normalizeOpenAIWSResponseID(responseID)
|
||
if id == "" {
|
||
return 0, nil
|
||
}
|
||
s.maybeCleanup()
|
||
|
||
now := time.Now()
|
||
s.responseToAccountMu.RLock()
|
||
if binding, ok := s.responseToAccount[id]; ok {
|
||
if now.Before(binding.expiresAt) {
|
||
accountID := binding.accountID
|
||
s.responseToAccountMu.RUnlock()
|
||
return accountID, nil
|
||
}
|
||
}
|
||
s.responseToAccountMu.RUnlock()
|
||
|
||
if s.cache == nil {
|
||
return 0, nil
|
||
}
|
||
|
||
cacheKey := openAIWSResponseAccountCacheKey(id)
|
||
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
|
||
defer cancel()
|
||
accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey)
|
||
if err != nil || accountID <= 0 {
|
||
// 缓存读取失败不阻断主流程,按未命中降级。
|
||
return 0, nil
|
||
}
|
||
return accountID, nil
|
||
}
|
||
|
||
func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error {
|
||
id := normalizeOpenAIWSResponseID(responseID)
|
||
if id == "" {
|
||
return nil
|
||
}
|
||
s.responseToAccountMu.Lock()
|
||
delete(s.responseToAccount, id)
|
||
s.responseToAccountMu.Unlock()
|
||
|
||
if s.cache == nil {
|
||
return nil
|
||
}
|
||
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
|
||
defer cancel()
|
||
return s.cache.DeleteSessionAccountID(cacheCtx, groupID, openAIWSResponseAccountCacheKey(id))
|
||
}
|
||
|
||
func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) {
|
||
id := normalizeOpenAIWSResponseID(responseID)
|
||
conn := strings.TrimSpace(connID)
|
||
if id == "" || conn == "" {
|
||
return
|
||
}
|
||
ttl = normalizeOpenAIWSTTL(ttl)
|
||
s.maybeCleanup()
|
||
|
||
s.responseToConnMu.Lock()
|
||
ensureBindingCapacity(s.responseToConn, id, openAIWSStateStoreMaxEntriesPerMap)
|
||
s.responseToConn[id] = openAIWSConnBinding{
|
||
connID: conn,
|
||
expiresAt: time.Now().Add(ttl),
|
||
}
|
||
s.responseToConnMu.Unlock()
|
||
}
|
||
|
||
func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) {
|
||
id := normalizeOpenAIWSResponseID(responseID)
|
||
if id == "" {
|
||
return "", false
|
||
}
|
||
s.maybeCleanup()
|
||
|
||
now := time.Now()
|
||
s.responseToConnMu.RLock()
|
||
binding, ok := s.responseToConn[id]
|
||
s.responseToConnMu.RUnlock()
|
||
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
|
||
return "", false
|
||
}
|
||
return binding.connID, true
|
||
}
|
||
|
||
func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) {
|
||
id := normalizeOpenAIWSResponseID(responseID)
|
||
if id == "" {
|
||
return
|
||
}
|
||
s.responseToConnMu.Lock()
|
||
delete(s.responseToConn, id)
|
||
s.responseToConnMu.Unlock()
|
||
}
|
||
|
||
func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) {
|
||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||
state := strings.TrimSpace(turnState)
|
||
if key == "" || state == "" {
|
||
return
|
||
}
|
||
ttl = normalizeOpenAIWSTTL(ttl)
|
||
s.maybeCleanup()
|
||
|
||
s.sessionToTurnStateMu.Lock()
|
||
ensureBindingCapacity(s.sessionToTurnState, key, openAIWSStateStoreMaxEntriesPerMap)
|
||
s.sessionToTurnState[key] = openAIWSTurnStateBinding{
|
||
turnState: state,
|
||
expiresAt: time.Now().Add(ttl),
|
||
}
|
||
s.sessionToTurnStateMu.Unlock()
|
||
}
|
||
|
||
func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) {
|
||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||
if key == "" {
|
||
return "", false
|
||
}
|
||
s.maybeCleanup()
|
||
|
||
now := time.Now()
|
||
s.sessionToTurnStateMu.RLock()
|
||
binding, ok := s.sessionToTurnState[key]
|
||
s.sessionToTurnStateMu.RUnlock()
|
||
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.turnState) == "" {
|
||
return "", false
|
||
}
|
||
return binding.turnState, true
|
||
}
|
||
|
||
func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessionHash string) {
|
||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||
if key == "" {
|
||
return
|
||
}
|
||
s.sessionToTurnStateMu.Lock()
|
||
delete(s.sessionToTurnState, key)
|
||
s.sessionToTurnStateMu.Unlock()
|
||
}
|
||
|
||
func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) {
|
||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||
conn := strings.TrimSpace(connID)
|
||
if key == "" || conn == "" {
|
||
return
|
||
}
|
||
ttl = normalizeOpenAIWSTTL(ttl)
|
||
s.maybeCleanup()
|
||
|
||
s.sessionToConnMu.Lock()
|
||
ensureBindingCapacity(s.sessionToConn, key, openAIWSStateStoreMaxEntriesPerMap)
|
||
s.sessionToConn[key] = openAIWSSessionConnBinding{
|
||
connID: conn,
|
||
expiresAt: time.Now().Add(ttl),
|
||
}
|
||
s.sessionToConnMu.Unlock()
|
||
}
|
||
|
||
func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash string) (string, bool) {
|
||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||
if key == "" {
|
||
return "", false
|
||
}
|
||
s.maybeCleanup()
|
||
|
||
now := time.Now()
|
||
s.sessionToConnMu.RLock()
|
||
binding, ok := s.sessionToConn[key]
|
||
s.sessionToConnMu.RUnlock()
|
||
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
|
||
return "", false
|
||
}
|
||
return binding.connID, true
|
||
}
|
||
|
||
func (s *defaultOpenAIWSStateStore) DeleteSessionConn(groupID int64, sessionHash string) {
|
||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||
if key == "" {
|
||
return
|
||
}
|
||
s.sessionToConnMu.Lock()
|
||
delete(s.sessionToConn, key)
|
||
s.sessionToConnMu.Unlock()
|
||
}
|
||
|
||
func (s *defaultOpenAIWSStateStore) maybeCleanup() {
|
||
if s == nil {
|
||
return
|
||
}
|
||
now := time.Now()
|
||
last := time.Unix(0, s.lastCleanupUnixNano.Load())
|
||
if now.Sub(last) < openAIWSStateStoreCleanupInterval {
|
||
return
|
||
}
|
||
if !s.lastCleanupUnixNano.CompareAndSwap(last.UnixNano(), now.UnixNano()) {
|
||
return
|
||
}
|
||
|
||
// 增量限额清理,避免高规模下一次性全量扫描导致长时间阻塞。
|
||
s.responseToAccountMu.Lock()
|
||
cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap)
|
||
s.responseToAccountMu.Unlock()
|
||
|
||
s.responseToConnMu.Lock()
|
||
cleanupExpiredConnBindings(s.responseToConn, now, openAIWSStateStoreCleanupMaxPerMap)
|
||
s.responseToConnMu.Unlock()
|
||
|
||
s.sessionToTurnStateMu.Lock()
|
||
cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap)
|
||
s.sessionToTurnStateMu.Unlock()
|
||
|
||
s.sessionToConnMu.Lock()
|
||
cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap)
|
||
s.sessionToConnMu.Unlock()
|
||
}
|
||
|
||
func cleanupExpiredAccountBindings(bindings map[string]openAIWSAccountBinding, now time.Time, maxScan int) {
|
||
if len(bindings) == 0 || maxScan <= 0 {
|
||
return
|
||
}
|
||
scanned := 0
|
||
for key, binding := range bindings {
|
||
if now.After(binding.expiresAt) {
|
||
delete(bindings, key)
|
||
}
|
||
scanned++
|
||
if scanned >= maxScan {
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now time.Time, maxScan int) {
|
||
if len(bindings) == 0 || maxScan <= 0 {
|
||
return
|
||
}
|
||
scanned := 0
|
||
for key, binding := range bindings {
|
||
if now.After(binding.expiresAt) {
|
||
delete(bindings, key)
|
||
}
|
||
scanned++
|
||
if scanned >= maxScan {
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) {
|
||
if len(bindings) == 0 || maxScan <= 0 {
|
||
return
|
||
}
|
||
scanned := 0
|
||
for key, binding := range bindings {
|
||
if now.After(binding.expiresAt) {
|
||
delete(bindings, key)
|
||
}
|
||
scanned++
|
||
if scanned >= maxScan {
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) {
|
||
if len(bindings) == 0 || maxScan <= 0 {
|
||
return
|
||
}
|
||
scanned := 0
|
||
for key, binding := range bindings {
|
||
if now.After(binding.expiresAt) {
|
||
delete(bindings, key)
|
||
}
|
||
scanned++
|
||
if scanned >= maxScan {
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
func ensureBindingCapacity[T any](bindings map[string]T, incomingKey string, maxEntries int) {
|
||
if len(bindings) < maxEntries || maxEntries <= 0 {
|
||
return
|
||
}
|
||
if _, exists := bindings[incomingKey]; exists {
|
||
return
|
||
}
|
||
// 固定上限保护:淘汰任意一项,优先保证内存有界。
|
||
for key := range bindings {
|
||
delete(bindings, key)
|
||
return
|
||
}
|
||
}
|
||
|
||
func normalizeOpenAIWSResponseID(responseID string) string {
|
||
return strings.TrimSpace(responseID)
|
||
}
|
||
|
||
func openAIWSResponseAccountCacheKey(responseID string) string {
|
||
sum := sha256.Sum256([]byte(responseID))
|
||
return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:])
|
||
}
|
||
|
||
func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration {
|
||
if ttl <= 0 {
|
||
return time.Hour
|
||
}
|
||
return ttl
|
||
}
|
||
|
||
func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string {
|
||
hash := strings.TrimSpace(sessionHash)
|
||
if hash == "" {
|
||
return ""
|
||
}
|
||
return fmt.Sprintf("%d:%s", groupID, hash)
|
||
}
|
||
|
||
func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
|
||
if ctx == nil {
|
||
ctx = context.Background()
|
||
}
|
||
return context.WithTimeout(ctx, openAIWSStateStoreRedisTimeout)
|
||
}
|