fix: 优化调度快照缓存以避免 Redis 大 MGET

This commit is contained in:
ius
2026-04-08 10:10:15 -07:00
parent 0d69c0cd64
commit 265687b56d
12 changed files with 631 additions and 163 deletions

View File

@@ -332,6 +332,10 @@ func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists",
"zadd", "zcard", "zrange", "zrangebyscore", "zrem", "zremrangebyscore", "zrevrange", "zrevrangebyscore", "zscore":
prefixOne(1)
case "mget":
for i := 1; i < len(args); i++ {
prefixOne(i)
}
case "del", "unlink":
for i := 1; i < len(args); i++ {
prefixOne(i)

View File

@@ -15,19 +15,39 @@ const (
schedulerBucketSetKey = "sched:buckets"
schedulerOutboxWatermarkKey = "sched:outbox:watermark"
schedulerAccountPrefix = "sched:acc:"
schedulerAccountMetaPrefix = "sched:meta:"
schedulerActivePrefix = "sched:active:"
schedulerReadyPrefix = "sched:ready:"
schedulerVersionPrefix = "sched:ver:"
schedulerSnapshotPrefix = "sched:"
schedulerLockPrefix = "sched:lock:"
defaultSchedulerSnapshotMGetChunkSize = 128
defaultSchedulerSnapshotWriteChunkSize = 256
)
type schedulerCache struct {
rdb *redis.Client
rdb *redis.Client
mgetChunkSize int
writeChunkSize int
}
func NewSchedulerCache(rdb *redis.Client) service.SchedulerCache {
return &schedulerCache{rdb: rdb}
return newSchedulerCacheWithChunkSizes(rdb, defaultSchedulerSnapshotMGetChunkSize, defaultSchedulerSnapshotWriteChunkSize)
}
func newSchedulerCacheWithChunkSizes(rdb *redis.Client, mgetChunkSize, writeChunkSize int) service.SchedulerCache {
if mgetChunkSize <= 0 {
mgetChunkSize = defaultSchedulerSnapshotMGetChunkSize
}
if writeChunkSize <= 0 {
writeChunkSize = defaultSchedulerSnapshotWriteChunkSize
}
return &schedulerCache{
rdb: rdb,
mgetChunkSize: mgetChunkSize,
writeChunkSize: writeChunkSize,
}
}
func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
@@ -65,9 +85,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
keys := make([]string, 0, len(ids))
for _, id := range ids {
keys = append(keys, schedulerAccountKey(id))
keys = append(keys, schedulerAccountMetaKey(id))
}
values, err := c.rdb.MGet(ctx, keys...).Result()
values, err := c.mgetChunked(ctx, keys)
if err != nil {
return nil, false, err
}
@@ -100,14 +120,11 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
versionStr := strconv.FormatInt(version, 10)
snapshotKey := schedulerSnapshotKey(bucket, versionStr)
pipe := c.rdb.Pipeline()
for _, account := range accounts {
payload, err := json.Marshal(account)
if err != nil {
return err
}
pipe.Set(ctx, schedulerAccountKey(strconv.FormatInt(account.ID, 10)), payload, 0)
if err := c.writeAccounts(ctx, accounts); err != nil {
return err
}
pipe := c.rdb.Pipeline()
if len(accounts) > 0 {
// 使用序号作为 score保持数据库返回的排序语义。
members := make([]redis.Z, 0, len(accounts))
@@ -117,7 +134,13 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
Member: strconv.FormatInt(account.ID, 10),
})
}
pipe.ZAdd(ctx, snapshotKey, members...)
for start := 0; start < len(members); start += c.writeChunkSize {
end := start + c.writeChunkSize
if end > len(members) {
end = len(members)
}
pipe.ZAdd(ctx, snapshotKey, members[start:end]...)
}
} else {
pipe.Del(ctx, snapshotKey)
}
@@ -151,20 +174,15 @@ func (c *schedulerCache) SetAccount(ctx context.Context, account *service.Accoun
if account == nil || account.ID <= 0 {
return nil
}
payload, err := json.Marshal(account)
if err != nil {
return err
}
key := schedulerAccountKey(strconv.FormatInt(account.ID, 10))
return c.rdb.Set(ctx, key, payload, 0).Err()
return c.writeAccounts(ctx, []service.Account{*account})
}
func (c *schedulerCache) DeleteAccount(ctx context.Context, accountID int64) error {
if accountID <= 0 {
return nil
}
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
return c.rdb.Del(ctx, key).Err()
id := strconv.FormatInt(accountID, 10)
return c.rdb.Del(ctx, schedulerAccountKey(id), schedulerAccountMetaKey(id)).Err()
}
func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
@@ -179,7 +197,7 @@ func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]t
ids = append(ids, id)
}
values, err := c.rdb.MGet(ctx, keys...).Result()
values, err := c.mgetChunked(ctx, keys)
if err != nil {
return err
}
@@ -198,7 +216,12 @@ func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]t
if err != nil {
return err
}
metaPayload, err := json.Marshal(buildSchedulerMetadataAccount(*account))
if err != nil {
return err
}
pipe.Set(ctx, keys[i], updated, 0)
pipe.Set(ctx, schedulerAccountMetaKey(strconv.FormatInt(ids[i], 10)), metaPayload, 0)
}
_, err = pipe.Exec(ctx)
return err
@@ -256,6 +279,10 @@ func schedulerAccountKey(id string) string {
return schedulerAccountPrefix + id
}
func schedulerAccountMetaKey(id string) string {
return schedulerAccountMetaPrefix + id
}
func ptrTime(t time.Time) *time.Time {
return &t
}
@@ -276,3 +303,137 @@ func decodeCachedAccount(val any) (*service.Account, error) {
}
return &account, nil
}
func (c *schedulerCache) writeAccounts(ctx context.Context, accounts []service.Account) error {
if len(accounts) == 0 {
return nil
}
pipe := c.rdb.Pipeline()
pending := 0
flush := func() error {
if pending == 0 {
return nil
}
if _, err := pipe.Exec(ctx); err != nil {
return err
}
pipe = c.rdb.Pipeline()
pending = 0
return nil
}
for _, account := range accounts {
fullPayload, err := json.Marshal(account)
if err != nil {
return err
}
metaPayload, err := json.Marshal(buildSchedulerMetadataAccount(account))
if err != nil {
return err
}
id := strconv.FormatInt(account.ID, 10)
pipe.Set(ctx, schedulerAccountKey(id), fullPayload, 0)
pipe.Set(ctx, schedulerAccountMetaKey(id), metaPayload, 0)
pending++
if pending >= c.writeChunkSize {
if err := flush(); err != nil {
return err
}
}
}
return flush()
}
func (c *schedulerCache) mgetChunked(ctx context.Context, keys []string) ([]any, error) {
if len(keys) == 0 {
return []any{}, nil
}
out := make([]any, 0, len(keys))
chunkSize := c.mgetChunkSize
if chunkSize <= 0 {
chunkSize = defaultSchedulerSnapshotMGetChunkSize
}
for start := 0; start < len(keys); start += chunkSize {
end := start + chunkSize
if end > len(keys) {
end = len(keys)
}
part, err := c.rdb.MGet(ctx, keys[start:end]...).Result()
if err != nil {
return nil, err
}
out = append(out, part...)
}
return out, nil
}
func buildSchedulerMetadataAccount(account service.Account) service.Account {
return service.Account{
ID: account.ID,
Name: account.Name,
Platform: account.Platform,
Type: account.Type,
Concurrency: account.Concurrency,
Priority: account.Priority,
RateMultiplier: account.RateMultiplier,
Status: account.Status,
LastUsedAt: account.LastUsedAt,
ExpiresAt: account.ExpiresAt,
AutoPauseOnExpired: account.AutoPauseOnExpired,
Schedulable: account.Schedulable,
RateLimitedAt: account.RateLimitedAt,
RateLimitResetAt: account.RateLimitResetAt,
OverloadUntil: account.OverloadUntil,
TempUnschedulableUntil: account.TempUnschedulableUntil,
TempUnschedulableReason: account.TempUnschedulableReason,
SessionWindowStart: account.SessionWindowStart,
SessionWindowEnd: account.SessionWindowEnd,
SessionWindowStatus: account.SessionWindowStatus,
Credentials: filterSchedulerCredentials(account.Credentials),
Extra: filterSchedulerExtra(account.Extra),
}
}
func filterSchedulerCredentials(credentials map[string]any) map[string]any {
if len(credentials) == 0 {
return nil
}
keys := []string{"model_mapping", "api_key", "project_id", "oauth_type"}
filtered := make(map[string]any)
for _, key := range keys {
if value, ok := credentials[key]; ok && value != nil {
filtered[key] = value
}
}
if len(filtered) == 0 {
return nil
}
return filtered
}
func filterSchedulerExtra(extra map[string]any) map[string]any {
if len(extra) == 0 {
return nil
}
keys := []string{
"mixed_scheduling",
"window_cost_limit",
"window_cost_sticky_reserve",
"max_sessions",
"session_idle_timeout_minutes",
}
filtered := make(map[string]any)
for _, key := range keys {
if value, ok := extra[key]; ok && value != nil {
filtered[key] = value
}
}
if len(filtered) == 0 {
return nil
}
return filtered
}

View File

@@ -0,0 +1,88 @@
//go:build integration
package repository
import (
"context"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T) {
ctx := context.Background()
rdb := testRedis(t)
cache := NewSchedulerCache(rdb)
bucket := service.SchedulerBucket{GroupID: 2, Platform: service.PlatformGemini, Mode: service.SchedulerModeSingle}
now := time.Now().UTC().Truncate(time.Second)
limitReset := now.Add(10 * time.Minute)
overloadUntil := now.Add(2 * time.Minute)
tempUnschedUntil := now.Add(3 * time.Minute)
windowEnd := now.Add(5 * time.Hour)
account := service.Account{
ID: 101,
Name: "gemini-heavy",
Platform: service.PlatformGemini,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Schedulable: true,
Concurrency: 3,
Priority: 7,
LastUsedAt: &now,
Credentials: map[string]any{
"api_key": "gemini-api-key",
"access_token": "secret-access-token",
"project_id": "proj-1",
"oauth_type": "ai_studio",
"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"},
"huge_blob": strings.Repeat("x", 4096),
},
Extra: map[string]any{
"mixed_scheduling": true,
"window_cost_limit": 12.5,
"window_cost_sticky_reserve": 8.0,
"max_sessions": 4,
"session_idle_timeout_minutes": 11,
"unused_large_field": strings.Repeat("y", 4096),
},
RateLimitResetAt: &limitReset,
OverloadUntil: &overloadUntil,
TempUnschedulableUntil: &tempUnschedUntil,
SessionWindowStart: &now,
SessionWindowEnd: &windowEnd,
SessionWindowStatus: "active",
}
require.NoError(t, cache.SetSnapshot(ctx, bucket, []service.Account{account}))
snapshot, hit, err := cache.GetSnapshot(ctx, bucket)
require.NoError(t, err)
require.True(t, hit)
require.Len(t, snapshot, 1)
got := snapshot[0]
require.NotNil(t, got)
require.Equal(t, "gemini-api-key", got.GetCredential("api_key"))
require.Equal(t, "proj-1", got.GetCredential("project_id"))
require.Equal(t, "ai_studio", got.GetCredential("oauth_type"))
require.NotEmpty(t, got.GetModelMapping())
require.Empty(t, got.GetCredential("access_token"))
require.Empty(t, got.GetCredential("huge_blob"))
require.Equal(t, true, got.Extra["mixed_scheduling"])
require.Equal(t, 12.5, got.GetWindowCostLimit())
require.Equal(t, 8.0, got.GetWindowCostStickyReserve())
require.Equal(t, 4, got.GetMaxSessions())
require.Equal(t, 11, got.GetSessionIdleTimeoutMinutes())
require.Nil(t, got.Extra["unused_large_field"])
full, err := cache.GetAccount(ctx, account.ID)
require.NoError(t, err)
require.NotNil(t, full)
require.Equal(t, "secret-access-token", full.GetCredential("access_token"))
require.Equal(t, strings.Repeat("x", 4096), full.GetCredential("huge_blob"))
}

View File

@@ -47,6 +47,21 @@ func ProvideSessionLimitCache(rdb *redis.Client, cfg *config.Config) service.Ses
return NewSessionLimitCache(rdb, defaultIdleTimeoutMinutes)
}
// ProvideSchedulerCache 创建调度快照缓存,并注入快照分块参数。
func ProvideSchedulerCache(rdb *redis.Client, cfg *config.Config) service.SchedulerCache {
mgetChunkSize := defaultSchedulerSnapshotMGetChunkSize
writeChunkSize := defaultSchedulerSnapshotWriteChunkSize
if cfg != nil {
if cfg.Gateway.Scheduling.SnapshotMGetChunkSize > 0 {
mgetChunkSize = cfg.Gateway.Scheduling.SnapshotMGetChunkSize
}
if cfg.Gateway.Scheduling.SnapshotWriteChunkSize > 0 {
writeChunkSize = cfg.Gateway.Scheduling.SnapshotWriteChunkSize
}
}
return newSchedulerCacheWithChunkSizes(rdb, mgetChunkSize, writeChunkSize)
}
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
@@ -92,7 +107,7 @@ var ProviderSet = wire.NewSet(
NewRedeemCache,
NewUpdateCache,
NewGeminiTokenCache,
NewSchedulerCache,
ProvideSchedulerCache,
NewSchedulerOutboxRepository,
NewProxyLatencyCache,
NewTotpCache,