Files
sub2api/backend/internal/repository/gateway_cache.go
erio 1cd033e521 style: apply gofmt formatting
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-14 09:08:00 +08:00

297 lines
9.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
_ "embed"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
stickySessionPrefix = "sticky_session:"
clientAffinityPrefix = "client_affinity:"
clientAffinityReversePrefix = "client_affinity_rev:"
)
var (
//go:embed lua/get_affinity.lua
getAffinityLua string
//go:embed lua/update_affinity.lua
updateAffinityLua string
//go:embed lua/get_affinity_count.lua
getAffinityCountLua string
//go:embed lua/get_affinity_clients.lua
getAffinityClientsLua string
//go:embed lua/get_affinity_clients_with_scores.lua
getAffinityClientsWithScoresLua string
//go:embed lua/clear_account_affinity.lua
clearAccountAffinityLua string
getAffinityScript = redis.NewScript(getAffinityLua)
updateAffinityScript = redis.NewScript(updateAffinityLua)
getAffinityCountScript = redis.NewScript(getAffinityCountLua)
getAffinityClientsScript = redis.NewScript(getAffinityClientsLua)
getAffinityClientsWithScoresScript = redis.NewScript(getAffinityClientsWithScoresLua)
clearAccountAffinityScript = redis.NewScript(clearAccountAffinityLua)
)
type gatewayCache struct {
rdb *redis.Client
}
func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
return &gatewayCache{rdb: rdb}
}
// ensureScriptLoaded 确保 Lua 脚本已加载到 Redis 服务器的脚本缓存中。
// Pipeline 中的 Script.Run 只发送 EVALSHA如果 Redis 重启过导致脚本缓存丢失,
// EVALSHA 会返回 NOSCRIPT 错误。此方法提前加载脚本以避免该问题。
func ensureScriptLoaded(ctx context.Context, rdb *redis.Client, script *redis.Script) {
exists, err := script.Exists(ctx, rdb).Result()
if err != nil || len(exists) == 0 || !exists[0] {
_ = script.Load(ctx, rdb).Err()
}
}
// buildSessionKey 构建 session key包含 groupID 实现分组隔离
// 格式: sticky_session:{groupID}:{sessionHash}
func buildSessionKey(groupID int64, sessionHash string) string {
return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash)
}
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Get(ctx, key).Int64()
}
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Set(ctx, key, accountID, ttl).Err()
}
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Expire(ctx, key, ttl).Err()
}
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Del(ctx, key).Err()
}
// buildAffinityKey 构建正向亲和 keyclient → accounts
// 格式: client_affinity:{groupID}:{clientID}
func buildAffinityKey(groupID int64, clientID string) string {
return fmt.Sprintf("%s%d:%s", clientAffinityPrefix, groupID, clientID)
}
// buildAffinityReverseKey 构建反向亲和 keyaccount → clients
// 格式: client_affinity_rev:{groupID}:{accountID}
func buildAffinityReverseKey(groupID int64, accountID int64) string {
return fmt.Sprintf("%s%d:%d", clientAffinityReversePrefix, groupID, accountID)
}
func (c *gatewayCache) GetClientAffinityAccounts(ctx context.Context, groupID int64, clientID string, ttl time.Duration) ([]int64, error) {
key := buildAffinityKey(groupID, clientID)
now := time.Now().Unix()
expireThreshold := now - int64(ttl.Seconds())
result, err := getAffinityScript.Run(ctx, c.rdb, []string{key}, expireThreshold).StringSlice()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, err
}
accountIDs := make([]int64, 0, len(result))
for _, s := range result {
id, err := strconv.ParseInt(s, 10, 64)
if err != nil {
continue
}
accountIDs = append(accountIDs, id)
}
return accountIDs, nil
}
func (c *gatewayCache) UpdateClientAffinity(ctx context.Context, groupID int64, clientID string, accountID int64, ttl time.Duration) error {
fwdKey := buildAffinityKey(groupID, clientID)
revKey := buildAffinityReverseKey(groupID, accountID)
now := time.Now().Unix()
ttlSeconds := int64(ttl.Seconds())
expireThreshold := now - ttlSeconds
return updateAffinityScript.Run(ctx, c.rdb, []string{fwdKey, revKey},
now, ttlSeconds, accountID, expireThreshold, clientID,
).Err()
}
// GetAccountAffinityCountBatch 批量获取账号的亲和客户端数量(惰性清理过期成员)
func (c *gatewayCache) GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error) {
if len(accountIDs) == 0 {
return map[int64]int64{}, nil
}
now := time.Now().Unix()
expireThreshold := now - int64(ttl.Seconds())
ensureScriptLoaded(ctx, c.rdb, getAffinityCountScript)
pipe := c.rdb.Pipeline()
cmds := make([]*redis.Cmd, len(accountIDs))
for i, accID := range accountIDs {
key := buildAffinityReverseKey(groupID, accID)
cmds[i] = getAffinityCountScript.Run(ctx, pipe, []string{key}, expireThreshold)
}
_, err := pipe.Exec(ctx)
if err != nil && err != redis.Nil {
return nil, err
}
result := make(map[int64]int64, len(accountIDs))
for i, accID := range accountIDs {
count, _ := cmds[i].Int64()
result[accID] = count
}
return result, nil
}
// GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和客户端列表(去重)。
// accountGroups: map[accountID][]groupID对每个 (groupID, accountID) 组合查询反向索引。
func (c *gatewayCache) GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error) {
if len(accountGroups) == 0 {
return map[int64][]string{}, nil
}
now := time.Now().Unix()
expireThreshold := now - int64(ttl.Seconds())
// 构建所有 (accountID, groupID) 组合的查询
type queryItem struct {
accountID int64
groupID int64
}
var queries []queryItem
for accID, groupIDs := range accountGroups {
for _, gID := range groupIDs {
queries = append(queries, queryItem{accountID: accID, groupID: gID})
}
}
ensureScriptLoaded(ctx, c.rdb, getAffinityClientsScript)
pipe := c.rdb.Pipeline()
cmds := make([]*redis.Cmd, len(queries))
for i, q := range queries {
key := buildAffinityReverseKey(q.groupID, q.accountID)
cmds[i] = getAffinityClientsScript.Run(ctx, pipe, []string{key}, expireThreshold)
}
_, err := pipe.Exec(ctx)
if err != nil && err != redis.Nil {
return nil, err
}
// 合并结果:同一个 accountID 跨多个 group 的 clientID 去重
result := make(map[int64][]string, len(accountGroups))
seen := make(map[int64]map[string]struct{}, len(accountGroups))
for i, q := range queries {
clients, _ := cmds[i].StringSlice()
if len(clients) == 0 {
continue
}
if seen[q.accountID] == nil {
seen[q.accountID] = make(map[string]struct{})
}
for _, clientID := range clients {
if _, exists := seen[q.accountID][clientID]; !exists {
seen[q.accountID][clientID] = struct{}{}
result[q.accountID] = append(result[q.accountID], clientID)
}
}
}
return result, nil
}
// GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间戳,去重取最近)。
func (c *gatewayCache) GetAccountAffinityClientsWithScores(
ctx context.Context,
accountID int64,
groupIDs []int64,
ttl time.Duration,
) ([]service.AffinityClient, error) {
if len(groupIDs) == 0 {
return nil, nil
}
now := time.Now().Unix()
expireThreshold := now - int64(ttl.Seconds())
ensureScriptLoaded(ctx, c.rdb, getAffinityClientsWithScoresScript)
pipe := c.rdb.Pipeline()
cmds := make([]*redis.Cmd, len(groupIDs))
for i, gID := range groupIDs {
key := buildAffinityReverseKey(gID, accountID)
cmds[i] = getAffinityClientsWithScoresScript.Run(ctx, pipe, []string{key}, expireThreshold)
}
_, err := pipe.Exec(ctx)
if err != nil && err != redis.Nil {
return nil, err
}
// 合并跨组结果,同一 clientID 取最近的 lastActive
seen := make(map[string]int64) // clientID → max timestamp
for _, cmd := range cmds {
vals, _ := cmd.StringSlice()
// vals 格式: [clientID1, score1, clientID2, score2, ...]
for j := 0; j+1 < len(vals); j += 2 {
clientID := vals[j]
ts, _ := strconv.ParseInt(vals[j+1], 10, 64)
if existing, ok := seen[clientID]; !ok || ts > existing {
seen[clientID] = ts
}
}
}
result := make([]service.AffinityClient, 0, len(seen))
for clientID, ts := range seen {
result = append(result, service.AffinityClient{
ClientID: clientID,
LastActive: time.Unix(ts, 0),
})
}
// 按最后活跃时间降序排序
service.SortAffinityClients(result)
return result, nil
}
// ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引)。
// 对每个 groupID 执行 Lua 脚本:读取反向索引获取所有客户端,
// 从每个客户端的正向索引中移除该账号,然后删除反向索引。
func (c *gatewayCache) ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
}
ensureScriptLoaded(ctx, c.rdb, clearAccountAffinityScript)
pipe := c.rdb.Pipeline()
for _, gID := range groupIDs {
revKey := buildAffinityReverseKey(gID, accountID)
clearAccountAffinityScript.Run(ctx, pipe, []string{revKey}, gID, accountID)
}
_, err := pipe.Exec(ctx)
if err != nil && err != redis.Nil {
return err
}
return nil
}