diff --git a/auth/builderid.go b/auth/builderid.go new file mode 100644 index 0000000..a96366c --- /dev/null +++ b/auth/builderid.go @@ -0,0 +1,256 @@ +package auth + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + "time" +) + +// BuilderIdSession Builder ID 登录会话 +type BuilderIdSession struct { + ID string + ClientID string + ClientSecret string + DeviceCode string + UserCode string + VerificationUri string + Interval int + ExpiresAt time.Time + Region string +} + +var ( + builderIdSessions = make(map[string]*BuilderIdSession) + builderIdMu sync.RWMutex +) + +// StartBuilderIdLogin 开始 Builder ID 登录 +func StartBuilderIdLogin(region string) (*BuilderIdSession, error) { + if region == "" { + region = "us-east-1" + } + + oidcBase := fmt.Sprintf("https://oidc.%s.amazonaws.com", region) + startUrl := "https://view.awsapps.com/start" + scopes := []string{ + "codewhisperer:completions", + "codewhisperer:analysis", + "codewhisperer:conversations", + "codewhisperer:transformations", + "codewhisperer:taskassist", + } + + // Step 1: 注册 OIDC 客户端 + regPayload := map[string]interface{}{ + "clientName": "Kiro API Proxy", + "clientType": "public", + "scopes": scopes, + "grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"}, + "issuerUrl": startUrl, + } + + regBody, _ := json.Marshal(regPayload) + regReq, _ := http.NewRequest("POST", oidcBase+"/client/register", bytes.NewReader(regBody)) + regReq.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + regResp, err := client.Do(regReq) + if err != nil { + return nil, fmt.Errorf("register client failed: %v", err) + } + defer regResp.Body.Close() + + if regResp.StatusCode != 200 { + respBody, _ := io.ReadAll(regResp.Body) + return nil, fmt.Errorf("register client failed: %d %s", regResp.StatusCode, string(respBody)) + } + + var regResult struct { + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + } + if err := json.NewDecoder(regResp.Body).Decode(®Result); err != nil { + return nil, fmt.Errorf("parse register response failed: %v", err) + } + + // Step 2: 发起设备授权 + authPayload := map[string]string{ + "clientId": regResult.ClientID, + "clientSecret": regResult.ClientSecret, + "startUrl": startUrl, + } + + authBody, _ := json.Marshal(authPayload) + authReq, _ := http.NewRequest("POST", oidcBase+"/device_authorization", bytes.NewReader(authBody)) + authReq.Header.Set("Content-Type", "application/json") + + authResp, err := client.Do(authReq) + if err != nil { + return nil, fmt.Errorf("device authorization failed: %v", err) + } + defer authResp.Body.Close() + + if authResp.StatusCode != 200 { + respBody, _ := io.ReadAll(authResp.Body) + return nil, fmt.Errorf("device authorization failed: %d %s", authResp.StatusCode, string(respBody)) + } + + var authResult struct { + DeviceCode string `json:"deviceCode"` + UserCode string `json:"userCode"` + VerificationUri string `json:"verificationUri"` + VerificationUriComplete string `json:"verificationUriComplete"` + Interval int `json:"interval"` + ExpiresIn int `json:"expiresIn"` + } + if err := json.NewDecoder(authResp.Body).Decode(&authResult); err != nil { + return nil, fmt.Errorf("parse auth response failed: %v", err) + } + + if authResult.Interval == 0 { + authResult.Interval = 5 + } + if authResult.ExpiresIn == 0 { + authResult.ExpiresIn = 600 + } + + verificationUri := authResult.VerificationUriComplete + if verificationUri == "" { + verificationUri = authResult.VerificationUri + } + + session := &BuilderIdSession{ + ID: GenerateAccountID(), + ClientID: regResult.ClientID, + ClientSecret: regResult.ClientSecret, + DeviceCode: authResult.DeviceCode, + UserCode: authResult.UserCode, + VerificationUri: verificationUri, + Interval: authResult.Interval, + ExpiresAt: time.Now().Add(time.Duration(authResult.ExpiresIn) * time.Second), + Region: region, + } + + builderIdMu.Lock() + builderIdSessions[session.ID] = session + builderIdMu.Unlock() + + // 清理过期会话 + go cleanupExpiredBuilderIdSessions() + + return session, nil +} + +// PollBuilderIdAuth 轮询 Builder ID 授权状态 +func PollBuilderIdAuth(sessionID string) (accessToken, refreshToken, clientID, clientSecret, region string, expiresIn int, status string, err error) { + builderIdMu.RLock() + session, exists := builderIdSessions[sessionID] + builderIdMu.RUnlock() + + if !exists { + return "", "", "", "", "", 0, "", fmt.Errorf("session not found or expired") + } + + if time.Now().After(session.ExpiresAt) { + builderIdMu.Lock() + delete(builderIdSessions, sessionID) + builderIdMu.Unlock() + return "", "", "", "", "", 0, "", fmt.Errorf("authorization expired") + } + + oidcBase := fmt.Sprintf("https://oidc.%s.amazonaws.com", session.Region) + + tokenPayload := map[string]string{ + "clientId": session.ClientID, + "clientSecret": session.ClientSecret, + "grantType": "urn:ietf:params:oauth:grant-type:device_code", + "deviceCode": session.DeviceCode, + } + + tokenBody, _ := json.Marshal(tokenPayload) + tokenReq, _ := http.NewRequest("POST", oidcBase+"/token", bytes.NewReader(tokenBody)) + tokenReq.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + tokenResp, err := client.Do(tokenReq) + if err != nil { + return "", "", "", "", "", 0, "", fmt.Errorf("token request failed: %v", err) + } + defer tokenResp.Body.Close() + + if tokenResp.StatusCode == 200 { + var tokenResult struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ExpiresIn int `json:"expiresIn"` + } + if err := json.NewDecoder(tokenResp.Body).Decode(&tokenResult); err != nil { + return "", "", "", "", "", 0, "", fmt.Errorf("parse token response failed: %v", err) + } + + // 清理会话 + builderIdMu.Lock() + delete(builderIdSessions, sessionID) + builderIdMu.Unlock() + + return tokenResult.AccessToken, tokenResult.RefreshToken, session.ClientID, session.ClientSecret, session.Region, tokenResult.ExpiresIn, "completed", nil + } + + if tokenResp.StatusCode == 400 { + var errResult struct { + Error string `json:"error"` + } + json.NewDecoder(tokenResp.Body).Decode(&errResult) + + switch errResult.Error { + case "authorization_pending": + return "", "", "", "", "", 0, "pending", nil + case "slow_down": + // 增加轮询间隔 + builderIdMu.Lock() + if s, ok := builderIdSessions[sessionID]; ok { + s.Interval += 5 + } + builderIdMu.Unlock() + return "", "", "", "", "", 0, "slow_down", nil + case "expired_token": + builderIdMu.Lock() + delete(builderIdSessions, sessionID) + builderIdMu.Unlock() + return "", "", "", "", "", 0, "", fmt.Errorf("device code expired") + case "access_denied": + builderIdMu.Lock() + delete(builderIdSessions, sessionID) + builderIdMu.Unlock() + return "", "", "", "", "", 0, "", fmt.Errorf("user denied authorization") + default: + return "", "", "", "", "", 0, "", fmt.Errorf("authorization error: %s", errResult.Error) + } + } + + return "", "", "", "", "", 0, "", fmt.Errorf("unexpected response: %d", tokenResp.StatusCode) +} + +// GetBuilderIdSession 获取会话信息 +func GetBuilderIdSession(sessionID string) *BuilderIdSession { + builderIdMu.RLock() + defer builderIdMu.RUnlock() + return builderIdSessions[sessionID] +} + +// cleanupExpiredBuilderIdSessions 清理过期会话 +func cleanupExpiredBuilderIdSessions() { + builderIdMu.Lock() + defer builderIdMu.Unlock() + + now := time.Now() + for id, session := range builderIdSessions { + if now.After(session.ExpiresAt) { + delete(builderIdSessions, id) + } + } +} diff --git a/proxy/handler.go b/proxy/handler.go index 0c7f9ae..f2f4432 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -842,6 +842,10 @@ func (h *Handler) handleAdminAPI(w http.ResponseWriter, r *http.Request) { h.apiStartIamSso(w, r) case path == "/auth/iam-sso/complete" && r.Method == "POST": h.apiCompleteIamSso(w, r) + case path == "/auth/builderid/start" && r.Method == "POST": + h.apiStartBuilderIdLogin(w, r) + case path == "/auth/builderid/poll" && r.Method == "POST": + h.apiPollBuilderIdAuth(w, r) case path == "/auth/sso-token" && r.Method == "POST": h.apiImportSsoToken(w, r) case path == "/auth/credentials" && r.Method == "POST": @@ -1072,6 +1076,98 @@ func (h *Handler) apiCompleteIamSso(w http.ResponseWriter, r *http.Request) { }) } +func (h *Handler) apiStartBuilderIdLogin(w http.ResponseWriter, r *http.Request) { + var req struct { + Region string `json:"region"` + } + json.NewDecoder(r.Body).Decode(&req) + + session, err := auth.StartBuilderIdLogin(req.Region) + if err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + json.NewEncoder(w).Encode(map[string]interface{}{ + "sessionId": session.ID, + "userCode": session.UserCode, + "verificationUri": session.VerificationUri, + "interval": session.Interval, + }) +} + +func (h *Handler) apiPollBuilderIdAuth(w http.ResponseWriter, r *http.Request) { + var req struct { + SessionID string `json:"sessionId"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) + return + } + + accessToken, refreshToken, clientID, clientSecret, region, expiresIn, status, err := auth.PollBuilderIdAuth(req.SessionID) + if err != nil { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": false, + "error": err.Error(), + }) + return + } + + if status == "pending" || status == "slow_down" { + // 获取当前间隔 + interval := 5 + if session := auth.GetBuilderIdSession(req.SessionID); session != nil { + interval = session.Interval + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": true, + "completed": false, + "status": status, + "interval": interval, + }) + return + } + + // 授权完成,获取用户信息 + email, _, _ := auth.GetUserInfo(accessToken) + + // 创建账号 + account := config.Account{ + ID: auth.GenerateAccountID(), + Email: email, + AccessToken: accessToken, + RefreshToken: refreshToken, + ClientID: clientID, + ClientSecret: clientSecret, + AuthMethod: "idc", + Provider: "BuilderId", + Region: region, + ExpiresAt: time.Now().Unix() + int64(expiresIn), + Enabled: true, + MachineId: config.GenerateMachineId(), + } + + if err := config.AddAccount(account); err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + h.pool.Reload() + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": true, + "completed": true, + "account": map[string]interface{}{ + "id": account.ID, + "email": account.Email, + }, + }) +} + func (h *Handler) apiImportSsoToken(w http.ResponseWriter, r *http.Request) { var req struct { BearerToken string `json:"bearerToken"` @@ -1089,44 +1185,67 @@ func (h *Handler) apiImportSsoToken(w http.ResponseWriter, r *http.Request) { return } - accessToken, refreshToken, clientID, clientSecret, expiresIn, err := auth.ImportFromSsoToken(req.BearerToken, req.Region) - if err != nil { - w.WriteHeader(500) - json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) - return - } + // 支持批量导入,按行分割 + tokens := strings.Split(strings.TrimSpace(req.BearerToken), "\n") + var imported []map[string]interface{} + var errors []string - // 获取用户信息 - email, _, _ := auth.GetUserInfo(accessToken) + for _, token := range tokens { + token = strings.TrimSpace(token) + if token == "" { + continue + } - // 创建账号 - account := config.Account{ - ID: auth.GenerateAccountID(), - Email: email, - AccessToken: accessToken, - RefreshToken: refreshToken, - ClientID: clientID, - ClientSecret: clientSecret, - AuthMethod: "idc", - Region: req.Region, - ExpiresAt: time.Now().Unix() + int64(expiresIn), - Enabled: true, - MachineId: config.GenerateMachineId(), - } + accessToken, refreshToken, clientID, clientSecret, expiresIn, err := auth.ImportFromSsoToken(token, req.Region) + if err != nil { + errors = append(errors, err.Error()) + continue + } - if err := config.AddAccount(account); err != nil { - w.WriteHeader(500) - json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) - return + // 获取用户信息 + email, _, _ := auth.GetUserInfo(accessToken) + + // 创建账号 + account := config.Account{ + ID: auth.GenerateAccountID(), + Email: email, + AccessToken: accessToken, + RefreshToken: refreshToken, + ClientID: clientID, + ClientSecret: clientSecret, + AuthMethod: "idc", + Region: req.Region, + ExpiresAt: time.Now().Unix() + int64(expiresIn), + Enabled: true, + MachineId: config.GenerateMachineId(), + } + + if err := config.AddAccount(account); err != nil { + errors = append(errors, err.Error()) + continue + } + + imported = append(imported, map[string]interface{}{ + "id": account.ID, + "email": account.Email, + }) } h.pool.Reload() + + if len(imported) == 0 && len(errors) > 0 { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": false, + "error": strings.Join(errors, "; "), + }) + return + } + json.NewEncoder(w).Encode(map[string]interface{}{ - "success": true, - "account": map[string]interface{}{ - "id": account.ID, - "email": account.Email, - }, + "success": true, + "accounts": imported, + "errors": errors, }) } diff --git a/web/index.html b/web/index.html index 797802d..c058575 100644 --- a/web/index.html +++ b/web/index.html @@ -169,9 +169,7 @@
账号列表
- - - +
@@ -319,13 +317,13 @@
${a.email || a.id.substring(0,12)+'...'}
${getSubBadge(a.subscriptionType)} - ${a.authMethod || '-'} + ${formatAuthMethod(a.provider || a.authMethod)} ${getStatusBadge(a)}
- +
@@ -355,6 +353,13 @@ return 'FREE'; } + function formatAuthMethod(method) { + if (!method) return '-'; + if (method === 'idc') return 'Enterprise'; + if (method === 'social') return 'Social'; + return method; + } + function getStatusBadge(a) { if (!a.hasToken) return '无Token'; if (a.expiresAt && a.expiresAt < Date.now()/1000) return '已过期'; @@ -391,7 +396,7 @@
邮箱
${a.email || '-'}
用户ID
${a.userId || '-'}
-
认证方式
${a.authMethod || '-'}
+
认证方式
${formatAuthMethod(a.provider || a.authMethod)}
Region
${a.region || 'us-east-1'}
@@ -537,33 +542,211 @@ const title = document.getElementById('modalTitle'); const body = document.getElementById('modalBody'); - if (type === 'credentials') { - title.textContent = '导入凭证'; + if (type === 'add') { + title.textContent = '添加账号'; body.innerHTML = ` +
+
+
AWS Builder ID
+
通过 AWS Builder ID 授权登录添加个人账号
+
+
+
IAM Identity Center (企业 SSO) 登录
+
通过 IAM Identity Center (企业 SSO) 授权添加企业账号
+
+
+
SSO Token
+
通过浏览器 x-amz-sso_authn Token 添加账号
+
+
+
Kiro 本地缓存
+
通过 Kiro IDE 本地缓存文件添加账号
+
+
+
凭证 JSON
+
通过 Kiro Account Manager 导出的凭证添加账号
+
+
+ `; + } else if (type === 'builderid') { + title.textContent = 'AWS Builder ID'; + body.innerHTML = ` +

通过 AWS Builder ID 授权登录添加个人账号

+
+
+ + +
+ +
+ `; + } else if (type === 'local') { + title.textContent = 'Kiro 本地缓存'; + body.innerHTML = ` +

通过 Kiro IDE 本地缓存文件添加账号

+
+

文件位置

+

Windows: %USERPROFILE%\\.aws\\sso\\cache\\

+

macOS/Linux: ~/.aws/sso/cache/

+
+
+ + +
+
+ +
+ + +
+
+
+ +
+ + +
+
+ `; + } else if (type === 'credentials') { + title.textContent = '凭证 JSON'; + body.innerHTML = ` +

通过 Kiro Account Manager 导出的凭证添加账号

-
- `; +
+ + +
+ `; } else if (type === 'sso') { title.textContent = 'SSO Token'; body.innerHTML = ` -
+
+

如何获取 Token?

+
    +
  1. 在浏览器中访问并登录 view.awsapps.com/start
  2. +
  3. 按 F12 打开开发者工具 → Application → Cookies
  4. +
  5. 找到并复制 x-amz-sso_authn 的值
  6. +
+
+
- `; + `; } else if (type === 'iam') { - title.textContent = 'IAM SSO 登录'; + title.textContent = 'IAM Identity Center (企业 SSO) 登录'; body.innerHTML = ` +

通过 IAM Identity Center (企业 SSO) 授权登录添加账号

- `; + `; } modal.classList.add('active'); } - function closeModal() { document.getElementById('addModal').classList.remove('active'); iamSession = ''; } + function closeModal() { + document.getElementById('addModal').classList.remove('active'); + iamSession = ''; + if (builderIdPollTimer) { clearTimeout(builderIdPollTimer); builderIdPollTimer = null; } + builderIdSession = ''; + } + + function loadLocalFile(input, targetId) { + const file = input.files[0]; + if (!file) return; + const reader = new FileReader(); + reader.onload = e => { document.getElementById(targetId).value = e.target.result; }; + reader.readAsText(file); + } + + function updateLocalFields() { + const provider = document.getElementById('localProvider').value; + const clientGroup = document.getElementById('localClientGroup'); + if (provider === 'Google' || provider === 'Github') { + clientGroup.style.display = 'none'; + } else { + clientGroup.style.display = 'block'; + } + } + + async function importLocalKiro() { + const provider = document.getElementById('localProvider').value; + const tokenJson = document.getElementById('localTokenJson').value.trim(); + const clientJson = document.getElementById('localClientJson').value.trim(); + const isSocial = provider === 'Google' || provider === 'Github'; + + if (!tokenJson) { alert('请提供 kiro-auth-token.json 内容'); return; } + + let tokenData, clientData; + try { + tokenData = JSON.parse(tokenJson); + } catch { alert('kiro-auth-token.json 格式错误'); return; } + + if (!tokenData.refreshToken) { alert('缺少 refreshToken'); return; } + + if (!isSocial) { + if (!clientJson) { alert('IdC 登录需要提供 {hash}.json 内容'); return; } + try { + clientData = JSON.parse(clientJson); + } catch { alert('{hash}.json 格式错误'); return; } + if (!clientData.clientId || !clientData.clientSecret) { alert('缺少 clientId 或 clientSecret'); return; } + } + + const payload = { + refreshToken: tokenData.refreshToken, + accessToken: tokenData.accessToken || '', + clientId: clientData?.clientId || '', + clientSecret: clientData?.clientSecret || '', + authMethod: isSocial ? 'social' : 'idc', + provider: provider + }; + + const res = await fetch('/admin/api/auth/credentials', { + method: 'POST', headers: { 'Content-Type': 'application/json', 'X-Admin-Password': password }, + body: JSON.stringify(payload) + }); + const d = await res.json(); + if (d.success) { closeModal(); loadAccounts(); loadStats(); alert('导入成功: ' + (d.account?.email || d.account?.id)); } + else alert('失败: ' + d.error); + } async function importCredentials() { try { @@ -585,8 +768,64 @@ body: JSON.stringify({ bearerToken: document.getElementById('ssoToken').value, region: document.getElementById('ssoRegion').value }) }); const d = await res.json(); - if (d.success) { closeModal(); loadAccounts(); loadStats(); alert('导入成功: ' + (d.account?.email || d.account?.id)); } - else alert('失败: ' + d.error); + if (d.success) { + closeModal(); loadAccounts(); loadStats(); + const count = d.accounts?.length || 0; + const errCount = d.errors?.length || 0; + let msg = '成功添加 ' + count + ' 个账号'; + if (errCount > 0) msg += ',' + errCount + ' 个失败'; + alert(msg); + } else alert('失败: ' + d.error); + } + + let builderIdSession = ''; + let builderIdPollTimer = null; + + async function startBuilderIdLogin() { + const region = document.getElementById('builderIdRegion').value || 'us-east-1'; + const res = await fetch('/admin/api/auth/builderid/start', { + method: 'POST', headers: { 'Content-Type': 'application/json', 'X-Admin-Password': password }, + body: JSON.stringify({ region }) + }); + const d = await res.json(); + if (d.sessionId) { + builderIdSession = d.sessionId; + document.getElementById('builderIdUserCode').textContent = d.userCode; + document.getElementById('builderIdVerifyUrl').textContent = d.verificationUri; + document.getElementById('builderIdStep1').classList.add('hidden'); + document.getElementById('builderIdStep2').classList.remove('hidden'); + // 开始轮询 + pollBuilderIdAuth(d.interval || 5); + } else alert('失败: ' + d.error); + } + + function pollBuilderIdAuth(interval) { + builderIdPollTimer = setTimeout(async () => { + const res = await fetch('/admin/api/auth/builderid/poll', { + method: 'POST', headers: { 'Content-Type': 'application/json', 'X-Admin-Password': password }, + body: JSON.stringify({ sessionId: builderIdSession }) + }); + const d = await res.json(); + if (d.completed) { + closeModal(); loadAccounts(); loadStats(); + alert('登录成功: ' + (d.account?.email || d.account?.id)); + } else if (d.success && !d.completed) { + document.getElementById('builderIdStatus').textContent = '等待授权中...'; + pollBuilderIdAuth(d.interval || interval); + } else { + alert('失败: ' + d.error); + cancelBuilderIdLogin(); + } + }, interval * 1000); + } + + function cancelBuilderIdLogin() { + if (builderIdPollTimer) { + clearTimeout(builderIdPollTimer); + builderIdPollTimer = null; + } + builderIdSession = ''; + showModal('add'); } let iamSession = ''; @@ -607,7 +846,7 @@ const d = await res.json(); if (d.authorizeUrl) { iamSession = d.sessionId; - window.open(d.authorizeUrl, '_blank'); + document.getElementById('iamAuthUrl').textContent = d.authorizeUrl; document.getElementById('iamStep2').classList.remove('hidden'); document.getElementById('iamBtn').textContent = '完成登录'; } else alert('失败: ' + d.error);