diff --git a/auth/builderid.go b/auth/builderid.go new file mode 100644 index 0000000..a96366c --- /dev/null +++ b/auth/builderid.go @@ -0,0 +1,256 @@ +package auth + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + "time" +) + +// BuilderIdSession Builder ID 登录会话 +type BuilderIdSession struct { + ID string + ClientID string + ClientSecret string + DeviceCode string + UserCode string + VerificationUri string + Interval int + ExpiresAt time.Time + Region string +} + +var ( + builderIdSessions = make(map[string]*BuilderIdSession) + builderIdMu sync.RWMutex +) + +// StartBuilderIdLogin 开始 Builder ID 登录 +func StartBuilderIdLogin(region string) (*BuilderIdSession, error) { + if region == "" { + region = "us-east-1" + } + + oidcBase := fmt.Sprintf("https://oidc.%s.amazonaws.com", region) + startUrl := "https://view.awsapps.com/start" + scopes := []string{ + "codewhisperer:completions", + "codewhisperer:analysis", + "codewhisperer:conversations", + "codewhisperer:transformations", + "codewhisperer:taskassist", + } + + // Step 1: 注册 OIDC 客户端 + regPayload := map[string]interface{}{ + "clientName": "Kiro API Proxy", + "clientType": "public", + "scopes": scopes, + "grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"}, + "issuerUrl": startUrl, + } + + regBody, _ := json.Marshal(regPayload) + regReq, _ := http.NewRequest("POST", oidcBase+"/client/register", bytes.NewReader(regBody)) + regReq.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + regResp, err := client.Do(regReq) + if err != nil { + return nil, fmt.Errorf("register client failed: %v", err) + } + defer regResp.Body.Close() + + if regResp.StatusCode != 200 { + respBody, _ := io.ReadAll(regResp.Body) + return nil, fmt.Errorf("register client failed: %d %s", regResp.StatusCode, string(respBody)) + } + + var regResult struct { + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + } + if err := json.NewDecoder(regResp.Body).Decode(®Result); err != nil { + return nil, fmt.Errorf("parse register response failed: %v", err) + } + + // Step 2: 发起设备授权 + authPayload := map[string]string{ + "clientId": regResult.ClientID, + "clientSecret": regResult.ClientSecret, + "startUrl": startUrl, + } + + authBody, _ := json.Marshal(authPayload) + authReq, _ := http.NewRequest("POST", oidcBase+"/device_authorization", bytes.NewReader(authBody)) + authReq.Header.Set("Content-Type", "application/json") + + authResp, err := client.Do(authReq) + if err != nil { + return nil, fmt.Errorf("device authorization failed: %v", err) + } + defer authResp.Body.Close() + + if authResp.StatusCode != 200 { + respBody, _ := io.ReadAll(authResp.Body) + return nil, fmt.Errorf("device authorization failed: %d %s", authResp.StatusCode, string(respBody)) + } + + var authResult struct { + DeviceCode string `json:"deviceCode"` + UserCode string `json:"userCode"` + VerificationUri string `json:"verificationUri"` + VerificationUriComplete string `json:"verificationUriComplete"` + Interval int `json:"interval"` + ExpiresIn int `json:"expiresIn"` + } + if err := json.NewDecoder(authResp.Body).Decode(&authResult); err != nil { + return nil, fmt.Errorf("parse auth response failed: %v", err) + } + + if authResult.Interval == 0 { + authResult.Interval = 5 + } + if authResult.ExpiresIn == 0 { + authResult.ExpiresIn = 600 + } + + verificationUri := authResult.VerificationUriComplete + if verificationUri == "" { + verificationUri = authResult.VerificationUri + } + + session := &BuilderIdSession{ + ID: GenerateAccountID(), + ClientID: regResult.ClientID, + ClientSecret: regResult.ClientSecret, + DeviceCode: authResult.DeviceCode, + UserCode: authResult.UserCode, + VerificationUri: verificationUri, + Interval: authResult.Interval, + ExpiresAt: time.Now().Add(time.Duration(authResult.ExpiresIn) * time.Second), + Region: region, + } + + builderIdMu.Lock() + builderIdSessions[session.ID] = session + builderIdMu.Unlock() + + // 清理过期会话 + go cleanupExpiredBuilderIdSessions() + + return session, nil +} + +// PollBuilderIdAuth 轮询 Builder ID 授权状态 +func PollBuilderIdAuth(sessionID string) (accessToken, refreshToken, clientID, clientSecret, region string, expiresIn int, status string, err error) { + builderIdMu.RLock() + session, exists := builderIdSessions[sessionID] + builderIdMu.RUnlock() + + if !exists { + return "", "", "", "", "", 0, "", fmt.Errorf("session not found or expired") + } + + if time.Now().After(session.ExpiresAt) { + builderIdMu.Lock() + delete(builderIdSessions, sessionID) + builderIdMu.Unlock() + return "", "", "", "", "", 0, "", fmt.Errorf("authorization expired") + } + + oidcBase := fmt.Sprintf("https://oidc.%s.amazonaws.com", session.Region) + + tokenPayload := map[string]string{ + "clientId": session.ClientID, + "clientSecret": session.ClientSecret, + "grantType": "urn:ietf:params:oauth:grant-type:device_code", + "deviceCode": session.DeviceCode, + } + + tokenBody, _ := json.Marshal(tokenPayload) + tokenReq, _ := http.NewRequest("POST", oidcBase+"/token", bytes.NewReader(tokenBody)) + tokenReq.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + tokenResp, err := client.Do(tokenReq) + if err != nil { + return "", "", "", "", "", 0, "", fmt.Errorf("token request failed: %v", err) + } + defer tokenResp.Body.Close() + + if tokenResp.StatusCode == 200 { + var tokenResult struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ExpiresIn int `json:"expiresIn"` + } + if err := json.NewDecoder(tokenResp.Body).Decode(&tokenResult); err != nil { + return "", "", "", "", "", 0, "", fmt.Errorf("parse token response failed: %v", err) + } + + // 清理会话 + builderIdMu.Lock() + delete(builderIdSessions, sessionID) + builderIdMu.Unlock() + + return tokenResult.AccessToken, tokenResult.RefreshToken, session.ClientID, session.ClientSecret, session.Region, tokenResult.ExpiresIn, "completed", nil + } + + if tokenResp.StatusCode == 400 { + var errResult struct { + Error string `json:"error"` + } + json.NewDecoder(tokenResp.Body).Decode(&errResult) + + switch errResult.Error { + case "authorization_pending": + return "", "", "", "", "", 0, "pending", nil + case "slow_down": + // 增加轮询间隔 + builderIdMu.Lock() + if s, ok := builderIdSessions[sessionID]; ok { + s.Interval += 5 + } + builderIdMu.Unlock() + return "", "", "", "", "", 0, "slow_down", nil + case "expired_token": + builderIdMu.Lock() + delete(builderIdSessions, sessionID) + builderIdMu.Unlock() + return "", "", "", "", "", 0, "", fmt.Errorf("device code expired") + case "access_denied": + builderIdMu.Lock() + delete(builderIdSessions, sessionID) + builderIdMu.Unlock() + return "", "", "", "", "", 0, "", fmt.Errorf("user denied authorization") + default: + return "", "", "", "", "", 0, "", fmt.Errorf("authorization error: %s", errResult.Error) + } + } + + return "", "", "", "", "", 0, "", fmt.Errorf("unexpected response: %d", tokenResp.StatusCode) +} + +// GetBuilderIdSession 获取会话信息 +func GetBuilderIdSession(sessionID string) *BuilderIdSession { + builderIdMu.RLock() + defer builderIdMu.RUnlock() + return builderIdSessions[sessionID] +} + +// cleanupExpiredBuilderIdSessions 清理过期会话 +func cleanupExpiredBuilderIdSessions() { + builderIdMu.Lock() + defer builderIdMu.Unlock() + + now := time.Now() + for id, session := range builderIdSessions { + if now.After(session.ExpiresAt) { + delete(builderIdSessions, id) + } + } +} diff --git a/proxy/handler.go b/proxy/handler.go index 0c7f9ae..f2f4432 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -842,6 +842,10 @@ func (h *Handler) handleAdminAPI(w http.ResponseWriter, r *http.Request) { h.apiStartIamSso(w, r) case path == "/auth/iam-sso/complete" && r.Method == "POST": h.apiCompleteIamSso(w, r) + case path == "/auth/builderid/start" && r.Method == "POST": + h.apiStartBuilderIdLogin(w, r) + case path == "/auth/builderid/poll" && r.Method == "POST": + h.apiPollBuilderIdAuth(w, r) case path == "/auth/sso-token" && r.Method == "POST": h.apiImportSsoToken(w, r) case path == "/auth/credentials" && r.Method == "POST": @@ -1072,6 +1076,98 @@ func (h *Handler) apiCompleteIamSso(w http.ResponseWriter, r *http.Request) { }) } +func (h *Handler) apiStartBuilderIdLogin(w http.ResponseWriter, r *http.Request) { + var req struct { + Region string `json:"region"` + } + json.NewDecoder(r.Body).Decode(&req) + + session, err := auth.StartBuilderIdLogin(req.Region) + if err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + json.NewEncoder(w).Encode(map[string]interface{}{ + "sessionId": session.ID, + "userCode": session.UserCode, + "verificationUri": session.VerificationUri, + "interval": session.Interval, + }) +} + +func (h *Handler) apiPollBuilderIdAuth(w http.ResponseWriter, r *http.Request) { + var req struct { + SessionID string `json:"sessionId"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) + return + } + + accessToken, refreshToken, clientID, clientSecret, region, expiresIn, status, err := auth.PollBuilderIdAuth(req.SessionID) + if err != nil { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": false, + "error": err.Error(), + }) + return + } + + if status == "pending" || status == "slow_down" { + // 获取当前间隔 + interval := 5 + if session := auth.GetBuilderIdSession(req.SessionID); session != nil { + interval = session.Interval + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": true, + "completed": false, + "status": status, + "interval": interval, + }) + return + } + + // 授权完成,获取用户信息 + email, _, _ := auth.GetUserInfo(accessToken) + + // 创建账号 + account := config.Account{ + ID: auth.GenerateAccountID(), + Email: email, + AccessToken: accessToken, + RefreshToken: refreshToken, + ClientID: clientID, + ClientSecret: clientSecret, + AuthMethod: "idc", + Provider: "BuilderId", + Region: region, + ExpiresAt: time.Now().Unix() + int64(expiresIn), + Enabled: true, + MachineId: config.GenerateMachineId(), + } + + if err := config.AddAccount(account); err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + h.pool.Reload() + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": true, + "completed": true, + "account": map[string]interface{}{ + "id": account.ID, + "email": account.Email, + }, + }) +} + func (h *Handler) apiImportSsoToken(w http.ResponseWriter, r *http.Request) { var req struct { BearerToken string `json:"bearerToken"` @@ -1089,44 +1185,67 @@ func (h *Handler) apiImportSsoToken(w http.ResponseWriter, r *http.Request) { return } - accessToken, refreshToken, clientID, clientSecret, expiresIn, err := auth.ImportFromSsoToken(req.BearerToken, req.Region) - if err != nil { - w.WriteHeader(500) - json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) - return - } + // 支持批量导入,按行分割 + tokens := strings.Split(strings.TrimSpace(req.BearerToken), "\n") + var imported []map[string]interface{} + var errors []string - // 获取用户信息 - email, _, _ := auth.GetUserInfo(accessToken) + for _, token := range tokens { + token = strings.TrimSpace(token) + if token == "" { + continue + } - // 创建账号 - account := config.Account{ - ID: auth.GenerateAccountID(), - Email: email, - AccessToken: accessToken, - RefreshToken: refreshToken, - ClientID: clientID, - ClientSecret: clientSecret, - AuthMethod: "idc", - Region: req.Region, - ExpiresAt: time.Now().Unix() + int64(expiresIn), - Enabled: true, - MachineId: config.GenerateMachineId(), - } + accessToken, refreshToken, clientID, clientSecret, expiresIn, err := auth.ImportFromSsoToken(token, req.Region) + if err != nil { + errors = append(errors, err.Error()) + continue + } - if err := config.AddAccount(account); err != nil { - w.WriteHeader(500) - json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) - return + // 获取用户信息 + email, _, _ := auth.GetUserInfo(accessToken) + + // 创建账号 + account := config.Account{ + ID: auth.GenerateAccountID(), + Email: email, + AccessToken: accessToken, + RefreshToken: refreshToken, + ClientID: clientID, + ClientSecret: clientSecret, + AuthMethod: "idc", + Region: req.Region, + ExpiresAt: time.Now().Unix() + int64(expiresIn), + Enabled: true, + MachineId: config.GenerateMachineId(), + } + + if err := config.AddAccount(account); err != nil { + errors = append(errors, err.Error()) + continue + } + + imported = append(imported, map[string]interface{}{ + "id": account.ID, + "email": account.Email, + }) } h.pool.Reload() + + if len(imported) == 0 && len(errors) > 0 { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": false, + "error": strings.Join(errors, "; "), + }) + return + } + json.NewEncoder(w).Encode(map[string]interface{}{ - "success": true, - "account": map[string]interface{}{ - "id": account.ID, - "email": account.Email, - }, + "success": true, + "accounts": imported, + "errors": errors, }) } diff --git a/web/index.html b/web/index.html index 797802d..c058575 100644 --- a/web/index.html +++ b/web/index.html @@ -169,9 +169,7 @@
通过 AWS Builder ID 授权登录添加个人账号
+等待授权中...
+ +通过 Kiro IDE 本地缓存文件添加账号
+文件位置
+Windows: %USERPROFILE%\\.aws\\sso\\cache\\
macOS/Linux: ~/.aws/sso/cache/
通过 Kiro Account Manager 导出的凭证添加账号
- - `; +如何获取 Token?
+view.awsapps.com/startx-amz-sso_authn 的值通过 IAM Identity Center (企业 SSO) 授权登录添加账号
请在浏览器中完成登录,然后粘贴回调 URL