257 lines
7.3 KiB
Go
257 lines
7.3 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// BuilderIdSession Builder ID 登录会话
|
|
type BuilderIdSession struct {
|
|
ID string
|
|
ClientID string
|
|
ClientSecret string
|
|
DeviceCode string
|
|
UserCode string
|
|
VerificationUri string
|
|
Interval int
|
|
ExpiresAt time.Time
|
|
Region string
|
|
}
|
|
|
|
var (
|
|
builderIdSessions = make(map[string]*BuilderIdSession)
|
|
builderIdMu sync.RWMutex
|
|
)
|
|
|
|
// StartBuilderIdLogin 开始 Builder ID 登录
|
|
func StartBuilderIdLogin(region string) (*BuilderIdSession, error) {
|
|
if region == "" {
|
|
region = "us-east-1"
|
|
}
|
|
|
|
oidcBase := fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
|
|
startUrl := "https://view.awsapps.com/start"
|
|
scopes := []string{
|
|
"codewhisperer:completions",
|
|
"codewhisperer:analysis",
|
|
"codewhisperer:conversations",
|
|
"codewhisperer:transformations",
|
|
"codewhisperer:taskassist",
|
|
}
|
|
|
|
// Step 1: 注册 OIDC 客户端
|
|
regPayload := map[string]interface{}{
|
|
"clientName": "Kiro",
|
|
"clientType": "public",
|
|
"scopes": scopes,
|
|
"grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"},
|
|
"issuerUrl": startUrl,
|
|
}
|
|
|
|
regBody, _ := json.Marshal(regPayload)
|
|
regReq, _ := http.NewRequest("POST", oidcBase+"/client/register", bytes.NewReader(regBody))
|
|
regReq.Header.Set("Content-Type", "application/json")
|
|
|
|
client := httpClient
|
|
regResp, err := client.Do(regReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("register client failed: %v", err)
|
|
}
|
|
defer regResp.Body.Close()
|
|
|
|
if regResp.StatusCode != 200 {
|
|
respBody, _ := io.ReadAll(regResp.Body)
|
|
return nil, fmt.Errorf("register client failed: %d %s", regResp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var regResult struct {
|
|
ClientID string `json:"clientId"`
|
|
ClientSecret string `json:"clientSecret"`
|
|
}
|
|
if err := json.NewDecoder(regResp.Body).Decode(®Result); err != nil {
|
|
return nil, fmt.Errorf("parse register response failed: %v", err)
|
|
}
|
|
|
|
// Step 2: 发起设备授权
|
|
authPayload := map[string]string{
|
|
"clientId": regResult.ClientID,
|
|
"clientSecret": regResult.ClientSecret,
|
|
"startUrl": startUrl,
|
|
}
|
|
|
|
authBody, _ := json.Marshal(authPayload)
|
|
authReq, _ := http.NewRequest("POST", oidcBase+"/device_authorization", bytes.NewReader(authBody))
|
|
authReq.Header.Set("Content-Type", "application/json")
|
|
|
|
authResp, err := client.Do(authReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("device authorization failed: %v", err)
|
|
}
|
|
defer authResp.Body.Close()
|
|
|
|
if authResp.StatusCode != 200 {
|
|
respBody, _ := io.ReadAll(authResp.Body)
|
|
return nil, fmt.Errorf("device authorization failed: %d %s", authResp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var authResult struct {
|
|
DeviceCode string `json:"deviceCode"`
|
|
UserCode string `json:"userCode"`
|
|
VerificationUri string `json:"verificationUri"`
|
|
VerificationUriComplete string `json:"verificationUriComplete"`
|
|
Interval int `json:"interval"`
|
|
ExpiresIn int `json:"expiresIn"`
|
|
}
|
|
if err := json.NewDecoder(authResp.Body).Decode(&authResult); err != nil {
|
|
return nil, fmt.Errorf("parse auth response failed: %v", err)
|
|
}
|
|
|
|
if authResult.Interval == 0 {
|
|
authResult.Interval = 5
|
|
}
|
|
if authResult.ExpiresIn == 0 {
|
|
authResult.ExpiresIn = 600
|
|
}
|
|
|
|
verificationUri := authResult.VerificationUriComplete
|
|
if verificationUri == "" {
|
|
verificationUri = authResult.VerificationUri
|
|
}
|
|
|
|
session := &BuilderIdSession{
|
|
ID: GenerateAccountID(),
|
|
ClientID: regResult.ClientID,
|
|
ClientSecret: regResult.ClientSecret,
|
|
DeviceCode: authResult.DeviceCode,
|
|
UserCode: authResult.UserCode,
|
|
VerificationUri: verificationUri,
|
|
Interval: authResult.Interval,
|
|
ExpiresAt: time.Now().Add(time.Duration(authResult.ExpiresIn) * time.Second),
|
|
Region: region,
|
|
}
|
|
|
|
builderIdMu.Lock()
|
|
builderIdSessions[session.ID] = session
|
|
builderIdMu.Unlock()
|
|
|
|
// 清理过期会话
|
|
go cleanupExpiredBuilderIdSessions()
|
|
|
|
return session, nil
|
|
}
|
|
|
|
// PollBuilderIdAuth 轮询 Builder ID 授权状态
|
|
func PollBuilderIdAuth(sessionID string) (accessToken, refreshToken, clientID, clientSecret, region string, expiresIn int, status string, err error) {
|
|
builderIdMu.RLock()
|
|
session, exists := builderIdSessions[sessionID]
|
|
builderIdMu.RUnlock()
|
|
|
|
if !exists {
|
|
return "", "", "", "", "", 0, "", fmt.Errorf("session not found or expired")
|
|
}
|
|
|
|
if time.Now().After(session.ExpiresAt) {
|
|
builderIdMu.Lock()
|
|
delete(builderIdSessions, sessionID)
|
|
builderIdMu.Unlock()
|
|
return "", "", "", "", "", 0, "", fmt.Errorf("authorization expired")
|
|
}
|
|
|
|
oidcBase := fmt.Sprintf("https://oidc.%s.amazonaws.com", session.Region)
|
|
|
|
tokenPayload := map[string]string{
|
|
"clientId": session.ClientID,
|
|
"clientSecret": session.ClientSecret,
|
|
"grantType": "urn:ietf:params:oauth:grant-type:device_code",
|
|
"deviceCode": session.DeviceCode,
|
|
}
|
|
|
|
tokenBody, _ := json.Marshal(tokenPayload)
|
|
tokenReq, _ := http.NewRequest("POST", oidcBase+"/token", bytes.NewReader(tokenBody))
|
|
tokenReq.Header.Set("Content-Type", "application/json")
|
|
|
|
client := httpClient
|
|
tokenResp, err := client.Do(tokenReq)
|
|
if err != nil {
|
|
return "", "", "", "", "", 0, "", fmt.Errorf("token request failed: %v", err)
|
|
}
|
|
defer tokenResp.Body.Close()
|
|
|
|
if tokenResp.StatusCode == 200 {
|
|
var tokenResult struct {
|
|
AccessToken string `json:"accessToken"`
|
|
RefreshToken string `json:"refreshToken"`
|
|
ExpiresIn int `json:"expiresIn"`
|
|
}
|
|
if err := json.NewDecoder(tokenResp.Body).Decode(&tokenResult); err != nil {
|
|
return "", "", "", "", "", 0, "", fmt.Errorf("parse token response failed: %v", err)
|
|
}
|
|
|
|
// 清理会话
|
|
builderIdMu.Lock()
|
|
delete(builderIdSessions, sessionID)
|
|
builderIdMu.Unlock()
|
|
|
|
return tokenResult.AccessToken, tokenResult.RefreshToken, session.ClientID, session.ClientSecret, session.Region, tokenResult.ExpiresIn, "completed", nil
|
|
}
|
|
|
|
if tokenResp.StatusCode == 400 {
|
|
var errResult struct {
|
|
Error string `json:"error"`
|
|
}
|
|
json.NewDecoder(tokenResp.Body).Decode(&errResult)
|
|
|
|
switch errResult.Error {
|
|
case "authorization_pending":
|
|
return "", "", "", "", "", 0, "pending", nil
|
|
case "slow_down":
|
|
// 增加轮询间隔
|
|
builderIdMu.Lock()
|
|
if s, ok := builderIdSessions[sessionID]; ok {
|
|
s.Interval += 5
|
|
}
|
|
builderIdMu.Unlock()
|
|
return "", "", "", "", "", 0, "slow_down", nil
|
|
case "expired_token":
|
|
builderIdMu.Lock()
|
|
delete(builderIdSessions, sessionID)
|
|
builderIdMu.Unlock()
|
|
return "", "", "", "", "", 0, "", fmt.Errorf("device code expired")
|
|
case "access_denied":
|
|
builderIdMu.Lock()
|
|
delete(builderIdSessions, sessionID)
|
|
builderIdMu.Unlock()
|
|
return "", "", "", "", "", 0, "", fmt.Errorf("user denied authorization")
|
|
default:
|
|
return "", "", "", "", "", 0, "", fmt.Errorf("authorization error: %s", errResult.Error)
|
|
}
|
|
}
|
|
|
|
return "", "", "", "", "", 0, "", fmt.Errorf("unexpected response: %d", tokenResp.StatusCode)
|
|
}
|
|
|
|
// GetBuilderIdSession 获取会话信息
|
|
func GetBuilderIdSession(sessionID string) *BuilderIdSession {
|
|
builderIdMu.RLock()
|
|
defer builderIdMu.RUnlock()
|
|
return builderIdSessions[sessionID]
|
|
}
|
|
|
|
// cleanupExpiredBuilderIdSessions 清理过期会话
|
|
func cleanupExpiredBuilderIdSessions() {
|
|
builderIdMu.Lock()
|
|
defer builderIdMu.Unlock()
|
|
|
|
now := time.Now()
|
|
for id, session := range builderIdSessions {
|
|
if now.After(session.ExpiresAt) {
|
|
delete(builderIdSessions, id)
|
|
}
|
|
}
|
|
}
|