feat: add versioning, account export, and dynamic models list
This commit is contained in:
327
proxy/handler.go
327
proxy/handler.go
@@ -29,6 +29,10 @@ type Handler struct {
|
||||
startTime int64
|
||||
stopRefresh chan struct{}
|
||||
stopStatsSaver chan struct{}
|
||||
// 模型缓存
|
||||
cachedModels []ModelInfo
|
||||
modelsCacheMu sync.RWMutex
|
||||
modelsCacheTime int64
|
||||
}
|
||||
|
||||
func NewHandler() *Handler {
|
||||
@@ -58,11 +62,13 @@ func (h *Handler) backgroundRefresh() {
|
||||
|
||||
// 启动时延迟 10 秒后执行一次
|
||||
time.Sleep(10 * time.Second)
|
||||
h.refreshModelsCache()
|
||||
h.refreshAllAccounts()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
h.refreshModelsCache()
|
||||
h.refreshAllAccounts()
|
||||
case <-h.stopRefresh:
|
||||
return
|
||||
@@ -211,19 +217,44 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// handleModels 模型列表
|
||||
func (h *Handler) handleModels(w http.ResponseWriter, r *http.Request) {
|
||||
models := []map[string]interface{}{
|
||||
{"id": "claude-sonnet-4.5", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-sonnet-4.5-thinking", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-sonnet-4", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-sonnet-4-thinking", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-haiku-4.5", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-haiku-4.5-thinking", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-opus-4.5", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-opus-4.5-thinking", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "auto", "object": "model", "owned_by": "kiro-api"},
|
||||
{"id": "gpt-4o", "object": "model", "owned_by": "kiro-proxy"},
|
||||
{"id": "gpt-4", "object": "model", "owned_by": "kiro-proxy"},
|
||||
// 尝试用缓存的真实模型列表
|
||||
h.modelsCacheMu.RLock()
|
||||
cached := h.cachedModels
|
||||
h.modelsCacheMu.RUnlock()
|
||||
|
||||
thinkingSuffix := config.GetThinkingConfig().Suffix
|
||||
|
||||
var models []map[string]interface{}
|
||||
if len(cached) > 0 {
|
||||
for _, m := range cached {
|
||||
models = append(models, map[string]interface{}{
|
||||
"id": m.ModelId, "object": "model", "owned_by": "anthropic",
|
||||
})
|
||||
// 自动生成 thinking 变体
|
||||
models = append(models, map[string]interface{}{
|
||||
"id": m.ModelId + thinkingSuffix, "object": "model", "owned_by": "anthropic",
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// fallback 静态列表
|
||||
models = []map[string]interface{}{
|
||||
{"id": "claude-sonnet-4.5", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-sonnet-4.5" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-sonnet-4", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-sonnet-4" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-haiku-4.5", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-haiku-4.5" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-opus-4.5", "object": "model", "owned_by": "anthropic"},
|
||||
{"id": "claude-opus-4.5" + thinkingSuffix, "object": "model", "owned_by": "anthropic"},
|
||||
}
|
||||
}
|
||||
// 添加别名模型
|
||||
models = append(models,
|
||||
map[string]interface{}{"id": "auto", "object": "model", "owned_by": "kiro-proxy"},
|
||||
map[string]interface{}{"id": "gpt-4o", "object": "model", "owned_by": "kiro-proxy"},
|
||||
map[string]interface{}{"id": "gpt-4", "object": "model", "owned_by": "kiro-proxy"},
|
||||
)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"object": "list",
|
||||
@@ -231,6 +262,33 @@ func (h *Handler) handleModels(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// refreshModelsCache 从 Kiro API 拉取模型列表并缓存
|
||||
func (h *Handler) refreshModelsCache() {
|
||||
account := h.pool.GetNext()
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 确保 token 有效
|
||||
if err := h.ensureValidToken(account); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
models, err := ListAvailableModels(account)
|
||||
if err != nil {
|
||||
fmt.Printf("[ModelsCache] Failed to refresh: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(models) > 0 {
|
||||
h.modelsCacheMu.Lock()
|
||||
h.cachedModels = models
|
||||
h.modelsCacheTime = time.Now().Unix()
|
||||
h.modelsCacheMu.Unlock()
|
||||
fmt.Printf("[ModelsCache] Cached %d models\n", len(models))
|
||||
}
|
||||
}
|
||||
|
||||
// handleCountTokens Token 计数(Claude Code 会调用)
|
||||
func (h *Handler) handleCountTokens(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
@@ -1282,6 +1340,10 @@ func (h *Handler) handleAdminAPI(w http.ResponseWriter, r *http.Request) {
|
||||
h.apiGetEndpointConfig(w, r)
|
||||
case path == "/endpoint" && r.Method == "POST":
|
||||
h.apiUpdateEndpointConfig(w, r)
|
||||
case path == "/version" && r.Method == "GET":
|
||||
h.apiGetVersion(w, r)
|
||||
case path == "/export" && r.Method == "POST":
|
||||
h.apiExportAccounts(w, r)
|
||||
default:
|
||||
w.WriteHeader(404)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Not Found"})
|
||||
@@ -1707,31 +1769,47 @@ func (h *Handler) apiImportCredentials(w http.ResponseWriter, r *http.Request) {
|
||||
req.AuthMethod = "social"
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有 accessToken,尝试刷新获取
|
||||
accessToken := req.AccessToken
|
||||
var expiresAt int64
|
||||
if accessToken == "" {
|
||||
tempAccount := &config.Account{
|
||||
RefreshToken: req.RefreshToken,
|
||||
ClientID: req.ClientID,
|
||||
ClientSecret: req.ClientSecret,
|
||||
AuthMethod: req.AuthMethod,
|
||||
Region: req.Region,
|
||||
// 标准化 authMethod
|
||||
switch strings.ToLower(req.AuthMethod) {
|
||||
case "idc", "builderid", "enterprise":
|
||||
req.AuthMethod = "idc"
|
||||
case "social", "google", "github":
|
||||
req.AuthMethod = "social"
|
||||
default:
|
||||
if req.ClientID != "" && req.ClientSecret != "" {
|
||||
req.AuthMethod = "idc"
|
||||
} else {
|
||||
req.AuthMethod = "social"
|
||||
}
|
||||
newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(tempAccount)
|
||||
if err != nil {
|
||||
}
|
||||
|
||||
// 始终尝试用 refreshToken 刷新获取新的 accessToken
|
||||
var accessToken string
|
||||
var expiresAt int64
|
||||
tempAccount := &config.Account{
|
||||
RefreshToken: req.RefreshToken,
|
||||
ClientID: req.ClientID,
|
||||
ClientSecret: req.ClientSecret,
|
||||
AuthMethod: req.AuthMethod,
|
||||
Region: req.Region,
|
||||
}
|
||||
newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(tempAccount)
|
||||
if err != nil {
|
||||
// 刷新失败,如果有传入的 accessToken 则尝试使用
|
||||
if req.AccessToken != "" {
|
||||
accessToken = req.AccessToken
|
||||
expiresAt = time.Now().Unix() + 300 // 可能已过期,设短一点
|
||||
} else {
|
||||
w.WriteHeader(400)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Token refresh failed: " + err.Error()})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
accessToken = newAccessToken
|
||||
if newRefreshToken != "" {
|
||||
req.RefreshToken = newRefreshToken
|
||||
}
|
||||
expiresAt = newExpiresAt
|
||||
} else {
|
||||
expiresAt = time.Now().Unix() + 3600 // 默认 1 小时
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
@@ -1858,13 +1936,14 @@ func (h *Handler) apiRefreshAccount(w http.ResponseWriter, r *http.Request, id s
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 token 是否过期,需要刷新
|
||||
if account.ExpiresAt > 0 && time.Now().Unix() > account.ExpiresAt-60 {
|
||||
// 先尝试刷新 token(不管是否过期,确保 token 有效)
|
||||
refreshTokenIfNeeded := func() error {
|
||||
if account.RefreshToken == "" {
|
||||
return nil
|
||||
}
|
||||
newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(account)
|
||||
if err != nil {
|
||||
w.WriteHeader(500)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Token refresh failed: " + err.Error()})
|
||||
return
|
||||
return err
|
||||
}
|
||||
account.AccessToken = newAccessToken
|
||||
if newRefreshToken != "" {
|
||||
@@ -1873,14 +1952,34 @@ func (h *Handler) apiRefreshAccount(w http.ResponseWriter, r *http.Request, id s
|
||||
account.ExpiresAt = newExpiresAt
|
||||
config.UpdateAccountToken(id, newAccessToken, newRefreshToken, newExpiresAt)
|
||||
h.pool.UpdateToken(id, newAccessToken, newRefreshToken, newExpiresAt)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查 token 是否快过期,先刷新
|
||||
if account.ExpiresAt > 0 && time.Now().Unix() > account.ExpiresAt-300 {
|
||||
if err := refreshTokenIfNeeded(); err != nil {
|
||||
w.WriteHeader(500)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Token refresh failed: " + err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 获取账户信息
|
||||
info, err := RefreshAccountInfo(account)
|
||||
if err != nil {
|
||||
w.WriteHeader(500)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
|
||||
return
|
||||
// 如果是 403/401,说明 token 无效,尝试刷新后重试
|
||||
errMsg := err.Error()
|
||||
if strings.Contains(errMsg, "403") || strings.Contains(errMsg, "401") || strings.Contains(errMsg, "invalid") || strings.Contains(errMsg, "expired") {
|
||||
if refreshErr := refreshTokenIfNeeded(); refreshErr == nil {
|
||||
// 重试
|
||||
info, err = RefreshAccountInfo(account)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
w.WriteHeader(500)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 保存到配置
|
||||
@@ -2015,3 +2114,161 @@ func (h *Handler) apiUpdateEndpointConfig(w http.ResponseWriter, r *http.Request
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]bool{"success": true})
|
||||
}
|
||||
|
||||
// apiGetVersion 获取版本信息
|
||||
func (h *Handler) apiGetVersion(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"version": config.Version,
|
||||
})
|
||||
}
|
||||
|
||||
// apiExportAccounts 导出账号凭证
|
||||
func (h *Handler) apiExportAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
IDs []string `json:"ids"` // 为空则导出全部
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
// 如果 body 为空或解析失败,导出全部
|
||||
req.IDs = nil
|
||||
}
|
||||
|
||||
accounts := config.GetAccounts()
|
||||
|
||||
// 如果指定了 ID,只导出指定的
|
||||
if len(req.IDs) > 0 {
|
||||
idSet := make(map[string]bool)
|
||||
for _, id := range req.IDs {
|
||||
idSet[id] = true
|
||||
}
|
||||
var filtered []config.Account
|
||||
for _, a := range accounts {
|
||||
if idSet[a.ID] {
|
||||
filtered = append(filtered, a)
|
||||
}
|
||||
}
|
||||
accounts = filtered
|
||||
}
|
||||
|
||||
// 构建兼容 Kiro Account Manager 的导出格式
|
||||
type ExportCredentials struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
CsrfToken string `json:"csrfToken"`
|
||||
RefreshToken string `json:"refreshToken,omitempty"`
|
||||
ClientID string `json:"clientId,omitempty"`
|
||||
ClientSecret string `json:"clientSecret,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
ExpiresAt int64 `json:"expiresAt"`
|
||||
AuthMethod string `json:"authMethod,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
}
|
||||
|
||||
type ExportSubscription struct {
|
||||
Type string `json:"type"`
|
||||
Title string `json:"title,omitempty"`
|
||||
}
|
||||
|
||||
type ExportUsage struct {
|
||||
Current float64 `json:"current"`
|
||||
Limit float64 `json:"limit"`
|
||||
PercentUsed float64 `json:"percentUsed"`
|
||||
LastUpdated int64 `json:"lastUpdated"`
|
||||
}
|
||||
|
||||
type ExportAccount struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Nickname string `json:"nickname,omitempty"`
|
||||
Idp string `json:"idp"`
|
||||
UserId string `json:"userId,omitempty"`
|
||||
MachineId string `json:"machineId,omitempty"`
|
||||
Credentials ExportCredentials `json:"credentials"`
|
||||
Subscription ExportSubscription `json:"subscription"`
|
||||
Usage ExportUsage `json:"usage"`
|
||||
Tags []string `json:"tags"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt int64 `json:"createdAt"`
|
||||
LastUsedAt int64 `json:"lastUsedAt"`
|
||||
}
|
||||
|
||||
type ExportData struct {
|
||||
Version string `json:"version"`
|
||||
ExportedAt int64 `json:"exportedAt"`
|
||||
Accounts []ExportAccount `json:"accounts"`
|
||||
Groups []interface{} `json:"groups"`
|
||||
Tags []interface{} `json:"tags"`
|
||||
}
|
||||
|
||||
exportAccounts := make([]ExportAccount, 0, len(accounts))
|
||||
for _, a := range accounts {
|
||||
// 映射 provider 到 idp
|
||||
idp := a.Provider
|
||||
if idp == "" {
|
||||
if a.AuthMethod == "social" {
|
||||
idp = "Google"
|
||||
} else {
|
||||
idp = "BuilderId"
|
||||
}
|
||||
}
|
||||
|
||||
// 映射 authMethod
|
||||
authMethod := a.AuthMethod
|
||||
if authMethod == "idc" {
|
||||
authMethod = "IdC"
|
||||
}
|
||||
|
||||
// 映射订阅类型
|
||||
subType := "Free"
|
||||
rawType := strings.ToUpper(a.SubscriptionType)
|
||||
if strings.Contains(rawType, "PRO_PLUS") || strings.Contains(rawType, "PROPLUS") {
|
||||
subType = "Pro_Plus"
|
||||
} else if strings.Contains(rawType, "PRO") {
|
||||
subType = "Pro"
|
||||
} else if strings.Contains(rawType, "POWER") {
|
||||
subType = "Pro_Plus"
|
||||
}
|
||||
|
||||
exportAccounts = append(exportAccounts, ExportAccount{
|
||||
ID: a.ID,
|
||||
Email: a.Email,
|
||||
Nickname: a.Nickname,
|
||||
Idp: idp,
|
||||
UserId: a.UserId,
|
||||
MachineId: a.MachineId,
|
||||
Credentials: ExportCredentials{
|
||||
AccessToken: a.AccessToken,
|
||||
CsrfToken: "",
|
||||
RefreshToken: a.RefreshToken,
|
||||
ClientID: a.ClientID,
|
||||
ClientSecret: a.ClientSecret,
|
||||
Region: a.Region,
|
||||
ExpiresAt: a.ExpiresAt * 1000, // 转为毫秒时间戳
|
||||
AuthMethod: authMethod,
|
||||
Provider: a.Provider,
|
||||
},
|
||||
Subscription: ExportSubscription{
|
||||
Type: subType,
|
||||
Title: a.SubscriptionTitle,
|
||||
},
|
||||
Usage: ExportUsage{
|
||||
Current: a.UsageCurrent,
|
||||
Limit: a.UsageLimit,
|
||||
PercentUsed: a.UsagePercent,
|
||||
LastUpdated: time.Now().UnixMilli(),
|
||||
},
|
||||
Tags: []string{},
|
||||
Status: "active",
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
LastUsedAt: time.Now().UnixMilli(),
|
||||
})
|
||||
}
|
||||
|
||||
data := ExportData{
|
||||
Version: config.Version,
|
||||
ExportedAt: time.Now().UnixMilli(),
|
||||
Accounts: exportAccounts,
|
||||
Groups: []interface{}{},
|
||||
Tags: []interface{}{},
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user