Files
sub2api/backend/internal/service/openai_ws_state_store.go
2026-02-28 15:01:20 +08:00

441 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
)
const (
openAIWSResponseAccountCachePrefix = "openai:response:"
openAIWSStateStoreCleanupInterval = time.Minute
openAIWSStateStoreCleanupMaxPerMap = 512
openAIWSStateStoreMaxEntriesPerMap = 65536
openAIWSStateStoreRedisTimeout = 3 * time.Second
)
type openAIWSAccountBinding struct {
accountID int64
expiresAt time.Time
}
type openAIWSConnBinding struct {
connID string
expiresAt time.Time
}
type openAIWSTurnStateBinding struct {
turnState string
expiresAt time.Time
}
type openAIWSSessionConnBinding struct {
connID string
expiresAt time.Time
}
// OpenAIWSStateStore 管理 WSv2 的粘连状态。
// - response_id -> account_id 用于续链路由
// - response_id -> conn_id 用于连接内上下文复用
//
// response_id -> account_id 优先走 GatewayCacheRedis同时维护本地热缓存。
// response_id -> conn_id 仅在本进程内有效。
type OpenAIWSStateStore interface {
BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error
GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error)
DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error
BindResponseConn(responseID, connID string, ttl time.Duration)
GetResponseConn(responseID string) (string, bool)
DeleteResponseConn(responseID string)
BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration)
GetSessionTurnState(groupID int64, sessionHash string) (string, bool)
DeleteSessionTurnState(groupID int64, sessionHash string)
BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration)
GetSessionConn(groupID int64, sessionHash string) (string, bool)
DeleteSessionConn(groupID int64, sessionHash string)
}
type defaultOpenAIWSStateStore struct {
cache GatewayCache
responseToAccountMu sync.RWMutex
responseToAccount map[string]openAIWSAccountBinding
responseToConnMu sync.RWMutex
responseToConn map[string]openAIWSConnBinding
sessionToTurnStateMu sync.RWMutex
sessionToTurnState map[string]openAIWSTurnStateBinding
sessionToConnMu sync.RWMutex
sessionToConn map[string]openAIWSSessionConnBinding
lastCleanupUnixNano atomic.Int64
}
// NewOpenAIWSStateStore 创建默认 WS 状态存储。
func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore {
store := &defaultOpenAIWSStateStore{
cache: cache,
responseToAccount: make(map[string]openAIWSAccountBinding, 256),
responseToConn: make(map[string]openAIWSConnBinding, 256),
sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256),
sessionToConn: make(map[string]openAIWSSessionConnBinding, 256),
}
store.lastCleanupUnixNano.Store(time.Now().UnixNano())
return store
}
func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" || accountID <= 0 {
return nil
}
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
expiresAt := time.Now().Add(ttl)
s.responseToAccountMu.Lock()
ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap)
s.responseToAccount[id] = openAIWSAccountBinding{accountID: accountID, expiresAt: expiresAt}
s.responseToAccountMu.Unlock()
if s.cache == nil {
return nil
}
cacheKey := openAIWSResponseAccountCacheKey(id)
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
defer cancel()
return s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl)
}
func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" {
return 0, nil
}
s.maybeCleanup()
now := time.Now()
s.responseToAccountMu.RLock()
if binding, ok := s.responseToAccount[id]; ok {
if now.Before(binding.expiresAt) {
accountID := binding.accountID
s.responseToAccountMu.RUnlock()
return accountID, nil
}
}
s.responseToAccountMu.RUnlock()
if s.cache == nil {
return 0, nil
}
cacheKey := openAIWSResponseAccountCacheKey(id)
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
defer cancel()
accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey)
if err != nil || accountID <= 0 {
// 缓存读取失败不阻断主流程,按未命中降级。
return 0, nil
}
return accountID, nil
}
func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" {
return nil
}
s.responseToAccountMu.Lock()
delete(s.responseToAccount, id)
s.responseToAccountMu.Unlock()
if s.cache == nil {
return nil
}
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
defer cancel()
return s.cache.DeleteSessionAccountID(cacheCtx, groupID, openAIWSResponseAccountCacheKey(id))
}
func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) {
id := normalizeOpenAIWSResponseID(responseID)
conn := strings.TrimSpace(connID)
if id == "" || conn == "" {
return
}
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
s.responseToConnMu.Lock()
ensureBindingCapacity(s.responseToConn, id, openAIWSStateStoreMaxEntriesPerMap)
s.responseToConn[id] = openAIWSConnBinding{
connID: conn,
expiresAt: time.Now().Add(ttl),
}
s.responseToConnMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" {
return "", false
}
s.maybeCleanup()
now := time.Now()
s.responseToConnMu.RLock()
binding, ok := s.responseToConn[id]
s.responseToConnMu.RUnlock()
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
return "", false
}
return binding.connID, true
}
func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" {
return
}
s.responseToConnMu.Lock()
delete(s.responseToConn, id)
s.responseToConnMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
state := strings.TrimSpace(turnState)
if key == "" || state == "" {
return
}
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
s.sessionToTurnStateMu.Lock()
ensureBindingCapacity(s.sessionToTurnState, key, openAIWSStateStoreMaxEntriesPerMap)
s.sessionToTurnState[key] = openAIWSTurnStateBinding{
turnState: state,
expiresAt: time.Now().Add(ttl),
}
s.sessionToTurnStateMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
if key == "" {
return "", false
}
s.maybeCleanup()
now := time.Now()
s.sessionToTurnStateMu.RLock()
binding, ok := s.sessionToTurnState[key]
s.sessionToTurnStateMu.RUnlock()
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.turnState) == "" {
return "", false
}
return binding.turnState, true
}
func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessionHash string) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
if key == "" {
return
}
s.sessionToTurnStateMu.Lock()
delete(s.sessionToTurnState, key)
s.sessionToTurnStateMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
conn := strings.TrimSpace(connID)
if key == "" || conn == "" {
return
}
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
s.sessionToConnMu.Lock()
ensureBindingCapacity(s.sessionToConn, key, openAIWSStateStoreMaxEntriesPerMap)
s.sessionToConn[key] = openAIWSSessionConnBinding{
connID: conn,
expiresAt: time.Now().Add(ttl),
}
s.sessionToConnMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash string) (string, bool) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
if key == "" {
return "", false
}
s.maybeCleanup()
now := time.Now()
s.sessionToConnMu.RLock()
binding, ok := s.sessionToConn[key]
s.sessionToConnMu.RUnlock()
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
return "", false
}
return binding.connID, true
}
func (s *defaultOpenAIWSStateStore) DeleteSessionConn(groupID int64, sessionHash string) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
if key == "" {
return
}
s.sessionToConnMu.Lock()
delete(s.sessionToConn, key)
s.sessionToConnMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) maybeCleanup() {
if s == nil {
return
}
now := time.Now()
last := time.Unix(0, s.lastCleanupUnixNano.Load())
if now.Sub(last) < openAIWSStateStoreCleanupInterval {
return
}
if !s.lastCleanupUnixNano.CompareAndSwap(last.UnixNano(), now.UnixNano()) {
return
}
// 增量限额清理,避免高规模下一次性全量扫描导致长时间阻塞。
s.responseToAccountMu.Lock()
cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap)
s.responseToAccountMu.Unlock()
s.responseToConnMu.Lock()
cleanupExpiredConnBindings(s.responseToConn, now, openAIWSStateStoreCleanupMaxPerMap)
s.responseToConnMu.Unlock()
s.sessionToTurnStateMu.Lock()
cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap)
s.sessionToTurnStateMu.Unlock()
s.sessionToConnMu.Lock()
cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap)
s.sessionToConnMu.Unlock()
}
func cleanupExpiredAccountBindings(bindings map[string]openAIWSAccountBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
}
scanned := 0
for key, binding := range bindings {
if now.After(binding.expiresAt) {
delete(bindings, key)
}
scanned++
if scanned >= maxScan {
break
}
}
}
func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
}
scanned := 0
for key, binding := range bindings {
if now.After(binding.expiresAt) {
delete(bindings, key)
}
scanned++
if scanned >= maxScan {
break
}
}
}
func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
}
scanned := 0
for key, binding := range bindings {
if now.After(binding.expiresAt) {
delete(bindings, key)
}
scanned++
if scanned >= maxScan {
break
}
}
}
func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
}
scanned := 0
for key, binding := range bindings {
if now.After(binding.expiresAt) {
delete(bindings, key)
}
scanned++
if scanned >= maxScan {
break
}
}
}
func ensureBindingCapacity[T any](bindings map[string]T, incomingKey string, maxEntries int) {
if len(bindings) < maxEntries || maxEntries <= 0 {
return
}
if _, exists := bindings[incomingKey]; exists {
return
}
// 固定上限保护:淘汰任意一项,优先保证内存有界。
for key := range bindings {
delete(bindings, key)
return
}
}
func normalizeOpenAIWSResponseID(responseID string) string {
return strings.TrimSpace(responseID)
}
func openAIWSResponseAccountCacheKey(responseID string) string {
sum := sha256.Sum256([]byte(responseID))
return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:])
}
func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration {
if ttl <= 0 {
return time.Hour
}
return ttl
}
func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string {
hash := strings.TrimSpace(sessionHash)
if hash == "" {
return ""
}
return fmt.Sprintf("%d:%s", groupID, hash)
}
func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = context.Background()
}
return context.WithTimeout(ctx, openAIWSStateStoreRedisTimeout)
}