Files
sub2api/backend/internal/service/openai_account_scheduler.go

1058 lines
30 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"container/heap"
"context"
"errors"
"fmt"
"hash/fnv"
"math"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"golang.org/x/sync/singleflight"
)
const (
openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
openAIAccountScheduleLayerSessionSticky = "session_hash"
openAIAccountScheduleLayerLoadBalance = "load_balance"
openAIAdvancedSchedulerSettingKey = "openai_advanced_scheduler_enabled"
)
const (
openAIAdvancedSchedulerSettingCacheTTL = 5 * time.Second
openAIAdvancedSchedulerSettingDBTimeout = 2 * time.Second
)
type cachedOpenAIAdvancedSchedulerSetting struct {
enabled bool
expiresAt int64
}
var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSchedulerSetting
var openAIAdvancedSchedulerSettingSF singleflight.Group
type OpenAIAccountScheduleRequest struct {
GroupID *int64
SessionHash string
StickyAccountID int64
PreviousResponseID string
RequestedModel string
RequiredTransport OpenAIUpstreamTransport
ExcludedIDs map[int64]struct{}
}
type OpenAIAccountScheduleDecision struct {
Layer string
StickyPreviousHit bool
StickySessionHit bool
CandidateCount int
TopK int
LatencyMs int64
LoadSkew float64
SelectedAccountID int64
SelectedAccountType string
}
type OpenAIAccountSchedulerMetricsSnapshot struct {
SelectTotal int64
StickyPreviousHitTotal int64
StickySessionHitTotal int64
LoadBalanceSelectTotal int64
AccountSwitchTotal int64
SchedulerLatencyMsTotal int64
SchedulerLatencyMsAvg float64
StickyHitRatio float64
AccountSwitchRate float64
LoadSkewAvg float64
RuntimeStatsAccountCount int
}
type OpenAIAccountScheduler interface {
Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error)
ReportResult(accountID int64, success bool, firstTokenMs *int)
ReportSwitch()
SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot
}
type openAIAccountSchedulerMetrics struct {
selectTotal atomic.Int64
stickyPreviousHitTotal atomic.Int64
stickySessionHitTotal atomic.Int64
loadBalanceSelectTotal atomic.Int64
accountSwitchTotal atomic.Int64
latencyMsTotal atomic.Int64
loadSkewMilliTotal atomic.Int64
}
func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
if m == nil {
return
}
m.selectTotal.Add(1)
m.latencyMsTotal.Add(decision.LatencyMs)
m.loadSkewMilliTotal.Add(int64(math.Round(decision.LoadSkew * 1000)))
if decision.StickyPreviousHit {
m.stickyPreviousHitTotal.Add(1)
}
if decision.StickySessionHit {
m.stickySessionHitTotal.Add(1)
}
if decision.Layer == openAIAccountScheduleLayerLoadBalance {
m.loadBalanceSelectTotal.Add(1)
}
}
func (m *openAIAccountSchedulerMetrics) recordSwitch() {
if m == nil {
return
}
m.accountSwitchTotal.Add(1)
}
type openAIAccountRuntimeStats struct {
accounts sync.Map
accountCount atomic.Int64
}
type openAIAccountRuntimeStat struct {
errorRateEWMABits atomic.Uint64
ttftEWMABits atomic.Uint64
}
func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats {
return &openAIAccountRuntimeStats{}
}
func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat {
if value, ok := s.accounts.Load(accountID); ok {
stat, _ := value.(*openAIAccountRuntimeStat)
if stat != nil {
return stat
}
}
stat := &openAIAccountRuntimeStat{}
stat.ttftEWMABits.Store(math.Float64bits(math.NaN()))
actual, loaded := s.accounts.LoadOrStore(accountID, stat)
if !loaded {
s.accountCount.Add(1)
return stat
}
existing, _ := actual.(*openAIAccountRuntimeStat)
if existing != nil {
return existing
}
return stat
}
func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) {
for {
oldBits := target.Load()
oldValue := math.Float64frombits(oldBits)
newValue := alpha*sample + (1-alpha)*oldValue
if target.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
return
}
}
}
func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int) {
if s == nil || accountID <= 0 {
return
}
const alpha = 0.2
stat := s.loadOrCreate(accountID)
errorSample := 1.0
if success {
errorSample = 0.0
}
updateEWMAAtomic(&stat.errorRateEWMABits, errorSample, alpha)
if firstTokenMs != nil && *firstTokenMs > 0 {
ttft := float64(*firstTokenMs)
ttftBits := math.Float64bits(ttft)
for {
oldBits := stat.ttftEWMABits.Load()
oldValue := math.Float64frombits(oldBits)
if math.IsNaN(oldValue) {
if stat.ttftEWMABits.CompareAndSwap(oldBits, ttftBits) {
break
}
continue
}
newValue := alpha*ttft + (1-alpha)*oldValue
if stat.ttftEWMABits.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
break
}
}
}
}
func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64, ttft float64, hasTTFT bool) {
if s == nil || accountID <= 0 {
return 0, 0, false
}
value, ok := s.accounts.Load(accountID)
if !ok {
return 0, 0, false
}
stat, _ := value.(*openAIAccountRuntimeStat)
if stat == nil {
return 0, 0, false
}
errorRate = clamp01(math.Float64frombits(stat.errorRateEWMABits.Load()))
ttftValue := math.Float64frombits(stat.ttftEWMABits.Load())
if math.IsNaN(ttftValue) {
return errorRate, 0, false
}
return errorRate, ttftValue, true
}
func (s *openAIAccountRuntimeStats) size() int {
if s == nil {
return 0
}
return int(s.accountCount.Load())
}
type defaultOpenAIAccountScheduler struct {
service *OpenAIGatewayService
metrics openAIAccountSchedulerMetrics
stats *openAIAccountRuntimeStats
}
func newDefaultOpenAIAccountScheduler(service *OpenAIGatewayService, stats *openAIAccountRuntimeStats) OpenAIAccountScheduler {
if stats == nil {
stats = newOpenAIAccountRuntimeStats()
}
return &defaultOpenAIAccountScheduler{
service: service,
stats: stats,
}
}
func (s *defaultOpenAIAccountScheduler) Select(
ctx context.Context,
req OpenAIAccountScheduleRequest,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
start := time.Now()
defer func() {
decision.LatencyMs = time.Since(start).Milliseconds()
s.metrics.recordSelect(decision)
}()
previousResponseID := strings.TrimSpace(req.PreviousResponseID)
if previousResponseID != "" {
selection, err := s.service.SelectAccountByPreviousResponseID(
ctx,
req.GroupID,
previousResponseID,
req.RequestedModel,
req.ExcludedIDs,
)
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) {
selection = nil
}
}
if selection != nil && selection.Account != nil {
decision.Layer = openAIAccountScheduleLayerPreviousResponse
decision.StickyPreviousHit = true
decision.SelectedAccountID = selection.Account.ID
decision.SelectedAccountType = selection.Account.Type
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, selection.Account.ID)
}
return selection, decision, nil
}
}
selection, err := s.selectBySessionHash(ctx, req)
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
decision.Layer = openAIAccountScheduleLayerSessionSticky
decision.StickySessionHit = true
decision.SelectedAccountID = selection.Account.ID
decision.SelectedAccountType = selection.Account.Type
return selection, decision, nil
}
selection, candidateCount, topK, loadSkew, err := s.selectByLoadBalance(ctx, req)
decision.Layer = openAIAccountScheduleLayerLoadBalance
decision.CandidateCount = candidateCount
decision.TopK = topK
decision.LoadSkew = loadSkew
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
decision.SelectedAccountID = selection.Account.ID
decision.SelectedAccountType = selection.Account.Type
}
return selection, decision, nil
}
func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
ctx context.Context,
req OpenAIAccountScheduleRequest,
) (*AccountSelectionResult, error) {
sessionHash := strings.TrimSpace(req.SessionHash)
if sessionHash == "" || s == nil || s.service == nil || s.service.cache == nil {
return nil, nil
}
accountID := req.StickyAccountID
if accountID <= 0 {
var err error
accountID, err = s.service.getStickySessionAccountID(ctx, req.GroupID, sessionHash)
if err != nil || accountID <= 0 {
return nil, nil
}
}
if accountID <= 0 {
return nil, nil
}
if req.ExcludedIDs != nil {
if _, excluded := req.ExcludedIDs[accountID]; excluded {
return nil, nil
}
}
account, err := s.service.getSchedulableAccount(ctx, accountID)
if err != nil || account == nil {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() || !account.IsSchedulable() {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
return nil, nil
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel)
if account == nil {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
cfg := s.service.schedulingConfig()
// WaitPlan.MaxConcurrency 使用 Concurrency非 EffectiveLoadFactor因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
if s.service.concurrencyService != nil {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
return nil, nil
}
type openAIAccountCandidateScore struct {
account *Account
loadInfo *AccountLoadInfo
score float64
errorRate float64
ttft float64
hasTTFT bool
}
type openAIAccountCandidateHeap []openAIAccountCandidateScore
func (h openAIAccountCandidateHeap) Len() int {
return len(h)
}
func (h openAIAccountCandidateHeap) Less(i, j int) bool {
// 最小堆根节点保存“最差”候选,便于 O(log k) 维护 topK。
return isOpenAIAccountCandidateBetter(h[j], h[i])
}
func (h openAIAccountCandidateHeap) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
}
func (h *openAIAccountCandidateHeap) Push(x any) {
candidate, ok := x.(openAIAccountCandidateScore)
if !ok {
panic("openAIAccountCandidateHeap: invalid element type")
}
*h = append(*h, candidate)
}
func (h *openAIAccountCandidateHeap) Pop() any {
old := *h
n := len(old)
last := old[n-1]
*h = old[:n-1]
return last
}
func isOpenAIAccountCandidateBetter(left openAIAccountCandidateScore, right openAIAccountCandidateScore) bool {
if left.score != right.score {
return left.score > right.score
}
if left.account.Priority != right.account.Priority {
return left.account.Priority < right.account.Priority
}
if left.loadInfo.LoadRate != right.loadInfo.LoadRate {
return left.loadInfo.LoadRate < right.loadInfo.LoadRate
}
if left.loadInfo.WaitingCount != right.loadInfo.WaitingCount {
return left.loadInfo.WaitingCount < right.loadInfo.WaitingCount
}
return left.account.ID < right.account.ID
}
func selectTopKOpenAICandidates(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore {
if len(candidates) == 0 {
return nil
}
if topK <= 0 {
topK = 1
}
if topK >= len(candidates) {
ranked := append([]openAIAccountCandidateScore(nil), candidates...)
sort.Slice(ranked, func(i, j int) bool {
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
})
return ranked
}
best := make(openAIAccountCandidateHeap, 0, topK)
for _, candidate := range candidates {
if len(best) < topK {
heap.Push(&best, candidate)
continue
}
if isOpenAIAccountCandidateBetter(candidate, best[0]) {
best[0] = candidate
heap.Fix(&best, 0)
}
}
ranked := make([]openAIAccountCandidateScore, len(best))
copy(ranked, best)
sort.Slice(ranked, func(i, j int) bool {
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
})
return ranked
}
type openAISelectionRNG struct {
state uint64
}
func newOpenAISelectionRNG(seed uint64) openAISelectionRNG {
if seed == 0 {
seed = 0x9e3779b97f4a7c15
}
return openAISelectionRNG{state: seed}
}
func (r *openAISelectionRNG) nextUint64() uint64 {
// xorshift64*
x := r.state
x ^= x >> 12
x ^= x << 25
x ^= x >> 27
r.state = x
return x * 2685821657736338717
}
func (r *openAISelectionRNG) nextFloat64() float64 {
// [0,1)
return float64(r.nextUint64()>>11) / (1 << 53)
}
func deriveOpenAISelectionSeed(req OpenAIAccountScheduleRequest) uint64 {
hasher := fnv.New64a()
writeValue := func(value string) {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return
}
_, _ = hasher.Write([]byte(trimmed))
_, _ = hasher.Write([]byte{0})
}
writeValue(req.SessionHash)
writeValue(req.PreviousResponseID)
writeValue(req.RequestedModel)
if req.GroupID != nil {
_, _ = hasher.Write([]byte(strconv.FormatInt(*req.GroupID, 10)))
}
seed := hasher.Sum64()
// 对“无会话锚点”的纯负载均衡请求引入时间熵,避免固定命中同一账号。
if strings.TrimSpace(req.SessionHash) == "" && strings.TrimSpace(req.PreviousResponseID) == "" {
seed ^= uint64(time.Now().UnixNano())
}
if seed == 0 {
seed = uint64(time.Now().UnixNano()) ^ 0x9e3779b97f4a7c15
}
return seed
}
func buildOpenAIWeightedSelectionOrder(
candidates []openAIAccountCandidateScore,
req OpenAIAccountScheduleRequest,
) []openAIAccountCandidateScore {
if len(candidates) <= 1 {
return append([]openAIAccountCandidateScore(nil), candidates...)
}
pool := append([]openAIAccountCandidateScore(nil), candidates...)
weights := make([]float64, len(pool))
minScore := pool[0].score
for i := 1; i < len(pool); i++ {
if pool[i].score < minScore {
minScore = pool[i].score
}
}
for i := range pool {
// 将 top-K 分值平移到正区间,避免“单一最高分账号”长期垄断。
weight := (pool[i].score - minScore) + 1.0
if math.IsNaN(weight) || math.IsInf(weight, 0) || weight <= 0 {
weight = 1.0
}
weights[i] = weight
}
order := make([]openAIAccountCandidateScore, 0, len(pool))
rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req))
for len(pool) > 0 {
total := 0.0
for _, w := range weights {
total += w
}
selectedIdx := 0
if total > 0 {
r := rng.nextFloat64() * total
acc := 0.0
for i, w := range weights {
acc += w
if r <= acc {
selectedIdx = i
break
}
}
} else {
selectedIdx = int(rng.nextUint64() % uint64(len(pool)))
}
order = append(order, pool[selectedIdx])
pool = append(pool[:selectedIdx], pool[selectedIdx+1:]...)
weights = append(weights[:selectedIdx], weights[selectedIdx+1:]...)
}
return order
}
func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
ctx context.Context,
req OpenAIAccountScheduleRequest,
) (*AccountSelectionResult, int, int, float64, error) {
accounts, err := s.service.listSchedulableAccounts(ctx, req.GroupID)
if err != nil {
return nil, 0, 0, 0, err
}
if len(accounts) == 0 {
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
}
// require_privacy_set: 获取分组信息
var schedGroup *Group
if req.GroupID != nil && s.service.schedulerSnapshot != nil {
schedGroup, _ = s.service.schedulerSnapshot.GetGroupByID(ctx, *req.GroupID)
}
filtered := make([]*Account, 0, len(accounts))
loadReq := make([]AccountWithConcurrency, 0, len(accounts))
for i := range accounts {
account := &accounts[i]
if req.ExcludedIDs != nil {
if _, excluded := req.ExcludedIDs[account.ID]; excluded {
continue
}
}
if !account.IsSchedulable() || !account.IsOpenAI() {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !account.IsPrivacySet() {
_ = s.service.accountRepo.SetError(ctx, account.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
continue
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
continue
}
filtered = append(filtered, account)
loadReq = append(loadReq, AccountWithConcurrency{
ID: account.ID,
MaxConcurrency: account.EffectiveLoadFactor(),
})
}
if len(filtered) == 0 {
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
}
loadMap := map[int64]*AccountLoadInfo{}
if s.service.concurrencyService != nil {
if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil {
loadMap = batchLoad
}
}
minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority
maxWaiting := 1
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
hasTTFTSample := false
candidates := make([]openAIAccountCandidateScore, 0, len(filtered))
for _, account := range filtered {
loadInfo := loadMap[account.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: account.ID}
}
if account.Priority < minPriority {
minPriority = account.Priority
}
if account.Priority > maxPriority {
maxPriority = account.Priority
}
if loadInfo.WaitingCount > maxWaiting {
maxWaiting = loadInfo.WaitingCount
}
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
if hasTTFT && ttft > 0 {
if !hasTTFTSample {
minTTFT, maxTTFT = ttft, ttft
hasTTFTSample = true
} else {
if ttft < minTTFT {
minTTFT = ttft
}
if ttft > maxTTFT {
maxTTFT = ttft
}
}
}
loadRate := float64(loadInfo.LoadRate)
loadRateSum += loadRate
loadRateSumSquares += loadRate * loadRate
candidates = append(candidates, openAIAccountCandidateScore{
account: account,
loadInfo: loadInfo,
errorRate: errorRate,
ttft: ttft,
hasTTFT: hasTTFT,
})
}
loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
weights := s.service.openAIWSSchedulerWeights()
for i := range candidates {
item := &candidates[i]
priorityFactor := 1.0
if maxPriority > minPriority {
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
}
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
errorFactor := 1 - clamp01(item.errorRate)
ttftFactor := 0.5
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
}
item.score = weights.Priority*priorityFactor +
weights.Load*loadFactor +
weights.Queue*queueFactor +
weights.ErrorRate*errorFactor +
weights.TTFT*ttftFactor
}
topK := s.service.openAIWSLBTopK()
if topK > len(candidates) {
topK = len(candidates)
}
if topK <= 0 {
topK = 1
}
rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
continue
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if acquireErr != nil {
return nil, len(candidates), topK, loadSkew, acquireErr
}
if result != nil && result.Acquired {
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
}
return &AccountSelectionResult{
Account: fresh,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, len(candidates), topK, loadSkew, nil
}
}
cfg := s.service.schedulingConfig()
// WaitPlan.MaxConcurrency 使用 Concurrency非 EffectiveLoadFactor因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
for _, candidate := range selectionOrder {
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
continue
}
return &AccountSelectionResult{
Account: fresh,
WaitPlan: &AccountWaitPlan{
AccountID: fresh.ID,
MaxConcurrency: fresh.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, len(candidates), topK, loadSkew, nil
}
return nil, len(candidates), topK, loadSkew, ErrNoAvailableAccounts
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
return true
}
if s == nil || s.service == nil {
return false
}
return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport)
}
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
if s == nil || s.stats == nil {
return
}
s.stats.report(accountID, success, firstTokenMs)
}
func (s *defaultOpenAIAccountScheduler) ReportSwitch() {
if s == nil {
return
}
s.metrics.recordSwitch()
}
func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot {
if s == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}
selectTotal := s.metrics.selectTotal.Load()
prevHit := s.metrics.stickyPreviousHitTotal.Load()
sessionHit := s.metrics.stickySessionHitTotal.Load()
switchTotal := s.metrics.accountSwitchTotal.Load()
latencyTotal := s.metrics.latencyMsTotal.Load()
loadSkewTotal := s.metrics.loadSkewMilliTotal.Load()
snapshot := OpenAIAccountSchedulerMetricsSnapshot{
SelectTotal: selectTotal,
StickyPreviousHitTotal: prevHit,
StickySessionHitTotal: sessionHit,
LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(),
AccountSwitchTotal: switchTotal,
SchedulerLatencyMsTotal: latencyTotal,
RuntimeStatsAccountCount: s.stats.size(),
}
if selectTotal > 0 {
snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal)
snapshot.StickyHitRatio = float64(prevHit+sessionHit) / float64(selectTotal)
snapshot.AccountSwitchRate = float64(switchTotal) / float64(selectTotal)
snapshot.LoadSkewAvg = float64(loadSkewTotal) / 1000 / float64(selectTotal)
}
return snapshot
}
func (s *OpenAIGatewayService) openAIAdvancedSchedulerSettingRepo() SettingRepository {
if s == nil || s.rateLimitService == nil || s.rateLimitService.settingService == nil {
return nil
}
return s.rateLimitService.settingService.settingRepo
}
func (s *OpenAIGatewayService) isOpenAIAdvancedSchedulerEnabled(ctx context.Context) bool {
if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.enabled
}
}
result, _, _ := openAIAdvancedSchedulerSettingSF.Do(openAIAdvancedSchedulerSettingKey, func() (any, error) {
if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.enabled, nil
}
}
enabled := false
if repo := s.openAIAdvancedSchedulerSettingRepo(); repo != nil {
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAdvancedSchedulerSettingDBTimeout)
defer cancel()
value, err := repo.GetValue(dbCtx, openAIAdvancedSchedulerSettingKey)
if err == nil {
enabled = strings.EqualFold(strings.TrimSpace(value), "true")
}
}
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
enabled: enabled,
expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
})
return enabled, nil
})
enabled, _ := result.(bool)
return enabled
}
func (s *OpenAIGatewayService) getOpenAIAccountScheduler(ctx context.Context) OpenAIAccountScheduler {
if s == nil {
return nil
}
if !s.isOpenAIAdvancedSchedulerEnabled(ctx) {
return nil
}
s.openaiSchedulerOnce.Do(func() {
if s.openaiAccountStats == nil {
s.openaiAccountStats = newOpenAIAccountRuntimeStats()
}
if s.openaiScheduler == nil {
s.openaiScheduler = newDefaultOpenAIAccountScheduler(s, s.openaiAccountStats)
}
})
return s.openaiScheduler
}
func resetOpenAIAdvancedSchedulerSettingCacheForTest() {
openAIAdvancedSchedulerSettingCache = atomic.Value{}
openAIAdvancedSchedulerSettingSF = singleflight.Group{}
}
func (s *OpenAIGatewayService) SelectAccountWithScheduler(
ctx context.Context,
groupID *int64,
previousResponseID string,
sessionHash string,
requestedModel string,
excludedIDs map[int64]struct{},
requiredTransport OpenAIUpstreamTransport,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler(ctx)
if scheduler == nil {
decision.Layer = openAIAccountScheduleLayerLoadBalance
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
return selection, decision, err
}
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
for {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
if err != nil {
return nil, decision, err
}
if selection == nil || selection.Account == nil {
return selection, decision, nil
}
if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) {
return selection, decision, nil
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if effectiveExcludedIDs == nil {
effectiveExcludedIDs = make(map[int64]struct{})
}
if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists {
return nil, decision, ErrNoAvailableAccounts
}
effectiveExcludedIDs[selection.Account.ID] = struct{}{}
}
}
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 {
stickyAccountID = accountID
}
}
return scheduler.Select(ctx, OpenAIAccountScheduleRequest{
GroupID: groupID,
SessionHash: sessionHash,
StickyAccountID: stickyAccountID,
PreviousResponseID: previousResponseID,
RequestedModel: requestedModel,
RequiredTransport: requiredTransport,
ExcludedIDs: excludedIDs,
})
}
func cloneExcludedAccountIDs(excludedIDs map[int64]struct{}) map[int64]struct{} {
if len(excludedIDs) == 0 {
return nil
}
cloned := make(map[int64]struct{}, len(excludedIDs))
for id := range excludedIDs {
cloned[id] = struct{}{}
}
return cloned
}
func (s *OpenAIGatewayService) isOpenAIAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
return true
}
if s == nil || account == nil {
return false
}
return s.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
}
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return
}
scheduler.ReportResult(accountID, success, firstTokenMs)
}
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return
}
scheduler.ReportSwitch()
}
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}
return scheduler.SnapshotMetrics()
}
func (s *OpenAIGatewayService) openAIWSSessionStickyTTL() time.Duration {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 {
return time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second
}
return openaiStickySessionTTL
}
func (s *OpenAIGatewayService) openAIWSLBTopK() int {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.LBTopK > 0 {
return s.cfg.Gateway.OpenAIWS.LBTopK
}
return 7
}
func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView {
if s != nil && s.cfg != nil {
return GatewayOpenAIWSSchedulerScoreWeightsView{
Priority: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority,
Load: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load,
Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue,
ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate,
TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT,
}
}
return GatewayOpenAIWSSchedulerScoreWeightsView{
Priority: 1.0,
Load: 1.0,
Queue: 0.7,
ErrorRate: 0.8,
TTFT: 0.5,
}
}
type GatewayOpenAIWSSchedulerScoreWeightsView struct {
Priority float64
Load float64
Queue float64
ErrorRate float64
TTFT float64
}
func clamp01(value float64) float64 {
switch {
case value < 0:
return 0
case value > 1:
return 1
default:
return value
}
}
func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 {
if count <= 1 {
return 0
}
mean := sum / float64(count)
variance := sumSquares/float64(count) - mean*mean
if variance < 0 {
variance = 0
}
return math.Sqrt(variance)
}