Files
kirogo/proxy/kiro_api.go
hkxiaoyao d05bd00207 Merge pull request #3 from hkxiaoyao/main
feat(account): add trial quota tracking and display
2026-02-06 13:16:24 +08:00

294 lines
8.6 KiB
Go

package proxy
import (
"encoding/json"
"fmt"
"io"
"kiro-api-proxy/config"
"net/http"
"strings"
"time"
)
const (
kiroRestAPIBase = "https://codewhisperer.us-east-1.amazonaws.com"
kiroVersion = "0.6.18"
)
// GetUsageLimits 获取账户使用量和订阅信息
func GetUsageLimits(account *config.Account) (*UsageLimitsResponse, error) {
url := fmt.Sprintf("%s/getUsageLimits?origin=AI_EDITOR&resourceType=AGENTIC_REQUEST&isEmailRequired=true", kiroRestAPIBase)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
setKiroHeaders(req, account)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
var result UsageLimitsResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
return &result, nil
}
// GetUserInfo 获取用户信息
func GetUserInfo(account *config.Account) (*UserInfoResponse, error) {
url := fmt.Sprintf("%s/GetUserInfo", kiroRestAPIBase)
payload := `{"origin":"KIRO_IDE"}`
req, err := http.NewRequest("POST", url, strings.NewReader(payload))
if err != nil {
return nil, err
}
setKiroHeaders(req, account)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
var result UserInfoResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
return &result, nil
}
// ListAvailableModels 获取可用模型列表
func ListAvailableModels(account *config.Account) ([]ModelInfo, error) {
url := fmt.Sprintf("%s/ListAvailableModels?origin=AI_EDITOR&maxResults=50", kiroRestAPIBase)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
setKiroHeaders(req, account)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
var result struct {
Models []ModelInfo `json:"models"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
return result.Models, nil
}
func setKiroHeaders(req *http.Request, account *config.Account) {
machineId := account.MachineId
var userAgent, amzUserAgent string
if machineId != "" {
userAgent = fmt.Sprintf("aws-sdk-js/1.0.18 ua/2.1 os/windows lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-%s-%s", kiroVersion, machineId)
amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.18 KiroIDE %s %s", kiroVersion, machineId)
} else {
userAgent = fmt.Sprintf("aws-sdk-js/1.0.18 ua/2.1 os/windows lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-%s", kiroVersion)
amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.18 KiroIDE-%s", kiroVersion)
}
req.Header.Set("Authorization", "Bearer "+account.AccessToken)
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", userAgent)
req.Header.Set("x-amz-user-agent", amzUserAgent)
req.Header.Set("x-amzn-codewhisperer-optout", "true")
}
// RefreshAccountInfo 刷新账户信息(使用量、订阅等)
func RefreshAccountInfo(account *config.Account) (*config.AccountInfo, error) {
info := &config.AccountInfo{
LastRefresh: time.Now().Unix(),
}
// 获取使用量和订阅信息
usage, err := GetUsageLimits(account)
if err != nil {
return nil, fmt.Errorf("GetUsageLimits: %w", err)
}
// 解析用户信息
if usage.UserInfo != nil {
info.Email = usage.UserInfo.Email
info.UserId = usage.UserInfo.UserId
}
// 解析订阅信息
if usage.SubscriptionInfo != nil {
// 优先从 SubscriptionTitle 或 SubscriptionName 解析类型
titleOrName := usage.SubscriptionInfo.SubscriptionTitle
if titleOrName == "" {
titleOrName = usage.SubscriptionInfo.SubscriptionName
}
if titleOrName == "" {
titleOrName = usage.SubscriptionInfo.SubscriptionType
}
info.SubscriptionType = parseSubscriptionType(titleOrName)
info.SubscriptionTitle = usage.SubscriptionInfo.SubscriptionTitle
if info.SubscriptionTitle == "" {
info.SubscriptionTitle = usage.SubscriptionInfo.SubscriptionName
}
fmt.Printf("[RefreshAccountInfo] Subscription: type=%s, title=%s, name=%s, parsed=%s\n",
usage.SubscriptionInfo.SubscriptionType,
usage.SubscriptionInfo.SubscriptionTitle,
usage.SubscriptionInfo.SubscriptionName,
info.SubscriptionType)
}
// 解析使用量
if len(usage.UsageBreakdownList) > 0 {
breakdown := usage.UsageBreakdownList[0]
info.UsageCurrent = breakdown.CurrentUsage
info.UsageLimit = breakdown.UsageLimit
if info.UsageLimit > 0 {
info.UsagePercent = info.UsageCurrent / info.UsageLimit
}
}
// 解析重置日期
if usage.NextDateReset != "" {
if ts, err := usage.NextDateReset.Int64(); err == nil && ts > 0 {
info.NextResetDate = time.Unix(ts, 0).Format("2006-01-02")
} else if f, err := usage.NextDateReset.Float64(); err == nil && f > 0 {
info.NextResetDate = time.Unix(int64(f), 0).Format("2006-01-02")
}
}
// 解析试用配额信息
if len(usage.UsageBreakdownList) > 0 {
breakdown := usage.UsageBreakdownList[0]
if breakdown.FreeTrialInfo != nil {
info.TrialUsageCurrent = breakdown.FreeTrialInfo.CurrentUsage
info.TrialUsageLimit = breakdown.FreeTrialInfo.UsageLimit
if info.TrialUsageLimit > 0 {
info.TrialUsagePercent = info.TrialUsageCurrent / info.TrialUsageLimit
}
info.TrialStatus = breakdown.FreeTrialInfo.FreeTrialStatus
// 解析试用到期时间
if breakdown.FreeTrialInfo.FreeTrialExpiry != "" {
if ts, err := breakdown.FreeTrialInfo.FreeTrialExpiry.Int64(); err == nil && ts > 0 {
info.TrialExpiresAt = ts
} else if f, err := breakdown.FreeTrialInfo.FreeTrialExpiry.Float64(); err == nil && f > 0 {
info.TrialExpiresAt = int64(f)
}
}
}
}
return info, nil
}
func parseSubscriptionType(raw string) string {
upper := strings.ToUpper(raw)
if strings.Contains(upper, "PRO_PLUS") || strings.Contains(upper, "PROPLUS") {
return "PRO_PLUS"
}
if strings.Contains(upper, "POWER") {
return "POWER"
}
if strings.Contains(upper, "PRO") {
return "PRO"
}
return "FREE"
}
// 响应结构体
type UsageLimitsResponse struct {
UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList"`
NextDateReset json.Number `json:"nextDateReset"`
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo"`
UserInfo *UserInfo `json:"userInfo"`
}
type UsageBreakdown struct {
ResourceType string `json:"resourceType"`
CurrentUsage float64 `json:"currentUsage"`
UsageLimit float64 `json:"usageLimit"`
Currency string `json:"currency"`
Unit string `json:"unit"`
OverageRate float64 `json:"overageRate"`
FreeTrialInfo *FreeTrialInfo `json:"freeTrialInfo"`
Bonuses []BonusInfo `json:"bonuses"`
}
type FreeTrialInfo struct {
CurrentUsage float64 `json:"currentUsage"`
UsageLimit float64 `json:"usageLimit"`
FreeTrialStatus string `json:"freeTrialStatus"`
FreeTrialExpiry json.Number `json:"freeTrialExpiry"`
}
type BonusInfo struct {
BonusCode string `json:"bonusCode"`
DisplayName string `json:"displayName"`
CurrentUsage float64 `json:"currentUsage"`
UsageLimit float64 `json:"usageLimit"`
ExpiresAt json.Number `json:"expiresAt"`
Status string `json:"status"`
}
type SubscriptionInfo struct {
SubscriptionName string `json:"subscriptionName"`
SubscriptionTitle string `json:"subscriptionTitle"`
SubscriptionType string `json:"subscriptionType"`
Status string `json:"status"`
UpgradeCapability string `json:"upgradeCapability"`
}
type UserInfo struct {
Email string `json:"email"`
UserId string `json:"userId"`
}
type UserInfoResponse struct {
Email string `json:"email"`
UserId string `json:"userId"`
Idp string `json:"idp"`
Status string `json:"status"`
}
type ModelInfo struct {
ModelId string `json:"modelId"`
ModelName string `json:"modelName"`
Description string `json:"description"`
InputTypes []string `json:"supportedInputTypes"`
RateMultiplier float64 `json:"rateMultiplier"`
TokenLimits *struct {
MaxInputTokens int `json:"maxInputTokens"`
MaxOutputTokens int `json:"maxOutputTokens"`
} `json:"tokenLimits"`
}