fix: resolve Kiro profile ARN for generation requests (#46)
This commit is contained in:
@@ -50,6 +50,7 @@ type Account struct {
|
||||
StartUrl string `json:"startUrl,omitempty"` // AWS SSO start URL
|
||||
ExpiresAt int64 `json:"expiresAt,omitempty"` // Token expiration timestamp (Unix seconds)
|
||||
MachineId string `json:"machineId,omitempty"` // UUID machine identifier for request tracking
|
||||
ProfileArn string `json:"profileArn,omitempty"` // CodeWhisperer/Kiro profile ARN for generation requests
|
||||
|
||||
// Priority weight for load balancing (higher = more requests)
|
||||
Weight int `json:"weight,omitempty"` // 0 or 1 = normal, 2+ = higher priority
|
||||
@@ -274,6 +275,18 @@ func UpdateAccount(id string, account Account) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateAccountProfileArn(id, profileArn string) error {
|
||||
cfgLock.Lock()
|
||||
defer cfgLock.Unlock()
|
||||
for i, a := range cfg.Accounts {
|
||||
if a.ID == id {
|
||||
cfg.Accounts[i].ProfileArn = profileArn
|
||||
return Save()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteAccount(id string) error {
|
||||
cfgLock.Lock()
|
||||
defer cfgLock.Unlock()
|
||||
|
||||
@@ -196,6 +196,17 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt
|
||||
if _, err := json.Marshal(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
if payload != nil && strings.TrimSpace(payload.ProfileArn) == "" {
|
||||
if profileArn, err := ResolveProfileArn(account); err == nil {
|
||||
payload.ProfileArn = profileArn
|
||||
} else {
|
||||
accountEmail := "<nil>"
|
||||
if account != nil {
|
||||
accountEmail = account.Email
|
||||
}
|
||||
fmt.Printf("[ProfileArn] Failed to resolve profile ARN for %s: %v\n", accountEmail, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 根据配置排序端点
|
||||
endpoints := getSortedEndpoints(config.GetPreferredEndpoint())
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"kiro-go/config"
|
||||
"net/http"
|
||||
neturl "net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -17,6 +18,7 @@ const (
|
||||
// GetUsageLimits 获取账户使用量和订阅信息
|
||||
func GetUsageLimits(account *config.Account) (*UsageLimitsResponse, error) {
|
||||
url := fmt.Sprintf("%s/getUsageLimits?origin=AI_EDITOR&resourceType=AGENTIC_REQUEST&isEmailRequired=true", kiroRestAPIBase)
|
||||
url = withProfileArnQuery(url, account)
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
@@ -77,6 +79,7 @@ func GetUserInfo(account *config.Account) (*UserInfoResponse, error) {
|
||||
// ListAvailableModels 获取可用模型列表
|
||||
func ListAvailableModels(account *config.Account) ([]ModelInfo, error) {
|
||||
url := fmt.Sprintf("%s/ListAvailableModels?origin=AI_EDITOR&maxResults=50", kiroRestAPIBase)
|
||||
url = withProfileArnQuery(url, account)
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
@@ -105,6 +108,66 @@ func ListAvailableModels(account *config.Account) ([]ModelInfo, error) {
|
||||
return result.Models, nil
|
||||
}
|
||||
|
||||
// ResolveProfileArn returns the account profile ARN, fetching and caching it
|
||||
// when it is missing. Some Kiro generation requests require this profile for
|
||||
// model authorization even when model listing works without it.
|
||||
func ResolveProfileArn(account *config.Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", fmt.Errorf("account is nil")
|
||||
}
|
||||
if profileArn := strings.TrimSpace(account.ProfileArn); profileArn != "" {
|
||||
return profileArn, nil
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("%s/ListAvailableProfiles", kiroRestAPIBase), strings.NewReader(`{"maxResults":10}`))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
setKiroHeaders(req, account)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := kiroRestHttpStore.Load().Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Profiles []struct {
|
||||
Arn string `json:"arn"`
|
||||
} `json:"profiles"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, profile := range result.Profiles {
|
||||
if profileArn := strings.TrimSpace(profile.Arn); profileArn != "" {
|
||||
if updateErr := config.UpdateAccountProfileArn(account.ID, profileArn); updateErr != nil {
|
||||
fmt.Printf("[ProfileArn] Failed to cache profile ARN for %s: %v\n", account.Email, updateErr)
|
||||
}
|
||||
account.ProfileArn = profileArn
|
||||
return profileArn, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no available Kiro profile")
|
||||
}
|
||||
|
||||
func withProfileArnQuery(rawURL string, account *config.Account) string {
|
||||
if account == nil {
|
||||
return rawURL
|
||||
}
|
||||
profileArn := strings.TrimSpace(account.ProfileArn)
|
||||
if profileArn == "" {
|
||||
return rawURL
|
||||
}
|
||||
return rawURL + "&profileArn=" + neturl.QueryEscape(profileArn)
|
||||
}
|
||||
|
||||
func setKiroHeaders(req *http.Request, account *config.Account) {
|
||||
host := ""
|
||||
if req.URL != nil {
|
||||
|
||||
96
proxy/kiro_api_test.go
Normal file
96
proxy/kiro_api_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"kiro-go/config"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResolveProfileArnReturnsCachedValueWithoutRequest(t *testing.T) {
|
||||
kiroRestHttpStore.Store(&http.Client{
|
||||
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Fatal("unexpected HTTP request for cached profile ARN")
|
||||
return nil, nil
|
||||
}),
|
||||
})
|
||||
t.Cleanup(func() { InitKiroHttpClient("") })
|
||||
|
||||
account := &config.Account{ProfileArn: " arn:aws:codewhisperer:profile/test "}
|
||||
got, err := ResolveProfileArn(account)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "arn:aws:codewhisperer:profile/test" {
|
||||
t.Fatalf("expected trimmed cached ARN, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveProfileArnFetchesAndCachesProfile(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
if err := config.Init(configPath); err != nil {
|
||||
t.Fatalf("init config: %v", err)
|
||||
}
|
||||
account := config.Account{
|
||||
ID: "acct-1",
|
||||
Email: "user@example.com",
|
||||
AccessToken: "access-token",
|
||||
Region: "us-east-1",
|
||||
UsageCurrent: 7,
|
||||
}
|
||||
if err := config.AddAccount(account); err != nil {
|
||||
t.Fatalf("add account: %v", err)
|
||||
}
|
||||
|
||||
kiroRestHttpStore.Store(&http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if req.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", req.Method)
|
||||
}
|
||||
if req.URL.Path != "/ListAvailableProfiles" {
|
||||
t.Fatalf("expected ListAvailableProfiles path, got %s", req.URL.Path)
|
||||
}
|
||||
if got := req.Header.Get("Content-Type"); got != "application/json" {
|
||||
t.Fatalf("expected JSON content type, got %q", got)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`{"profiles":[{"arn":" arn:aws:codewhisperer:profile/fetched "}]} `)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}),
|
||||
})
|
||||
t.Cleanup(func() { InitKiroHttpClient("") })
|
||||
|
||||
requestAccount := account
|
||||
requestAccount.UsageCurrent = 0
|
||||
got, err := ResolveProfileArn(&requestAccount)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "arn:aws:codewhisperer:profile/fetched" {
|
||||
t.Fatalf("expected fetched ARN, got %q", got)
|
||||
}
|
||||
if requestAccount.ProfileArn != got {
|
||||
t.Fatalf("expected account to be updated with fetched ARN, got %q", requestAccount.ProfileArn)
|
||||
}
|
||||
|
||||
accounts := config.GetAccounts()
|
||||
if len(accounts) != 1 {
|
||||
t.Fatalf("expected one persisted account, got %d", len(accounts))
|
||||
}
|
||||
if accounts[0].ProfileArn != got {
|
||||
t.Fatalf("expected persisted account profile ARN %q, got %q", got, accounts[0].ProfileArn)
|
||||
}
|
||||
if accounts[0].UsageCurrent != 7 {
|
||||
t.Fatalf("expected profile cache update to preserve usage fields, got usageCurrent=%v", accounts[0].UsageCurrent)
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return fn(req)
|
||||
}
|
||||
Reference in New Issue
Block a user