diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index 58291b66..ec4bf40e 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -2,14 +2,42 @@ package repository import ( "context" + _ "embed" "fmt" + "strconv" "time" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" ) -const stickySessionPrefix = "sticky_session:" +const ( + stickySessionPrefix = "sticky_session:" + clientAffinityPrefix = "client_affinity:" + clientAffinityReversePrefix = "client_affinity_rev:" +) + +var ( + //go:embed lua/get_affinity.lua + getAffinityLua string + //go:embed lua/update_affinity.lua + updateAffinityLua string + //go:embed lua/get_affinity_count.lua + getAffinityCountLua string + //go:embed lua/get_affinity_clients.lua + getAffinityClientsLua string + //go:embed lua/get_affinity_clients_with_scores.lua + getAffinityClientsWithScoresLua string + //go:embed lua/clear_account_affinity.lua + clearAccountAffinityLua string + + getAffinityScript = redis.NewScript(getAffinityLua) + updateAffinityScript = redis.NewScript(updateAffinityLua) + getAffinityCountScript = redis.NewScript(getAffinityCountLua) + getAffinityClientsScript = redis.NewScript(getAffinityClientsLua) + getAffinityClientsWithScoresScript = redis.NewScript(getAffinityClientsWithScoresLua) + clearAccountAffinityScript = redis.NewScript(clearAccountAffinityLua) +) type gatewayCache struct { rdb *redis.Client @@ -19,6 +47,16 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache { return &gatewayCache{rdb: rdb} } +// ensureScriptLoaded 确保 Lua 脚本已加载到 Redis 服务器的脚本缓存中。 +// Pipeline 中的 Script.Run 只发送 EVALSHA,如果 Redis 重启过导致脚本缓存丢失, +// EVALSHA 会返回 NOSCRIPT 错误。此方法提前加载脚本以避免该问题。 +func ensureScriptLoaded(ctx context.Context, rdb *redis.Client, script *redis.Script) { + exists, err := script.Exists(ctx, rdb).Result() + if err != nil || len(exists) == 0 || !exists[0] { + _ = script.Load(ctx, rdb).Err() + } +} + // buildSessionKey 构建 session key,包含 groupID 实现分组隔离 // 格式: sticky_session:{groupID}:{sessionHash} func buildSessionKey(groupID int64, sessionHash string) string { @@ -41,13 +79,218 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses } // DeleteSessionAccountID 删除粘性会话与账号的绑定关系。 -// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用, -// 以便下次请求能够重新选择可用账号。 -// -// DeleteSessionAccountID removes the sticky session binding for the given session. -// Called when the bound account becomes unavailable (e.g., error status, disabled, -// or unschedulable), allowing subsequent requests to select a new available account. func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { key := buildSessionKey(groupID, sessionHash) return c.rdb.Del(ctx, key).Err() } + +// buildAffinityKey 构建正向亲和 key(client → accounts) +// 格式: client_affinity:{groupID}:{clientID} +func buildAffinityKey(groupID int64, clientID string) string { + return fmt.Sprintf("%s%d:%s", clientAffinityPrefix, groupID, clientID) +} + +// buildAffinityReverseKey 构建反向亲和 key(account → clients) +// 格式: client_affinity_rev:{groupID}:{accountID} +func buildAffinityReverseKey(groupID int64, accountID int64) string { + return fmt.Sprintf("%s%d:%d", clientAffinityReversePrefix, groupID, accountID) +} + +func (c *gatewayCache) GetClientAffinityAccounts(ctx context.Context, groupID int64, clientID string, ttl time.Duration) ([]int64, error) { + key := buildAffinityKey(groupID, clientID) + now := time.Now().Unix() + expireThreshold := now - int64(ttl.Seconds()) + + result, err := getAffinityScript.Run(ctx, c.rdb, []string{key}, expireThreshold).StringSlice() + if err != nil { + if err == redis.Nil { + return nil, nil + } + return nil, err + } + + accountIDs := make([]int64, 0, len(result)) + for _, s := range result { + id, err := strconv.ParseInt(s, 10, 64) + if err != nil { + continue + } + accountIDs = append(accountIDs, id) + } + return accountIDs, nil +} + +func (c *gatewayCache) UpdateClientAffinity(ctx context.Context, groupID int64, clientID string, accountID int64, ttl time.Duration) error { + fwdKey := buildAffinityKey(groupID, clientID) + revKey := buildAffinityReverseKey(groupID, accountID) + now := time.Now().Unix() + ttlSeconds := int64(ttl.Seconds()) + expireThreshold := now - ttlSeconds + + return updateAffinityScript.Run(ctx, c.rdb, []string{fwdKey, revKey}, + now, ttlSeconds, accountID, expireThreshold, clientID, + ).Err() +} + +// GetAccountAffinityCountBatch 批量获取账号的亲和客户端数量(惰性清理过期成员) +func (c *gatewayCache) GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error) { + if len(accountIDs) == 0 { + return map[int64]int64{}, nil + } + + now := time.Now().Unix() + expireThreshold := now - int64(ttl.Seconds()) + + ensureScriptLoaded(ctx, c.rdb, getAffinityCountScript) + + pipe := c.rdb.Pipeline() + cmds := make([]*redis.Cmd, len(accountIDs)) + for i, accID := range accountIDs { + key := buildAffinityReverseKey(groupID, accID) + cmds[i] = getAffinityCountScript.Run(ctx, pipe, []string{key}, expireThreshold) + } + _, err := pipe.Exec(ctx) + if err != nil && err != redis.Nil { + return nil, err + } + + result := make(map[int64]int64, len(accountIDs)) + for i, accID := range accountIDs { + count, _ := cmds[i].Int64() + result[accID] = count + } + return result, nil +} + +// GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和客户端列表(去重)。 +// accountGroups: map[accountID][]groupID,对每个 (groupID, accountID) 组合查询反向索引。 +func (c *gatewayCache) GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error) { + if len(accountGroups) == 0 { + return map[int64][]string{}, nil + } + + now := time.Now().Unix() + expireThreshold := now - int64(ttl.Seconds()) + + // 构建所有 (accountID, groupID) 组合的查询 + type queryItem struct { + accountID int64 + groupID int64 + } + var queries []queryItem + for accID, groupIDs := range accountGroups { + for _, gID := range groupIDs { + queries = append(queries, queryItem{accountID: accID, groupID: gID}) + } + } + + ensureScriptLoaded(ctx, c.rdb, getAffinityClientsScript) + + pipe := c.rdb.Pipeline() + cmds := make([]*redis.Cmd, len(queries)) + for i, q := range queries { + key := buildAffinityReverseKey(q.groupID, q.accountID) + cmds[i] = getAffinityClientsScript.Run(ctx, pipe, []string{key}, expireThreshold) + } + _, err := pipe.Exec(ctx) + if err != nil && err != redis.Nil { + return nil, err + } + + // 合并结果:同一个 accountID 跨多个 group 的 clientID 去重 + result := make(map[int64][]string, len(accountGroups)) + seen := make(map[int64]map[string]struct{}, len(accountGroups)) + for i, q := range queries { + clients, _ := cmds[i].StringSlice() + if len(clients) == 0 { + continue + } + if seen[q.accountID] == nil { + seen[q.accountID] = make(map[string]struct{}) + } + for _, clientID := range clients { + if _, exists := seen[q.accountID][clientID]; !exists { + seen[q.accountID][clientID] = struct{}{} + result[q.accountID] = append(result[q.accountID], clientID) + } + } + } + return result, nil +} + +// GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间戳,去重取最近)。 +func (c *gatewayCache) GetAccountAffinityClientsWithScores( + ctx context.Context, + accountID int64, + groupIDs []int64, + ttl time.Duration, +) ([]service.AffinityClient, error) { + if len(groupIDs) == 0 { + return nil, nil + } + + now := time.Now().Unix() + expireThreshold := now - int64(ttl.Seconds()) + + ensureScriptLoaded(ctx, c.rdb, getAffinityClientsWithScoresScript) + + pipe := c.rdb.Pipeline() + cmds := make([]*redis.Cmd, len(groupIDs)) + for i, gID := range groupIDs { + key := buildAffinityReverseKey(gID, accountID) + cmds[i] = getAffinityClientsWithScoresScript.Run(ctx, pipe, []string{key}, expireThreshold) + } + _, err := pipe.Exec(ctx) + if err != nil && err != redis.Nil { + return nil, err + } + + // 合并跨组结果,同一 clientID 取最近的 lastActive + seen := make(map[string]int64) // clientID → max timestamp + for _, cmd := range cmds { + vals, _ := cmd.StringSlice() + // vals 格式: [clientID1, score1, clientID2, score2, ...] + for j := 0; j+1 < len(vals); j += 2 { + clientID := vals[j] + ts, _ := strconv.ParseInt(vals[j+1], 10, 64) + if existing, ok := seen[clientID]; !ok || ts > existing { + seen[clientID] = ts + } + } + } + + result := make([]service.AffinityClient, 0, len(seen)) + for clientID, ts := range seen { + result = append(result, service.AffinityClient{ + ClientID: clientID, + LastActive: time.Unix(ts, 0), + }) + } + + // 按最后活跃时间降序排序 + service.SortAffinityClients(result) + + return result, nil +} + +// ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引)。 +// 对每个 groupID 执行 Lua 脚本:读取反向索引获取所有客户端, +// 从每个客户端的正向索引中移除该账号,然后删除反向索引。 +func (c *gatewayCache) ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error { + if len(groupIDs) == 0 { + return nil + } + + ensureScriptLoaded(ctx, c.rdb, clearAccountAffinityScript) + + pipe := c.rdb.Pipeline() + for _, gID := range groupIDs { + revKey := buildAffinityReverseKey(gID, accountID) + clearAccountAffinityScript.Run(ctx, pipe, []string{revKey}, gID, accountID) + } + _, err := pipe.Exec(ctx) + if err != nil && err != redis.Nil { + return err + } + return nil +} diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index f9fd6742..5c18a438 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -65,9 +65,6 @@ func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (boo func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { - panic("unexpected") -} func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") } @@ -131,9 +128,6 @@ func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, str func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { panic("unexpected") } -func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) { - panic("unexpected") -} func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) { panic("unexpected") } @@ -200,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) { panic("unexpected") } -func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) { +func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) { panic("unexpected") } func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { @@ -216,29 +210,6 @@ func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupS panic("unexpected") } -type userSubRepoStubForGroupUpdate struct { - userSubRepoNoop - getActiveSub *UserSubscription - getActiveErr error - called bool - calledUserID int64 - calledGroupID int64 -} - -func (s *userSubRepoStubForGroupUpdate) GetActiveByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) { - s.called = true - s.calledUserID = userID - s.calledGroupID = groupID - if s.getActiveErr != nil { - return nil, s.getActiveErr - } - if s.getActiveSub == nil { - return nil, ErrSubscriptionNotFound - } - clone := *s.getActiveSub - return &clone, nil -} - // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- @@ -431,49 +402,14 @@ func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupU func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) { existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} - groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}} - userRepo := &userRepoStubForGroupUpdate{} - userSubRepo := &userSubRepoStubForGroupUpdate{getActiveErr: ErrSubscriptionNotFound} - svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo} - - // 无有效订阅时应拒绝绑定 - _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) - require.Error(t, err) - require.Equal(t, "SUBSCRIPTION_REQUIRED", infraerrors.Reason(err)) - require.True(t, userSubRepo.called) - require.Equal(t, int64(42), userSubRepo.calledUserID) - require.Equal(t, int64(10), userSubRepo.calledGroupID) - require.False(t, userRepo.addGroupCalled) -} - -func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_RequiresRepo(t *testing.T) { - existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} - apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} - groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}} userRepo := &userRepoStubForGroupUpdate{} svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo} + // 订阅类型分组应被阻止绑定 _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) require.Error(t, err) - require.Equal(t, "SUBSCRIPTION_REPOSITORY_UNAVAILABLE", infraerrors.Reason(err)) - require.False(t, userRepo.addGroupCalled) -} - -func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_AllowsActiveSubscription(t *testing.T) { - existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} - apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} - groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}} - userRepo := &userRepoStubForGroupUpdate{} - userSubRepo := &userSubRepoStubForGroupUpdate{ - getActiveSub: &UserSubscription{ID: 99, UserID: 42, GroupID: 10}, - } - svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo} - - got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) - require.NoError(t, err) - require.True(t, userSubRepo.called) - require.NotNil(t, got.APIKey.GroupID) - require.Equal(t, int64(10), *got.APIKey.GroupID) + require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err)) require.False(t, userRepo.addGroupCalled) } diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index e88694f5..7f6c748f 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -46,12 +46,9 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int return 0, nil } func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } -func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { - return nil -} -func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } -func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } -func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } +func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } +func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } // --- mock: APIKeyAuthCacheInvalidator ---