Files
kirogo/pool/account.go
hkxiaoyao ad7aabd554 feat: Add validation and account management functionality (#21)
* feat: Add validation and account management functionality

- Add validation for clientID and clientSecret in refreshOIDCToken function
- Add weight field for load balancing priority in Account struct
- Implement weighted轮询策略以根据账号权重分配选择概率。
- Add batch account management functionality including enabling, disabling, refreshing, and retrieving account details.
- Update Kiro API version and adjust user agent strings to reflect new version numbers.
- Update Kiro version and modify user agent strings and header settings.
- Refactor model mapping to an ordered list for precise key matching.
- Add account bulk actions and filtering toolbar to index.html

* feat: Add logic to skip accounts with exhausted usage limits

- Add logic to skip accounts with exhausted usage limits when selecting the next account.
2026-02-23 21:47:17 +08:00

219 lines
4.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Package pool 账号池管理
// 实现轮询负载均衡、错误冷却、Token 刷新
package pool
import (
"kiro-api-proxy/config"
"sync"
"sync/atomic"
"time"
)
// AccountPool 账号池
type AccountPool struct {
mu sync.RWMutex
accounts []config.Account
currentIndex uint64
cooldowns map[string]time.Time // 账号冷却时间
errorCounts map[string]int // 连续错误计数
}
var (
pool *AccountPool
poolOnce sync.Once
)
// GetPool 获取全局账号池单例
func GetPool() *AccountPool {
poolOnce.Do(func() {
pool = &AccountPool{
cooldowns: make(map[string]time.Time),
errorCounts: make(map[string]int),
}
pool.Reload()
})
return pool
}
// Reload 从配置重新加载账号
// 构建加权列表weight<=1 出现 1 次weight>=2 出现 weight 次
func (p *AccountPool) Reload() {
p.mu.Lock()
defer p.mu.Unlock()
enabled := config.GetEnabledAccounts()
var weighted []config.Account
for _, a := range enabled {
w := a.Weight
if w < 1 {
w = 1
}
for j := 0; j < w; j++ {
weighted = append(weighted, a)
}
}
p.accounts = weighted
}
// GetNext 获取下一个可用账号(加权轮询)
func (p *AccountPool) GetNext() *config.Account {
p.mu.RLock()
defer p.mu.RUnlock()
if len(p.accounts) == 0 {
return nil
}
now := time.Now()
n := len(p.accounts)
seen := make(map[string]bool)
// 加权轮询查找可用账号
for i := 0; i < n; i++ {
idx := atomic.AddUint64(&p.currentIndex, 1) % uint64(n)
acc := &p.accounts[idx]
if seen[acc.ID] {
continue
}
// 跳过冷却中的账号
if cooldown, ok := p.cooldowns[acc.ID]; ok && now.Before(cooldown) {
seen[acc.ID] = true
continue
}
// 跳过即将过期的 Token
if acc.ExpiresAt > 0 && time.Now().Unix() > acc.ExpiresAt-300 {
seen[acc.ID] = true
continue
}
// 跳过额度已用尽的账号(适用于所有订阅类型)
if acc.UsageLimit > 0 && acc.UsageCurrent >= acc.UsageLimit {
seen[acc.ID] = true
continue
}
return acc
}
// 无可用账号,返回冷却时间最短的(排除额度用尽的)
var best *config.Account
var earliest time.Time
for i := range p.accounts {
acc := &p.accounts[i]
// 额度用尽的账号不作为 fallback
if acc.UsageLimit > 0 && acc.UsageCurrent >= acc.UsageLimit {
continue
}
if cooldown, ok := p.cooldowns[acc.ID]; ok {
if best == nil || cooldown.Before(earliest) {
best = acc
earliest = cooldown
}
} else {
return acc
}
}
return best
}
// GetByID 根据 ID 获取账号
func (p *AccountPool) GetByID(id string) *config.Account {
p.mu.RLock()
defer p.mu.RUnlock()
for i := range p.accounts {
if p.accounts[i].ID == id {
return &p.accounts[i]
}
}
return nil
}
// RecordSuccess 记录请求成功,清除冷却
func (p *AccountPool) RecordSuccess(id string) {
p.mu.Lock()
defer p.mu.Unlock()
delete(p.cooldowns, id)
p.errorCounts[id] = 0
}
// RecordError 记录请求错误,设置冷却
func (p *AccountPool) RecordError(id string, isQuotaError bool) {
p.mu.Lock()
defer p.mu.Unlock()
p.errorCounts[id]++
if isQuotaError {
// 配额错误,冷却 1 小时
p.cooldowns[id] = time.Now().Add(time.Hour)
} else if p.errorCounts[id] >= 3 {
// 连续 3 次错误,冷却 1 分钟
p.cooldowns[id] = time.Now().Add(time.Minute)
}
}
// UpdateToken 更新账号 Token
func (p *AccountPool) UpdateToken(id, accessToken, refreshToken string, expiresAt int64) {
p.mu.Lock()
defer p.mu.Unlock()
for i := range p.accounts {
if p.accounts[i].ID == id {
p.accounts[i].AccessToken = accessToken
if refreshToken != "" {
p.accounts[i].RefreshToken = refreshToken
}
p.accounts[i].ExpiresAt = expiresAt
break
}
}
}
// Count 返回账号总数
func (p *AccountPool) Count() int {
p.mu.RLock()
defer p.mu.RUnlock()
return len(p.accounts)
}
// AvailableCount 返回可用账号数
func (p *AccountPool) AvailableCount() int {
p.mu.RLock()
defer p.mu.RUnlock()
now := time.Now()
count := 0
for _, acc := range p.accounts {
if cooldown, ok := p.cooldowns[acc.ID]; ok && now.Before(cooldown) {
continue
}
count++
}
return count
}
// UpdateStats 更新账号统计
func (p *AccountPool) UpdateStats(id string, tokens int, credits float64) {
p.mu.Lock()
defer p.mu.Unlock()
for i := range p.accounts {
if p.accounts[i].ID == id {
p.accounts[i].RequestCount++
p.accounts[i].TotalTokens += tokens
p.accounts[i].TotalCredits += credits
p.accounts[i].LastUsed = time.Now().Unix()
go config.UpdateAccountStats(id, p.accounts[i].RequestCount, p.accounts[i].ErrorCount, p.accounts[i].TotalTokens, p.accounts[i].TotalCredits, p.accounts[i].LastUsed)
break
}
}
}
// GetAllAccounts 获取所有账号副本
func (p *AccountPool) GetAllAccounts() []config.Account {
p.mu.RLock()
defer p.mu.RUnlock()
result := make([]config.Account, len(p.accounts))
copy(result, p.accounts)
return result
}