- Multi-account pool with round-robin load balancing - Auto token refresh for IAM IdC and Social auth - Streaming support (SSE) - Web admin panel with account management - Docker support with GitHub Actions CI/CD - Machine ID management per account - Usage tracking (requests, tokens, credits)
339 lines
9.9 KiB
Go
339 lines
9.9 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
// ImportFromSsoToken 从 SSO Token (x-amz-sso_authn) 导入账号
|
|
func ImportFromSsoToken(bearerToken, region string) (accessToken, refreshToken, clientID, clientSecret string, expiresIn int, err error) {
|
|
if region == "" {
|
|
region = "us-east-1"
|
|
}
|
|
|
|
oidcBase := fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
|
|
portalBase := "https://portal.sso.us-east-1.amazonaws.com"
|
|
startUrl := "https://view.awsapps.com/start"
|
|
|
|
// 1. 注册 OIDC 客户端
|
|
clientID, clientSecret, err = registerDeviceClient(oidcBase, startUrl)
|
|
if err != nil {
|
|
return "", "", "", "", 0, fmt.Errorf("注册客户端失败: %w", err)
|
|
}
|
|
|
|
// 2. 发起设备授权
|
|
deviceCode, userCode, interval, err := startDeviceAuth(oidcBase, clientID, clientSecret, startUrl)
|
|
if err != nil {
|
|
return "", "", "", "", 0, fmt.Errorf("设备授权失败: %w", err)
|
|
}
|
|
|
|
// 3. 验证 Bearer Token
|
|
if err := verifyBearerToken(portalBase, bearerToken); err != nil {
|
|
return "", "", "", "", 0, fmt.Errorf("Token 验证失败: %w", err)
|
|
}
|
|
|
|
// 4. 获取设备会话令牌
|
|
deviceSessionToken, err := getDeviceSessionToken(portalBase, bearerToken)
|
|
if err != nil {
|
|
return "", "", "", "", 0, fmt.Errorf("获取设备会话失败: %w", err)
|
|
}
|
|
|
|
// 5. 接受用户代码
|
|
deviceContext, err := acceptUserCode(oidcBase, userCode, deviceSessionToken)
|
|
if err != nil {
|
|
return "", "", "", "", 0, fmt.Errorf("接受用户代码失败: %w", err)
|
|
}
|
|
|
|
// 6. 批准授权
|
|
if deviceContext != nil {
|
|
if err := approveAuth(oidcBase, deviceContext, deviceSessionToken); err != nil {
|
|
return "", "", "", "", 0, fmt.Errorf("批准授权失败: %w", err)
|
|
}
|
|
}
|
|
|
|
// 7. 轮询获取 Token
|
|
accessToken, refreshToken, expiresIn, err = pollForToken(oidcBase, clientID, clientSecret, deviceCode, interval)
|
|
if err != nil {
|
|
return "", "", "", "", 0, fmt.Errorf("获取 Token 失败: %w", err)
|
|
}
|
|
|
|
return accessToken, refreshToken, clientID, clientSecret, expiresIn, nil
|
|
}
|
|
|
|
func registerDeviceClient(oidcBase, startUrl string) (clientID, clientSecret string, err error) {
|
|
payload := map[string]interface{}{
|
|
"clientName": "Kiro API Proxy",
|
|
"clientType": "public",
|
|
"scopes": scopes,
|
|
"grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"},
|
|
"issuerUrl": startUrl,
|
|
}
|
|
|
|
body, _ := json.Marshal(payload)
|
|
req, _ := http.NewRequest("POST", oidcBase+"/client/register", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 200 {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return "", "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var result struct {
|
|
ClientID string `json:"clientId"`
|
|
ClientSecret string `json:"clientSecret"`
|
|
}
|
|
json.NewDecoder(resp.Body).Decode(&result)
|
|
return result.ClientID, result.ClientSecret, nil
|
|
}
|
|
|
|
func startDeviceAuth(oidcBase, clientID, clientSecret, startUrl string) (deviceCode, userCode string, interval int, err error) {
|
|
payload := map[string]string{
|
|
"clientId": clientID,
|
|
"clientSecret": clientSecret,
|
|
"startUrl": startUrl,
|
|
}
|
|
|
|
body, _ := json.Marshal(payload)
|
|
req, _ := http.NewRequest("POST", oidcBase+"/device_authorization", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", "", 0, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 200 {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return "", "", 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var result struct {
|
|
DeviceCode string `json:"deviceCode"`
|
|
UserCode string `json:"userCode"`
|
|
Interval int `json:"interval"`
|
|
}
|
|
json.NewDecoder(resp.Body).Decode(&result)
|
|
if result.Interval == 0 {
|
|
result.Interval = 1
|
|
}
|
|
return result.DeviceCode, result.UserCode, result.Interval, nil
|
|
}
|
|
|
|
func verifyBearerToken(portalBase, bearerToken string) error {
|
|
req, _ := http.NewRequest("GET", portalBase+"/token/whoAmI", nil)
|
|
req.Header.Set("Authorization", "Bearer "+bearerToken)
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 200 {
|
|
return fmt.Errorf("HTTP %d", resp.StatusCode)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getDeviceSessionToken(portalBase, bearerToken string) (string, error) {
|
|
req, _ := http.NewRequest("POST", portalBase+"/session/device", bytes.NewReader([]byte("{}")))
|
|
req.Header.Set("Authorization", "Bearer "+bearerToken)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 200 {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var result struct {
|
|
Token string `json:"token"`
|
|
}
|
|
json.NewDecoder(resp.Body).Decode(&result)
|
|
return result.Token, nil
|
|
}
|
|
|
|
type deviceContextInfo struct {
|
|
DeviceContextID string `json:"deviceContextId"`
|
|
ClientID string `json:"clientId"`
|
|
ClientType string `json:"clientType"`
|
|
}
|
|
|
|
func acceptUserCode(oidcBase, userCode, deviceSessionToken string) (*deviceContextInfo, error) {
|
|
payload := map[string]string{
|
|
"userCode": userCode,
|
|
"userSessionId": deviceSessionToken,
|
|
}
|
|
|
|
body, _ := json.Marshal(payload)
|
|
req, _ := http.NewRequest("POST", oidcBase+"/device_authorization/accept_user_code", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Referer", "https://view.awsapps.com/")
|
|
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 200 {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var result struct {
|
|
DeviceContext *deviceContextInfo `json:"deviceContext"`
|
|
}
|
|
json.NewDecoder(resp.Body).Decode(&result)
|
|
return result.DeviceContext, nil
|
|
}
|
|
|
|
func approveAuth(oidcBase string, deviceContext *deviceContextInfo, deviceSessionToken string) error {
|
|
payload := map[string]interface{}{
|
|
"deviceContext": map[string]string{
|
|
"deviceContextId": deviceContext.DeviceContextID,
|
|
"clientId": deviceContext.ClientID,
|
|
"clientType": deviceContext.ClientType,
|
|
},
|
|
"userSessionId": deviceSessionToken,
|
|
}
|
|
|
|
body, _ := json.Marshal(payload)
|
|
req, _ := http.NewRequest("POST", oidcBase+"/device_authorization/associate_token", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Referer", "https://view.awsapps.com/")
|
|
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 200 {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func pollForToken(oidcBase, clientID, clientSecret, deviceCode string, interval int) (accessToken, refreshToken string, expiresIn int, err error) {
|
|
payload := map[string]string{
|
|
"clientId": clientID,
|
|
"clientSecret": clientSecret,
|
|
"grantType": "urn:ietf:params:oauth:grant-type:device_code",
|
|
"deviceCode": deviceCode,
|
|
}
|
|
|
|
timeout := time.After(2 * time.Minute)
|
|
ticker := time.NewTicker(time.Duration(interval) * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-timeout:
|
|
return "", "", 0, fmt.Errorf("授权超时")
|
|
case <-ticker.C:
|
|
body, _ := json.Marshal(payload)
|
|
req, _ := http.NewRequest("POST", oidcBase+"/token", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
if resp.StatusCode == 200 {
|
|
var result struct {
|
|
AccessToken string `json:"accessToken"`
|
|
RefreshToken string `json:"refreshToken"`
|
|
ExpiresIn int `json:"expiresIn"`
|
|
}
|
|
json.NewDecoder(resp.Body).Decode(&result)
|
|
resp.Body.Close()
|
|
return result.AccessToken, result.RefreshToken, result.ExpiresIn, nil
|
|
}
|
|
|
|
if resp.StatusCode == 400 {
|
|
var errResult struct {
|
|
Error string `json:"error"`
|
|
}
|
|
json.NewDecoder(resp.Body).Decode(&errResult)
|
|
resp.Body.Close()
|
|
|
|
if errResult.Error == "authorization_pending" {
|
|
continue
|
|
} else if errResult.Error == "slow_down" {
|
|
interval += 5
|
|
ticker.Reset(time.Duration(interval) * time.Second)
|
|
continue
|
|
}
|
|
return "", "", 0, fmt.Errorf("授权错误: %s", errResult.Error)
|
|
}
|
|
resp.Body.Close()
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetUserInfo 获取用户信息
|
|
func GetUserInfo(accessToken string) (email, userID string, err error) {
|
|
// 调用 Kiro API 获取用量信息(包含用户信息)
|
|
url := "https://q.us-east-1.amazonaws.com/getUsageLimits?origin=AI_EDITOR&resourceType=AGENTIC_REQUEST&isEmailRequired=true"
|
|
|
|
req, _ := http.NewRequest("GET", url, nil)
|
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
req.Header.Set("Accept", "application/json")
|
|
req.Header.Set("User-Agent", "aws-sdk-js/1.0.18 KiroAPIProxy")
|
|
req.Header.Set("x-amz-user-agent", "aws-sdk-js/1.0.18 KiroAPIProxy")
|
|
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 200 {
|
|
return "", "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
|
}
|
|
|
|
var result struct {
|
|
UserInfo struct {
|
|
Email string `json:"email"`
|
|
UserID string `json:"userId"`
|
|
} `json:"userInfo"`
|
|
}
|
|
json.NewDecoder(resp.Body).Decode(&result)
|
|
return result.UserInfo.Email, result.UserInfo.UserID, nil
|
|
}
|
|
|
|
// GenerateAccountID 生成账号 ID
|
|
func GenerateAccountID() string {
|
|
return uuid.New().String()
|
|
}
|