feat: add versioning, account export, and dynamic models list

This commit is contained in:
Quorinex
2026-02-08 01:31:27 +08:00
parent 9aad3dec7e
commit 3e7cca04ba
6 changed files with 587 additions and 54 deletions

View File

@@ -29,6 +29,10 @@ type Handler struct {
startTime int64
stopRefresh chan struct{}
stopStatsSaver chan struct{}
// 模型缓存
cachedModels []ModelInfo
modelsCacheMu sync.RWMutex
modelsCacheTime int64
}
func NewHandler() *Handler {
@@ -58,11 +62,13 @@ func (h *Handler) backgroundRefresh() {
// 启动时延迟 10 秒后执行一次
time.Sleep(10 * time.Second)
h.refreshModelsCache()
h.refreshAllAccounts()
for {
select {
case <-ticker.C:
h.refreshModelsCache()
h.refreshAllAccounts()
case <-h.stopRefresh:
return
@@ -211,19 +217,44 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) {
// handleModels 模型列表
func (h *Handler) handleModels(w http.ResponseWriter, r *http.Request) {
models := []map[string]interface{}{
{"id": "claude-sonnet-4.5", "object": "model", "owned_by": "anthropic"},
{"id": "claude-sonnet-4.5-thinking", "object": "model", "owned_by": "anthropic"},
{"id": "claude-sonnet-4", "object": "model", "owned_by": "anthropic"},
{"id": "claude-sonnet-4-thinking", "object": "model", "owned_by": "anthropic"},
{"id": "claude-haiku-4.5", "object": "model", "owned_by": "anthropic"},
{"id": "claude-haiku-4.5-thinking", "object": "model", "owned_by": "anthropic"},
{"id": "claude-opus-4.5", "object": "model", "owned_by": "anthropic"},
{"id": "claude-opus-4.5-thinking", "object": "model", "owned_by": "anthropic"},
{"id": "auto", "object": "model", "owned_by": "kiro-api"},
{"id": "gpt-4o", "object": "model", "owned_by": "kiro-proxy"},
{"id": "gpt-4", "object": "model", "owned_by": "kiro-proxy"},
// 尝试用缓存的真实模型列表
h.modelsCacheMu.RLock()
cached := h.cachedModels
h.modelsCacheMu.RUnlock()
thinkingSuffix := config.GetThinkingConfig().Suffix
var models []map[string]interface{}
if len(cached) > 0 {
for _, m := range cached {
models = append(models, map[string]interface{}{
"id": m.ModelId, "object": "model", "owned_by": "anthropic",
})
// 自动生成 thinking 变体
models = append(models, map[string]interface{}{
"id": m.ModelId + thinkingSuffix, "object": "model", "owned_by": "anthropic",
})
}
} else {
// fallback 静态列表
models = []map[string]interface{}{
{"id": "claude-sonnet-4.5", "object": "model", "owned_by": "anthropic"},
{"id": "claude-sonnet-4.5" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
{"id": "claude-sonnet-4", "object": "model", "owned_by": "anthropic"},
{"id": "claude-sonnet-4" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
{"id": "claude-haiku-4.5", "object": "model", "owned_by": "anthropic"},
{"id": "claude-haiku-4.5" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
{"id": "claude-opus-4.5", "object": "model", "owned_by": "anthropic"},
{"id": "claude-opus-4.5" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
}
}
// 添加别名模型
models = append(models,
map[string]interface{}{"id": "auto", "object": "model", "owned_by": "kiro-proxy"},
map[string]interface{}{"id": "gpt-4o", "object": "model", "owned_by": "kiro-proxy"},
map[string]interface{}{"id": "gpt-4", "object": "model", "owned_by": "kiro-proxy"},
)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]interface{}{
"object": "list",
@@ -231,6 +262,33 @@ func (h *Handler) handleModels(w http.ResponseWriter, r *http.Request) {
})
}
// refreshModelsCache 从 Kiro API 拉取模型列表并缓存
func (h *Handler) refreshModelsCache() {
account := h.pool.GetNext()
if account == nil {
return
}
// 确保 token 有效
if err := h.ensureValidToken(account); err != nil {
return
}
models, err := ListAvailableModels(account)
if err != nil {
fmt.Printf("[ModelsCache] Failed to refresh: %v\n", err)
return
}
if len(models) > 0 {
h.modelsCacheMu.Lock()
h.cachedModels = models
h.modelsCacheTime = time.Now().Unix()
h.modelsCacheMu.Unlock()
fmt.Printf("[ModelsCache] Cached %d models\n", len(models))
}
}
// handleCountTokens Token 计数Claude Code 会调用)
func (h *Handler) handleCountTokens(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
@@ -1282,6 +1340,10 @@ func (h *Handler) handleAdminAPI(w http.ResponseWriter, r *http.Request) {
h.apiGetEndpointConfig(w, r)
case path == "/endpoint" && r.Method == "POST":
h.apiUpdateEndpointConfig(w, r)
case path == "/version" && r.Method == "GET":
h.apiGetVersion(w, r)
case path == "/export" && r.Method == "POST":
h.apiExportAccounts(w, r)
default:
w.WriteHeader(404)
json.NewEncoder(w).Encode(map[string]string{"error": "Not Found"})
@@ -1707,31 +1769,47 @@ func (h *Handler) apiImportCredentials(w http.ResponseWriter, r *http.Request) {
req.AuthMethod = "social"
}
}
// 如果没有 accessToken尝试刷新获取
accessToken := req.AccessToken
var expiresAt int64
if accessToken == "" {
tempAccount := &config.Account{
RefreshToken: req.RefreshToken,
ClientID: req.ClientID,
ClientSecret: req.ClientSecret,
AuthMethod: req.AuthMethod,
Region: req.Region,
// 标准化 authMethod
switch strings.ToLower(req.AuthMethod) {
case "idc", "builderid", "enterprise":
req.AuthMethod = "idc"
case "social", "google", "github":
req.AuthMethod = "social"
default:
if req.ClientID != "" && req.ClientSecret != "" {
req.AuthMethod = "idc"
} else {
req.AuthMethod = "social"
}
newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(tempAccount)
if err != nil {
}
// 始终尝试用 refreshToken 刷新获取新的 accessToken
var accessToken string
var expiresAt int64
tempAccount := &config.Account{
RefreshToken: req.RefreshToken,
ClientID: req.ClientID,
ClientSecret: req.ClientSecret,
AuthMethod: req.AuthMethod,
Region: req.Region,
}
newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(tempAccount)
if err != nil {
// 刷新失败,如果有传入的 accessToken 则尝试使用
if req.AccessToken != "" {
accessToken = req.AccessToken
expiresAt = time.Now().Unix() + 300 // 可能已过期,设短一点
} else {
w.WriteHeader(400)
json.NewEncoder(w).Encode(map[string]string{"error": "Token refresh failed: " + err.Error()})
return
}
} else {
accessToken = newAccessToken
if newRefreshToken != "" {
req.RefreshToken = newRefreshToken
}
expiresAt = newExpiresAt
} else {
expiresAt = time.Now().Unix() + 3600 // 默认 1 小时
}
// 获取用户信息
@@ -1858,13 +1936,14 @@ func (h *Handler) apiRefreshAccount(w http.ResponseWriter, r *http.Request, id s
return
}
// 检查 token 是否过期,需要刷新
if account.ExpiresAt > 0 && time.Now().Unix() > account.ExpiresAt-60 {
// 先尝试刷新 token(不管是否过期,确保 token 有效)
refreshTokenIfNeeded := func() error {
if account.RefreshToken == "" {
return nil
}
newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(account)
if err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": "Token refresh failed: " + err.Error()})
return
return err
}
account.AccessToken = newAccessToken
if newRefreshToken != "" {
@@ -1873,14 +1952,34 @@ func (h *Handler) apiRefreshAccount(w http.ResponseWriter, r *http.Request, id s
account.ExpiresAt = newExpiresAt
config.UpdateAccountToken(id, newAccessToken, newRefreshToken, newExpiresAt)
h.pool.UpdateToken(id, newAccessToken, newRefreshToken, newExpiresAt)
return nil
}
// 检查 token 是否快过期,先刷新
if account.ExpiresAt > 0 && time.Now().Unix() > account.ExpiresAt-300 {
if err := refreshTokenIfNeeded(); err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": "Token refresh failed: " + err.Error()})
return
}
}
// 获取账户信息
info, err := RefreshAccountInfo(account)
if err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
// 如果是 403/401说明 token 无效,尝试刷新后重试
errMsg := err.Error()
if strings.Contains(errMsg, "403") || strings.Contains(errMsg, "401") || strings.Contains(errMsg, "invalid") || strings.Contains(errMsg, "expired") {
if refreshErr := refreshTokenIfNeeded(); refreshErr == nil {
// 重试
info, err = RefreshAccountInfo(account)
}
}
if err != nil {
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
}
// 保存到配置
@@ -2015,3 +2114,161 @@ func (h *Handler) apiUpdateEndpointConfig(w http.ResponseWriter, r *http.Request
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
// apiGetVersion 获取版本信息
func (h *Handler) apiGetVersion(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"version": config.Version,
})
}
// apiExportAccounts 导出账号凭证
func (h *Handler) apiExportAccounts(w http.ResponseWriter, r *http.Request) {
var req struct {
IDs []string `json:"ids"` // 为空则导出全部
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
// 如果 body 为空或解析失败,导出全部
req.IDs = nil
}
accounts := config.GetAccounts()
// 如果指定了 ID只导出指定的
if len(req.IDs) > 0 {
idSet := make(map[string]bool)
for _, id := range req.IDs {
idSet[id] = true
}
var filtered []config.Account
for _, a := range accounts {
if idSet[a.ID] {
filtered = append(filtered, a)
}
}
accounts = filtered
}
// 构建兼容 Kiro Account Manager 的导出格式
type ExportCredentials struct {
AccessToken string `json:"accessToken"`
CsrfToken string `json:"csrfToken"`
RefreshToken string `json:"refreshToken,omitempty"`
ClientID string `json:"clientId,omitempty"`
ClientSecret string `json:"clientSecret,omitempty"`
Region string `json:"region,omitempty"`
ExpiresAt int64 `json:"expiresAt"`
AuthMethod string `json:"authMethod,omitempty"`
Provider string `json:"provider,omitempty"`
}
type ExportSubscription struct {
Type string `json:"type"`
Title string `json:"title,omitempty"`
}
type ExportUsage struct {
Current float64 `json:"current"`
Limit float64 `json:"limit"`
PercentUsed float64 `json:"percentUsed"`
LastUpdated int64 `json:"lastUpdated"`
}
type ExportAccount struct {
ID string `json:"id"`
Email string `json:"email"`
Nickname string `json:"nickname,omitempty"`
Idp string `json:"idp"`
UserId string `json:"userId,omitempty"`
MachineId string `json:"machineId,omitempty"`
Credentials ExportCredentials `json:"credentials"`
Subscription ExportSubscription `json:"subscription"`
Usage ExportUsage `json:"usage"`
Tags []string `json:"tags"`
Status string `json:"status"`
CreatedAt int64 `json:"createdAt"`
LastUsedAt int64 `json:"lastUsedAt"`
}
type ExportData struct {
Version string `json:"version"`
ExportedAt int64 `json:"exportedAt"`
Accounts []ExportAccount `json:"accounts"`
Groups []interface{} `json:"groups"`
Tags []interface{} `json:"tags"`
}
exportAccounts := make([]ExportAccount, 0, len(accounts))
for _, a := range accounts {
// 映射 provider 到 idp
idp := a.Provider
if idp == "" {
if a.AuthMethod == "social" {
idp = "Google"
} else {
idp = "BuilderId"
}
}
// 映射 authMethod
authMethod := a.AuthMethod
if authMethod == "idc" {
authMethod = "IdC"
}
// 映射订阅类型
subType := "Free"
rawType := strings.ToUpper(a.SubscriptionType)
if strings.Contains(rawType, "PRO_PLUS") || strings.Contains(rawType, "PROPLUS") {
subType = "Pro_Plus"
} else if strings.Contains(rawType, "PRO") {
subType = "Pro"
} else if strings.Contains(rawType, "POWER") {
subType = "Pro_Plus"
}
exportAccounts = append(exportAccounts, ExportAccount{
ID: a.ID,
Email: a.Email,
Nickname: a.Nickname,
Idp: idp,
UserId: a.UserId,
MachineId: a.MachineId,
Credentials: ExportCredentials{
AccessToken: a.AccessToken,
CsrfToken: "",
RefreshToken: a.RefreshToken,
ClientID: a.ClientID,
ClientSecret: a.ClientSecret,
Region: a.Region,
ExpiresAt: a.ExpiresAt * 1000, // 转为毫秒时间戳
AuthMethod: authMethod,
Provider: a.Provider,
},
Subscription: ExportSubscription{
Type: subType,
Title: a.SubscriptionTitle,
},
Usage: ExportUsage{
Current: a.UsageCurrent,
Limit: a.UsageLimit,
PercentUsed: a.UsagePercent,
LastUpdated: time.Now().UnixMilli(),
},
Tags: []string{},
Status: "active",
CreatedAt: time.Now().UnixMilli(),
LastUsedAt: time.Now().UnixMilli(),
})
}
data := ExportData{
Version: config.Version,
ExportedAt: time.Now().UnixMilli(),
Accounts: exportAccounts,
Groups: []interface{}{},
Tags: []interface{}{},
}
json.NewEncoder(w).Encode(data)
}