perf(gateway): 优化热点路径并补齐高覆盖测试

This commit is contained in:
yangjianbo
2026-02-22 13:31:30 +08:00
parent 2f520c8d47
commit a89477ddf5
16 changed files with 1760 additions and 76 deletions

View File

@@ -24,12 +24,15 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/cespare/xxhash/v2"
"github.com/google/uuid"
gocache "github.com/patrickmn/go-cache"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/sync/singleflight"
"github.com/gin-gonic/gin"
)
@@ -44,6 +47,9 @@ const (
// separator between system blocks, we add "\n\n" at concatenation time.
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
defaultUserGroupRateCacheTTL = 30 * time.Second
defaultModelsListCacheTTL = 15 * time.Second
)
const (
@@ -62,6 +68,53 @@ type accountWithLoad struct {
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
var (
windowCostPrefetchCacheHitTotal atomic.Int64
windowCostPrefetchCacheMissTotal atomic.Int64
windowCostPrefetchBatchSQLTotal atomic.Int64
windowCostPrefetchFallbackTotal atomic.Int64
windowCostPrefetchErrorTotal atomic.Int64
userGroupRateCacheHitTotal atomic.Int64
userGroupRateCacheMissTotal atomic.Int64
userGroupRateCacheLoadTotal atomic.Int64
userGroupRateCacheSFSharedTotal atomic.Int64
userGroupRateCacheFallbackTotal atomic.Int64
modelsListCacheHitTotal atomic.Int64
modelsListCacheMissTotal atomic.Int64
modelsListCacheStoreTotal atomic.Int64
)
func GatewayWindowCostPrefetchStats() (cacheHit, cacheMiss, batchSQL, fallback, errCount int64) {
return windowCostPrefetchCacheHitTotal.Load(),
windowCostPrefetchCacheMissTotal.Load(),
windowCostPrefetchBatchSQLTotal.Load(),
windowCostPrefetchFallbackTotal.Load(),
windowCostPrefetchErrorTotal.Load()
}
func GatewayUserGroupRateCacheStats() (cacheHit, cacheMiss, load, singleflightShared, fallback int64) {
return userGroupRateCacheHitTotal.Load(),
userGroupRateCacheMissTotal.Load(),
userGroupRateCacheLoadTotal.Load(),
userGroupRateCacheSFSharedTotal.Load(),
userGroupRateCacheFallbackTotal.Load()
}
func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) {
return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load()
}
func cloneStringSlice(src []string) []string {
if len(src) == 0 {
return nil
}
dst := make([]string, len(src))
copy(dst, src)
return dst
}
// IsForceCacheBilling 检查是否启用强制缓存计费
func IsForceCacheBilling(ctx context.Context) bool {
v, _ := ctx.Value(ForceCacheBillingContextKey).(bool)
@@ -302,6 +355,42 @@ func derefGroupID(groupID *int64) int64 {
return *groupID
}
func resolveUserGroupRateCacheTTL(cfg *config.Config) time.Duration {
if cfg == nil || cfg.Gateway.UserGroupRateCacheTTLSeconds <= 0 {
return defaultUserGroupRateCacheTTL
}
return time.Duration(cfg.Gateway.UserGroupRateCacheTTLSeconds) * time.Second
}
func resolveModelsListCacheTTL(cfg *config.Config) time.Duration {
if cfg == nil || cfg.Gateway.ModelsListCacheTTLSeconds <= 0 {
return defaultModelsListCacheTTL
}
return time.Duration(cfg.Gateway.ModelsListCacheTTLSeconds) * time.Second
}
func modelsListCacheKey(groupID *int64, platform string) string {
return fmt.Sprintf("%d|%s", derefGroupID(groupID), strings.TrimSpace(platform))
}
func prefetchedStickyAccountIDFromContext(ctx context.Context) int64 {
if ctx == nil {
return 0
}
v := ctx.Value(ctxkey.PrefetchedStickyAccountID)
switch t := v.(type) {
case int64:
if t > 0 {
return t
}
case int:
if t > 0 {
return int64(t)
}
}
return 0
}
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
// 或请求的模型处于限流状态时,返回 true。
@@ -421,6 +510,10 @@ type GatewayService struct {
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken
userGroupRateCache *gocache.Cache
userGroupRateSF singleflight.Group
modelsListCache *gocache.Cache
modelsListCacheTTL time.Duration
}
// NewGatewayService creates a new GatewayService
@@ -445,6 +538,9 @@ func NewGatewayService(
sessionLimitCache SessionLimitCache,
digestStore *DigestSessionStore,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg)
return &GatewayService{
accountRepo: accountRepo,
groupRepo: groupRepo,
@@ -465,6 +561,9 @@ func NewGatewayService(
deferredService: deferredService,
claudeTokenProvider: claudeTokenProvider,
sessionLimitCache: sessionLimitCache,
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
modelsListCache: gocache.New(modelsListTTL, time.Minute),
modelsListCacheTTL: modelsListTTL,
}
}
@@ -937,7 +1036,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if prefetch := prefetchedStickyAccountIDFromContext(ctx); prefetch > 0 {
stickyAccountID = prefetch
} else if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
stickyAccountID = accountID
}
@@ -1035,6 +1136,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if len(accounts) == 0 {
return nil, errors.New("no available accounts")
}
ctx = s.withWindowCostPrefetch(ctx, accounts)
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
@@ -1125,9 +1227,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if len(routingCandidates) > 0 {
// 1.5. 在路由账号范围内检查粘性会话
if sessionHash != "" && s.cache != nil {
stickyAccountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && stickyAccountID > 0 && containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
if sessionHash != "" && stickyAccountID > 0 {
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
// 粘性账号在路由列表中,优先使用
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
if stickyAccount.IsSchedulable() &&
@@ -1273,9 +1374,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
// ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
if len(routingAccountIDs) == 0 && sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) {
accountID := stickyAccountID
if accountID > 0 && !isExcluded(accountID) {
account, ok := accountByID[accountID]
if ok {
// 检查账户是否需要清理粘性会话绑定
@@ -1760,6 +1861,129 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
type usageLogWindowStatsBatchProvider interface {
GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error)
}
type windowCostPrefetchContextKeyType struct{}
var windowCostPrefetchContextKey = windowCostPrefetchContextKeyType{}
func windowCostFromPrefetchContext(ctx context.Context, accountID int64) (float64, bool) {
if ctx == nil || accountID <= 0 {
return 0, false
}
m, ok := ctx.Value(windowCostPrefetchContextKey).(map[int64]float64)
if !ok || len(m) == 0 {
return 0, false
}
v, exists := m[accountID]
return v, exists
}
func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []Account) context.Context {
if ctx == nil || len(accounts) == 0 || s.sessionLimitCache == nil || s.usageLogRepo == nil {
return ctx
}
accountByID := make(map[int64]*Account)
accountIDs := make([]int64, 0, len(accounts))
for i := range accounts {
account := &accounts[i]
if account == nil || !account.IsAnthropicOAuthOrSetupToken() {
continue
}
if account.GetWindowCostLimit() <= 0 {
continue
}
accountByID[account.ID] = account
accountIDs = append(accountIDs, account.ID)
}
if len(accountIDs) == 0 {
return ctx
}
costs := make(map[int64]float64, len(accountIDs))
cacheValues, err := s.sessionLimitCache.GetWindowCostBatch(ctx, accountIDs)
if err == nil {
for accountID, cost := range cacheValues {
costs[accountID] = cost
}
windowCostPrefetchCacheHitTotal.Add(int64(len(cacheValues)))
} else {
windowCostPrefetchErrorTotal.Add(1)
logger.LegacyPrintf("service.gateway", "window_cost batch cache read failed: %v", err)
}
cacheMissCount := len(accountIDs) - len(costs)
if cacheMissCount < 0 {
cacheMissCount = 0
}
windowCostPrefetchCacheMissTotal.Add(int64(cacheMissCount))
missingByStart := make(map[int64][]int64)
startTimes := make(map[int64]time.Time)
for _, accountID := range accountIDs {
if _, ok := costs[accountID]; ok {
continue
}
account := accountByID[accountID]
if account == nil {
continue
}
startTime := account.GetCurrentWindowStartTime()
startKey := startTime.Unix()
missingByStart[startKey] = append(missingByStart[startKey], accountID)
startTimes[startKey] = startTime
}
if len(missingByStart) == 0 {
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
}
batchReader, hasBatch := s.usageLogRepo.(usageLogWindowStatsBatchProvider)
for startKey, ids := range missingByStart {
startTime := startTimes[startKey]
if hasBatch {
windowCostPrefetchBatchSQLTotal.Add(1)
queryStart := time.Now()
statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, ids, startTime)
if err == nil {
slog.Debug("window_cost_batch_query_ok",
"accounts", len(ids),
"window_start", startTime.Format(time.RFC3339),
"duration_ms", time.Since(queryStart).Milliseconds())
for _, accountID := range ids {
stats := statsByAccount[accountID]
cost := 0.0
if stats != nil {
cost = stats.StandardCost
}
costs[accountID] = cost
_ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost)
}
continue
}
windowCostPrefetchErrorTotal.Add(1)
logger.LegacyPrintf("service.gateway", "window_cost batch db query failed: start=%s err=%v", startTime.Format(time.RFC3339), err)
}
// 回退路径:缺少批量仓储能力或批量查询失败时,按账号单查(失败开放)。
windowCostPrefetchFallbackTotal.Add(int64(len(ids)))
for _, accountID := range ids {
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime)
if err != nil {
windowCostPrefetchErrorTotal.Add(1)
continue
}
cost := stats.StandardCost
costs[accountID] = cost
_ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost)
}
}
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
}
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 返回 true 表示可调度false 表示不可调度
@@ -1776,6 +2000,10 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context,
// 尝试从缓存获取窗口费用
var currentCost float64
if cost, ok := windowCostFromPrefetchContext(ctx, account.ID); ok {
currentCost = cost
goto checkSchedulability
}
if s.sessionLimitCache != nil {
if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit {
currentCost = cost
@@ -5264,6 +5492,66 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
return body
}
func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
if s == nil || userID <= 0 || groupID <= 0 {
return groupDefaultMultiplier
}
key := fmt.Sprintf("%d:%d", userID, groupID)
if s.userGroupRateCache != nil {
if cached, ok := s.userGroupRateCache.Get(key); ok {
if multiplier, castOK := cached.(float64); castOK {
userGroupRateCacheHitTotal.Add(1)
return multiplier
}
}
}
if s.userGroupRateRepo == nil {
return groupDefaultMultiplier
}
userGroupRateCacheMissTotal.Add(1)
value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) {
if s.userGroupRateCache != nil {
if cached, ok := s.userGroupRateCache.Get(key); ok {
if multiplier, castOK := cached.(float64); castOK {
userGroupRateCacheHitTotal.Add(1)
return multiplier, nil
}
}
}
userGroupRateCacheLoadTotal.Add(1)
userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID)
if repoErr != nil {
return nil, repoErr
}
multiplier := groupDefaultMultiplier
if userRate != nil {
multiplier = *userRate
}
if s.userGroupRateCache != nil {
s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg))
}
return multiplier, nil
})
if shared {
userGroupRateCacheSFSharedTotal.Add(1)
}
if err != nil {
userGroupRateCacheFallbackTotal.Add(1)
logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err)
return groupDefaultMultiplier
}
multiplier, ok := value.(float64)
if !ok {
userGroupRateCacheFallbackTotal.Add(1)
return groupDefaultMultiplier
}
return multiplier
}
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
@@ -5307,16 +5595,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
multiplier := 1.0
if s.cfg != nil {
multiplier = s.cfg.Default.RateMultiplier
}
if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier
// 检查用户专属倍率
if s.userGroupRateRepo != nil {
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
multiplier = *userRate
}
}
groupDefault := apiKey.Group.RateMultiplier
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
}
var cost *CostBreakdown
@@ -5522,16 +5807,13 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
multiplier := 1.0
if s.cfg != nil {
multiplier = s.cfg.Default.RateMultiplier
}
if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier
// 检查用户专属倍率
if s.userGroupRateRepo != nil {
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
multiplier = *userRate
}
}
groupDefault := apiKey.Group.RateMultiplier
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
}
var cost *CostBreakdown
@@ -6145,6 +6427,17 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
cacheKey := modelsListCacheKey(groupID, platform)
if s.modelsListCache != nil {
if cached, found := s.modelsListCache.Get(cacheKey); found {
if models, ok := cached.([]string); ok {
modelsListCacheHitTotal.Add(1)
return cloneStringSlice(models)
}
}
}
modelsListCacheMissTotal.Add(1)
var accounts []Account
var err error
@@ -6185,6 +6478,10 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64,
// If no account has model_mapping, return nil (use default)
if !hasAnyMapping {
if s.modelsListCache != nil {
s.modelsListCache.Set(cacheKey, []string(nil), s.modelsListCacheTTL)
modelsListCacheStoreTotal.Add(1)
}
return nil
}
@@ -6193,8 +6490,45 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64,
for model := range modelSet {
models = append(models, model)
}
sort.Strings(models)
return models
if s.modelsListCache != nil {
s.modelsListCache.Set(cacheKey, cloneStringSlice(models), s.modelsListCacheTTL)
modelsListCacheStoreTotal.Add(1)
}
return cloneStringSlice(models)
}
func (s *GatewayService) InvalidateAvailableModelsCache(groupID *int64, platform string) {
if s == nil || s.modelsListCache == nil {
return
}
normalizedPlatform := strings.TrimSpace(platform)
// 完整匹配时精准失效;否则按维度批量失效。
if groupID != nil && normalizedPlatform != "" {
s.modelsListCache.Delete(modelsListCacheKey(groupID, normalizedPlatform))
return
}
targetGroup := derefGroupID(groupID)
for key := range s.modelsListCache.Items() {
parts := strings.SplitN(key, "|", 2)
if len(parts) != 2 {
continue
}
groupPart, parseErr := strconv.ParseInt(parts[0], 10, 64)
if parseErr != nil {
continue
}
if groupID != nil && groupPart != targetGroup {
continue
}
if normalizedPlatform != "" && parts[1] != normalizedPlatform {
continue
}
s.modelsListCache.Delete(key)
}
}
// reconcileCachedTokens 兼容 Kimi 等上游: