feat(sync): full code sync from release
This commit is contained in:
909
backend/internal/service/openai_account_scheduler.go
Normal file
909
backend/internal/service/openai_account_scheduler.go
Normal file
@@ -0,0 +1,909 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"errors"
|
||||
"hash/fnv"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
|
||||
openAIAccountScheduleLayerSessionSticky = "session_hash"
|
||||
openAIAccountScheduleLayerLoadBalance = "load_balance"
|
||||
)
|
||||
|
||||
type OpenAIAccountScheduleRequest struct {
|
||||
GroupID *int64
|
||||
SessionHash string
|
||||
StickyAccountID int64
|
||||
PreviousResponseID string
|
||||
RequestedModel string
|
||||
RequiredTransport OpenAIUpstreamTransport
|
||||
ExcludedIDs map[int64]struct{}
|
||||
}
|
||||
|
||||
type OpenAIAccountScheduleDecision struct {
|
||||
Layer string
|
||||
StickyPreviousHit bool
|
||||
StickySessionHit bool
|
||||
CandidateCount int
|
||||
TopK int
|
||||
LatencyMs int64
|
||||
LoadSkew float64
|
||||
SelectedAccountID int64
|
||||
SelectedAccountType string
|
||||
}
|
||||
|
||||
type OpenAIAccountSchedulerMetricsSnapshot struct {
|
||||
SelectTotal int64
|
||||
StickyPreviousHitTotal int64
|
||||
StickySessionHitTotal int64
|
||||
LoadBalanceSelectTotal int64
|
||||
AccountSwitchTotal int64
|
||||
SchedulerLatencyMsTotal int64
|
||||
SchedulerLatencyMsAvg float64
|
||||
StickyHitRatio float64
|
||||
AccountSwitchRate float64
|
||||
LoadSkewAvg float64
|
||||
RuntimeStatsAccountCount int
|
||||
}
|
||||
|
||||
type OpenAIAccountScheduler interface {
|
||||
Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error)
|
||||
ReportResult(accountID int64, success bool, firstTokenMs *int)
|
||||
ReportSwitch()
|
||||
SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot
|
||||
}
|
||||
|
||||
type openAIAccountSchedulerMetrics struct {
|
||||
selectTotal atomic.Int64
|
||||
stickyPreviousHitTotal atomic.Int64
|
||||
stickySessionHitTotal atomic.Int64
|
||||
loadBalanceSelectTotal atomic.Int64
|
||||
accountSwitchTotal atomic.Int64
|
||||
latencyMsTotal atomic.Int64
|
||||
loadSkewMilliTotal atomic.Int64
|
||||
}
|
||||
|
||||
func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.selectTotal.Add(1)
|
||||
m.latencyMsTotal.Add(decision.LatencyMs)
|
||||
m.loadSkewMilliTotal.Add(int64(math.Round(decision.LoadSkew * 1000)))
|
||||
if decision.StickyPreviousHit {
|
||||
m.stickyPreviousHitTotal.Add(1)
|
||||
}
|
||||
if decision.StickySessionHit {
|
||||
m.stickySessionHitTotal.Add(1)
|
||||
}
|
||||
if decision.Layer == openAIAccountScheduleLayerLoadBalance {
|
||||
m.loadBalanceSelectTotal.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *openAIAccountSchedulerMetrics) recordSwitch() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.accountSwitchTotal.Add(1)
|
||||
}
|
||||
|
||||
type openAIAccountRuntimeStats struct {
|
||||
accounts sync.Map
|
||||
accountCount atomic.Int64
|
||||
}
|
||||
|
||||
type openAIAccountRuntimeStat struct {
|
||||
errorRateEWMABits atomic.Uint64
|
||||
ttftEWMABits atomic.Uint64
|
||||
}
|
||||
|
||||
func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats {
|
||||
return &openAIAccountRuntimeStats{}
|
||||
}
|
||||
|
||||
func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat {
|
||||
if value, ok := s.accounts.Load(accountID); ok {
|
||||
stat, _ := value.(*openAIAccountRuntimeStat)
|
||||
if stat != nil {
|
||||
return stat
|
||||
}
|
||||
}
|
||||
|
||||
stat := &openAIAccountRuntimeStat{}
|
||||
stat.ttftEWMABits.Store(math.Float64bits(math.NaN()))
|
||||
actual, loaded := s.accounts.LoadOrStore(accountID, stat)
|
||||
if !loaded {
|
||||
s.accountCount.Add(1)
|
||||
return stat
|
||||
}
|
||||
existing, _ := actual.(*openAIAccountRuntimeStat)
|
||||
if existing != nil {
|
||||
return existing
|
||||
}
|
||||
return stat
|
||||
}
|
||||
|
||||
func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) {
|
||||
for {
|
||||
oldBits := target.Load()
|
||||
oldValue := math.Float64frombits(oldBits)
|
||||
newValue := alpha*sample + (1-alpha)*oldValue
|
||||
if target.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int) {
|
||||
if s == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
const alpha = 0.2
|
||||
stat := s.loadOrCreate(accountID)
|
||||
|
||||
errorSample := 1.0
|
||||
if success {
|
||||
errorSample = 0.0
|
||||
}
|
||||
updateEWMAAtomic(&stat.errorRateEWMABits, errorSample, alpha)
|
||||
|
||||
if firstTokenMs != nil && *firstTokenMs > 0 {
|
||||
ttft := float64(*firstTokenMs)
|
||||
ttftBits := math.Float64bits(ttft)
|
||||
for {
|
||||
oldBits := stat.ttftEWMABits.Load()
|
||||
oldValue := math.Float64frombits(oldBits)
|
||||
if math.IsNaN(oldValue) {
|
||||
if stat.ttftEWMABits.CompareAndSwap(oldBits, ttftBits) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
newValue := alpha*ttft + (1-alpha)*oldValue
|
||||
if stat.ttftEWMABits.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64, ttft float64, hasTTFT bool) {
|
||||
if s == nil || accountID <= 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
value, ok := s.accounts.Load(accountID)
|
||||
if !ok {
|
||||
return 0, 0, false
|
||||
}
|
||||
stat, _ := value.(*openAIAccountRuntimeStat)
|
||||
if stat == nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
errorRate = clamp01(math.Float64frombits(stat.errorRateEWMABits.Load()))
|
||||
ttftValue := math.Float64frombits(stat.ttftEWMABits.Load())
|
||||
if math.IsNaN(ttftValue) {
|
||||
return errorRate, 0, false
|
||||
}
|
||||
return errorRate, ttftValue, true
|
||||
}
|
||||
|
||||
func (s *openAIAccountRuntimeStats) size() int {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
return int(s.accountCount.Load())
|
||||
}
|
||||
|
||||
type defaultOpenAIAccountScheduler struct {
|
||||
service *OpenAIGatewayService
|
||||
metrics openAIAccountSchedulerMetrics
|
||||
stats *openAIAccountRuntimeStats
|
||||
}
|
||||
|
||||
func newDefaultOpenAIAccountScheduler(service *OpenAIGatewayService, stats *openAIAccountRuntimeStats) OpenAIAccountScheduler {
|
||||
if stats == nil {
|
||||
stats = newOpenAIAccountRuntimeStats()
|
||||
}
|
||||
return &defaultOpenAIAccountScheduler{
|
||||
service: service,
|
||||
stats: stats,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) Select(
|
||||
ctx context.Context,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||
decision := OpenAIAccountScheduleDecision{}
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
decision.LatencyMs = time.Since(start).Milliseconds()
|
||||
s.metrics.recordSelect(decision)
|
||||
}()
|
||||
|
||||
previousResponseID := strings.TrimSpace(req.PreviousResponseID)
|
||||
if previousResponseID != "" {
|
||||
selection, err := s.service.SelectAccountByPreviousResponseID(
|
||||
ctx,
|
||||
req.GroupID,
|
||||
previousResponseID,
|
||||
req.RequestedModel,
|
||||
req.ExcludedIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, decision, err
|
||||
}
|
||||
if selection != nil && selection.Account != nil {
|
||||
if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) {
|
||||
selection = nil
|
||||
}
|
||||
}
|
||||
if selection != nil && selection.Account != nil {
|
||||
decision.Layer = openAIAccountScheduleLayerPreviousResponse
|
||||
decision.StickyPreviousHit = true
|
||||
decision.SelectedAccountID = selection.Account.ID
|
||||
decision.SelectedAccountType = selection.Account.Type
|
||||
if req.SessionHash != "" {
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, selection.Account.ID)
|
||||
}
|
||||
return selection, decision, nil
|
||||
}
|
||||
}
|
||||
|
||||
selection, err := s.selectBySessionHash(ctx, req)
|
||||
if err != nil {
|
||||
return nil, decision, err
|
||||
}
|
||||
if selection != nil && selection.Account != nil {
|
||||
decision.Layer = openAIAccountScheduleLayerSessionSticky
|
||||
decision.StickySessionHit = true
|
||||
decision.SelectedAccountID = selection.Account.ID
|
||||
decision.SelectedAccountType = selection.Account.Type
|
||||
return selection, decision, nil
|
||||
}
|
||||
|
||||
selection, candidateCount, topK, loadSkew, err := s.selectByLoadBalance(ctx, req)
|
||||
decision.Layer = openAIAccountScheduleLayerLoadBalance
|
||||
decision.CandidateCount = candidateCount
|
||||
decision.TopK = topK
|
||||
decision.LoadSkew = loadSkew
|
||||
if err != nil {
|
||||
return nil, decision, err
|
||||
}
|
||||
if selection != nil && selection.Account != nil {
|
||||
decision.SelectedAccountID = selection.Account.ID
|
||||
decision.SelectedAccountType = selection.Account.Type
|
||||
}
|
||||
return selection, decision, nil
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
||||
ctx context.Context,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
) (*AccountSelectionResult, error) {
|
||||
sessionHash := strings.TrimSpace(req.SessionHash)
|
||||
if sessionHash == "" || s == nil || s.service == nil || s.service.cache == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
accountID := req.StickyAccountID
|
||||
if accountID <= 0 {
|
||||
var err error
|
||||
accountID, err = s.service.getStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
if err != nil || accountID <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
if accountID <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if req.ExcludedIDs != nil {
|
||||
if _, excluded := req.ExcludedIDs[accountID]; excluded {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
account, err := s.service.getSchedulableAccount(ctx, accountID)
|
||||
if err != nil || account == nil {
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() {
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||
return nil, nil
|
||||
}
|
||||
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if acquireErr == nil && result.Acquired {
|
||||
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
cfg := s.service.schedulingConfig()
|
||||
if s.service.concurrencyService != nil {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type openAIAccountCandidateScore struct {
|
||||
account *Account
|
||||
loadInfo *AccountLoadInfo
|
||||
score float64
|
||||
errorRate float64
|
||||
ttft float64
|
||||
hasTTFT bool
|
||||
}
|
||||
|
||||
type openAIAccountCandidateHeap []openAIAccountCandidateScore
|
||||
|
||||
func (h openAIAccountCandidateHeap) Len() int {
|
||||
return len(h)
|
||||
}
|
||||
|
||||
func (h openAIAccountCandidateHeap) Less(i, j int) bool {
|
||||
// 最小堆根节点保存“最差”候选,便于 O(log k) 维护 topK。
|
||||
return isOpenAIAccountCandidateBetter(h[j], h[i])
|
||||
}
|
||||
|
||||
func (h openAIAccountCandidateHeap) Swap(i, j int) {
|
||||
h[i], h[j] = h[j], h[i]
|
||||
}
|
||||
|
||||
func (h *openAIAccountCandidateHeap) Push(x any) {
|
||||
candidate, ok := x.(openAIAccountCandidateScore)
|
||||
if !ok {
|
||||
panic("openAIAccountCandidateHeap: invalid element type")
|
||||
}
|
||||
*h = append(*h, candidate)
|
||||
}
|
||||
|
||||
func (h *openAIAccountCandidateHeap) Pop() any {
|
||||
old := *h
|
||||
n := len(old)
|
||||
last := old[n-1]
|
||||
*h = old[:n-1]
|
||||
return last
|
||||
}
|
||||
|
||||
func isOpenAIAccountCandidateBetter(left openAIAccountCandidateScore, right openAIAccountCandidateScore) bool {
|
||||
if left.score != right.score {
|
||||
return left.score > right.score
|
||||
}
|
||||
if left.account.Priority != right.account.Priority {
|
||||
return left.account.Priority < right.account.Priority
|
||||
}
|
||||
if left.loadInfo.LoadRate != right.loadInfo.LoadRate {
|
||||
return left.loadInfo.LoadRate < right.loadInfo.LoadRate
|
||||
}
|
||||
if left.loadInfo.WaitingCount != right.loadInfo.WaitingCount {
|
||||
return left.loadInfo.WaitingCount < right.loadInfo.WaitingCount
|
||||
}
|
||||
return left.account.ID < right.account.ID
|
||||
}
|
||||
|
||||
func selectTopKOpenAICandidates(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
if topK <= 0 {
|
||||
topK = 1
|
||||
}
|
||||
if topK >= len(candidates) {
|
||||
ranked := append([]openAIAccountCandidateScore(nil), candidates...)
|
||||
sort.Slice(ranked, func(i, j int) bool {
|
||||
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
|
||||
})
|
||||
return ranked
|
||||
}
|
||||
|
||||
best := make(openAIAccountCandidateHeap, 0, topK)
|
||||
for _, candidate := range candidates {
|
||||
if len(best) < topK {
|
||||
heap.Push(&best, candidate)
|
||||
continue
|
||||
}
|
||||
if isOpenAIAccountCandidateBetter(candidate, best[0]) {
|
||||
best[0] = candidate
|
||||
heap.Fix(&best, 0)
|
||||
}
|
||||
}
|
||||
|
||||
ranked := make([]openAIAccountCandidateScore, len(best))
|
||||
copy(ranked, best)
|
||||
sort.Slice(ranked, func(i, j int) bool {
|
||||
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
|
||||
})
|
||||
return ranked
|
||||
}
|
||||
|
||||
type openAISelectionRNG struct {
|
||||
state uint64
|
||||
}
|
||||
|
||||
func newOpenAISelectionRNG(seed uint64) openAISelectionRNG {
|
||||
if seed == 0 {
|
||||
seed = 0x9e3779b97f4a7c15
|
||||
}
|
||||
return openAISelectionRNG{state: seed}
|
||||
}
|
||||
|
||||
func (r *openAISelectionRNG) nextUint64() uint64 {
|
||||
// xorshift64*
|
||||
x := r.state
|
||||
x ^= x >> 12
|
||||
x ^= x << 25
|
||||
x ^= x >> 27
|
||||
r.state = x
|
||||
return x * 2685821657736338717
|
||||
}
|
||||
|
||||
func (r *openAISelectionRNG) nextFloat64() float64 {
|
||||
// [0,1)
|
||||
return float64(r.nextUint64()>>11) / (1 << 53)
|
||||
}
|
||||
|
||||
func deriveOpenAISelectionSeed(req OpenAIAccountScheduleRequest) uint64 {
|
||||
hasher := fnv.New64a()
|
||||
writeValue := func(value string) {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return
|
||||
}
|
||||
_, _ = hasher.Write([]byte(trimmed))
|
||||
_, _ = hasher.Write([]byte{0})
|
||||
}
|
||||
|
||||
writeValue(req.SessionHash)
|
||||
writeValue(req.PreviousResponseID)
|
||||
writeValue(req.RequestedModel)
|
||||
if req.GroupID != nil {
|
||||
_, _ = hasher.Write([]byte(strconv.FormatInt(*req.GroupID, 10)))
|
||||
}
|
||||
|
||||
seed := hasher.Sum64()
|
||||
// 对“无会话锚点”的纯负载均衡请求引入时间熵,避免固定命中同一账号。
|
||||
if strings.TrimSpace(req.SessionHash) == "" && strings.TrimSpace(req.PreviousResponseID) == "" {
|
||||
seed ^= uint64(time.Now().UnixNano())
|
||||
}
|
||||
if seed == 0 {
|
||||
seed = uint64(time.Now().UnixNano()) ^ 0x9e3779b97f4a7c15
|
||||
}
|
||||
return seed
|
||||
}
|
||||
|
||||
func buildOpenAIWeightedSelectionOrder(
|
||||
candidates []openAIAccountCandidateScore,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
) []openAIAccountCandidateScore {
|
||||
if len(candidates) <= 1 {
|
||||
return append([]openAIAccountCandidateScore(nil), candidates...)
|
||||
}
|
||||
|
||||
pool := append([]openAIAccountCandidateScore(nil), candidates...)
|
||||
weights := make([]float64, len(pool))
|
||||
minScore := pool[0].score
|
||||
for i := 1; i < len(pool); i++ {
|
||||
if pool[i].score < minScore {
|
||||
minScore = pool[i].score
|
||||
}
|
||||
}
|
||||
for i := range pool {
|
||||
// 将 top-K 分值平移到正区间,避免“单一最高分账号”长期垄断。
|
||||
weight := (pool[i].score - minScore) + 1.0
|
||||
if math.IsNaN(weight) || math.IsInf(weight, 0) || weight <= 0 {
|
||||
weight = 1.0
|
||||
}
|
||||
weights[i] = weight
|
||||
}
|
||||
|
||||
order := make([]openAIAccountCandidateScore, 0, len(pool))
|
||||
rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req))
|
||||
for len(pool) > 0 {
|
||||
total := 0.0
|
||||
for _, w := range weights {
|
||||
total += w
|
||||
}
|
||||
|
||||
selectedIdx := 0
|
||||
if total > 0 {
|
||||
r := rng.nextFloat64() * total
|
||||
acc := 0.0
|
||||
for i, w := range weights {
|
||||
acc += w
|
||||
if r <= acc {
|
||||
selectedIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
selectedIdx = int(rng.nextUint64() % uint64(len(pool)))
|
||||
}
|
||||
|
||||
order = append(order, pool[selectedIdx])
|
||||
pool = append(pool[:selectedIdx], pool[selectedIdx+1:]...)
|
||||
weights = append(weights[:selectedIdx], weights[selectedIdx+1:]...)
|
||||
}
|
||||
return order
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
ctx context.Context,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
) (*AccountSelectionResult, int, int, float64, error) {
|
||||
accounts, err := s.service.listSchedulableAccounts(ctx, req.GroupID)
|
||||
if err != nil {
|
||||
return nil, 0, 0, 0, err
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
|
||||
}
|
||||
|
||||
filtered := make([]*Account, 0, len(accounts))
|
||||
loadReq := make([]AccountWithConcurrency, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
account := &accounts[i]
|
||||
if req.ExcludedIDs != nil {
|
||||
if _, excluded := req.ExcludedIDs[account.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !account.IsSchedulable() || !account.IsOpenAI() {
|
||||
continue
|
||||
}
|
||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, account)
|
||||
loadReq = append(loadReq, AccountWithConcurrency{
|
||||
ID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
})
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
|
||||
}
|
||||
|
||||
loadMap := map[int64]*AccountLoadInfo{}
|
||||
if s.service.concurrencyService != nil {
|
||||
if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil {
|
||||
loadMap = batchLoad
|
||||
}
|
||||
}
|
||||
|
||||
minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority
|
||||
maxWaiting := 1
|
||||
loadRateSum := 0.0
|
||||
loadRateSumSquares := 0.0
|
||||
minTTFT, maxTTFT := 0.0, 0.0
|
||||
hasTTFTSample := false
|
||||
candidates := make([]openAIAccountCandidateScore, 0, len(filtered))
|
||||
for _, account := range filtered {
|
||||
loadInfo := loadMap[account.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: account.ID}
|
||||
}
|
||||
if account.Priority < minPriority {
|
||||
minPriority = account.Priority
|
||||
}
|
||||
if account.Priority > maxPriority {
|
||||
maxPriority = account.Priority
|
||||
}
|
||||
if loadInfo.WaitingCount > maxWaiting {
|
||||
maxWaiting = loadInfo.WaitingCount
|
||||
}
|
||||
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
|
||||
if hasTTFT && ttft > 0 {
|
||||
if !hasTTFTSample {
|
||||
minTTFT, maxTTFT = ttft, ttft
|
||||
hasTTFTSample = true
|
||||
} else {
|
||||
if ttft < minTTFT {
|
||||
minTTFT = ttft
|
||||
}
|
||||
if ttft > maxTTFT {
|
||||
maxTTFT = ttft
|
||||
}
|
||||
}
|
||||
}
|
||||
loadRate := float64(loadInfo.LoadRate)
|
||||
loadRateSum += loadRate
|
||||
loadRateSumSquares += loadRate * loadRate
|
||||
candidates = append(candidates, openAIAccountCandidateScore{
|
||||
account: account,
|
||||
loadInfo: loadInfo,
|
||||
errorRate: errorRate,
|
||||
ttft: ttft,
|
||||
hasTTFT: hasTTFT,
|
||||
})
|
||||
}
|
||||
loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
|
||||
|
||||
weights := s.service.openAIWSSchedulerWeights()
|
||||
for i := range candidates {
|
||||
item := &candidates[i]
|
||||
priorityFactor := 1.0
|
||||
if maxPriority > minPriority {
|
||||
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
|
||||
}
|
||||
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
|
||||
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
|
||||
errorFactor := 1 - clamp01(item.errorRate)
|
||||
ttftFactor := 0.5
|
||||
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
|
||||
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
|
||||
}
|
||||
|
||||
item.score = weights.Priority*priorityFactor +
|
||||
weights.Load*loadFactor +
|
||||
weights.Queue*queueFactor +
|
||||
weights.ErrorRate*errorFactor +
|
||||
weights.TTFT*ttftFactor
|
||||
}
|
||||
|
||||
topK := s.service.openAIWSLBTopK()
|
||||
if topK > len(candidates) {
|
||||
topK = len(candidates)
|
||||
}
|
||||
if topK <= 0 {
|
||||
topK = 1
|
||||
}
|
||||
rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
|
||||
selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
|
||||
|
||||
for i := 0; i < len(selectionOrder); i++ {
|
||||
candidate := selectionOrder[i]
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency)
|
||||
if acquireErr != nil {
|
||||
return nil, len(candidates), topK, loadSkew, acquireErr
|
||||
}
|
||||
if result != nil && result.Acquired {
|
||||
if req.SessionHash != "" {
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: candidate.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
}
|
||||
}
|
||||
|
||||
cfg := s.service.schedulingConfig()
|
||||
candidate := selectionOrder[0]
|
||||
return &AccountSelectionResult{
|
||||
Account: candidate.account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: candidate.account.ID,
|
||||
MaxConcurrency: candidate.account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
|
||||
// HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。
|
||||
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
|
||||
return true
|
||||
}
|
||||
if s == nil || s.service == nil || account == nil {
|
||||
return false
|
||||
}
|
||||
return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
|
||||
if s == nil || s.stats == nil {
|
||||
return
|
||||
}
|
||||
s.stats.report(accountID, success, firstTokenMs)
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) ReportSwitch() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.metrics.recordSwitch()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot {
|
||||
if s == nil {
|
||||
return OpenAIAccountSchedulerMetricsSnapshot{}
|
||||
}
|
||||
|
||||
selectTotal := s.metrics.selectTotal.Load()
|
||||
prevHit := s.metrics.stickyPreviousHitTotal.Load()
|
||||
sessionHit := s.metrics.stickySessionHitTotal.Load()
|
||||
switchTotal := s.metrics.accountSwitchTotal.Load()
|
||||
latencyTotal := s.metrics.latencyMsTotal.Load()
|
||||
loadSkewTotal := s.metrics.loadSkewMilliTotal.Load()
|
||||
|
||||
snapshot := OpenAIAccountSchedulerMetricsSnapshot{
|
||||
SelectTotal: selectTotal,
|
||||
StickyPreviousHitTotal: prevHit,
|
||||
StickySessionHitTotal: sessionHit,
|
||||
LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(),
|
||||
AccountSwitchTotal: switchTotal,
|
||||
SchedulerLatencyMsTotal: latencyTotal,
|
||||
RuntimeStatsAccountCount: s.stats.size(),
|
||||
}
|
||||
if selectTotal > 0 {
|
||||
snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal)
|
||||
snapshot.StickyHitRatio = float64(prevHit+sessionHit) / float64(selectTotal)
|
||||
snapshot.AccountSwitchRate = float64(switchTotal) / float64(selectTotal)
|
||||
snapshot.LoadSkewAvg = float64(loadSkewTotal) / 1000 / float64(selectTotal)
|
||||
}
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
s.openaiSchedulerOnce.Do(func() {
|
||||
if s.openaiAccountStats == nil {
|
||||
s.openaiAccountStats = newOpenAIAccountRuntimeStats()
|
||||
}
|
||||
if s.openaiScheduler == nil {
|
||||
s.openaiScheduler = newDefaultOpenAIAccountScheduler(s, s.openaiAccountStats)
|
||||
}
|
||||
})
|
||||
return s.openaiScheduler
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) SelectAccountWithScheduler(
|
||||
ctx context.Context,
|
||||
groupID *int64,
|
||||
previousResponseID string,
|
||||
sessionHash string,
|
||||
requestedModel string,
|
||||
excludedIDs map[int64]struct{},
|
||||
requiredTransport OpenAIUpstreamTransport,
|
||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||
decision := OpenAIAccountScheduleDecision{}
|
||||
scheduler := s.getOpenAIAccountScheduler()
|
||||
if scheduler == nil {
|
||||
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
decision.Layer = openAIAccountScheduleLayerLoadBalance
|
||||
return selection, decision, err
|
||||
}
|
||||
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 {
|
||||
stickyAccountID = accountID
|
||||
}
|
||||
}
|
||||
|
||||
return scheduler.Select(ctx, OpenAIAccountScheduleRequest{
|
||||
GroupID: groupID,
|
||||
SessionHash: sessionHash,
|
||||
StickyAccountID: stickyAccountID,
|
||||
PreviousResponseID: previousResponseID,
|
||||
RequestedModel: requestedModel,
|
||||
RequiredTransport: requiredTransport,
|
||||
ExcludedIDs: excludedIDs,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
|
||||
scheduler := s.getOpenAIAccountScheduler()
|
||||
if scheduler == nil {
|
||||
return
|
||||
}
|
||||
scheduler.ReportResult(accountID, success, firstTokenMs)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
|
||||
scheduler := s.getOpenAIAccountScheduler()
|
||||
if scheduler == nil {
|
||||
return
|
||||
}
|
||||
scheduler.ReportSwitch()
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
|
||||
scheduler := s.getOpenAIAccountScheduler()
|
||||
if scheduler == nil {
|
||||
return OpenAIAccountSchedulerMetricsSnapshot{}
|
||||
}
|
||||
return scheduler.SnapshotMetrics()
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIWSSessionStickyTTL() time.Duration {
|
||||
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 {
|
||||
return time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second
|
||||
}
|
||||
return openaiStickySessionTTL
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIWSLBTopK() int {
|
||||
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.LBTopK > 0 {
|
||||
return s.cfg.Gateway.OpenAIWS.LBTopK
|
||||
}
|
||||
return 7
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView {
|
||||
if s != nil && s.cfg != nil {
|
||||
return GatewayOpenAIWSSchedulerScoreWeightsView{
|
||||
Priority: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority,
|
||||
Load: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load,
|
||||
Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue,
|
||||
ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate,
|
||||
TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT,
|
||||
}
|
||||
}
|
||||
return GatewayOpenAIWSSchedulerScoreWeightsView{
|
||||
Priority: 1.0,
|
||||
Load: 1.0,
|
||||
Queue: 0.7,
|
||||
ErrorRate: 0.8,
|
||||
TTFT: 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
type GatewayOpenAIWSSchedulerScoreWeightsView struct {
|
||||
Priority float64
|
||||
Load float64
|
||||
Queue float64
|
||||
ErrorRate float64
|
||||
TTFT float64
|
||||
}
|
||||
|
||||
func clamp01(value float64) float64 {
|
||||
switch {
|
||||
case value < 0:
|
||||
return 0
|
||||
case value > 1:
|
||||
return 1
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 {
|
||||
if count <= 1 {
|
||||
return 0
|
||||
}
|
||||
mean := sum / float64(count)
|
||||
variance := sumSquares/float64(count) - mean*mean
|
||||
if variance < 0 {
|
||||
variance = 0
|
||||
}
|
||||
return math.Sqrt(variance)
|
||||
}
|
||||
Reference in New Issue
Block a user