* feat: Add validation and account management functionality - Add validation for clientID and clientSecret in refreshOIDCToken function - Add weight field for load balancing priority in Account struct - Implement weighted轮询策略以根据账号权重分配选择概率。 - Add batch account management functionality including enabling, disabling, refreshing, and retrieving account details. - Update Kiro API version and adjust user agent strings to reflect new version numbers. - Update Kiro version and modify user agent strings and header settings. - Refactor model mapping to an ordered list for precise key matching. - Add account bulk actions and filtering toolbar to index.html * feat: Add logic to skip accounts with exhausted usage limits - Add logic to skip accounts with exhausted usage limits when selecting the next account.
219 lines
4.8 KiB
Go
219 lines
4.8 KiB
Go
// Package pool 账号池管理
|
||
// 实现轮询负载均衡、错误冷却、Token 刷新
|
||
package pool
|
||
|
||
import (
|
||
"kiro-api-proxy/config"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
)
|
||
|
||
// AccountPool 账号池
|
||
type AccountPool struct {
|
||
mu sync.RWMutex
|
||
accounts []config.Account
|
||
currentIndex uint64
|
||
cooldowns map[string]time.Time // 账号冷却时间
|
||
errorCounts map[string]int // 连续错误计数
|
||
}
|
||
|
||
var (
|
||
pool *AccountPool
|
||
poolOnce sync.Once
|
||
)
|
||
|
||
// GetPool 获取全局账号池单例
|
||
func GetPool() *AccountPool {
|
||
poolOnce.Do(func() {
|
||
pool = &AccountPool{
|
||
cooldowns: make(map[string]time.Time),
|
||
errorCounts: make(map[string]int),
|
||
}
|
||
pool.Reload()
|
||
})
|
||
return pool
|
||
}
|
||
|
||
// Reload 从配置重新加载账号
|
||
// 构建加权列表:weight<=1 出现 1 次,weight>=2 出现 weight 次
|
||
func (p *AccountPool) Reload() {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
enabled := config.GetEnabledAccounts()
|
||
var weighted []config.Account
|
||
for _, a := range enabled {
|
||
w := a.Weight
|
||
if w < 1 {
|
||
w = 1
|
||
}
|
||
for j := 0; j < w; j++ {
|
||
weighted = append(weighted, a)
|
||
}
|
||
}
|
||
p.accounts = weighted
|
||
}
|
||
|
||
// GetNext 获取下一个可用账号(加权轮询)
|
||
func (p *AccountPool) GetNext() *config.Account {
|
||
p.mu.RLock()
|
||
defer p.mu.RUnlock()
|
||
|
||
if len(p.accounts) == 0 {
|
||
return nil
|
||
}
|
||
|
||
now := time.Now()
|
||
n := len(p.accounts)
|
||
seen := make(map[string]bool)
|
||
|
||
// 加权轮询查找可用账号
|
||
for i := 0; i < n; i++ {
|
||
idx := atomic.AddUint64(&p.currentIndex, 1) % uint64(n)
|
||
acc := &p.accounts[idx]
|
||
|
||
if seen[acc.ID] {
|
||
continue
|
||
}
|
||
|
||
// 跳过冷却中的账号
|
||
if cooldown, ok := p.cooldowns[acc.ID]; ok && now.Before(cooldown) {
|
||
seen[acc.ID] = true
|
||
continue
|
||
}
|
||
|
||
// 跳过即将过期的 Token
|
||
if acc.ExpiresAt > 0 && time.Now().Unix() > acc.ExpiresAt-300 {
|
||
seen[acc.ID] = true
|
||
continue
|
||
}
|
||
|
||
// 跳过额度已用尽的账号(适用于所有订阅类型)
|
||
if acc.UsageLimit > 0 && acc.UsageCurrent >= acc.UsageLimit {
|
||
seen[acc.ID] = true
|
||
continue
|
||
}
|
||
|
||
return acc
|
||
}
|
||
|
||
// 无可用账号,返回冷却时间最短的(排除额度用尽的)
|
||
var best *config.Account
|
||
var earliest time.Time
|
||
for i := range p.accounts {
|
||
acc := &p.accounts[i]
|
||
// 额度用尽的账号不作为 fallback
|
||
if acc.UsageLimit > 0 && acc.UsageCurrent >= acc.UsageLimit {
|
||
continue
|
||
}
|
||
if cooldown, ok := p.cooldowns[acc.ID]; ok {
|
||
if best == nil || cooldown.Before(earliest) {
|
||
best = acc
|
||
earliest = cooldown
|
||
}
|
||
} else {
|
||
return acc
|
||
}
|
||
}
|
||
return best
|
||
}
|
||
|
||
// GetByID 根据 ID 获取账号
|
||
func (p *AccountPool) GetByID(id string) *config.Account {
|
||
p.mu.RLock()
|
||
defer p.mu.RUnlock()
|
||
for i := range p.accounts {
|
||
if p.accounts[i].ID == id {
|
||
return &p.accounts[i]
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// RecordSuccess 记录请求成功,清除冷却
|
||
func (p *AccountPool) RecordSuccess(id string) {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
delete(p.cooldowns, id)
|
||
p.errorCounts[id] = 0
|
||
}
|
||
|
||
// RecordError 记录请求错误,设置冷却
|
||
func (p *AccountPool) RecordError(id string, isQuotaError bool) {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
|
||
p.errorCounts[id]++
|
||
|
||
if isQuotaError {
|
||
// 配额错误,冷却 1 小时
|
||
p.cooldowns[id] = time.Now().Add(time.Hour)
|
||
} else if p.errorCounts[id] >= 3 {
|
||
// 连续 3 次错误,冷却 1 分钟
|
||
p.cooldowns[id] = time.Now().Add(time.Minute)
|
||
}
|
||
}
|
||
|
||
// UpdateToken 更新账号 Token
|
||
func (p *AccountPool) UpdateToken(id, accessToken, refreshToken string, expiresAt int64) {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
for i := range p.accounts {
|
||
if p.accounts[i].ID == id {
|
||
p.accounts[i].AccessToken = accessToken
|
||
if refreshToken != "" {
|
||
p.accounts[i].RefreshToken = refreshToken
|
||
}
|
||
p.accounts[i].ExpiresAt = expiresAt
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// Count 返回账号总数
|
||
func (p *AccountPool) Count() int {
|
||
p.mu.RLock()
|
||
defer p.mu.RUnlock()
|
||
return len(p.accounts)
|
||
}
|
||
|
||
// AvailableCount 返回可用账号数
|
||
func (p *AccountPool) AvailableCount() int {
|
||
p.mu.RLock()
|
||
defer p.mu.RUnlock()
|
||
now := time.Now()
|
||
count := 0
|
||
for _, acc := range p.accounts {
|
||
if cooldown, ok := p.cooldowns[acc.ID]; ok && now.Before(cooldown) {
|
||
continue
|
||
}
|
||
count++
|
||
}
|
||
return count
|
||
}
|
||
|
||
// UpdateStats 更新账号统计
|
||
func (p *AccountPool) UpdateStats(id string, tokens int, credits float64) {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
for i := range p.accounts {
|
||
if p.accounts[i].ID == id {
|
||
p.accounts[i].RequestCount++
|
||
p.accounts[i].TotalTokens += tokens
|
||
p.accounts[i].TotalCredits += credits
|
||
p.accounts[i].LastUsed = time.Now().Unix()
|
||
go config.UpdateAccountStats(id, p.accounts[i].RequestCount, p.accounts[i].ErrorCount, p.accounts[i].TotalTokens, p.accounts[i].TotalCredits, p.accounts[i].LastUsed)
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// GetAllAccounts 获取所有账号副本
|
||
func (p *AccountPool) GetAllAccounts() []config.Account {
|
||
p.mu.RLock()
|
||
defer p.mu.RUnlock()
|
||
result := make([]config.Account, len(p.accounts))
|
||
copy(result, p.accounts)
|
||
return result
|
||
}
|