fix: resolve Kiro profile ARN for generation requests (#46)

This commit is contained in:
Delicious233
2026-05-12 18:56:59 +08:00
committed by GitHub
parent f9e45a5f1d
commit 08a9747c99
4 changed files with 183 additions and 0 deletions

View File

@@ -50,6 +50,7 @@ type Account struct {
StartUrl string `json:"startUrl,omitempty"` // AWS SSO start URL
ExpiresAt int64 `json:"expiresAt,omitempty"` // Token expiration timestamp (Unix seconds)
MachineId string `json:"machineId,omitempty"` // UUID machine identifier for request tracking
ProfileArn string `json:"profileArn,omitempty"` // CodeWhisperer/Kiro profile ARN for generation requests
// Priority weight for load balancing (higher = more requests)
Weight int `json:"weight,omitempty"` // 0 or 1 = normal, 2+ = higher priority
@@ -274,6 +275,18 @@ func UpdateAccount(id string, account Account) error {
return nil
}
func UpdateAccountProfileArn(id, profileArn string) error {
cfgLock.Lock()
defer cfgLock.Unlock()
for i, a := range cfg.Accounts {
if a.ID == id {
cfg.Accounts[i].ProfileArn = profileArn
return Save()
}
}
return nil
}
func DeleteAccount(id string) error {
cfgLock.Lock()
defer cfgLock.Unlock()

View File

@@ -196,6 +196,17 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt
if _, err := json.Marshal(payload); err != nil {
return err
}
if payload != nil && strings.TrimSpace(payload.ProfileArn) == "" {
if profileArn, err := ResolveProfileArn(account); err == nil {
payload.ProfileArn = profileArn
} else {
accountEmail := "<nil>"
if account != nil {
accountEmail = account.Email
}
fmt.Printf("[ProfileArn] Failed to resolve profile ARN for %s: %v\n", accountEmail, err)
}
}
// 根据配置排序端点
endpoints := getSortedEndpoints(config.GetPreferredEndpoint())

View File

@@ -6,6 +6,7 @@ import (
"io"
"kiro-go/config"
"net/http"
neturl "net/url"
"strings"
"time"
)
@@ -17,6 +18,7 @@ const (
// GetUsageLimits 获取账户使用量和订阅信息
func GetUsageLimits(account *config.Account) (*UsageLimitsResponse, error) {
url := fmt.Sprintf("%s/getUsageLimits?origin=AI_EDITOR&resourceType=AGENTIC_REQUEST&isEmailRequired=true", kiroRestAPIBase)
url = withProfileArnQuery(url, account)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
@@ -77,6 +79,7 @@ func GetUserInfo(account *config.Account) (*UserInfoResponse, error) {
// ListAvailableModels 获取可用模型列表
func ListAvailableModels(account *config.Account) ([]ModelInfo, error) {
url := fmt.Sprintf("%s/ListAvailableModels?origin=AI_EDITOR&maxResults=50", kiroRestAPIBase)
url = withProfileArnQuery(url, account)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
@@ -105,6 +108,66 @@ func ListAvailableModels(account *config.Account) ([]ModelInfo, error) {
return result.Models, nil
}
// ResolveProfileArn returns the account profile ARN, fetching and caching it
// when it is missing. Some Kiro generation requests require this profile for
// model authorization even when model listing works without it.
func ResolveProfileArn(account *config.Account) (string, error) {
if account == nil {
return "", fmt.Errorf("account is nil")
}
if profileArn := strings.TrimSpace(account.ProfileArn); profileArn != "" {
return profileArn, nil
}
req, err := http.NewRequest("POST", fmt.Sprintf("%s/ListAvailableProfiles", kiroRestAPIBase), strings.NewReader(`{"maxResults":10}`))
if err != nil {
return "", err
}
setKiroHeaders(req, account)
req.Header.Set("Content-Type", "application/json")
resp, err := kiroRestHttpStore.Load().Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
var result struct {
Profiles []struct {
Arn string `json:"arn"`
} `json:"profiles"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", err
}
for _, profile := range result.Profiles {
if profileArn := strings.TrimSpace(profile.Arn); profileArn != "" {
if updateErr := config.UpdateAccountProfileArn(account.ID, profileArn); updateErr != nil {
fmt.Printf("[ProfileArn] Failed to cache profile ARN for %s: %v\n", account.Email, updateErr)
}
account.ProfileArn = profileArn
return profileArn, nil
}
}
return "", fmt.Errorf("no available Kiro profile")
}
func withProfileArnQuery(rawURL string, account *config.Account) string {
if account == nil {
return rawURL
}
profileArn := strings.TrimSpace(account.ProfileArn)
if profileArn == "" {
return rawURL
}
return rawURL + "&profileArn=" + neturl.QueryEscape(profileArn)
}
func setKiroHeaders(req *http.Request, account *config.Account) {
host := ""
if req.URL != nil {

96
proxy/kiro_api_test.go Normal file
View File

@@ -0,0 +1,96 @@
package proxy
import (
"io"
"kiro-go/config"
"net/http"
"path/filepath"
"strings"
"testing"
)
func TestResolveProfileArnReturnsCachedValueWithoutRequest(t *testing.T) {
kiroRestHttpStore.Store(&http.Client{
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
t.Fatal("unexpected HTTP request for cached profile ARN")
return nil, nil
}),
})
t.Cleanup(func() { InitKiroHttpClient("") })
account := &config.Account{ProfileArn: " arn:aws:codewhisperer:profile/test "}
got, err := ResolveProfileArn(account)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "arn:aws:codewhisperer:profile/test" {
t.Fatalf("expected trimmed cached ARN, got %q", got)
}
}
func TestResolveProfileArnFetchesAndCachesProfile(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
if err := config.Init(configPath); err != nil {
t.Fatalf("init config: %v", err)
}
account := config.Account{
ID: "acct-1",
Email: "user@example.com",
AccessToken: "access-token",
Region: "us-east-1",
UsageCurrent: 7,
}
if err := config.AddAccount(account); err != nil {
t.Fatalf("add account: %v", err)
}
kiroRestHttpStore.Store(&http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", req.Method)
}
if req.URL.Path != "/ListAvailableProfiles" {
t.Fatalf("expected ListAvailableProfiles path, got %s", req.URL.Path)
}
if got := req.Header.Get("Content-Type"); got != "application/json" {
t.Fatalf("expected JSON content type, got %q", got)
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{"profiles":[{"arn":" arn:aws:codewhisperer:profile/fetched "}]} `)),
Header: make(http.Header),
}, nil
}),
})
t.Cleanup(func() { InitKiroHttpClient("") })
requestAccount := account
requestAccount.UsageCurrent = 0
got, err := ResolveProfileArn(&requestAccount)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "arn:aws:codewhisperer:profile/fetched" {
t.Fatalf("expected fetched ARN, got %q", got)
}
if requestAccount.ProfileArn != got {
t.Fatalf("expected account to be updated with fetched ARN, got %q", requestAccount.ProfileArn)
}
accounts := config.GetAccounts()
if len(accounts) != 1 {
t.Fatalf("expected one persisted account, got %d", len(accounts))
}
if accounts[0].ProfileArn != got {
t.Fatalf("expected persisted account profile ARN %q, got %q", got, accounts[0].ProfileArn)
}
if accounts[0].UsageCurrent != 7 {
t.Fatalf("expected profile cache update to preserve usage fields, got usageCurrent=%v", accounts[0].UsageCurrent)
}
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}