feat(sync): full code sync from release
This commit is contained in:
440
backend/internal/service/openai_ws_state_store.go
Normal file
440
backend/internal/service/openai_ws_state_store.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
openAIWSResponseAccountCachePrefix = "openai:response:"
|
||||
openAIWSStateStoreCleanupInterval = time.Minute
|
||||
openAIWSStateStoreCleanupMaxPerMap = 512
|
||||
openAIWSStateStoreMaxEntriesPerMap = 65536
|
||||
openAIWSStateStoreRedisTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
type openAIWSAccountBinding struct {
|
||||
accountID int64
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type openAIWSConnBinding struct {
|
||||
connID string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type openAIWSTurnStateBinding struct {
|
||||
turnState string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type openAIWSSessionConnBinding struct {
|
||||
connID string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// OpenAIWSStateStore 管理 WSv2 的粘连状态。
|
||||
// - response_id -> account_id 用于续链路由
|
||||
// - response_id -> conn_id 用于连接内上下文复用
|
||||
//
|
||||
// response_id -> account_id 优先走 GatewayCache(Redis),同时维护本地热缓存。
|
||||
// response_id -> conn_id 仅在本进程内有效。
|
||||
type OpenAIWSStateStore interface {
|
||||
BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error
|
||||
GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error)
|
||||
DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error
|
||||
|
||||
BindResponseConn(responseID, connID string, ttl time.Duration)
|
||||
GetResponseConn(responseID string) (string, bool)
|
||||
DeleteResponseConn(responseID string)
|
||||
|
||||
BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration)
|
||||
GetSessionTurnState(groupID int64, sessionHash string) (string, bool)
|
||||
DeleteSessionTurnState(groupID int64, sessionHash string)
|
||||
|
||||
BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration)
|
||||
GetSessionConn(groupID int64, sessionHash string) (string, bool)
|
||||
DeleteSessionConn(groupID int64, sessionHash string)
|
||||
}
|
||||
|
||||
type defaultOpenAIWSStateStore struct {
|
||||
cache GatewayCache
|
||||
|
||||
responseToAccountMu sync.RWMutex
|
||||
responseToAccount map[string]openAIWSAccountBinding
|
||||
responseToConnMu sync.RWMutex
|
||||
responseToConn map[string]openAIWSConnBinding
|
||||
sessionToTurnStateMu sync.RWMutex
|
||||
sessionToTurnState map[string]openAIWSTurnStateBinding
|
||||
sessionToConnMu sync.RWMutex
|
||||
sessionToConn map[string]openAIWSSessionConnBinding
|
||||
|
||||
lastCleanupUnixNano atomic.Int64
|
||||
}
|
||||
|
||||
// NewOpenAIWSStateStore 创建默认 WS 状态存储。
|
||||
func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore {
|
||||
store := &defaultOpenAIWSStateStore{
|
||||
cache: cache,
|
||||
responseToAccount: make(map[string]openAIWSAccountBinding, 256),
|
||||
responseToConn: make(map[string]openAIWSConnBinding, 256),
|
||||
sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256),
|
||||
sessionToConn: make(map[string]openAIWSSessionConnBinding, 256),
|
||||
}
|
||||
store.lastCleanupUnixNano.Store(time.Now().UnixNano())
|
||||
return store
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
if id == "" || accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
ttl = normalizeOpenAIWSTTL(ttl)
|
||||
s.maybeCleanup()
|
||||
|
||||
expiresAt := time.Now().Add(ttl)
|
||||
s.responseToAccountMu.Lock()
|
||||
ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap)
|
||||
s.responseToAccount[id] = openAIWSAccountBinding{accountID: accountID, expiresAt: expiresAt}
|
||||
s.responseToAccountMu.Unlock()
|
||||
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
cacheKey := openAIWSResponseAccountCacheKey(id)
|
||||
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
return s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl)
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
if id == "" {
|
||||
return 0, nil
|
||||
}
|
||||
s.maybeCleanup()
|
||||
|
||||
now := time.Now()
|
||||
s.responseToAccountMu.RLock()
|
||||
if binding, ok := s.responseToAccount[id]; ok {
|
||||
if now.Before(binding.expiresAt) {
|
||||
accountID := binding.accountID
|
||||
s.responseToAccountMu.RUnlock()
|
||||
return accountID, nil
|
||||
}
|
||||
}
|
||||
s.responseToAccountMu.RUnlock()
|
||||
|
||||
if s.cache == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
cacheKey := openAIWSResponseAccountCacheKey(id)
|
||||
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey)
|
||||
if err != nil || accountID <= 0 {
|
||||
// 缓存读取失败不阻断主流程,按未命中降级。
|
||||
return 0, nil
|
||||
}
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
if id == "" {
|
||||
return nil
|
||||
}
|
||||
s.responseToAccountMu.Lock()
|
||||
delete(s.responseToAccount, id)
|
||||
s.responseToAccountMu.Unlock()
|
||||
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
return s.cache.DeleteSessionAccountID(cacheCtx, groupID, openAIWSResponseAccountCacheKey(id))
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
conn := strings.TrimSpace(connID)
|
||||
if id == "" || conn == "" {
|
||||
return
|
||||
}
|
||||
ttl = normalizeOpenAIWSTTL(ttl)
|
||||
s.maybeCleanup()
|
||||
|
||||
s.responseToConnMu.Lock()
|
||||
ensureBindingCapacity(s.responseToConn, id, openAIWSStateStoreMaxEntriesPerMap)
|
||||
s.responseToConn[id] = openAIWSConnBinding{
|
||||
connID: conn,
|
||||
expiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
s.responseToConnMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
if id == "" {
|
||||
return "", false
|
||||
}
|
||||
s.maybeCleanup()
|
||||
|
||||
now := time.Now()
|
||||
s.responseToConnMu.RLock()
|
||||
binding, ok := s.responseToConn[id]
|
||||
s.responseToConnMu.RUnlock()
|
||||
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
|
||||
return "", false
|
||||
}
|
||||
return binding.connID, true
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
if id == "" {
|
||||
return
|
||||
}
|
||||
s.responseToConnMu.Lock()
|
||||
delete(s.responseToConn, id)
|
||||
s.responseToConnMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
state := strings.TrimSpace(turnState)
|
||||
if key == "" || state == "" {
|
||||
return
|
||||
}
|
||||
ttl = normalizeOpenAIWSTTL(ttl)
|
||||
s.maybeCleanup()
|
||||
|
||||
s.sessionToTurnStateMu.Lock()
|
||||
ensureBindingCapacity(s.sessionToTurnState, key, openAIWSStateStoreMaxEntriesPerMap)
|
||||
s.sessionToTurnState[key] = openAIWSTurnStateBinding{
|
||||
turnState: state,
|
||||
expiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
s.sessionToTurnStateMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
if key == "" {
|
||||
return "", false
|
||||
}
|
||||
s.maybeCleanup()
|
||||
|
||||
now := time.Now()
|
||||
s.sessionToTurnStateMu.RLock()
|
||||
binding, ok := s.sessionToTurnState[key]
|
||||
s.sessionToTurnStateMu.RUnlock()
|
||||
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.turnState) == "" {
|
||||
return "", false
|
||||
}
|
||||
return binding.turnState, true
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessionHash string) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
s.sessionToTurnStateMu.Lock()
|
||||
delete(s.sessionToTurnState, key)
|
||||
s.sessionToTurnStateMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
conn := strings.TrimSpace(connID)
|
||||
if key == "" || conn == "" {
|
||||
return
|
||||
}
|
||||
ttl = normalizeOpenAIWSTTL(ttl)
|
||||
s.maybeCleanup()
|
||||
|
||||
s.sessionToConnMu.Lock()
|
||||
ensureBindingCapacity(s.sessionToConn, key, openAIWSStateStoreMaxEntriesPerMap)
|
||||
s.sessionToConn[key] = openAIWSSessionConnBinding{
|
||||
connID: conn,
|
||||
expiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
s.sessionToConnMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash string) (string, bool) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
if key == "" {
|
||||
return "", false
|
||||
}
|
||||
s.maybeCleanup()
|
||||
|
||||
now := time.Now()
|
||||
s.sessionToConnMu.RLock()
|
||||
binding, ok := s.sessionToConn[key]
|
||||
s.sessionToConnMu.RUnlock()
|
||||
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
|
||||
return "", false
|
||||
}
|
||||
return binding.connID, true
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) DeleteSessionConn(groupID int64, sessionHash string) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
s.sessionToConnMu.Lock()
|
||||
delete(s.sessionToConn, key)
|
||||
s.sessionToConnMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) maybeCleanup() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
last := time.Unix(0, s.lastCleanupUnixNano.Load())
|
||||
if now.Sub(last) < openAIWSStateStoreCleanupInterval {
|
||||
return
|
||||
}
|
||||
if !s.lastCleanupUnixNano.CompareAndSwap(last.UnixNano(), now.UnixNano()) {
|
||||
return
|
||||
}
|
||||
|
||||
// 增量限额清理,避免高规模下一次性全量扫描导致长时间阻塞。
|
||||
s.responseToAccountMu.Lock()
|
||||
cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap)
|
||||
s.responseToAccountMu.Unlock()
|
||||
|
||||
s.responseToConnMu.Lock()
|
||||
cleanupExpiredConnBindings(s.responseToConn, now, openAIWSStateStoreCleanupMaxPerMap)
|
||||
s.responseToConnMu.Unlock()
|
||||
|
||||
s.sessionToTurnStateMu.Lock()
|
||||
cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap)
|
||||
s.sessionToTurnStateMu.Unlock()
|
||||
|
||||
s.sessionToConnMu.Lock()
|
||||
cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap)
|
||||
s.sessionToConnMu.Unlock()
|
||||
}
|
||||
|
||||
func cleanupExpiredAccountBindings(bindings map[string]openAIWSAccountBinding, now time.Time, maxScan int) {
|
||||
if len(bindings) == 0 || maxScan <= 0 {
|
||||
return
|
||||
}
|
||||
scanned := 0
|
||||
for key, binding := range bindings {
|
||||
if now.After(binding.expiresAt) {
|
||||
delete(bindings, key)
|
||||
}
|
||||
scanned++
|
||||
if scanned >= maxScan {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now time.Time, maxScan int) {
|
||||
if len(bindings) == 0 || maxScan <= 0 {
|
||||
return
|
||||
}
|
||||
scanned := 0
|
||||
for key, binding := range bindings {
|
||||
if now.After(binding.expiresAt) {
|
||||
delete(bindings, key)
|
||||
}
|
||||
scanned++
|
||||
if scanned >= maxScan {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) {
|
||||
if len(bindings) == 0 || maxScan <= 0 {
|
||||
return
|
||||
}
|
||||
scanned := 0
|
||||
for key, binding := range bindings {
|
||||
if now.After(binding.expiresAt) {
|
||||
delete(bindings, key)
|
||||
}
|
||||
scanned++
|
||||
if scanned >= maxScan {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) {
|
||||
if len(bindings) == 0 || maxScan <= 0 {
|
||||
return
|
||||
}
|
||||
scanned := 0
|
||||
for key, binding := range bindings {
|
||||
if now.After(binding.expiresAt) {
|
||||
delete(bindings, key)
|
||||
}
|
||||
scanned++
|
||||
if scanned >= maxScan {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ensureBindingCapacity[T any](bindings map[string]T, incomingKey string, maxEntries int) {
|
||||
if len(bindings) < maxEntries || maxEntries <= 0 {
|
||||
return
|
||||
}
|
||||
if _, exists := bindings[incomingKey]; exists {
|
||||
return
|
||||
}
|
||||
// 固定上限保护:淘汰任意一项,优先保证内存有界。
|
||||
for key := range bindings {
|
||||
delete(bindings, key)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpenAIWSResponseID(responseID string) string {
|
||||
return strings.TrimSpace(responseID)
|
||||
}
|
||||
|
||||
func openAIWSResponseAccountCacheKey(responseID string) string {
|
||||
sum := sha256.Sum256([]byte(responseID))
|
||||
return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration {
|
||||
if ttl <= 0 {
|
||||
return time.Hour
|
||||
}
|
||||
return ttl
|
||||
}
|
||||
|
||||
func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string {
|
||||
hash := strings.TrimSpace(sessionHash)
|
||||
if hash == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%d:%s", groupID, hash)
|
||||
}
|
||||
|
||||
func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithTimeout(ctx, openAIWSStateStoreRedisTimeout)
|
||||
}
|
||||
Reference in New Issue
Block a user