Files
xinghuoapi/backend/internal/service/openai_account_scheduler.go
erio 12b33f4ea4 fix: address load_factor code review findings
- Fix bulk edit: send 0 instead of null/NaN to clear load_factor
- Fix edit modal: explicit NaN check instead of implicit falsy
- Fix create modal: use ?? instead of || for load_factor
- Add load_factor upper limit validation (max 10000)
- Add //go:build unit tag and self-contained intPtrHelper in test
- Add design intent comments on WaitPlan.MaxConcurrency
2026-03-06 05:07:52 +08:00

912 lines
25 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"container/heap"
"context"
"errors"
"hash/fnv"
"math"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
const (
openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
openAIAccountScheduleLayerSessionSticky = "session_hash"
openAIAccountScheduleLayerLoadBalance = "load_balance"
)
type OpenAIAccountScheduleRequest struct {
GroupID *int64
SessionHash string
StickyAccountID int64
PreviousResponseID string
RequestedModel string
RequiredTransport OpenAIUpstreamTransport
ExcludedIDs map[int64]struct{}
}
type OpenAIAccountScheduleDecision struct {
Layer string
StickyPreviousHit bool
StickySessionHit bool
CandidateCount int
TopK int
LatencyMs int64
LoadSkew float64
SelectedAccountID int64
SelectedAccountType string
}
type OpenAIAccountSchedulerMetricsSnapshot struct {
SelectTotal int64
StickyPreviousHitTotal int64
StickySessionHitTotal int64
LoadBalanceSelectTotal int64
AccountSwitchTotal int64
SchedulerLatencyMsTotal int64
SchedulerLatencyMsAvg float64
StickyHitRatio float64
AccountSwitchRate float64
LoadSkewAvg float64
RuntimeStatsAccountCount int
}
type OpenAIAccountScheduler interface {
Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error)
ReportResult(accountID int64, success bool, firstTokenMs *int)
ReportSwitch()
SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot
}
type openAIAccountSchedulerMetrics struct {
selectTotal atomic.Int64
stickyPreviousHitTotal atomic.Int64
stickySessionHitTotal atomic.Int64
loadBalanceSelectTotal atomic.Int64
accountSwitchTotal atomic.Int64
latencyMsTotal atomic.Int64
loadSkewMilliTotal atomic.Int64
}
func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
if m == nil {
return
}
m.selectTotal.Add(1)
m.latencyMsTotal.Add(decision.LatencyMs)
m.loadSkewMilliTotal.Add(int64(math.Round(decision.LoadSkew * 1000)))
if decision.StickyPreviousHit {
m.stickyPreviousHitTotal.Add(1)
}
if decision.StickySessionHit {
m.stickySessionHitTotal.Add(1)
}
if decision.Layer == openAIAccountScheduleLayerLoadBalance {
m.loadBalanceSelectTotal.Add(1)
}
}
func (m *openAIAccountSchedulerMetrics) recordSwitch() {
if m == nil {
return
}
m.accountSwitchTotal.Add(1)
}
type openAIAccountRuntimeStats struct {
accounts sync.Map
accountCount atomic.Int64
}
type openAIAccountRuntimeStat struct {
errorRateEWMABits atomic.Uint64
ttftEWMABits atomic.Uint64
}
func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats {
return &openAIAccountRuntimeStats{}
}
func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat {
if value, ok := s.accounts.Load(accountID); ok {
stat, _ := value.(*openAIAccountRuntimeStat)
if stat != nil {
return stat
}
}
stat := &openAIAccountRuntimeStat{}
stat.ttftEWMABits.Store(math.Float64bits(math.NaN()))
actual, loaded := s.accounts.LoadOrStore(accountID, stat)
if !loaded {
s.accountCount.Add(1)
return stat
}
existing, _ := actual.(*openAIAccountRuntimeStat)
if existing != nil {
return existing
}
return stat
}
func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) {
for {
oldBits := target.Load()
oldValue := math.Float64frombits(oldBits)
newValue := alpha*sample + (1-alpha)*oldValue
if target.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
return
}
}
}
func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int) {
if s == nil || accountID <= 0 {
return
}
const alpha = 0.2
stat := s.loadOrCreate(accountID)
errorSample := 1.0
if success {
errorSample = 0.0
}
updateEWMAAtomic(&stat.errorRateEWMABits, errorSample, alpha)
if firstTokenMs != nil && *firstTokenMs > 0 {
ttft := float64(*firstTokenMs)
ttftBits := math.Float64bits(ttft)
for {
oldBits := stat.ttftEWMABits.Load()
oldValue := math.Float64frombits(oldBits)
if math.IsNaN(oldValue) {
if stat.ttftEWMABits.CompareAndSwap(oldBits, ttftBits) {
break
}
continue
}
newValue := alpha*ttft + (1-alpha)*oldValue
if stat.ttftEWMABits.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
break
}
}
}
}
func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64, ttft float64, hasTTFT bool) {
if s == nil || accountID <= 0 {
return 0, 0, false
}
value, ok := s.accounts.Load(accountID)
if !ok {
return 0, 0, false
}
stat, _ := value.(*openAIAccountRuntimeStat)
if stat == nil {
return 0, 0, false
}
errorRate = clamp01(math.Float64frombits(stat.errorRateEWMABits.Load()))
ttftValue := math.Float64frombits(stat.ttftEWMABits.Load())
if math.IsNaN(ttftValue) {
return errorRate, 0, false
}
return errorRate, ttftValue, true
}
func (s *openAIAccountRuntimeStats) size() int {
if s == nil {
return 0
}
return int(s.accountCount.Load())
}
type defaultOpenAIAccountScheduler struct {
service *OpenAIGatewayService
metrics openAIAccountSchedulerMetrics
stats *openAIAccountRuntimeStats
}
func newDefaultOpenAIAccountScheduler(service *OpenAIGatewayService, stats *openAIAccountRuntimeStats) OpenAIAccountScheduler {
if stats == nil {
stats = newOpenAIAccountRuntimeStats()
}
return &defaultOpenAIAccountScheduler{
service: service,
stats: stats,
}
}
func (s *defaultOpenAIAccountScheduler) Select(
ctx context.Context,
req OpenAIAccountScheduleRequest,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
start := time.Now()
defer func() {
decision.LatencyMs = time.Since(start).Milliseconds()
s.metrics.recordSelect(decision)
}()
previousResponseID := strings.TrimSpace(req.PreviousResponseID)
if previousResponseID != "" {
selection, err := s.service.SelectAccountByPreviousResponseID(
ctx,
req.GroupID,
previousResponseID,
req.RequestedModel,
req.ExcludedIDs,
)
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) {
selection = nil
}
}
if selection != nil && selection.Account != nil {
decision.Layer = openAIAccountScheduleLayerPreviousResponse
decision.StickyPreviousHit = true
decision.SelectedAccountID = selection.Account.ID
decision.SelectedAccountType = selection.Account.Type
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, selection.Account.ID)
}
return selection, decision, nil
}
}
selection, err := s.selectBySessionHash(ctx, req)
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
decision.Layer = openAIAccountScheduleLayerSessionSticky
decision.StickySessionHit = true
decision.SelectedAccountID = selection.Account.ID
decision.SelectedAccountType = selection.Account.Type
return selection, decision, nil
}
selection, candidateCount, topK, loadSkew, err := s.selectByLoadBalance(ctx, req)
decision.Layer = openAIAccountScheduleLayerLoadBalance
decision.CandidateCount = candidateCount
decision.TopK = topK
decision.LoadSkew = loadSkew
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
decision.SelectedAccountID = selection.Account.ID
decision.SelectedAccountType = selection.Account.Type
}
return selection, decision, nil
}
func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
ctx context.Context,
req OpenAIAccountScheduleRequest,
) (*AccountSelectionResult, error) {
sessionHash := strings.TrimSpace(req.SessionHash)
if sessionHash == "" || s == nil || s.service == nil || s.service.cache == nil {
return nil, nil
}
accountID := req.StickyAccountID
if accountID <= 0 {
var err error
accountID, err = s.service.getStickySessionAccountID(ctx, req.GroupID, sessionHash)
if err != nil || accountID <= 0 {
return nil, nil
}
}
if accountID <= 0 {
return nil, nil
}
if req.ExcludedIDs != nil {
if _, excluded := req.ExcludedIDs[accountID]; excluded {
return nil, nil
}
}
account, err := s.service.getSchedulableAccount(ctx, accountID)
if err != nil || account == nil {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
return nil, nil
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
cfg := s.service.schedulingConfig()
// WaitPlan.MaxConcurrency 使用 Concurrency非 EffectiveLoadFactor因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
if s.service.concurrencyService != nil {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
return nil, nil
}
type openAIAccountCandidateScore struct {
account *Account
loadInfo *AccountLoadInfo
score float64
errorRate float64
ttft float64
hasTTFT bool
}
type openAIAccountCandidateHeap []openAIAccountCandidateScore
func (h openAIAccountCandidateHeap) Len() int {
return len(h)
}
func (h openAIAccountCandidateHeap) Less(i, j int) bool {
// 最小堆根节点保存“最差”候选,便于 O(log k) 维护 topK。
return isOpenAIAccountCandidateBetter(h[j], h[i])
}
func (h openAIAccountCandidateHeap) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
}
func (h *openAIAccountCandidateHeap) Push(x any) {
candidate, ok := x.(openAIAccountCandidateScore)
if !ok {
panic("openAIAccountCandidateHeap: invalid element type")
}
*h = append(*h, candidate)
}
func (h *openAIAccountCandidateHeap) Pop() any {
old := *h
n := len(old)
last := old[n-1]
*h = old[:n-1]
return last
}
func isOpenAIAccountCandidateBetter(left openAIAccountCandidateScore, right openAIAccountCandidateScore) bool {
if left.score != right.score {
return left.score > right.score
}
if left.account.Priority != right.account.Priority {
return left.account.Priority < right.account.Priority
}
if left.loadInfo.LoadRate != right.loadInfo.LoadRate {
return left.loadInfo.LoadRate < right.loadInfo.LoadRate
}
if left.loadInfo.WaitingCount != right.loadInfo.WaitingCount {
return left.loadInfo.WaitingCount < right.loadInfo.WaitingCount
}
return left.account.ID < right.account.ID
}
func selectTopKOpenAICandidates(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore {
if len(candidates) == 0 {
return nil
}
if topK <= 0 {
topK = 1
}
if topK >= len(candidates) {
ranked := append([]openAIAccountCandidateScore(nil), candidates...)
sort.Slice(ranked, func(i, j int) bool {
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
})
return ranked
}
best := make(openAIAccountCandidateHeap, 0, topK)
for _, candidate := range candidates {
if len(best) < topK {
heap.Push(&best, candidate)
continue
}
if isOpenAIAccountCandidateBetter(candidate, best[0]) {
best[0] = candidate
heap.Fix(&best, 0)
}
}
ranked := make([]openAIAccountCandidateScore, len(best))
copy(ranked, best)
sort.Slice(ranked, func(i, j int) bool {
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
})
return ranked
}
type openAISelectionRNG struct {
state uint64
}
func newOpenAISelectionRNG(seed uint64) openAISelectionRNG {
if seed == 0 {
seed = 0x9e3779b97f4a7c15
}
return openAISelectionRNG{state: seed}
}
func (r *openAISelectionRNG) nextUint64() uint64 {
// xorshift64*
x := r.state
x ^= x >> 12
x ^= x << 25
x ^= x >> 27
r.state = x
return x * 2685821657736338717
}
func (r *openAISelectionRNG) nextFloat64() float64 {
// [0,1)
return float64(r.nextUint64()>>11) / (1 << 53)
}
func deriveOpenAISelectionSeed(req OpenAIAccountScheduleRequest) uint64 {
hasher := fnv.New64a()
writeValue := func(value string) {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return
}
_, _ = hasher.Write([]byte(trimmed))
_, _ = hasher.Write([]byte{0})
}
writeValue(req.SessionHash)
writeValue(req.PreviousResponseID)
writeValue(req.RequestedModel)
if req.GroupID != nil {
_, _ = hasher.Write([]byte(strconv.FormatInt(*req.GroupID, 10)))
}
seed := hasher.Sum64()
// 对“无会话锚点”的纯负载均衡请求引入时间熵,避免固定命中同一账号。
if strings.TrimSpace(req.SessionHash) == "" && strings.TrimSpace(req.PreviousResponseID) == "" {
seed ^= uint64(time.Now().UnixNano())
}
if seed == 0 {
seed = uint64(time.Now().UnixNano()) ^ 0x9e3779b97f4a7c15
}
return seed
}
func buildOpenAIWeightedSelectionOrder(
candidates []openAIAccountCandidateScore,
req OpenAIAccountScheduleRequest,
) []openAIAccountCandidateScore {
if len(candidates) <= 1 {
return append([]openAIAccountCandidateScore(nil), candidates...)
}
pool := append([]openAIAccountCandidateScore(nil), candidates...)
weights := make([]float64, len(pool))
minScore := pool[0].score
for i := 1; i < len(pool); i++ {
if pool[i].score < minScore {
minScore = pool[i].score
}
}
for i := range pool {
// 将 top-K 分值平移到正区间,避免“单一最高分账号”长期垄断。
weight := (pool[i].score - minScore) + 1.0
if math.IsNaN(weight) || math.IsInf(weight, 0) || weight <= 0 {
weight = 1.0
}
weights[i] = weight
}
order := make([]openAIAccountCandidateScore, 0, len(pool))
rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req))
for len(pool) > 0 {
total := 0.0
for _, w := range weights {
total += w
}
selectedIdx := 0
if total > 0 {
r := rng.nextFloat64() * total
acc := 0.0
for i, w := range weights {
acc += w
if r <= acc {
selectedIdx = i
break
}
}
} else {
selectedIdx = int(rng.nextUint64() % uint64(len(pool)))
}
order = append(order, pool[selectedIdx])
pool = append(pool[:selectedIdx], pool[selectedIdx+1:]...)
weights = append(weights[:selectedIdx], weights[selectedIdx+1:]...)
}
return order
}
func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
ctx context.Context,
req OpenAIAccountScheduleRequest,
) (*AccountSelectionResult, int, int, float64, error) {
accounts, err := s.service.listSchedulableAccounts(ctx, req.GroupID)
if err != nil {
return nil, 0, 0, 0, err
}
if len(accounts) == 0 {
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
}
filtered := make([]*Account, 0, len(accounts))
loadReq := make([]AccountWithConcurrency, 0, len(accounts))
for i := range accounts {
account := &accounts[i]
if req.ExcludedIDs != nil {
if _, excluded := req.ExcludedIDs[account.ID]; excluded {
continue
}
}
if !account.IsSchedulable() || !account.IsOpenAI() {
continue
}
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
continue
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
continue
}
filtered = append(filtered, account)
loadReq = append(loadReq, AccountWithConcurrency{
ID: account.ID,
MaxConcurrency: account.EffectiveLoadFactor(),
})
}
if len(filtered) == 0 {
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
}
loadMap := map[int64]*AccountLoadInfo{}
if s.service.concurrencyService != nil {
if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil {
loadMap = batchLoad
}
}
minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority
maxWaiting := 1
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
hasTTFTSample := false
candidates := make([]openAIAccountCandidateScore, 0, len(filtered))
for _, account := range filtered {
loadInfo := loadMap[account.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: account.ID}
}
if account.Priority < minPriority {
minPriority = account.Priority
}
if account.Priority > maxPriority {
maxPriority = account.Priority
}
if loadInfo.WaitingCount > maxWaiting {
maxWaiting = loadInfo.WaitingCount
}
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
if hasTTFT && ttft > 0 {
if !hasTTFTSample {
minTTFT, maxTTFT = ttft, ttft
hasTTFTSample = true
} else {
if ttft < minTTFT {
minTTFT = ttft
}
if ttft > maxTTFT {
maxTTFT = ttft
}
}
}
loadRate := float64(loadInfo.LoadRate)
loadRateSum += loadRate
loadRateSumSquares += loadRate * loadRate
candidates = append(candidates, openAIAccountCandidateScore{
account: account,
loadInfo: loadInfo,
errorRate: errorRate,
ttft: ttft,
hasTTFT: hasTTFT,
})
}
loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
weights := s.service.openAIWSSchedulerWeights()
for i := range candidates {
item := &candidates[i]
priorityFactor := 1.0
if maxPriority > minPriority {
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
}
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
errorFactor := 1 - clamp01(item.errorRate)
ttftFactor := 0.5
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
}
item.score = weights.Priority*priorityFactor +
weights.Load*loadFactor +
weights.Queue*queueFactor +
weights.ErrorRate*errorFactor +
weights.TTFT*ttftFactor
}
topK := s.service.openAIWSLBTopK()
if topK > len(candidates) {
topK = len(candidates)
}
if topK <= 0 {
topK = 1
}
rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency)
if acquireErr != nil {
return nil, len(candidates), topK, loadSkew, acquireErr
}
if result != nil && result.Acquired {
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID)
}
return &AccountSelectionResult{
Account: candidate.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, len(candidates), topK, loadSkew, nil
}
}
cfg := s.service.schedulingConfig()
// WaitPlan.MaxConcurrency 使用 Concurrency非 EffectiveLoadFactor因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
candidate := selectionOrder[0]
return &AccountSelectionResult{
Account: candidate.account,
WaitPlan: &AccountWaitPlan{
AccountID: candidate.account.ID,
MaxConcurrency: candidate.account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, len(candidates), topK, loadSkew, nil
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
// HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
return true
}
if s == nil || s.service == nil || account == nil {
return false
}
return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
}
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
if s == nil || s.stats == nil {
return
}
s.stats.report(accountID, success, firstTokenMs)
}
func (s *defaultOpenAIAccountScheduler) ReportSwitch() {
if s == nil {
return
}
s.metrics.recordSwitch()
}
func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot {
if s == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}
selectTotal := s.metrics.selectTotal.Load()
prevHit := s.metrics.stickyPreviousHitTotal.Load()
sessionHit := s.metrics.stickySessionHitTotal.Load()
switchTotal := s.metrics.accountSwitchTotal.Load()
latencyTotal := s.metrics.latencyMsTotal.Load()
loadSkewTotal := s.metrics.loadSkewMilliTotal.Load()
snapshot := OpenAIAccountSchedulerMetricsSnapshot{
SelectTotal: selectTotal,
StickyPreviousHitTotal: prevHit,
StickySessionHitTotal: sessionHit,
LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(),
AccountSwitchTotal: switchTotal,
SchedulerLatencyMsTotal: latencyTotal,
RuntimeStatsAccountCount: s.stats.size(),
}
if selectTotal > 0 {
snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal)
snapshot.StickyHitRatio = float64(prevHit+sessionHit) / float64(selectTotal)
snapshot.AccountSwitchRate = float64(switchTotal) / float64(selectTotal)
snapshot.LoadSkewAvg = float64(loadSkewTotal) / 1000 / float64(selectTotal)
}
return snapshot
}
func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler {
if s == nil {
return nil
}
s.openaiSchedulerOnce.Do(func() {
if s.openaiAccountStats == nil {
s.openaiAccountStats = newOpenAIAccountRuntimeStats()
}
if s.openaiScheduler == nil {
s.openaiScheduler = newDefaultOpenAIAccountScheduler(s, s.openaiAccountStats)
}
})
return s.openaiScheduler
}
func (s *OpenAIGatewayService) SelectAccountWithScheduler(
ctx context.Context,
groupID *int64,
previousResponseID string,
sessionHash string,
requestedModel string,
excludedIDs map[int64]struct{},
requiredTransport OpenAIUpstreamTransport,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
decision.Layer = openAIAccountScheduleLayerLoadBalance
return selection, decision, err
}
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 {
stickyAccountID = accountID
}
}
return scheduler.Select(ctx, OpenAIAccountScheduleRequest{
GroupID: groupID,
SessionHash: sessionHash,
StickyAccountID: stickyAccountID,
PreviousResponseID: previousResponseID,
RequestedModel: requestedModel,
RequiredTransport: requiredTransport,
ExcludedIDs: excludedIDs,
})
}
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
return
}
scheduler.ReportResult(accountID, success, firstTokenMs)
}
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
return
}
scheduler.ReportSwitch()
}
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}
return scheduler.SnapshotMetrics()
}
func (s *OpenAIGatewayService) openAIWSSessionStickyTTL() time.Duration {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 {
return time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second
}
return openaiStickySessionTTL
}
func (s *OpenAIGatewayService) openAIWSLBTopK() int {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.LBTopK > 0 {
return s.cfg.Gateway.OpenAIWS.LBTopK
}
return 7
}
func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView {
if s != nil && s.cfg != nil {
return GatewayOpenAIWSSchedulerScoreWeightsView{
Priority: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority,
Load: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load,
Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue,
ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate,
TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT,
}
}
return GatewayOpenAIWSSchedulerScoreWeightsView{
Priority: 1.0,
Load: 1.0,
Queue: 0.7,
ErrorRate: 0.8,
TTFT: 0.5,
}
}
type GatewayOpenAIWSSchedulerScoreWeightsView struct {
Priority float64
Load float64
Queue float64
ErrorRate float64
TTFT float64
}
func clamp01(value float64) float64 {
switch {
case value < 0:
return 0
case value > 1:
return 1
default:
return value
}
}
func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 {
if count <= 1 {
return 0
}
mean := sum / float64(count)
variance := sumSquares/float64(count) - mean*mean
if variance < 0 {
variance = 0
}
return math.Sqrt(variance)
}