perf(gateway): 优化热点路径并补齐高覆盖测试
This commit is contained in:
755
backend/internal/service/gateway_hotpath_optimization_test.go
Normal file
755
backend/internal/service/gateway_hotpath_optimization_test.go
Normal file
@@ -0,0 +1,755 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type userGroupRateRepoHotpathStub struct {
|
||||
UserGroupRateRepository
|
||||
|
||||
rate *float64
|
||||
err error
|
||||
wait <-chan struct{}
|
||||
calls atomic.Int64
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoHotpathStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||
s.calls.Add(1)
|
||||
if s.wait != nil {
|
||||
<-s.wait
|
||||
}
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return s.rate, nil
|
||||
}
|
||||
|
||||
type usageLogWindowBatchRepoStub struct {
|
||||
UsageLogRepository
|
||||
|
||||
batchResult map[int64]*usagestats.AccountStats
|
||||
batchErr error
|
||||
batchCalls atomic.Int64
|
||||
|
||||
singleResult map[int64]*usagestats.AccountStats
|
||||
singleErr error
|
||||
singleCalls atomic.Int64
|
||||
}
|
||||
|
||||
func (s *usageLogWindowBatchRepoStub) GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) {
|
||||
s.batchCalls.Add(1)
|
||||
if s.batchErr != nil {
|
||||
return nil, s.batchErr
|
||||
}
|
||||
out := make(map[int64]*usagestats.AccountStats, len(accountIDs))
|
||||
for _, id := range accountIDs {
|
||||
if stats, ok := s.batchResult[id]; ok {
|
||||
out[id] = stats
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *usageLogWindowBatchRepoStub) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||||
s.singleCalls.Add(1)
|
||||
if s.singleErr != nil {
|
||||
return nil, s.singleErr
|
||||
}
|
||||
if stats, ok := s.singleResult[accountID]; ok {
|
||||
return stats, nil
|
||||
}
|
||||
return &usagestats.AccountStats{}, nil
|
||||
}
|
||||
|
||||
type sessionLimitCacheHotpathStub struct {
|
||||
SessionLimitCache
|
||||
|
||||
batchData map[int64]float64
|
||||
batchErr error
|
||||
|
||||
setData map[int64]float64
|
||||
setErr error
|
||||
}
|
||||
|
||||
func (s *sessionLimitCacheHotpathStub) GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) {
|
||||
if s.batchErr != nil {
|
||||
return nil, s.batchErr
|
||||
}
|
||||
out := make(map[int64]float64, len(accountIDs))
|
||||
for _, id := range accountIDs {
|
||||
if v, ok := s.batchData[id]; ok {
|
||||
out[id] = v
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *sessionLimitCacheHotpathStub) SetWindowCost(ctx context.Context, accountID int64, cost float64) error {
|
||||
if s.setErr != nil {
|
||||
return s.setErr
|
||||
}
|
||||
if s.setData == nil {
|
||||
s.setData = make(map[int64]float64)
|
||||
}
|
||||
s.setData[accountID] = cost
|
||||
return nil
|
||||
}
|
||||
|
||||
type modelsListAccountRepoStub struct {
|
||||
AccountRepository
|
||||
|
||||
byGroup map[int64][]Account
|
||||
all []Account
|
||||
err error
|
||||
|
||||
listByGroupCalls atomic.Int64
|
||||
listAllCalls atomic.Int64
|
||||
}
|
||||
|
||||
type stickyGatewayCacheHotpathStub struct {
|
||||
GatewayCache
|
||||
|
||||
stickyID int64
|
||||
getCalls atomic.Int64
|
||||
}
|
||||
|
||||
func (s *stickyGatewayCacheHotpathStub) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||
s.getCalls.Add(1)
|
||||
if s.stickyID > 0 {
|
||||
return s.stickyID, nil
|
||||
}
|
||||
return 0, errors.New("not found")
|
||||
}
|
||||
|
||||
func (s *stickyGatewayCacheHotpathStub) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stickyGatewayCacheHotpathStub) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stickyGatewayCacheHotpathStub) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *modelsListAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
s.listByGroupCalls.Add(1)
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
accounts, ok := s.byGroup[groupID]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
out := make([]Account, len(accounts))
|
||||
copy(out, accounts)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *modelsListAccountRepoStub) ListSchedulable(ctx context.Context) ([]Account, error) {
|
||||
s.listAllCalls.Add(1)
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
out := make([]Account, len(s.all))
|
||||
copy(out, s.all)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func resetGatewayHotpathStatsForTest() {
|
||||
windowCostPrefetchCacheHitTotal.Store(0)
|
||||
windowCostPrefetchCacheMissTotal.Store(0)
|
||||
windowCostPrefetchBatchSQLTotal.Store(0)
|
||||
windowCostPrefetchFallbackTotal.Store(0)
|
||||
windowCostPrefetchErrorTotal.Store(0)
|
||||
|
||||
userGroupRateCacheHitTotal.Store(0)
|
||||
userGroupRateCacheMissTotal.Store(0)
|
||||
userGroupRateCacheLoadTotal.Store(0)
|
||||
userGroupRateCacheSFSharedTotal.Store(0)
|
||||
userGroupRateCacheFallbackTotal.Store(0)
|
||||
|
||||
modelsListCacheHitTotal.Store(0)
|
||||
modelsListCacheMissTotal.Store(0)
|
||||
modelsListCacheStoreTotal.Store(0)
|
||||
}
|
||||
|
||||
func TestGetUserGroupRateMultiplier_UsesCacheAndSingleflight(t *testing.T) {
|
||||
resetGatewayHotpathStatsForTest()
|
||||
|
||||
rate := 1.7
|
||||
unblock := make(chan struct{})
|
||||
repo := &userGroupRateRepoHotpathStub{
|
||||
rate: &rate,
|
||||
wait: unblock,
|
||||
}
|
||||
svc := &GatewayService{
|
||||
userGroupRateRepo: repo,
|
||||
userGroupRateCache: gocache.New(time.Minute, time.Minute),
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
UserGroupRateCacheTTLSeconds: 30,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const concurrent = 12
|
||||
results := make([]float64, concurrent)
|
||||
start := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrent)
|
||||
for i := 0; i < concurrent; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
results[idx] = svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2)
|
||||
}(i)
|
||||
}
|
||||
|
||||
close(start)
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
close(unblock)
|
||||
wg.Wait()
|
||||
|
||||
for _, got := range results {
|
||||
require.Equal(t, rate, got)
|
||||
}
|
||||
require.Equal(t, int64(1), repo.calls.Load())
|
||||
|
||||
// 再次读取应命中缓存,不再回源。
|
||||
got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2)
|
||||
require.Equal(t, rate, got)
|
||||
require.Equal(t, int64(1), repo.calls.Load())
|
||||
|
||||
hit, miss, load, sfShared, fallback := GatewayUserGroupRateCacheStats()
|
||||
require.GreaterOrEqual(t, hit, int64(1))
|
||||
require.Equal(t, int64(12), miss)
|
||||
require.Equal(t, int64(1), load)
|
||||
require.GreaterOrEqual(t, sfShared, int64(1))
|
||||
require.Equal(t, int64(0), fallback)
|
||||
}
|
||||
|
||||
func TestGetUserGroupRateMultiplier_FallbackOnRepoError(t *testing.T) {
|
||||
resetGatewayHotpathStatsForTest()
|
||||
|
||||
repo := &userGroupRateRepoHotpathStub{
|
||||
err: errors.New("db down"),
|
||||
}
|
||||
svc := &GatewayService{
|
||||
userGroupRateRepo: repo,
|
||||
userGroupRateCache: gocache.New(time.Minute, time.Minute),
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
UserGroupRateCacheTTLSeconds: 30,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.25)
|
||||
require.Equal(t, 1.25, got)
|
||||
require.Equal(t, int64(1), repo.calls.Load())
|
||||
|
||||
_, _, _, _, fallback := GatewayUserGroupRateCacheStats()
|
||||
require.Equal(t, int64(1), fallback)
|
||||
}
|
||||
|
||||
func TestGetUserGroupRateMultiplier_CacheHitAndNilRepo(t *testing.T) {
|
||||
resetGatewayHotpathStatsForTest()
|
||||
|
||||
repo := &userGroupRateRepoHotpathStub{
|
||||
err: errors.New("should not be called"),
|
||||
}
|
||||
svc := &GatewayService{
|
||||
userGroupRateRepo: repo,
|
||||
userGroupRateCache: gocache.New(time.Minute, time.Minute),
|
||||
}
|
||||
key := "101:202"
|
||||
svc.userGroupRateCache.Set(key, 2.3, time.Minute)
|
||||
|
||||
got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.1)
|
||||
require.Equal(t, 2.3, got)
|
||||
|
||||
hit, miss, load, _, fallback := GatewayUserGroupRateCacheStats()
|
||||
require.Equal(t, int64(1), hit)
|
||||
require.Equal(t, int64(0), miss)
|
||||
require.Equal(t, int64(0), load)
|
||||
require.Equal(t, int64(0), fallback)
|
||||
require.Equal(t, int64(0), repo.calls.Load())
|
||||
|
||||
// 无 repo 时直接返回分组默认倍率
|
||||
svc2 := &GatewayService{
|
||||
userGroupRateCache: gocache.New(time.Minute, time.Minute),
|
||||
}
|
||||
svc2.userGroupRateCache.Set(key, 1.9, time.Minute)
|
||||
require.Equal(t, 1.9, svc2.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.4))
|
||||
require.Equal(t, 1.4, svc2.getUserGroupRateMultiplier(context.Background(), 0, 202, 1.4))
|
||||
svc2.userGroupRateCache.Delete(key)
|
||||
require.Equal(t, 1.4, svc2.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.4))
|
||||
}
|
||||
|
||||
func TestWithWindowCostPrefetch_BatchReadAndContextReuse(t *testing.T) {
|
||||
resetGatewayHotpathStatsForTest()
|
||||
|
||||
windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour)
|
||||
windowEnd := windowStart.Add(5 * time.Hour)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{"window_cost_limit": 100.0},
|
||||
SessionWindowStart: &windowStart,
|
||||
SessionWindowEnd: &windowEnd,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeSetupToken,
|
||||
Extra: map[string]any{"window_cost_limit": 100.0},
|
||||
SessionWindowStart: &windowStart,
|
||||
SessionWindowEnd: &windowEnd,
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{"window_cost_limit": 100.0},
|
||||
},
|
||||
}
|
||||
|
||||
cache := &sessionLimitCacheHotpathStub{
|
||||
batchData: map[int64]float64{
|
||||
1: 11.0,
|
||||
},
|
||||
}
|
||||
repo := &usageLogWindowBatchRepoStub{
|
||||
batchResult: map[int64]*usagestats.AccountStats{
|
||||
2: {StandardCost: 22.0},
|
||||
},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
sessionLimitCache: cache,
|
||||
usageLogRepo: repo,
|
||||
}
|
||||
|
||||
outCtx := svc.withWindowCostPrefetch(context.Background(), accounts)
|
||||
require.NotNil(t, outCtx)
|
||||
|
||||
cost1, ok1 := windowCostFromPrefetchContext(outCtx, 1)
|
||||
require.True(t, ok1)
|
||||
require.Equal(t, 11.0, cost1)
|
||||
|
||||
cost2, ok2 := windowCostFromPrefetchContext(outCtx, 2)
|
||||
require.True(t, ok2)
|
||||
require.Equal(t, 22.0, cost2)
|
||||
|
||||
_, ok3 := windowCostFromPrefetchContext(outCtx, 3)
|
||||
require.False(t, ok3)
|
||||
|
||||
require.Equal(t, int64(1), repo.batchCalls.Load())
|
||||
require.Equal(t, 22.0, cache.setData[2])
|
||||
|
||||
hit, miss, batchSQL, fallback, errCount := GatewayWindowCostPrefetchStats()
|
||||
require.Equal(t, int64(1), hit)
|
||||
require.Equal(t, int64(1), miss)
|
||||
require.Equal(t, int64(1), batchSQL)
|
||||
require.Equal(t, int64(0), fallback)
|
||||
require.Equal(t, int64(0), errCount)
|
||||
}
|
||||
|
||||
func TestWithWindowCostPrefetch_AllHitNoSQL(t *testing.T) {
|
||||
resetGatewayHotpathStatsForTest()
|
||||
|
||||
windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour)
|
||||
windowEnd := windowStart.Add(5 * time.Hour)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{"window_cost_limit": 100.0},
|
||||
SessionWindowStart: &windowStart,
|
||||
SessionWindowEnd: &windowEnd,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeSetupToken,
|
||||
Extra: map[string]any{"window_cost_limit": 100.0},
|
||||
SessionWindowStart: &windowStart,
|
||||
SessionWindowEnd: &windowEnd,
|
||||
},
|
||||
}
|
||||
|
||||
cache := &sessionLimitCacheHotpathStub{
|
||||
batchData: map[int64]float64{
|
||||
1: 11.0,
|
||||
2: 22.0,
|
||||
},
|
||||
}
|
||||
repo := &usageLogWindowBatchRepoStub{}
|
||||
svc := &GatewayService{
|
||||
sessionLimitCache: cache,
|
||||
usageLogRepo: repo,
|
||||
}
|
||||
|
||||
outCtx := svc.withWindowCostPrefetch(context.Background(), accounts)
|
||||
cost1, ok1 := windowCostFromPrefetchContext(outCtx, 1)
|
||||
cost2, ok2 := windowCostFromPrefetchContext(outCtx, 2)
|
||||
require.True(t, ok1)
|
||||
require.True(t, ok2)
|
||||
require.Equal(t, 11.0, cost1)
|
||||
require.Equal(t, 22.0, cost2)
|
||||
require.Equal(t, int64(0), repo.batchCalls.Load())
|
||||
require.Equal(t, int64(0), repo.singleCalls.Load())
|
||||
|
||||
hit, miss, batchSQL, fallback, errCount := GatewayWindowCostPrefetchStats()
|
||||
require.Equal(t, int64(2), hit)
|
||||
require.Equal(t, int64(0), miss)
|
||||
require.Equal(t, int64(0), batchSQL)
|
||||
require.Equal(t, int64(0), fallback)
|
||||
require.Equal(t, int64(0), errCount)
|
||||
}
|
||||
|
||||
func TestWithWindowCostPrefetch_BatchErrorFallbackSingleQuery(t *testing.T) {
|
||||
resetGatewayHotpathStatsForTest()
|
||||
|
||||
windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour)
|
||||
windowEnd := windowStart.Add(5 * time.Hour)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 2,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeSetupToken,
|
||||
Extra: map[string]any{"window_cost_limit": 100.0},
|
||||
SessionWindowStart: &windowStart,
|
||||
SessionWindowEnd: &windowEnd,
|
||||
},
|
||||
}
|
||||
|
||||
cache := &sessionLimitCacheHotpathStub{}
|
||||
repo := &usageLogWindowBatchRepoStub{
|
||||
batchErr: errors.New("batch failed"),
|
||||
singleResult: map[int64]*usagestats.AccountStats{
|
||||
2: {StandardCost: 33.0},
|
||||
},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
sessionLimitCache: cache,
|
||||
usageLogRepo: repo,
|
||||
}
|
||||
|
||||
outCtx := svc.withWindowCostPrefetch(context.Background(), accounts)
|
||||
cost, ok := windowCostFromPrefetchContext(outCtx, 2)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 33.0, cost)
|
||||
require.Equal(t, int64(1), repo.batchCalls.Load())
|
||||
require.Equal(t, int64(1), repo.singleCalls.Load())
|
||||
|
||||
_, _, _, fallback, errCount := GatewayWindowCostPrefetchStats()
|
||||
require.Equal(t, int64(1), fallback)
|
||||
require.Equal(t, int64(1), errCount)
|
||||
}
|
||||
|
||||
func TestGetAvailableModels_UsesShortCacheAndSupportsInvalidation(t *testing.T) {
|
||||
resetGatewayHotpathStatsForTest()
|
||||
|
||||
groupID := int64(9)
|
||||
repo := &modelsListAccountRepoStub{
|
||||
byGroup: map[int64][]Account{
|
||||
groupID: {
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformAnthropic,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-3-5-sonnet": "claude-3-5-sonnet",
|
||||
"claude-3-5-haiku": "claude-3-5-haiku",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
modelsListCache: gocache.New(time.Minute, time.Minute),
|
||||
modelsListCacheTTL: time.Minute,
|
||||
}
|
||||
|
||||
models1 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic)
|
||||
require.Equal(t, []string{"claude-3-5-haiku", "claude-3-5-sonnet"}, models1)
|
||||
require.Equal(t, int64(1), repo.listByGroupCalls.Load())
|
||||
|
||||
// TTL 内再次请求应命中缓存,不回源。
|
||||
models2 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic)
|
||||
require.Equal(t, models1, models2)
|
||||
require.Equal(t, int64(1), repo.listByGroupCalls.Load())
|
||||
|
||||
// 更新仓储数据,但缓存未失效前应继续返回旧值。
|
||||
repo.byGroup[groupID] = []Account{
|
||||
{
|
||||
ID: 3,
|
||||
Platform: PlatformAnthropic,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-3-7-sonnet": "claude-3-7-sonnet",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
models3 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic)
|
||||
require.Equal(t, []string{"claude-3-5-haiku", "claude-3-5-sonnet"}, models3)
|
||||
require.Equal(t, int64(1), repo.listByGroupCalls.Load())
|
||||
|
||||
svc.InvalidateAvailableModelsCache(&groupID, PlatformAnthropic)
|
||||
models4 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic)
|
||||
require.Equal(t, []string{"claude-3-7-sonnet"}, models4)
|
||||
require.Equal(t, int64(2), repo.listByGroupCalls.Load())
|
||||
|
||||
hit, miss, store := GatewayModelsListCacheStats()
|
||||
require.Equal(t, int64(2), hit)
|
||||
require.Equal(t, int64(2), miss)
|
||||
require.Equal(t, int64(2), store)
|
||||
}
|
||||
|
||||
func TestGetAvailableModels_ErrorAndGlobalListBranches(t *testing.T) {
|
||||
resetGatewayHotpathStatsForTest()
|
||||
|
||||
errRepo := &modelsListAccountRepoStub{
|
||||
err: errors.New("db error"),
|
||||
}
|
||||
svcErr := &GatewayService{
|
||||
accountRepo: errRepo,
|
||||
modelsListCache: gocache.New(time.Minute, time.Minute),
|
||||
modelsListCacheTTL: time.Minute,
|
||||
}
|
||||
require.Nil(t, svcErr.GetAvailableModels(context.Background(), nil, ""))
|
||||
|
||||
okRepo := &modelsListAccountRepoStub{
|
||||
all: []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformAnthropic,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-3-5-sonnet": "claude-3-5-sonnet",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
svcOK := &GatewayService{
|
||||
accountRepo: okRepo,
|
||||
modelsListCache: gocache.New(time.Minute, time.Minute),
|
||||
modelsListCacheTTL: time.Minute,
|
||||
}
|
||||
models := svcOK.GetAvailableModels(context.Background(), nil, "")
|
||||
require.Equal(t, []string{"claude-3-5-sonnet", "gemini-2.5-pro"}, models)
|
||||
require.Equal(t, int64(1), okRepo.listAllCalls.Load())
|
||||
}
|
||||
|
||||
func TestGatewayHotpathHelpers_CacheTTLAndStickyContext(t *testing.T) {
|
||||
t.Run("resolve_user_group_rate_cache_ttl", func(t *testing.T) {
|
||||
require.Equal(t, defaultUserGroupRateCacheTTL, resolveUserGroupRateCacheTTL(nil))
|
||||
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
UserGroupRateCacheTTLSeconds: 45,
|
||||
},
|
||||
}
|
||||
require.Equal(t, 45*time.Second, resolveUserGroupRateCacheTTL(cfg))
|
||||
})
|
||||
|
||||
t.Run("resolve_models_list_cache_ttl", func(t *testing.T) {
|
||||
require.Equal(t, defaultModelsListCacheTTL, resolveModelsListCacheTTL(nil))
|
||||
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
ModelsListCacheTTLSeconds: 20,
|
||||
},
|
||||
}
|
||||
require.Equal(t, 20*time.Second, resolveModelsListCacheTTL(cfg))
|
||||
})
|
||||
|
||||
t.Run("prefetched_sticky_account_id_from_context", func(t *testing.T) {
|
||||
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.TODO()))
|
||||
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.Background()))
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(123))
|
||||
require.Equal(t, int64(123), prefetchedStickyAccountIDFromContext(ctx))
|
||||
|
||||
ctx2 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, 456)
|
||||
require.Equal(t, int64(456), prefetchedStickyAccountIDFromContext(ctx2))
|
||||
|
||||
ctx3 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, "invalid")
|
||||
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx3))
|
||||
})
|
||||
|
||||
t.Run("window_cost_from_prefetch_context", func(t *testing.T) {
|
||||
require.Equal(t, false, func() bool {
|
||||
_, ok := windowCostFromPrefetchContext(context.TODO(), 0)
|
||||
return ok
|
||||
}())
|
||||
require.Equal(t, false, func() bool {
|
||||
_, ok := windowCostFromPrefetchContext(context.Background(), 1)
|
||||
return ok
|
||||
}())
|
||||
|
||||
ctx := context.WithValue(context.Background(), windowCostPrefetchContextKey, map[int64]float64{
|
||||
9: 12.34,
|
||||
})
|
||||
cost, ok := windowCostFromPrefetchContext(ctx, 9)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 12.34, cost)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInvalidateAvailableModelsCache_ByDimensions(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
modelsListCache: gocache.New(time.Minute, time.Minute),
|
||||
}
|
||||
group9 := int64(9)
|
||||
group10 := int64(10)
|
||||
svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformAnthropic), []string{"a"}, time.Minute)
|
||||
svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformGemini), []string{"b"}, time.Minute)
|
||||
svc.modelsListCache.Set(modelsListCacheKey(&group10, PlatformAnthropic), []string{"c"}, time.Minute)
|
||||
svc.modelsListCache.Set("invalid-key", []string{"d"}, time.Minute)
|
||||
|
||||
t.Run("invalidate_group_and_platform", func(t *testing.T) {
|
||||
svc.InvalidateAvailableModelsCache(&group9, PlatformAnthropic)
|
||||
_, found := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic))
|
||||
require.False(t, found)
|
||||
_, stillFound := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini))
|
||||
require.True(t, stillFound)
|
||||
})
|
||||
|
||||
t.Run("invalidate_group_only", func(t *testing.T) {
|
||||
svc.InvalidateAvailableModelsCache(&group9, "")
|
||||
_, foundA := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic))
|
||||
_, foundB := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini))
|
||||
require.False(t, foundA)
|
||||
require.False(t, foundB)
|
||||
_, foundOtherGroup := svc.modelsListCache.Get(modelsListCacheKey(&group10, PlatformAnthropic))
|
||||
require.True(t, foundOtherGroup)
|
||||
})
|
||||
|
||||
t.Run("invalidate_platform_only", func(t *testing.T) {
|
||||
// 重建数据后仅按 platform 失效
|
||||
svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformAnthropic), []string{"a"}, time.Minute)
|
||||
svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformGemini), []string{"b"}, time.Minute)
|
||||
svc.modelsListCache.Set(modelsListCacheKey(&group10, PlatformAnthropic), []string{"c"}, time.Minute)
|
||||
|
||||
svc.InvalidateAvailableModelsCache(nil, PlatformAnthropic)
|
||||
_, found9Anthropic := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic))
|
||||
_, found10Anthropic := svc.modelsListCache.Get(modelsListCacheKey(&group10, PlatformAnthropic))
|
||||
_, found9Gemini := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini))
|
||||
require.False(t, found9Anthropic)
|
||||
require.False(t, found10Anthropic)
|
||||
require.True(t, found9Gemini)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
|
||||
now := time.Now().Add(-time.Minute)
|
||||
account := Account{
|
||||
ID: 88,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 4,
|
||||
Priority: 1,
|
||||
LastUsedAt: &now,
|
||||
}
|
||||
|
||||
repo := stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||
concurrency := NewConcurrencyService(stubConcurrencyCache{})
|
||||
|
||||
cfg := &config.Config{
|
||||
RunMode: config.RunModeStandard,
|
||||
Gateway: config.GatewayConfig{
|
||||
Scheduling: config.GatewaySchedulingConfig{
|
||||
LoadBatchEnabled: true,
|
||||
StickySessionMaxWaiting: 3,
|
||||
StickySessionWaitTimeout: time.Second,
|
||||
FallbackWaitTimeout: time.Second,
|
||||
FallbackMaxWaiting: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
baseCtx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformAnthropic)
|
||||
|
||||
t.Run("without_prefetch_reads_cache_once", func(t *testing.T) {
|
||||
cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID}
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: concurrency,
|
||||
userGroupRateCache: gocache.New(time.Minute, time.Minute),
|
||||
modelsListCache: gocache.New(time.Minute, time.Minute),
|
||||
modelsListCacheTTL: time.Minute,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, account.ID, result.Account.ID)
|
||||
require.Equal(t, int64(1), cache.getCalls.Load())
|
||||
})
|
||||
|
||||
t.Run("with_prefetch_skips_cache_read", func(t *testing.T) {
|
||||
cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID}
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: concurrency,
|
||||
userGroupRateCache: gocache.New(time.Minute, time.Minute),
|
||||
modelsListCache: gocache.New(time.Minute, time.Minute),
|
||||
modelsListCacheTTL: time.Minute,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, account.ID, result.Account.ID)
|
||||
require.Equal(t, int64(0), cache.getCalls.Load())
|
||||
})
|
||||
}
|
||||
@@ -24,12 +24,15 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/google/uuid"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -44,6 +47,9 @@ const (
|
||||
// separator between system blocks, we add "\n\n" at concatenation time.
|
||||
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
|
||||
|
||||
defaultUserGroupRateCacheTTL = 30 * time.Second
|
||||
defaultModelsListCacheTTL = 15 * time.Second
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -62,6 +68,53 @@ type accountWithLoad struct {
|
||||
|
||||
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
|
||||
|
||||
var (
|
||||
windowCostPrefetchCacheHitTotal atomic.Int64
|
||||
windowCostPrefetchCacheMissTotal atomic.Int64
|
||||
windowCostPrefetchBatchSQLTotal atomic.Int64
|
||||
windowCostPrefetchFallbackTotal atomic.Int64
|
||||
windowCostPrefetchErrorTotal atomic.Int64
|
||||
|
||||
userGroupRateCacheHitTotal atomic.Int64
|
||||
userGroupRateCacheMissTotal atomic.Int64
|
||||
userGroupRateCacheLoadTotal atomic.Int64
|
||||
userGroupRateCacheSFSharedTotal atomic.Int64
|
||||
userGroupRateCacheFallbackTotal atomic.Int64
|
||||
|
||||
modelsListCacheHitTotal atomic.Int64
|
||||
modelsListCacheMissTotal atomic.Int64
|
||||
modelsListCacheStoreTotal atomic.Int64
|
||||
)
|
||||
|
||||
func GatewayWindowCostPrefetchStats() (cacheHit, cacheMiss, batchSQL, fallback, errCount int64) {
|
||||
return windowCostPrefetchCacheHitTotal.Load(),
|
||||
windowCostPrefetchCacheMissTotal.Load(),
|
||||
windowCostPrefetchBatchSQLTotal.Load(),
|
||||
windowCostPrefetchFallbackTotal.Load(),
|
||||
windowCostPrefetchErrorTotal.Load()
|
||||
}
|
||||
|
||||
func GatewayUserGroupRateCacheStats() (cacheHit, cacheMiss, load, singleflightShared, fallback int64) {
|
||||
return userGroupRateCacheHitTotal.Load(),
|
||||
userGroupRateCacheMissTotal.Load(),
|
||||
userGroupRateCacheLoadTotal.Load(),
|
||||
userGroupRateCacheSFSharedTotal.Load(),
|
||||
userGroupRateCacheFallbackTotal.Load()
|
||||
}
|
||||
|
||||
func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) {
|
||||
return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load()
|
||||
}
|
||||
|
||||
func cloneStringSlice(src []string) []string {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
dst := make([]string, len(src))
|
||||
copy(dst, src)
|
||||
return dst
|
||||
}
|
||||
|
||||
// IsForceCacheBilling 检查是否启用强制缓存计费
|
||||
func IsForceCacheBilling(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(ForceCacheBillingContextKey).(bool)
|
||||
@@ -302,6 +355,42 @@ func derefGroupID(groupID *int64) int64 {
|
||||
return *groupID
|
||||
}
|
||||
|
||||
func resolveUserGroupRateCacheTTL(cfg *config.Config) time.Duration {
|
||||
if cfg == nil || cfg.Gateway.UserGroupRateCacheTTLSeconds <= 0 {
|
||||
return defaultUserGroupRateCacheTTL
|
||||
}
|
||||
return time.Duration(cfg.Gateway.UserGroupRateCacheTTLSeconds) * time.Second
|
||||
}
|
||||
|
||||
func resolveModelsListCacheTTL(cfg *config.Config) time.Duration {
|
||||
if cfg == nil || cfg.Gateway.ModelsListCacheTTLSeconds <= 0 {
|
||||
return defaultModelsListCacheTTL
|
||||
}
|
||||
return time.Duration(cfg.Gateway.ModelsListCacheTTLSeconds) * time.Second
|
||||
}
|
||||
|
||||
func modelsListCacheKey(groupID *int64, platform string) string {
|
||||
return fmt.Sprintf("%d|%s", derefGroupID(groupID), strings.TrimSpace(platform))
|
||||
}
|
||||
|
||||
func prefetchedStickyAccountIDFromContext(ctx context.Context) int64 {
|
||||
if ctx == nil {
|
||||
return 0
|
||||
}
|
||||
v := ctx.Value(ctxkey.PrefetchedStickyAccountID)
|
||||
switch t := v.(type) {
|
||||
case int64:
|
||||
if t > 0 {
|
||||
return t
|
||||
}
|
||||
case int:
|
||||
if t > 0 {
|
||||
return int64(t)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
|
||||
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
|
||||
// 或请求的模型处于限流状态时,返回 true。
|
||||
@@ -421,6 +510,10 @@ type GatewayService struct {
|
||||
concurrencyService *ConcurrencyService
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||
userGroupRateCache *gocache.Cache
|
||||
userGroupRateSF singleflight.Group
|
||||
modelsListCache *gocache.Cache
|
||||
modelsListCacheTTL time.Duration
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
@@ -445,6 +538,9 @@ func NewGatewayService(
|
||||
sessionLimitCache SessionLimitCache,
|
||||
digestStore *DigestSessionStore,
|
||||
) *GatewayService {
|
||||
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
|
||||
modelsListTTL := resolveModelsListCacheTTL(cfg)
|
||||
|
||||
return &GatewayService{
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
@@ -465,6 +561,9 @@ func NewGatewayService(
|
||||
deferredService: deferredService,
|
||||
claudeTokenProvider: claudeTokenProvider,
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
||||
modelsListCache: gocache.New(modelsListTTL, time.Minute),
|
||||
modelsListCacheTTL: modelsListTTL,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -937,7 +1036,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
cfg := s.schedulingConfig()
|
||||
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if prefetch := prefetchedStickyAccountIDFromContext(ctx); prefetch > 0 {
|
||||
stickyAccountID = prefetch
|
||||
} else if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
|
||||
stickyAccountID = accountID
|
||||
}
|
||||
@@ -1035,6 +1136,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if len(accounts) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
ctx = s.withWindowCostPrefetch(ctx, accounts)
|
||||
|
||||
isExcluded := func(accountID int64) bool {
|
||||
if excludedIDs == nil {
|
||||
@@ -1125,9 +1227,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
if len(routingCandidates) > 0 {
|
||||
// 1.5. 在路由账号范围内检查粘性会话
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
stickyAccountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && stickyAccountID > 0 && containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
||||
if sessionHash != "" && stickyAccountID > 0 {
|
||||
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
||||
// 粘性账号在路由列表中,优先使用
|
||||
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
|
||||
if stickyAccount.IsSchedulable() &&
|
||||
@@ -1273,9 +1374,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
// ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
|
||||
if len(routingAccountIDs) == 0 && sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) {
|
||||
accountID := stickyAccountID
|
||||
if accountID > 0 && !isExcluded(accountID) {
|
||||
account, ok := accountByID[accountID]
|
||||
if ok {
|
||||
// 检查账户是否需要清理粘性会话绑定
|
||||
@@ -1760,6 +1861,129 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
|
||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
}
|
||||
|
||||
type usageLogWindowStatsBatchProvider interface {
|
||||
GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error)
|
||||
}
|
||||
|
||||
type windowCostPrefetchContextKeyType struct{}
|
||||
|
||||
var windowCostPrefetchContextKey = windowCostPrefetchContextKeyType{}
|
||||
|
||||
func windowCostFromPrefetchContext(ctx context.Context, accountID int64) (float64, bool) {
|
||||
if ctx == nil || accountID <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
m, ok := ctx.Value(windowCostPrefetchContextKey).(map[int64]float64)
|
||||
if !ok || len(m) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
v, exists := m[accountID]
|
||||
return v, exists
|
||||
}
|
||||
|
||||
func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []Account) context.Context {
|
||||
if ctx == nil || len(accounts) == 0 || s.sessionLimitCache == nil || s.usageLogRepo == nil {
|
||||
return ctx
|
||||
}
|
||||
|
||||
accountByID := make(map[int64]*Account)
|
||||
accountIDs := make([]int64, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
account := &accounts[i]
|
||||
if account == nil || !account.IsAnthropicOAuthOrSetupToken() {
|
||||
continue
|
||||
}
|
||||
if account.GetWindowCostLimit() <= 0 {
|
||||
continue
|
||||
}
|
||||
accountByID[account.ID] = account
|
||||
accountIDs = append(accountIDs, account.ID)
|
||||
}
|
||||
if len(accountIDs) == 0 {
|
||||
return ctx
|
||||
}
|
||||
|
||||
costs := make(map[int64]float64, len(accountIDs))
|
||||
cacheValues, err := s.sessionLimitCache.GetWindowCostBatch(ctx, accountIDs)
|
||||
if err == nil {
|
||||
for accountID, cost := range cacheValues {
|
||||
costs[accountID] = cost
|
||||
}
|
||||
windowCostPrefetchCacheHitTotal.Add(int64(len(cacheValues)))
|
||||
} else {
|
||||
windowCostPrefetchErrorTotal.Add(1)
|
||||
logger.LegacyPrintf("service.gateway", "window_cost batch cache read failed: %v", err)
|
||||
}
|
||||
cacheMissCount := len(accountIDs) - len(costs)
|
||||
if cacheMissCount < 0 {
|
||||
cacheMissCount = 0
|
||||
}
|
||||
windowCostPrefetchCacheMissTotal.Add(int64(cacheMissCount))
|
||||
|
||||
missingByStart := make(map[int64][]int64)
|
||||
startTimes := make(map[int64]time.Time)
|
||||
for _, accountID := range accountIDs {
|
||||
if _, ok := costs[accountID]; ok {
|
||||
continue
|
||||
}
|
||||
account := accountByID[accountID]
|
||||
if account == nil {
|
||||
continue
|
||||
}
|
||||
startTime := account.GetCurrentWindowStartTime()
|
||||
startKey := startTime.Unix()
|
||||
missingByStart[startKey] = append(missingByStart[startKey], accountID)
|
||||
startTimes[startKey] = startTime
|
||||
}
|
||||
if len(missingByStart) == 0 {
|
||||
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
|
||||
}
|
||||
|
||||
batchReader, hasBatch := s.usageLogRepo.(usageLogWindowStatsBatchProvider)
|
||||
for startKey, ids := range missingByStart {
|
||||
startTime := startTimes[startKey]
|
||||
|
||||
if hasBatch {
|
||||
windowCostPrefetchBatchSQLTotal.Add(1)
|
||||
queryStart := time.Now()
|
||||
statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, ids, startTime)
|
||||
if err == nil {
|
||||
slog.Debug("window_cost_batch_query_ok",
|
||||
"accounts", len(ids),
|
||||
"window_start", startTime.Format(time.RFC3339),
|
||||
"duration_ms", time.Since(queryStart).Milliseconds())
|
||||
for _, accountID := range ids {
|
||||
stats := statsByAccount[accountID]
|
||||
cost := 0.0
|
||||
if stats != nil {
|
||||
cost = stats.StandardCost
|
||||
}
|
||||
costs[accountID] = cost
|
||||
_ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost)
|
||||
}
|
||||
continue
|
||||
}
|
||||
windowCostPrefetchErrorTotal.Add(1)
|
||||
logger.LegacyPrintf("service.gateway", "window_cost batch db query failed: start=%s err=%v", startTime.Format(time.RFC3339), err)
|
||||
}
|
||||
|
||||
// 回退路径:缺少批量仓储能力或批量查询失败时,按账号单查(失败开放)。
|
||||
windowCostPrefetchFallbackTotal.Add(int64(len(ids)))
|
||||
for _, accountID := range ids {
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime)
|
||||
if err != nil {
|
||||
windowCostPrefetchErrorTotal.Add(1)
|
||||
continue
|
||||
}
|
||||
cost := stats.StandardCost
|
||||
costs[accountID] = cost
|
||||
_ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost)
|
||||
}
|
||||
}
|
||||
|
||||
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
|
||||
}
|
||||
|
||||
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||
// 返回 true 表示可调度,false 表示不可调度
|
||||
@@ -1776,6 +2000,10 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context,
|
||||
|
||||
// 尝试从缓存获取窗口费用
|
||||
var currentCost float64
|
||||
if cost, ok := windowCostFromPrefetchContext(ctx, account.ID); ok {
|
||||
currentCost = cost
|
||||
goto checkSchedulability
|
||||
}
|
||||
if s.sessionLimitCache != nil {
|
||||
if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit {
|
||||
currentCost = cost
|
||||
@@ -5264,6 +5492,66 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
||||
return body
|
||||
}
|
||||
|
||||
func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
|
||||
if s == nil || userID <= 0 || groupID <= 0 {
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%d:%d", userID, groupID)
|
||||
if s.userGroupRateCache != nil {
|
||||
if cached, ok := s.userGroupRateCache.Get(key); ok {
|
||||
if multiplier, castOK := cached.(float64); castOK {
|
||||
userGroupRateCacheHitTotal.Add(1)
|
||||
return multiplier
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.userGroupRateRepo == nil {
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
userGroupRateCacheMissTotal.Add(1)
|
||||
|
||||
value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) {
|
||||
if s.userGroupRateCache != nil {
|
||||
if cached, ok := s.userGroupRateCache.Get(key); ok {
|
||||
if multiplier, castOK := cached.(float64); castOK {
|
||||
userGroupRateCacheHitTotal.Add(1)
|
||||
return multiplier, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userGroupRateCacheLoadTotal.Add(1)
|
||||
userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID)
|
||||
if repoErr != nil {
|
||||
return nil, repoErr
|
||||
}
|
||||
multiplier := groupDefaultMultiplier
|
||||
if userRate != nil {
|
||||
multiplier = *userRate
|
||||
}
|
||||
if s.userGroupRateCache != nil {
|
||||
s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg))
|
||||
}
|
||||
return multiplier, nil
|
||||
})
|
||||
if shared {
|
||||
userGroupRateCacheSFSharedTotal.Add(1)
|
||||
}
|
||||
if err != nil {
|
||||
userGroupRateCacheFallbackTotal.Add(1)
|
||||
logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err)
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
|
||||
multiplier, ok := value.(float64)
|
||||
if !ok {
|
||||
userGroupRateCacheFallbackTotal.Add(1)
|
||||
return groupDefaultMultiplier
|
||||
}
|
||||
return multiplier
|
||||
}
|
||||
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
@@ -5307,16 +5595,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
}
|
||||
|
||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
multiplier := 1.0
|
||||
if s.cfg != nil {
|
||||
multiplier = s.cfg.Default.RateMultiplier
|
||||
}
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
|
||||
// 检查用户专属倍率
|
||||
if s.userGroupRateRepo != nil {
|
||||
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
|
||||
multiplier = *userRate
|
||||
}
|
||||
}
|
||||
groupDefault := apiKey.Group.RateMultiplier
|
||||
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
@@ -5522,16 +5807,13 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
}
|
||||
|
||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
multiplier := 1.0
|
||||
if s.cfg != nil {
|
||||
multiplier = s.cfg.Default.RateMultiplier
|
||||
}
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
|
||||
// 检查用户专属倍率
|
||||
if s.userGroupRateRepo != nil {
|
||||
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
|
||||
multiplier = *userRate
|
||||
}
|
||||
}
|
||||
groupDefault := apiKey.Group.RateMultiplier
|
||||
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
@@ -6145,6 +6427,17 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
// GetAvailableModels returns the list of models available for a group
|
||||
// It aggregates model_mapping keys from all schedulable accounts in the group
|
||||
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
|
||||
cacheKey := modelsListCacheKey(groupID, platform)
|
||||
if s.modelsListCache != nil {
|
||||
if cached, found := s.modelsListCache.Get(cacheKey); found {
|
||||
if models, ok := cached.([]string); ok {
|
||||
modelsListCacheHitTotal.Add(1)
|
||||
return cloneStringSlice(models)
|
||||
}
|
||||
}
|
||||
}
|
||||
modelsListCacheMissTotal.Add(1)
|
||||
|
||||
var accounts []Account
|
||||
var err error
|
||||
|
||||
@@ -6185,6 +6478,10 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64,
|
||||
|
||||
// If no account has model_mapping, return nil (use default)
|
||||
if !hasAnyMapping {
|
||||
if s.modelsListCache != nil {
|
||||
s.modelsListCache.Set(cacheKey, []string(nil), s.modelsListCacheTTL)
|
||||
modelsListCacheStoreTotal.Add(1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6193,8 +6490,45 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64,
|
||||
for model := range modelSet {
|
||||
models = append(models, model)
|
||||
}
|
||||
sort.Strings(models)
|
||||
|
||||
return models
|
||||
if s.modelsListCache != nil {
|
||||
s.modelsListCache.Set(cacheKey, cloneStringSlice(models), s.modelsListCacheTTL)
|
||||
modelsListCacheStoreTotal.Add(1)
|
||||
}
|
||||
return cloneStringSlice(models)
|
||||
}
|
||||
|
||||
func (s *GatewayService) InvalidateAvailableModelsCache(groupID *int64, platform string) {
|
||||
if s == nil || s.modelsListCache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
normalizedPlatform := strings.TrimSpace(platform)
|
||||
// 完整匹配时精准失效;否则按维度批量失效。
|
||||
if groupID != nil && normalizedPlatform != "" {
|
||||
s.modelsListCache.Delete(modelsListCacheKey(groupID, normalizedPlatform))
|
||||
return
|
||||
}
|
||||
|
||||
targetGroup := derefGroupID(groupID)
|
||||
for key := range s.modelsListCache.Items() {
|
||||
parts := strings.SplitN(key, "|", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
groupPart, parseErr := strconv.ParseInt(parts[0], 10, 64)
|
||||
if parseErr != nil {
|
||||
continue
|
||||
}
|
||||
if groupID != nil && groupPart != targetGroup {
|
||||
continue
|
||||
}
|
||||
if normalizedPlatform != "" && parts[1] != normalizedPlatform {
|
||||
continue
|
||||
}
|
||||
s.modelsListCache.Delete(key)
|
||||
}
|
||||
}
|
||||
|
||||
// reconcileCachedTokens 兼容 Kimi 等上游:
|
||||
|
||||
@@ -20,6 +20,22 @@ const (
|
||||
opsMaxStoredErrorBodyBytes = 20 * 1024
|
||||
)
|
||||
|
||||
// PrepareOpsRequestBodyForQueue 在入队前对请求体执行脱敏与裁剪,返回可直接写入 OpsInsertErrorLogInput 的字段。
|
||||
// 该方法用于避免异步队列持有大块原始请求体,减少错误风暴下的内存放大风险。
|
||||
func PrepareOpsRequestBodyForQueue(raw []byte) (requestBodyJSON *string, truncated bool, requestBodyBytes *int) {
|
||||
if len(raw) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
sanitized, truncated, bytesLen := sanitizeAndTrimRequestBody(raw, opsMaxStoredRequestBodyBytes)
|
||||
if sanitized != "" {
|
||||
out := sanitized
|
||||
requestBodyJSON = &out
|
||||
}
|
||||
n := bytesLen
|
||||
requestBodyBytes = &n
|
||||
return requestBodyJSON, truncated, requestBodyBytes
|
||||
}
|
||||
|
||||
// OpsService provides ingestion and query APIs for the Ops monitoring module.
|
||||
type OpsService struct {
|
||||
opsRepo OpsRepository
|
||||
@@ -132,12 +148,7 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn
|
||||
|
||||
// Sanitize + trim request body (errors only).
|
||||
if len(rawRequestBody) > 0 {
|
||||
sanitized, truncated, bytesLen := sanitizeAndTrimRequestBody(rawRequestBody, opsMaxStoredRequestBodyBytes)
|
||||
if sanitized != "" {
|
||||
entry.RequestBodyJSON = &sanitized
|
||||
}
|
||||
entry.RequestBodyTruncated = truncated
|
||||
entry.RequestBodyBytes = &bytesLen
|
||||
entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = PrepareOpsRequestBodyForQueue(rawRequestBody)
|
||||
}
|
||||
|
||||
// Sanitize + truncate error_body to avoid storing sensitive data.
|
||||
|
||||
60
backend/internal/service/ops_service_prepare_queue_test.go
Normal file
60
backend/internal/service/ops_service_prepare_queue_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPrepareOpsRequestBodyForQueue_EmptyBody(t *testing.T) {
|
||||
requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(nil)
|
||||
require.Nil(t, requestBodyJSON)
|
||||
require.False(t, truncated)
|
||||
require.Nil(t, requestBodyBytes)
|
||||
}
|
||||
|
||||
func TestPrepareOpsRequestBodyForQueue_InvalidJSON(t *testing.T) {
|
||||
raw := []byte("{invalid-json")
|
||||
requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw)
|
||||
require.Nil(t, requestBodyJSON)
|
||||
require.False(t, truncated)
|
||||
require.NotNil(t, requestBodyBytes)
|
||||
require.Equal(t, len(raw), *requestBodyBytes)
|
||||
}
|
||||
|
||||
func TestPrepareOpsRequestBodyForQueue_RedactSensitiveFields(t *testing.T) {
|
||||
raw := []byte(`{
|
||||
"model":"claude-3-5-sonnet-20241022",
|
||||
"api_key":"sk-test-123",
|
||||
"headers":{"authorization":"Bearer secret-token"},
|
||||
"messages":[{"role":"user","content":"hello"}]
|
||||
}`)
|
||||
|
||||
requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw)
|
||||
require.NotNil(t, requestBodyJSON)
|
||||
require.NotNil(t, requestBodyBytes)
|
||||
require.False(t, truncated)
|
||||
require.Equal(t, len(raw), *requestBodyBytes)
|
||||
|
||||
var body map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(*requestBodyJSON), &body))
|
||||
require.Equal(t, "[REDACTED]", body["api_key"])
|
||||
headers, ok := body["headers"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "[REDACTED]", headers["authorization"])
|
||||
}
|
||||
|
||||
func TestPrepareOpsRequestBodyForQueue_LargeBodyTruncated(t *testing.T) {
|
||||
largeMsg := strings.Repeat("x", opsMaxStoredRequestBodyBytes*2)
|
||||
raw := []byte(`{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":"` + largeMsg + `"}]}`)
|
||||
|
||||
requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw)
|
||||
require.NotNil(t, requestBodyJSON)
|
||||
require.NotNil(t, requestBodyBytes)
|
||||
require.True(t, truncated)
|
||||
require.Equal(t, len(raw), *requestBodyBytes)
|
||||
require.LessOrEqual(t, len(*requestBodyJSON), opsMaxStoredRequestBodyBytes)
|
||||
require.Contains(t, *requestBodyJSON, "request_body_truncated")
|
||||
}
|
||||
Reference in New Issue
Block a user