Files
sub2api/backend/internal/service/openai_account_scheduler.go
2026-02-28 15:01:20 +08:00

910 lines
25 KiB
Go

package service
import (
"container/heap"
"context"
"errors"
"hash/fnv"
"math"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
const (
openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
openAIAccountScheduleLayerSessionSticky = "session_hash"
openAIAccountScheduleLayerLoadBalance = "load_balance"
)
type OpenAIAccountScheduleRequest struct {
GroupID *int64
SessionHash string
StickyAccountID int64
PreviousResponseID string
RequestedModel string
RequiredTransport OpenAIUpstreamTransport
ExcludedIDs map[int64]struct{}
}
type OpenAIAccountScheduleDecision struct {
Layer string
StickyPreviousHit bool
StickySessionHit bool
CandidateCount int
TopK int
LatencyMs int64
LoadSkew float64
SelectedAccountID int64
SelectedAccountType string
}
type OpenAIAccountSchedulerMetricsSnapshot struct {
SelectTotal int64
StickyPreviousHitTotal int64
StickySessionHitTotal int64
LoadBalanceSelectTotal int64
AccountSwitchTotal int64
SchedulerLatencyMsTotal int64
SchedulerLatencyMsAvg float64
StickyHitRatio float64
AccountSwitchRate float64
LoadSkewAvg float64
RuntimeStatsAccountCount int
}
type OpenAIAccountScheduler interface {
Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error)
ReportResult(accountID int64, success bool, firstTokenMs *int)
ReportSwitch()
SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot
}
type openAIAccountSchedulerMetrics struct {
selectTotal atomic.Int64
stickyPreviousHitTotal atomic.Int64
stickySessionHitTotal atomic.Int64
loadBalanceSelectTotal atomic.Int64
accountSwitchTotal atomic.Int64
latencyMsTotal atomic.Int64
loadSkewMilliTotal atomic.Int64
}
func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
if m == nil {
return
}
m.selectTotal.Add(1)
m.latencyMsTotal.Add(decision.LatencyMs)
m.loadSkewMilliTotal.Add(int64(math.Round(decision.LoadSkew * 1000)))
if decision.StickyPreviousHit {
m.stickyPreviousHitTotal.Add(1)
}
if decision.StickySessionHit {
m.stickySessionHitTotal.Add(1)
}
if decision.Layer == openAIAccountScheduleLayerLoadBalance {
m.loadBalanceSelectTotal.Add(1)
}
}
func (m *openAIAccountSchedulerMetrics) recordSwitch() {
if m == nil {
return
}
m.accountSwitchTotal.Add(1)
}
type openAIAccountRuntimeStats struct {
accounts sync.Map
accountCount atomic.Int64
}
type openAIAccountRuntimeStat struct {
errorRateEWMABits atomic.Uint64
ttftEWMABits atomic.Uint64
}
func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats {
return &openAIAccountRuntimeStats{}
}
func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat {
if value, ok := s.accounts.Load(accountID); ok {
stat, _ := value.(*openAIAccountRuntimeStat)
if stat != nil {
return stat
}
}
stat := &openAIAccountRuntimeStat{}
stat.ttftEWMABits.Store(math.Float64bits(math.NaN()))
actual, loaded := s.accounts.LoadOrStore(accountID, stat)
if !loaded {
s.accountCount.Add(1)
return stat
}
existing, _ := actual.(*openAIAccountRuntimeStat)
if existing != nil {
return existing
}
return stat
}
func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) {
for {
oldBits := target.Load()
oldValue := math.Float64frombits(oldBits)
newValue := alpha*sample + (1-alpha)*oldValue
if target.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
return
}
}
}
func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int) {
if s == nil || accountID <= 0 {
return
}
const alpha = 0.2
stat := s.loadOrCreate(accountID)
errorSample := 1.0
if success {
errorSample = 0.0
}
updateEWMAAtomic(&stat.errorRateEWMABits, errorSample, alpha)
if firstTokenMs != nil && *firstTokenMs > 0 {
ttft := float64(*firstTokenMs)
ttftBits := math.Float64bits(ttft)
for {
oldBits := stat.ttftEWMABits.Load()
oldValue := math.Float64frombits(oldBits)
if math.IsNaN(oldValue) {
if stat.ttftEWMABits.CompareAndSwap(oldBits, ttftBits) {
break
}
continue
}
newValue := alpha*ttft + (1-alpha)*oldValue
if stat.ttftEWMABits.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
break
}
}
}
}
func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64, ttft float64, hasTTFT bool) {
if s == nil || accountID <= 0 {
return 0, 0, false
}
value, ok := s.accounts.Load(accountID)
if !ok {
return 0, 0, false
}
stat, _ := value.(*openAIAccountRuntimeStat)
if stat == nil {
return 0, 0, false
}
errorRate = clamp01(math.Float64frombits(stat.errorRateEWMABits.Load()))
ttftValue := math.Float64frombits(stat.ttftEWMABits.Load())
if math.IsNaN(ttftValue) {
return errorRate, 0, false
}
return errorRate, ttftValue, true
}
func (s *openAIAccountRuntimeStats) size() int {
if s == nil {
return 0
}
return int(s.accountCount.Load())
}
type defaultOpenAIAccountScheduler struct {
service *OpenAIGatewayService
metrics openAIAccountSchedulerMetrics
stats *openAIAccountRuntimeStats
}
func newDefaultOpenAIAccountScheduler(service *OpenAIGatewayService, stats *openAIAccountRuntimeStats) OpenAIAccountScheduler {
if stats == nil {
stats = newOpenAIAccountRuntimeStats()
}
return &defaultOpenAIAccountScheduler{
service: service,
stats: stats,
}
}
func (s *defaultOpenAIAccountScheduler) Select(
ctx context.Context,
req OpenAIAccountScheduleRequest,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
start := time.Now()
defer func() {
decision.LatencyMs = time.Since(start).Milliseconds()
s.metrics.recordSelect(decision)
}()
previousResponseID := strings.TrimSpace(req.PreviousResponseID)
if previousResponseID != "" {
selection, err := s.service.SelectAccountByPreviousResponseID(
ctx,
req.GroupID,
previousResponseID,
req.RequestedModel,
req.ExcludedIDs,
)
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) {
selection = nil
}
}
if selection != nil && selection.Account != nil {
decision.Layer = openAIAccountScheduleLayerPreviousResponse
decision.StickyPreviousHit = true
decision.SelectedAccountID = selection.Account.ID
decision.SelectedAccountType = selection.Account.Type
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, selection.Account.ID)
}
return selection, decision, nil
}
}
selection, err := s.selectBySessionHash(ctx, req)
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
decision.Layer = openAIAccountScheduleLayerSessionSticky
decision.StickySessionHit = true
decision.SelectedAccountID = selection.Account.ID
decision.SelectedAccountType = selection.Account.Type
return selection, decision, nil
}
selection, candidateCount, topK, loadSkew, err := s.selectByLoadBalance(ctx, req)
decision.Layer = openAIAccountScheduleLayerLoadBalance
decision.CandidateCount = candidateCount
decision.TopK = topK
decision.LoadSkew = loadSkew
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
decision.SelectedAccountID = selection.Account.ID
decision.SelectedAccountType = selection.Account.Type
}
return selection, decision, nil
}
func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
ctx context.Context,
req OpenAIAccountScheduleRequest,
) (*AccountSelectionResult, error) {
sessionHash := strings.TrimSpace(req.SessionHash)
if sessionHash == "" || s == nil || s.service == nil || s.service.cache == nil {
return nil, nil
}
accountID := req.StickyAccountID
if accountID <= 0 {
var err error
accountID, err = s.service.getStickySessionAccountID(ctx, req.GroupID, sessionHash)
if err != nil || accountID <= 0 {
return nil, nil
}
}
if accountID <= 0 {
return nil, nil
}
if req.ExcludedIDs != nil {
if _, excluded := req.ExcludedIDs[accountID]; excluded {
return nil, nil
}
}
account, err := s.service.getSchedulableAccount(ctx, accountID)
if err != nil || account == nil {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
return nil, nil
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
cfg := s.service.schedulingConfig()
if s.service.concurrencyService != nil {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
return nil, nil
}
type openAIAccountCandidateScore struct {
account *Account
loadInfo *AccountLoadInfo
score float64
errorRate float64
ttft float64
hasTTFT bool
}
type openAIAccountCandidateHeap []openAIAccountCandidateScore
func (h openAIAccountCandidateHeap) Len() int {
return len(h)
}
func (h openAIAccountCandidateHeap) Less(i, j int) bool {
// 最小堆根节点保存“最差”候选,便于 O(log k) 维护 topK。
return isOpenAIAccountCandidateBetter(h[j], h[i])
}
func (h openAIAccountCandidateHeap) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
}
func (h *openAIAccountCandidateHeap) Push(x any) {
candidate, ok := x.(openAIAccountCandidateScore)
if !ok {
panic("openAIAccountCandidateHeap: invalid element type")
}
*h = append(*h, candidate)
}
func (h *openAIAccountCandidateHeap) Pop() any {
old := *h
n := len(old)
last := old[n-1]
*h = old[:n-1]
return last
}
func isOpenAIAccountCandidateBetter(left openAIAccountCandidateScore, right openAIAccountCandidateScore) bool {
if left.score != right.score {
return left.score > right.score
}
if left.account.Priority != right.account.Priority {
return left.account.Priority < right.account.Priority
}
if left.loadInfo.LoadRate != right.loadInfo.LoadRate {
return left.loadInfo.LoadRate < right.loadInfo.LoadRate
}
if left.loadInfo.WaitingCount != right.loadInfo.WaitingCount {
return left.loadInfo.WaitingCount < right.loadInfo.WaitingCount
}
return left.account.ID < right.account.ID
}
func selectTopKOpenAICandidates(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore {
if len(candidates) == 0 {
return nil
}
if topK <= 0 {
topK = 1
}
if topK >= len(candidates) {
ranked := append([]openAIAccountCandidateScore(nil), candidates...)
sort.Slice(ranked, func(i, j int) bool {
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
})
return ranked
}
best := make(openAIAccountCandidateHeap, 0, topK)
for _, candidate := range candidates {
if len(best) < topK {
heap.Push(&best, candidate)
continue
}
if isOpenAIAccountCandidateBetter(candidate, best[0]) {
best[0] = candidate
heap.Fix(&best, 0)
}
}
ranked := make([]openAIAccountCandidateScore, len(best))
copy(ranked, best)
sort.Slice(ranked, func(i, j int) bool {
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
})
return ranked
}
type openAISelectionRNG struct {
state uint64
}
func newOpenAISelectionRNG(seed uint64) openAISelectionRNG {
if seed == 0 {
seed = 0x9e3779b97f4a7c15
}
return openAISelectionRNG{state: seed}
}
func (r *openAISelectionRNG) nextUint64() uint64 {
// xorshift64*
x := r.state
x ^= x >> 12
x ^= x << 25
x ^= x >> 27
r.state = x
return x * 2685821657736338717
}
func (r *openAISelectionRNG) nextFloat64() float64 {
// [0,1)
return float64(r.nextUint64()>>11) / (1 << 53)
}
func deriveOpenAISelectionSeed(req OpenAIAccountScheduleRequest) uint64 {
hasher := fnv.New64a()
writeValue := func(value string) {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return
}
_, _ = hasher.Write([]byte(trimmed))
_, _ = hasher.Write([]byte{0})
}
writeValue(req.SessionHash)
writeValue(req.PreviousResponseID)
writeValue(req.RequestedModel)
if req.GroupID != nil {
_, _ = hasher.Write([]byte(strconv.FormatInt(*req.GroupID, 10)))
}
seed := hasher.Sum64()
// 对“无会话锚点”的纯负载均衡请求引入时间熵,避免固定命中同一账号。
if strings.TrimSpace(req.SessionHash) == "" && strings.TrimSpace(req.PreviousResponseID) == "" {
seed ^= uint64(time.Now().UnixNano())
}
if seed == 0 {
seed = uint64(time.Now().UnixNano()) ^ 0x9e3779b97f4a7c15
}
return seed
}
func buildOpenAIWeightedSelectionOrder(
candidates []openAIAccountCandidateScore,
req OpenAIAccountScheduleRequest,
) []openAIAccountCandidateScore {
if len(candidates) <= 1 {
return append([]openAIAccountCandidateScore(nil), candidates...)
}
pool := append([]openAIAccountCandidateScore(nil), candidates...)
weights := make([]float64, len(pool))
minScore := pool[0].score
for i := 1; i < len(pool); i++ {
if pool[i].score < minScore {
minScore = pool[i].score
}
}
for i := range pool {
// 将 top-K 分值平移到正区间,避免“单一最高分账号”长期垄断。
weight := (pool[i].score - minScore) + 1.0
if math.IsNaN(weight) || math.IsInf(weight, 0) || weight <= 0 {
weight = 1.0
}
weights[i] = weight
}
order := make([]openAIAccountCandidateScore, 0, len(pool))
rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req))
for len(pool) > 0 {
total := 0.0
for _, w := range weights {
total += w
}
selectedIdx := 0
if total > 0 {
r := rng.nextFloat64() * total
acc := 0.0
for i, w := range weights {
acc += w
if r <= acc {
selectedIdx = i
break
}
}
} else {
selectedIdx = int(rng.nextUint64() % uint64(len(pool)))
}
order = append(order, pool[selectedIdx])
pool = append(pool[:selectedIdx], pool[selectedIdx+1:]...)
weights = append(weights[:selectedIdx], weights[selectedIdx+1:]...)
}
return order
}
func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
ctx context.Context,
req OpenAIAccountScheduleRequest,
) (*AccountSelectionResult, int, int, float64, error) {
accounts, err := s.service.listSchedulableAccounts(ctx, req.GroupID)
if err != nil {
return nil, 0, 0, 0, err
}
if len(accounts) == 0 {
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
}
filtered := make([]*Account, 0, len(accounts))
loadReq := make([]AccountWithConcurrency, 0, len(accounts))
for i := range accounts {
account := &accounts[i]
if req.ExcludedIDs != nil {
if _, excluded := req.ExcludedIDs[account.ID]; excluded {
continue
}
}
if !account.IsSchedulable() || !account.IsOpenAI() {
continue
}
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
continue
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
continue
}
filtered = append(filtered, account)
loadReq = append(loadReq, AccountWithConcurrency{
ID: account.ID,
MaxConcurrency: account.Concurrency,
})
}
if len(filtered) == 0 {
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
}
loadMap := map[int64]*AccountLoadInfo{}
if s.service.concurrencyService != nil {
if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil {
loadMap = batchLoad
}
}
minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority
maxWaiting := 1
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
hasTTFTSample := false
candidates := make([]openAIAccountCandidateScore, 0, len(filtered))
for _, account := range filtered {
loadInfo := loadMap[account.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: account.ID}
}
if account.Priority < minPriority {
minPriority = account.Priority
}
if account.Priority > maxPriority {
maxPriority = account.Priority
}
if loadInfo.WaitingCount > maxWaiting {
maxWaiting = loadInfo.WaitingCount
}
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
if hasTTFT && ttft > 0 {
if !hasTTFTSample {
minTTFT, maxTTFT = ttft, ttft
hasTTFTSample = true
} else {
if ttft < minTTFT {
minTTFT = ttft
}
if ttft > maxTTFT {
maxTTFT = ttft
}
}
}
loadRate := float64(loadInfo.LoadRate)
loadRateSum += loadRate
loadRateSumSquares += loadRate * loadRate
candidates = append(candidates, openAIAccountCandidateScore{
account: account,
loadInfo: loadInfo,
errorRate: errorRate,
ttft: ttft,
hasTTFT: hasTTFT,
})
}
loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
weights := s.service.openAIWSSchedulerWeights()
for i := range candidates {
item := &candidates[i]
priorityFactor := 1.0
if maxPriority > minPriority {
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
}
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
errorFactor := 1 - clamp01(item.errorRate)
ttftFactor := 0.5
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
}
item.score = weights.Priority*priorityFactor +
weights.Load*loadFactor +
weights.Queue*queueFactor +
weights.ErrorRate*errorFactor +
weights.TTFT*ttftFactor
}
topK := s.service.openAIWSLBTopK()
if topK > len(candidates) {
topK = len(candidates)
}
if topK <= 0 {
topK = 1
}
rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency)
if acquireErr != nil {
return nil, len(candidates), topK, loadSkew, acquireErr
}
if result != nil && result.Acquired {
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID)
}
return &AccountSelectionResult{
Account: candidate.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, len(candidates), topK, loadSkew, nil
}
}
cfg := s.service.schedulingConfig()
candidate := selectionOrder[0]
return &AccountSelectionResult{
Account: candidate.account,
WaitPlan: &AccountWaitPlan{
AccountID: candidate.account.ID,
MaxConcurrency: candidate.account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, len(candidates), topK, loadSkew, nil
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
// HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
return true
}
if s == nil || s.service == nil || account == nil {
return false
}
return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
}
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
if s == nil || s.stats == nil {
return
}
s.stats.report(accountID, success, firstTokenMs)
}
func (s *defaultOpenAIAccountScheduler) ReportSwitch() {
if s == nil {
return
}
s.metrics.recordSwitch()
}
func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot {
if s == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}
selectTotal := s.metrics.selectTotal.Load()
prevHit := s.metrics.stickyPreviousHitTotal.Load()
sessionHit := s.metrics.stickySessionHitTotal.Load()
switchTotal := s.metrics.accountSwitchTotal.Load()
latencyTotal := s.metrics.latencyMsTotal.Load()
loadSkewTotal := s.metrics.loadSkewMilliTotal.Load()
snapshot := OpenAIAccountSchedulerMetricsSnapshot{
SelectTotal: selectTotal,
StickyPreviousHitTotal: prevHit,
StickySessionHitTotal: sessionHit,
LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(),
AccountSwitchTotal: switchTotal,
SchedulerLatencyMsTotal: latencyTotal,
RuntimeStatsAccountCount: s.stats.size(),
}
if selectTotal > 0 {
snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal)
snapshot.StickyHitRatio = float64(prevHit+sessionHit) / float64(selectTotal)
snapshot.AccountSwitchRate = float64(switchTotal) / float64(selectTotal)
snapshot.LoadSkewAvg = float64(loadSkewTotal) / 1000 / float64(selectTotal)
}
return snapshot
}
func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler {
if s == nil {
return nil
}
s.openaiSchedulerOnce.Do(func() {
if s.openaiAccountStats == nil {
s.openaiAccountStats = newOpenAIAccountRuntimeStats()
}
if s.openaiScheduler == nil {
s.openaiScheduler = newDefaultOpenAIAccountScheduler(s, s.openaiAccountStats)
}
})
return s.openaiScheduler
}
func (s *OpenAIGatewayService) SelectAccountWithScheduler(
ctx context.Context,
groupID *int64,
previousResponseID string,
sessionHash string,
requestedModel string,
excludedIDs map[int64]struct{},
requiredTransport OpenAIUpstreamTransport,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
decision.Layer = openAIAccountScheduleLayerLoadBalance
return selection, decision, err
}
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 {
stickyAccountID = accountID
}
}
return scheduler.Select(ctx, OpenAIAccountScheduleRequest{
GroupID: groupID,
SessionHash: sessionHash,
StickyAccountID: stickyAccountID,
PreviousResponseID: previousResponseID,
RequestedModel: requestedModel,
RequiredTransport: requiredTransport,
ExcludedIDs: excludedIDs,
})
}
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
return
}
scheduler.ReportResult(accountID, success, firstTokenMs)
}
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
return
}
scheduler.ReportSwitch()
}
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}
return scheduler.SnapshotMetrics()
}
func (s *OpenAIGatewayService) openAIWSSessionStickyTTL() time.Duration {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 {
return time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second
}
return openaiStickySessionTTL
}
func (s *OpenAIGatewayService) openAIWSLBTopK() int {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.LBTopK > 0 {
return s.cfg.Gateway.OpenAIWS.LBTopK
}
return 7
}
func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView {
if s != nil && s.cfg != nil {
return GatewayOpenAIWSSchedulerScoreWeightsView{
Priority: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority,
Load: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load,
Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue,
ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate,
TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT,
}
}
return GatewayOpenAIWSSchedulerScoreWeightsView{
Priority: 1.0,
Load: 1.0,
Queue: 0.7,
ErrorRate: 0.8,
TTFT: 0.5,
}
}
type GatewayOpenAIWSSchedulerScoreWeightsView struct {
Priority float64
Load float64
Queue float64
ErrorRate float64
TTFT float64
}
func clamp01(value float64) float64 {
switch {
case value < 0:
return 0
case value > 1:
return 1
default:
return value
}
}
func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 {
if count <= 1 {
return 0
}
mean := sum / float64(count)
variance := sumSquares/float64(count) - mean*mean
if variance < 0 {
variance = 0
}
return math.Sqrt(variance)
}