From c5e6d421635807f3abdd96b031c40a2b6add4a20 Mon Sep 17 00:00:00 2001 From: Quorinex Date: Wed, 4 Feb 2026 00:37:05 +0800 Subject: [PATCH] feat: Kiro API Proxy - OpenAI/Anthropic compatible API service - Multi-account pool with round-robin load balancing - Auto token refresh for IAM IdC and Social auth - Streaming support (SSE) - Web admin panel with account management - Docker support with GitHub Actions CI/CD - Machine ID management per account - Usage tracking (requests, tokens, credits) --- .github/workflows/docker.yml | 60 ++ Dockerfile | 20 + README.md | 189 +++++ README_CN.md | 189 +++++ auth/iam_sso.go | 267 +++++++ auth/oidc.go | 102 +++ auth/sso_token.go | 338 +++++++++ config/config.go | 331 ++++++++ docker-compose.yml | 12 + go.mod | 5 + go.sum | 2 + main.go | 54 ++ pool/account.go | 189 +++++ proxy/handler.go | 1392 ++++++++++++++++++++++++++++++++++ proxy/kiro.go | 370 +++++++++ proxy/kiro_api.go | 271 +++++++ proxy/translator.go | 811 ++++++++++++++++++++ web/index.html | 616 +++++++++++++++ 18 files changed, 5218 insertions(+) create mode 100644 .github/workflows/docker.yml create mode 100644 Dockerfile create mode 100644 README.md create mode 100644 README_CN.md create mode 100644 auth/iam_sso.go create mode 100644 auth/oidc.go create mode 100644 auth/sso_token.go create mode 100644 config/config.go create mode 100644 docker-compose.yml create mode 100644 go.mod create mode 100644 go.sum create mode 100644 main.go create mode 100644 pool/account.go create mode 100644 proxy/handler.go create mode 100644 proxy/kiro.go create mode 100644 proxy/kiro_api.go create mode 100644 proxy/translator.go create mode 100644 web/index.html diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml new file mode 100644 index 0000000..18047de --- /dev/null +++ b/.github/workflows/docker.yml @@ -0,0 +1,60 @@ +name: Build Docker Image + +on: + push: + branches: [main, master] + tags: ['v*'] + pull_request: + branches: [main, master] + workflow_dispatch: + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }}/kiro-api-proxy + +jobs: + build: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to GitHub Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=ref,event=branch + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=sha,prefix= + + - name: Build and push + uses: docker/build-push-action@v5 + with: + context: . + platforms: linux/amd64,linux/arm64 + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..9834d80 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,20 @@ +FROM golang:1.21-alpine AS builder + +WORKDIR /app +COPY go.mod go.sum ./ +RUN go mod download + +COPY . . +RUN CGO_ENABLED=0 GOOS=linux go build -o kiro-api-proxy . + +FROM alpine:latest +RUN apk --no-cache add ca-certificates + +WORKDIR /app +COPY --from=builder /app/kiro-api-proxy . +COPY --from=builder /app/web ./web + +EXPOSE 8080 +VOLUME /app/data + +CMD ["./kiro-api-proxy"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..cbb044e --- /dev/null +++ b/README.md @@ -0,0 +1,189 @@ +# Kiro API Proxy + +[![Go Version](https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go)](https://go.dev/) +[![Docker](https://img.shields.io/badge/Docker-Ready-2496ED?style=flat&logo=docker)](https://www.docker.com/) +[![License](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE) + +Convert Kiro accounts to OpenAI / Anthropic compatible API service. + +[English](README.md) | [中文](README_CN.md) + +## Features + +- 🔄 **Anthropic Claude API** - Full support for `/v1/messages` endpoint +- 🤖 **OpenAI Chat API** - Compatible with `/v1/chat/completions` +- ⚖️ **Multi-Account Pool** - Round-robin load balancing +- 🔐 **Auto Token Refresh** - Seamless token management +- 📡 **Streaming** - Real-time SSE responses +- 🎛️ **Web Admin Panel** - Easy account management +- 🔑 **Multiple Auth Methods** - IAM SSO, SSO Token, Credentials import +- 📊 **Usage Tracking** - Monitor requests, tokens, and credits + +## Quick Start + +### Docker Compose (Recommended) + +```bash +git clone https://github.com/Quorinex/kiro-api-proxy.git +cd kiro-api-proxy + +# Create data directory for persistence +mkdir -p data + +docker-compose up -d +``` + +### Docker Run + +```bash +# Create data directory +mkdir -p /path/to/data + +docker run -d \ + --name kiro-api-proxy \ + -p 8080:8080 \ + -e ADMIN_PASSWORD=your_secure_password \ + -v /path/to/data:/app/data \ + --restart unless-stopped \ + ghcr.io/quorinex/kiro-api-proxy:latest +``` + +> 📁 The `/app/data` volume stores `config.json` with accounts and settings. Mount it for data persistence. + +### Build from Source + +```bash +git clone https://github.com/Quorinex/kiro-api-proxy.git +cd kiro-api-proxy +go build -o kiro-api-proxy . +./kiro-api-proxy +``` + +## Configuration + +Config file is auto-created at `data/config.json` on first run: + +```json +{ + "password": "changeme", + "port": 8080, + "host": "127.0.0.1", + "requireApiKey": false, + "apiKey": "", + "accounts": [] +} +``` + +> ⚠️ **Change the default password before production use!** + +## Environment Variables + +| Variable | Description | Default | +|----------|-------------|---------| +| `CONFIG_PATH` | Config file path | `data/config.json` | +| `ADMIN_PASSWORD` | Admin panel password (overrides config) | - | + +## Usage + +### 1. Access Admin Panel + +Open `http://localhost:8080/admin` and login with your password. + +### 2. Add Accounts + +Three methods available: + +| Method | Description | +|--------|-------------| +| **IAM SSO** | For enterprise users with SSO Start URL | +| **SSO Token** | Import `x-amz-sso_authn` from browser | +| **Credentials** | Import JSON from Kiro Account Manager | + +#### Credentials Format + +```json +{ + "refreshToken": "eyJ...", + "accessToken": "eyJ...", + "clientId": "xxx", + "clientSecret": "xxx" +} +``` + +### 3. Call API + +#### Claude API + +```bash +curl http://localhost:8080/v1/messages \ + -H "Content-Type: application/json" \ + -H "anthropic-version: 2023-06-01" \ + -d '{ + "model": "claude-sonnet-4-20250514", + "max_tokens": 1024, + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +#### OpenAI API + +```bash +curl http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer any" \ + -d '{ + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +## Model Mapping + +| Request Model | Actual Model | +|---------------|--------------| +| `claude-sonnet-4-20250514` | claude-sonnet-4-20250514 | +| `claude-sonnet-4.5` | claude-sonnet-4.5 | +| `claude-haiku-4.5` | claude-haiku-4.5 | +| `claude-opus-4.5` | claude-opus-4.5 | +| `gpt-4o`, `gpt-4` | claude-sonnet-4-20250514 | +| `gpt-3.5-turbo` | claude-sonnet-4-20250514 | + +## API Endpoints + +| Endpoint | Description | +|----------|-------------| +| `GET /health` | Health check | +| `GET /v1/models` | List models | +| `POST /v1/messages` | Claude Messages API | +| `POST /v1/messages/count_tokens` | Token counting | +| `POST /v1/chat/completions` | OpenAI Chat API | +| `GET /admin` | Admin panel | + +## Project Structure + +``` +kiro-api-proxy/ +├── main.go # Entry point +├── config/ # Configuration management +├── pool/ # Account pool & load balancing +├── proxy/ # API handlers & Kiro client +│ ├── handler.go # HTTP routing & admin API +│ ├── kiro.go # Kiro API client +│ ├── kiro_api.go # Kiro REST API (usage, models) +│ └── translator.go # Request/response conversion +├── auth/ # Authentication +│ ├── iam_sso.go # IAM SSO login +│ ├── oidc.go # OIDC token refresh +│ └── sso_token.go # SSO token import +├── web/ # Admin panel frontend +├── Dockerfile +└── docker-compose.yml +``` + +## Disclaimer + +This project is for educational and research purposes only. Please comply with Kiro's Terms of Service. + +## License + +[MIT](LICENSE) diff --git a/README_CN.md b/README_CN.md new file mode 100644 index 0000000..3614c1b --- /dev/null +++ b/README_CN.md @@ -0,0 +1,189 @@ +# Kiro API Proxy + +[![Go Version](https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go)](https://go.dev/) +[![Docker](https://img.shields.io/badge/Docker-Ready-2496ED?style=flat&logo=docker)](https://www.docker.com/) +[![License](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE) + +将 Kiro 账号转换为 OpenAI / Anthropic 兼容的 API 服务。 + +[English](README.md) | 中文 + +## 功能特性 + +- 🔄 **Anthropic Claude API** - 完整支持 `/v1/messages` 端点 +- 🤖 **OpenAI Chat API** - 兼容 `/v1/chat/completions` +- ⚖️ **多账号池** - 轮询负载均衡 +- 🔐 **自动刷新 Token** - 无缝 Token 管理 +- 📡 **流式响应** - 实时 SSE 输出 +- 🎛️ **Web 管理面板** - 便捷的账号管理 +- 🔑 **多种认证方式** - IAM SSO、SSO Token、凭证导入 +- 📊 **用量追踪** - 监控请求数、Token、Credits + +## 快速开始 + +### Docker Compose(推荐) + +```bash +git clone https://github.com/Quorinex/kiro-api-proxy.git +cd kiro-api-proxy + +# 创建数据目录用于持久化 +mkdir -p data + +docker-compose up -d +``` + +### Docker 运行 + +```bash +# 创建数据目录 +mkdir -p /path/to/data + +docker run -d \ + --name kiro-api-proxy \ + -p 8080:8080 \ + -e ADMIN_PASSWORD=your_secure_password \ + -v /path/to/data:/app/data \ + --restart unless-stopped \ + ghcr.io/quorinex/kiro-api-proxy:latest +``` + +> 📁 `/app/data` 卷存储 `config.json`(包含账号和设置),挂载此目录以实现数据持久化。 + +### 源码编译 + +```bash +git clone https://github.com/Quorinex/kiro-api-proxy.git +cd kiro-api-proxy +go build -o kiro-api-proxy . +./kiro-api-proxy +``` + +## 配置 + +首次运行会自动创建 `data/config.json`: + +```json +{ + "password": "changeme", + "port": 8080, + "host": "127.0.0.1", + "requireApiKey": false, + "apiKey": "", + "accounts": [] +} +``` + +> ⚠️ **生产环境请务必修改默认密码!** + +## 环境变量 + +| 变量 | 说明 | 默认值 | +|-----|------|-------| +| `CONFIG_PATH` | 配置文件路径 | `data/config.json` | +| `ADMIN_PASSWORD` | 管理面板密码(覆盖配置文件) | - | + +## 使用方法 + +### 1. 访问管理面板 + +打开 `http://localhost:8080/admin`,输入密码登录。 + +### 2. 添加账号 + +支持三种方式: + +| 方式 | 说明 | +|------|------| +| **IAM SSO** | 企业用户,输入 SSO Start URL | +| **SSO Token** | 从浏览器导入 `x-amz-sso_authn` | +| **凭证导入** | 从 Kiro Account Manager 导入 JSON | + +#### 凭证格式 + +```json +{ + "refreshToken": "eyJ...", + "accessToken": "eyJ...", + "clientId": "xxx", + "clientSecret": "xxx" +} +``` + +### 3. 调用 API + +#### Claude API + +```bash +curl http://localhost:8080/v1/messages \ + -H "Content-Type: application/json" \ + -H "anthropic-version: 2023-06-01" \ + -d '{ + "model": "claude-sonnet-4-20250514", + "max_tokens": 1024, + "messages": [{"role": "user", "content": "你好!"}] + }' +``` + +#### OpenAI API + +```bash +curl http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer any" \ + -d '{ + "model": "gpt-4o", + "messages": [{"role": "user", "content": "你好!"}] + }' +``` + +## 模型映射 + +| 请求模型 | 实际模型 | +|---------|---------| +| `claude-sonnet-4-20250514` | claude-sonnet-4-20250514 | +| `claude-sonnet-4.5` | claude-sonnet-4.5 | +| `claude-haiku-4.5` | claude-haiku-4.5 | +| `claude-opus-4.5` | claude-opus-4.5 | +| `gpt-4o`, `gpt-4` | claude-sonnet-4-20250514 | +| `gpt-3.5-turbo` | claude-sonnet-4-20250514 | + +## API 端点 + +| 端点 | 说明 | +|-----|------| +| `GET /health` | 健康检查 | +| `GET /v1/models` | 模型列表 | +| `POST /v1/messages` | Claude Messages API | +| `POST /v1/messages/count_tokens` | Token 计数 | +| `POST /v1/chat/completions` | OpenAI Chat API | +| `GET /admin` | 管理面板 | + +## 项目结构 + +``` +kiro-api-proxy/ +├── main.go # 入口 +├── config/ # 配置管理 +├── pool/ # 账号池 & 负载均衡 +├── proxy/ # API 处理 & Kiro 客户端 +│ ├── handler.go # HTTP 路由 & 管理 API +│ ├── kiro.go # Kiro API 客户端 +│ ├── kiro_api.go # Kiro REST API(用量、模型) +│ └── translator.go # 请求/响应转换 +├── auth/ # 认证 +│ ├── iam_sso.go # IAM SSO 登录 +│ ├── oidc.go # OIDC Token 刷新 +│ └── sso_token.go # SSO Token 导入 +├── web/ # 管理面板前端 +├── Dockerfile +└── docker-compose.yml +``` + +## 免责声明 + +本项目仅供学习研究使用,请遵守 Kiro 服务条款。 + +## 许可证 + +[MIT](LICENSE) diff --git a/auth/iam_sso.go b/auth/iam_sso.go new file mode 100644 index 0000000..cc8ea06 --- /dev/null +++ b/auth/iam_sso.go @@ -0,0 +1,267 @@ +package auth + +import ( + "bytes" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "sync" + "time" + + "github.com/google/uuid" +) + +type IamSsoSession struct { + ClientID string + ClientSecret string + CodeVerifier string + State string + Region string + StartUrl string + RedirectUri string + ExpiresAt time.Time +} + +var ( + sessions = make(map[string]*IamSsoSession) + sessionsMu sync.RWMutex +) + +var scopes = []string{ + "codewhisperer:completions", + "codewhisperer:analysis", + "codewhisperer:conversations", + "codewhisperer:transformations", + "codewhisperer:taskassist", +} + +// StartIamSsoLogin 发起 IAM SSO 登录 +func StartIamSsoLogin(startUrl, region string) (sessionID, authorizeUrl string, expiresIn int, err error) { + if region == "" { + region = "us-east-1" + } + + oidcBase := fmt.Sprintf("https://oidc.%s.amazonaws.com", region) + redirectUri := "http://127.0.0.1/oauth/callback" + + // 1. 注册 OIDC 客户端 + clientID, clientSecret, err := registerOIDCClient(oidcBase, startUrl, redirectUri) + if err != nil { + return "", "", 0, fmt.Errorf("注册客户端失败: %w", err) + } + + // 2. 生成 PKCE + codeVerifier := generateCodeVerifier() + codeChallenge := generateCodeChallenge(codeVerifier) + state := uuid.New().String() + + // 3. 构建授权 URL + params := url.Values{} + params.Set("response_type", "code") + params.Set("client_id", clientID) + params.Set("redirect_uri", redirectUri) + params.Set("scopes", joinScopes()) + params.Set("state", state) + params.Set("code_challenge", codeChallenge) + params.Set("code_challenge_method", "S256") + + authorizeUrl = fmt.Sprintf("%s/authorize?%s", oidcBase, params.Encode()) + + // 4. 保存会话 + sessionID = uuid.New().String() + session := &IamSsoSession{ + ClientID: clientID, + ClientSecret: clientSecret, + CodeVerifier: codeVerifier, + State: state, + Region: region, + StartUrl: startUrl, + RedirectUri: redirectUri, + ExpiresAt: time.Now().Add(10 * time.Minute), + } + + sessionsMu.Lock() + sessions[sessionID] = session + sessionsMu.Unlock() + + // 清理过期会话 + go cleanupExpiredSessions() + + return sessionID, authorizeUrl, 600, nil +} + +// CompleteIamSsoLogin 完成 IAM SSO 登录 +func CompleteIamSsoLogin(sessionID, callbackUrl string) (accessToken, refreshToken, clientID, clientSecret, region string, expiresIn int, err error) { + sessionsMu.RLock() + session, ok := sessions[sessionID] + sessionsMu.RUnlock() + + if !ok { + return "", "", "", "", "", 0, fmt.Errorf("会话不存在或已过期") + } + + if time.Now().After(session.ExpiresAt) { + sessionsMu.Lock() + delete(sessions, sessionID) + sessionsMu.Unlock() + return "", "", "", "", "", 0, fmt.Errorf("会话已过期") + } + + // 解析回调 URL + parsedUrl, err := url.Parse(callbackUrl) + if err != nil { + return "", "", "", "", "", 0, fmt.Errorf("无效的回调 URL") + } + + code := parsedUrl.Query().Get("code") + state := parsedUrl.Query().Get("state") + errorParam := parsedUrl.Query().Get("error") + + if errorParam != "" { + return "", "", "", "", "", 0, fmt.Errorf("授权失败: %s", errorParam) + } + + if state != session.State { + return "", "", "", "", "", 0, fmt.Errorf("状态不匹配,可能存在安全风险") + } + + if code == "" { + return "", "", "", "", "", 0, fmt.Errorf("未收到授权码") + } + + // 用 code 换取 token + oidcBase := fmt.Sprintf("https://oidc.%s.amazonaws.com", session.Region) + accessToken, refreshToken, expiresIn, err = exchangeToken( + oidcBase, + session.ClientID, + session.ClientSecret, + code, + session.CodeVerifier, + session.RedirectUri, + ) + if err != nil { + return "", "", "", "", "", 0, err + } + + // 清理会话 + sessionsMu.Lock() + delete(sessions, sessionID) + sessionsMu.Unlock() + + return accessToken, refreshToken, session.ClientID, session.ClientSecret, session.Region, expiresIn, nil +} + +func registerOIDCClient(oidcBase, startUrl, redirectUri string) (clientID, clientSecret string, err error) { + payload := map[string]interface{}{ + "clientName": "Kiro API Proxy", + "clientType": "public", + "scopes": scopes, + "grantTypes": []string{"authorization_code", "refresh_token"}, + "redirectUris": []string{redirectUri}, + "issuerUrl": startUrl, + } + + body, _ := json.Marshal(payload) + req, _ := http.NewRequest("POST", oidcBase+"/client/register", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", "", err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + respBody, _ := io.ReadAll(resp.Body) + return "", "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody)) + } + + var result struct { + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + } + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", "", err + } + + return result.ClientID, result.ClientSecret, nil +} + +func exchangeToken(oidcBase, clientID, clientSecret, code, codeVerifier, redirectUri string) (accessToken, refreshToken string, expiresIn int, err error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "grantType": "authorization_code", + "redirectUri": redirectUri, + "code": code, + "codeVerifier": codeVerifier, + } + + body, _ := json.Marshal(payload) + req, _ := http.NewRequest("POST", oidcBase+"/token", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", "", 0, err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + respBody, _ := io.ReadAll(resp.Body) + return "", "", 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody)) + } + + var result struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ExpiresIn int `json:"expiresIn"` + } + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", "", 0, err + } + + return result.AccessToken, result.RefreshToken, result.ExpiresIn, nil +} + +func generateCodeVerifier() string { + b := make([]byte, 32) + rand.Read(b) + return base64.RawURLEncoding.EncodeToString(b) +} + +func generateCodeChallenge(verifier string) string { + h := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(h[:]) +} + +func joinScopes() string { + result := "" + for i, s := range scopes { + if i > 0 { + result += "," + } + result += s + } + return result +} + +func cleanupExpiredSessions() { + sessionsMu.Lock() + defer sessionsMu.Unlock() + now := time.Now() + for id, s := range sessions { + if now.After(s.ExpiresAt) { + delete(sessions, id) + } + } +} diff --git a/auth/oidc.go b/auth/oidc.go new file mode 100644 index 0000000..5d656af --- /dev/null +++ b/auth/oidc.go @@ -0,0 +1,102 @@ +package auth + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "kiro-api-proxy/config" + "net/http" + "time" +) + +// RefreshToken 刷新 access token +func RefreshToken(account *config.Account) (string, string, int64, error) { + if account.AuthMethod == "social" { + return refreshSocialToken(account.RefreshToken) + } + return refreshOIDCToken(account.RefreshToken, account.ClientID, account.ClientSecret, account.Region) +} + +// refreshOIDCToken IdC/Builder ID token 刷新 +func refreshOIDCToken(refreshToken, clientID, clientSecret, region string) (string, string, int64, error) { + if region == "" { + region = "us-east-1" + } + + url := fmt.Sprintf("https://oidc.%s.amazonaws.com/token", region) + + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "refreshToken": refreshToken, + "grantType": "refresh_token", + } + + body, _ := json.Marshal(payload) + req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", "", 0, err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + respBody, _ := io.ReadAll(resp.Body) + return "", "", 0, fmt.Errorf("refresh failed: %d %s", resp.StatusCode, string(respBody)) + } + + var result struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ExpiresIn int `json:"expiresIn"` + } + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", "", 0, err + } + + expiresAt := time.Now().Unix() + int64(result.ExpiresIn) + return result.AccessToken, result.RefreshToken, expiresAt, nil +} + +// refreshSocialToken Social (GitHub/Google) token 刷新 +func refreshSocialToken(refreshToken string) (string, string, int64, error) { + url := "https://prod.us-east-1.auth.desktop.kiro.dev/refreshToken" + + payload := map[string]string{ + "refreshToken": refreshToken, + } + + body, _ := json.Marshal(payload) + req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", "", 0, err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + respBody, _ := io.ReadAll(resp.Body) + return "", "", 0, fmt.Errorf("refresh failed: %d %s", resp.StatusCode, string(respBody)) + } + + var result struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ExpiresIn int `json:"expiresIn"` + } + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", "", 0, err + } + + expiresAt := time.Now().Unix() + int64(result.ExpiresIn) + return result.AccessToken, result.RefreshToken, expiresAt, nil +} diff --git a/auth/sso_token.go b/auth/sso_token.go new file mode 100644 index 0000000..ed88783 --- /dev/null +++ b/auth/sso_token.go @@ -0,0 +1,338 @@ +package auth + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/google/uuid" +) + +// ImportFromSsoToken 从 SSO Token (x-amz-sso_authn) 导入账号 +func ImportFromSsoToken(bearerToken, region string) (accessToken, refreshToken, clientID, clientSecret string, expiresIn int, err error) { + if region == "" { + region = "us-east-1" + } + + oidcBase := fmt.Sprintf("https://oidc.%s.amazonaws.com", region) + portalBase := "https://portal.sso.us-east-1.amazonaws.com" + startUrl := "https://view.awsapps.com/start" + + // 1. 注册 OIDC 客户端 + clientID, clientSecret, err = registerDeviceClient(oidcBase, startUrl) + if err != nil { + return "", "", "", "", 0, fmt.Errorf("注册客户端失败: %w", err) + } + + // 2. 发起设备授权 + deviceCode, userCode, interval, err := startDeviceAuth(oidcBase, clientID, clientSecret, startUrl) + if err != nil { + return "", "", "", "", 0, fmt.Errorf("设备授权失败: %w", err) + } + + // 3. 验证 Bearer Token + if err := verifyBearerToken(portalBase, bearerToken); err != nil { + return "", "", "", "", 0, fmt.Errorf("Token 验证失败: %w", err) + } + + // 4. 获取设备会话令牌 + deviceSessionToken, err := getDeviceSessionToken(portalBase, bearerToken) + if err != nil { + return "", "", "", "", 0, fmt.Errorf("获取设备会话失败: %w", err) + } + + // 5. 接受用户代码 + deviceContext, err := acceptUserCode(oidcBase, userCode, deviceSessionToken) + if err != nil { + return "", "", "", "", 0, fmt.Errorf("接受用户代码失败: %w", err) + } + + // 6. 批准授权 + if deviceContext != nil { + if err := approveAuth(oidcBase, deviceContext, deviceSessionToken); err != nil { + return "", "", "", "", 0, fmt.Errorf("批准授权失败: %w", err) + } + } + + // 7. 轮询获取 Token + accessToken, refreshToken, expiresIn, err = pollForToken(oidcBase, clientID, clientSecret, deviceCode, interval) + if err != nil { + return "", "", "", "", 0, fmt.Errorf("获取 Token 失败: %w", err) + } + + return accessToken, refreshToken, clientID, clientSecret, expiresIn, nil +} + +func registerDeviceClient(oidcBase, startUrl string) (clientID, clientSecret string, err error) { + payload := map[string]interface{}{ + "clientName": "Kiro API Proxy", + "clientType": "public", + "scopes": scopes, + "grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"}, + "issuerUrl": startUrl, + } + + body, _ := json.Marshal(payload) + req, _ := http.NewRequest("POST", oidcBase+"/client/register", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", "", err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + respBody, _ := io.ReadAll(resp.Body) + return "", "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody)) + } + + var result struct { + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + } + json.NewDecoder(resp.Body).Decode(&result) + return result.ClientID, result.ClientSecret, nil +} + +func startDeviceAuth(oidcBase, clientID, clientSecret, startUrl string) (deviceCode, userCode string, interval int, err error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "startUrl": startUrl, + } + + body, _ := json.Marshal(payload) + req, _ := http.NewRequest("POST", oidcBase+"/device_authorization", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", "", 0, err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + respBody, _ := io.ReadAll(resp.Body) + return "", "", 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody)) + } + + var result struct { + DeviceCode string `json:"deviceCode"` + UserCode string `json:"userCode"` + Interval int `json:"interval"` + } + json.NewDecoder(resp.Body).Decode(&result) + if result.Interval == 0 { + result.Interval = 1 + } + return result.DeviceCode, result.UserCode, result.Interval, nil +} + +func verifyBearerToken(portalBase, bearerToken string) error { + req, _ := http.NewRequest("GET", portalBase+"/token/whoAmI", nil) + req.Header.Set("Authorization", "Bearer "+bearerToken) + req.Header.Set("Accept", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return fmt.Errorf("HTTP %d", resp.StatusCode) + } + return nil +} + +func getDeviceSessionToken(portalBase, bearerToken string) (string, error) { + req, _ := http.NewRequest("POST", portalBase+"/session/device", bytes.NewReader([]byte("{}"))) + req.Header.Set("Authorization", "Bearer "+bearerToken) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + respBody, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody)) + } + + var result struct { + Token string `json:"token"` + } + json.NewDecoder(resp.Body).Decode(&result) + return result.Token, nil +} + +type deviceContextInfo struct { + DeviceContextID string `json:"deviceContextId"` + ClientID string `json:"clientId"` + ClientType string `json:"clientType"` +} + +func acceptUserCode(oidcBase, userCode, deviceSessionToken string) (*deviceContextInfo, error) { + payload := map[string]string{ + "userCode": userCode, + "userSessionId": deviceSessionToken, + } + + body, _ := json.Marshal(payload) + req, _ := http.NewRequest("POST", oidcBase+"/device_authorization/accept_user_code", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Referer", "https://view.awsapps.com/") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + respBody, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody)) + } + + var result struct { + DeviceContext *deviceContextInfo `json:"deviceContext"` + } + json.NewDecoder(resp.Body).Decode(&result) + return result.DeviceContext, nil +} + +func approveAuth(oidcBase string, deviceContext *deviceContextInfo, deviceSessionToken string) error { + payload := map[string]interface{}{ + "deviceContext": map[string]string{ + "deviceContextId": deviceContext.DeviceContextID, + "clientId": deviceContext.ClientID, + "clientType": deviceContext.ClientType, + }, + "userSessionId": deviceSessionToken, + } + + body, _ := json.Marshal(payload) + req, _ := http.NewRequest("POST", oidcBase+"/device_authorization/associate_token", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Referer", "https://view.awsapps.com/") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody)) + } + return nil +} + +func pollForToken(oidcBase, clientID, clientSecret, deviceCode string, interval int) (accessToken, refreshToken string, expiresIn int, err error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "grantType": "urn:ietf:params:oauth:grant-type:device_code", + "deviceCode": deviceCode, + } + + timeout := time.After(2 * time.Minute) + ticker := time.NewTicker(time.Duration(interval) * time.Second) + defer ticker.Stop() + + for { + select { + case <-timeout: + return "", "", 0, fmt.Errorf("授权超时") + case <-ticker.C: + body, _ := json.Marshal(payload) + req, _ := http.NewRequest("POST", oidcBase+"/token", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + continue + } + + if resp.StatusCode == 200 { + var result struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ExpiresIn int `json:"expiresIn"` + } + json.NewDecoder(resp.Body).Decode(&result) + resp.Body.Close() + return result.AccessToken, result.RefreshToken, result.ExpiresIn, nil + } + + if resp.StatusCode == 400 { + var errResult struct { + Error string `json:"error"` + } + json.NewDecoder(resp.Body).Decode(&errResult) + resp.Body.Close() + + if errResult.Error == "authorization_pending" { + continue + } else if errResult.Error == "slow_down" { + interval += 5 + ticker.Reset(time.Duration(interval) * time.Second) + continue + } + return "", "", 0, fmt.Errorf("授权错误: %s", errResult.Error) + } + resp.Body.Close() + } + } +} + +// GetUserInfo 获取用户信息 +func GetUserInfo(accessToken string) (email, userID string, err error) { + // 调用 Kiro API 获取用量信息(包含用户信息) + url := "https://q.us-east-1.amazonaws.com/getUsageLimits?origin=AI_EDITOR&resourceType=AGENTIC_REQUEST&isEmailRequired=true" + + req, _ := http.NewRequest("GET", url, nil) + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "aws-sdk-js/1.0.18 KiroAPIProxy") + req.Header.Set("x-amz-user-agent", "aws-sdk-js/1.0.18 KiroAPIProxy") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", "", err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return "", "", fmt.Errorf("HTTP %d", resp.StatusCode) + } + + var result struct { + UserInfo struct { + Email string `json:"email"` + UserID string `json:"userId"` + } `json:"userInfo"` + } + json.NewDecoder(resp.Body).Decode(&result) + return result.UserInfo.Email, result.UserInfo.UserID, nil +} + +// GenerateAccountID 生成账号 ID +func GenerateAccountID() string { + return uuid.New().String() +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..849a6cb --- /dev/null +++ b/config/config.go @@ -0,0 +1,331 @@ +// Package config 配置管理模块 +// 负责账号、设置、统计数据的持久化存储 +package config + +import ( + "crypto/rand" + "encoding/json" + "fmt" + "os" + "sync" +) + +// GenerateMachineId 生成 UUID v4 格式的机器码 +func GenerateMachineId() string { + bytes := make([]byte, 16) + rand.Read(bytes) + bytes[6] = (bytes[6] & 0x0f) | 0x40 // 版本 4 + bytes[8] = (bytes[8] & 0x3f) | 0x80 // 变体 + return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", + bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16]) +} + +// Account 账号信息 +type Account struct { + // 基本信息 + ID string `json:"id"` + Email string `json:"email,omitempty"` + UserId string `json:"userId,omitempty"` + Nickname string `json:"nickname,omitempty"` + + // 认证信息 + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ClientID string `json:"clientId,omitempty"` + ClientSecret string `json:"clientSecret,omitempty"` + AuthMethod string `json:"authMethod"` // idc | social + Provider string `json:"provider,omitempty"` + Region string `json:"region"` + StartUrl string `json:"startUrl,omitempty"` + ExpiresAt int64 `json:"expiresAt,omitempty"` + MachineId string `json:"machineId,omitempty"` // UUID 格式机器码 + + // 状态 + Enabled bool `json:"enabled"` + + // 订阅信息 + SubscriptionType string `json:"subscriptionType,omitempty"` // FREE | PRO | PRO_PLUS | POWER + SubscriptionTitle string `json:"subscriptionTitle,omitempty"` + DaysRemaining int `json:"daysRemaining,omitempty"` + + // 使用量 + UsageCurrent float64 `json:"usageCurrent,omitempty"` + UsageLimit float64 `json:"usageLimit,omitempty"` + UsagePercent float64 `json:"usagePercent,omitempty"` + NextResetDate string `json:"nextResetDate,omitempty"` + LastRefresh int64 `json:"lastRefresh,omitempty"` + + // 运行时统计 + RequestCount int `json:"requestCount,omitempty"` + ErrorCount int `json:"errorCount,omitempty"` + LastUsed int64 `json:"lastUsed,omitempty"` + TotalTokens int `json:"totalTokens,omitempty"` + TotalCredits float64 `json:"totalCredits,omitempty"` +} + +// Config 全局配置 +type Config struct { + Password string `json:"password"` + Port int `json:"port"` + Host string `json:"host"` + ApiKey string `json:"apiKey,omitempty"` + RequireApiKey bool `json:"requireApiKey"` + Accounts []Account `json:"accounts"` + + // 全局统计 + TotalRequests int `json:"totalRequests,omitempty"` + SuccessRequests int `json:"successRequests,omitempty"` + FailedRequests int `json:"failedRequests,omitempty"` + TotalTokens int `json:"totalTokens,omitempty"` + TotalCredits float64 `json:"totalCredits,omitempty"` +} + +// AccountInfo 账户信息更新结构 +type AccountInfo struct { + Email string + UserId string + SubscriptionType string + SubscriptionTitle string + DaysRemaining int + UsageCurrent float64 + UsageLimit float64 + UsagePercent float64 + NextResetDate string + LastRefresh int64 +} + +var ( + cfg *Config + cfgLock sync.RWMutex + cfgPath string +) + +// Init 初始化配置 +func Init(path string) error { + cfgPath = path + return Load() +} + +// Load 从文件加载配置 +func Load() error { + cfgLock.Lock() + defer cfgLock.Unlock() + + data, err := os.ReadFile(cfgPath) + if err != nil { + if os.IsNotExist(err) { + // 创建默认配置 + cfg = &Config{ + Password: "changeme", + Port: 8080, + Host: "127.0.0.1", + RequireApiKey: false, + Accounts: []Account{}, + } + return Save() + } + return err + } + + var c Config + if err := json.Unmarshal(data, &c); err != nil { + return err + } + cfg = &c + return nil +} + +// Save 保存配置到文件 +func Save() error { + data, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return err + } + return os.WriteFile(cfgPath, data, 0600) +} + +// SetPassword 设置密码(用于环境变量覆盖) +func SetPassword(password string) { + cfgLock.Lock() + defer cfgLock.Unlock() + cfg.Password = password +} + +func Get() *Config { + cfgLock.RLock() + defer cfgLock.RUnlock() + return cfg +} + +func GetPassword() string { + cfgLock.RLock() + defer cfgLock.RUnlock() + return cfg.Password +} + +func GetPort() int { + cfgLock.RLock() + defer cfgLock.RUnlock() + if cfg.Port == 0 { + return 8080 + } + return cfg.Port +} + +func GetHost() string { + cfgLock.RLock() + defer cfgLock.RUnlock() + if cfg.Host == "" { + return "127.0.0.1" + } + return cfg.Host +} + +func GetAccounts() []Account { + cfgLock.RLock() + defer cfgLock.RUnlock() + accounts := make([]Account, len(cfg.Accounts)) + copy(accounts, cfg.Accounts) + return accounts +} + +func GetEnabledAccounts() []Account { + cfgLock.RLock() + defer cfgLock.RUnlock() + var accounts []Account + for _, a := range cfg.Accounts { + if a.Enabled { + accounts = append(accounts, a) + } + } + return accounts +} + +func AddAccount(account Account) error { + cfgLock.Lock() + defer cfgLock.Unlock() + cfg.Accounts = append(cfg.Accounts, account) + return Save() +} + +func UpdateAccount(id string, account Account) error { + cfgLock.Lock() + defer cfgLock.Unlock() + for i, a := range cfg.Accounts { + if a.ID == id { + cfg.Accounts[i] = account + return Save() + } + } + return nil +} + +func DeleteAccount(id string) error { + cfgLock.Lock() + defer cfgLock.Unlock() + for i, a := range cfg.Accounts { + if a.ID == id { + cfg.Accounts = append(cfg.Accounts[:i], cfg.Accounts[i+1:]...) + return Save() + } + } + return nil +} + +func UpdateAccountToken(id, accessToken, refreshToken string, expiresAt int64) error { + cfgLock.Lock() + defer cfgLock.Unlock() + for i, a := range cfg.Accounts { + if a.ID == id { + cfg.Accounts[i].AccessToken = accessToken + if refreshToken != "" { + cfg.Accounts[i].RefreshToken = refreshToken + } + cfg.Accounts[i].ExpiresAt = expiresAt + return Save() + } + } + return nil +} + +func GetApiKey() string { + cfgLock.RLock() + defer cfgLock.RUnlock() + return cfg.ApiKey +} + +func IsApiKeyRequired() bool { + cfgLock.RLock() + defer cfgLock.RUnlock() + return cfg.RequireApiKey +} + +func UpdateSettings(apiKey string, requireApiKey bool, password string) error { + cfgLock.Lock() + defer cfgLock.Unlock() + cfg.ApiKey = apiKey + cfg.RequireApiKey = requireApiKey + if password != "" { + cfg.Password = password + } + return Save() +} + +func UpdateStats(totalReq, successReq, failedReq, totalTokens int, totalCredits float64) error { + cfgLock.Lock() + defer cfgLock.Unlock() + cfg.TotalRequests = totalReq + cfg.SuccessRequests = successReq + cfg.FailedRequests = failedReq + cfg.TotalTokens = totalTokens + cfg.TotalCredits = totalCredits + return Save() +} + +func GetStats() (int, int, int, int, float64) { + cfgLock.RLock() + defer cfgLock.RUnlock() + return cfg.TotalRequests, cfg.SuccessRequests, cfg.FailedRequests, cfg.TotalTokens, cfg.TotalCredits +} + +func UpdateAccountStats(id string, requestCount, errorCount, totalTokens int, totalCredits float64, lastUsed int64) error { + cfgLock.Lock() + defer cfgLock.Unlock() + for i, a := range cfg.Accounts { + if a.ID == id { + cfg.Accounts[i].RequestCount = requestCount + cfg.Accounts[i].ErrorCount = errorCount + cfg.Accounts[i].TotalTokens = totalTokens + cfg.Accounts[i].TotalCredits = totalCredits + cfg.Accounts[i].LastUsed = lastUsed + return Save() + } + } + return nil +} + +// UpdateAccountInfo 更新账户的订阅和使用量信息 +func UpdateAccountInfo(id string, info AccountInfo) error { + cfgLock.Lock() + defer cfgLock.Unlock() + for i, a := range cfg.Accounts { + if a.ID == id { + if info.Email != "" { + cfg.Accounts[i].Email = info.Email + } + if info.UserId != "" { + cfg.Accounts[i].UserId = info.UserId + } + cfg.Accounts[i].SubscriptionType = info.SubscriptionType + cfg.Accounts[i].SubscriptionTitle = info.SubscriptionTitle + cfg.Accounts[i].DaysRemaining = info.DaysRemaining + cfg.Accounts[i].UsageCurrent = info.UsageCurrent + cfg.Accounts[i].UsageLimit = info.UsageLimit + cfg.Accounts[i].UsagePercent = info.UsagePercent + cfg.Accounts[i].NextResetDate = info.NextResetDate + cfg.Accounts[i].LastRefresh = info.LastRefresh + return Save() + } + } + return nil +} diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..48780b6 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,12 @@ +version: '3.8' + +services: + kiro-api-proxy: + build: . + ports: + - "8080:8080" + volumes: + - ./data:/app/data + environment: + - CONFIG_PATH=/app/data/config.json + restart: unless-stopped diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f1bd668 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module kiro-api-proxy + +go 1.21 + +require github.com/google/uuid v1.6.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7790d7c --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= diff --git a/main.go b/main.go new file mode 100644 index 0000000..12b8357 --- /dev/null +++ b/main.go @@ -0,0 +1,54 @@ +// Kiro API Proxy - 将 Kiro API 转换为 OpenAI/Anthropic 兼容格式 +// 支持多账号池、自动 Token 刷新、流式响应 +package main + +import ( + "fmt" + "kiro-api-proxy/config" + "kiro-api-proxy/pool" + "kiro-api-proxy/proxy" + "log" + "net/http" + "os" + "path/filepath" +) + +func main() { + // 配置文件路径,支持环境变量覆盖 + configPath := "data/config.json" + if envPath := os.Getenv("CONFIG_PATH"); envPath != "" { + configPath = envPath + } + + // 确保数据目录存在 + if err := os.MkdirAll(filepath.Dir(configPath), 0755); err != nil { + log.Fatalf("Failed to create data directory: %v", err) + } + + // 加载配置 + if err := config.Init(configPath); err != nil { + log.Fatalf("Failed to load config: %v", err) + } + + // 环境变量覆盖密码 + if envPassword := os.Getenv("ADMIN_PASSWORD"); envPassword != "" { + config.SetPassword(envPassword) + } + + // 初始化账号池 + pool.GetPool() + + // 创建 HTTP 处理器(包含后台刷新任务) + handler := proxy.NewHandler() + + // 启动服务器 + addr := fmt.Sprintf("%s:%d", config.GetHost(), config.GetPort()) + log.Printf("Kiro API Proxy starting on http://%s", addr) + log.Printf("Admin panel: http://%s/admin", addr) + log.Printf("Claude API: http://%s/v1/messages", addr) + log.Printf("OpenAI API: http://%s/v1/chat/completions", addr) + + if err := http.ListenAndServe(addr, handler); err != nil { + log.Fatalf("Server failed: %v", err) + } +} diff --git a/pool/account.go b/pool/account.go new file mode 100644 index 0000000..f259a25 --- /dev/null +++ b/pool/account.go @@ -0,0 +1,189 @@ +// Package pool 账号池管理 +// 实现轮询负载均衡、错误冷却、Token 刷新 +package pool + +import ( + "kiro-api-proxy/config" + "sync" + "sync/atomic" + "time" +) + +// AccountPool 账号池 +type AccountPool struct { + mu sync.RWMutex + accounts []config.Account + currentIndex uint64 + cooldowns map[string]time.Time // 账号冷却时间 + errorCounts map[string]int // 连续错误计数 +} + +var ( + pool *AccountPool + poolOnce sync.Once +) + +// GetPool 获取全局账号池单例 +func GetPool() *AccountPool { + poolOnce.Do(func() { + pool = &AccountPool{ + cooldowns: make(map[string]time.Time), + errorCounts: make(map[string]int), + } + pool.Reload() + }) + return pool +} + +// Reload 从配置重新加载账号 +func (p *AccountPool) Reload() { + p.mu.Lock() + defer p.mu.Unlock() + p.accounts = config.GetEnabledAccounts() +} + +// GetNext 获取下一个可用账号(轮询) +func (p *AccountPool) GetNext() *config.Account { + p.mu.RLock() + defer p.mu.RUnlock() + + if len(p.accounts) == 0 { + return nil + } + + now := time.Now() + n := len(p.accounts) + + // 轮询查找可用账号 + for i := 0; i < n; i++ { + idx := atomic.AddUint64(&p.currentIndex, 1) % uint64(n) + acc := &p.accounts[idx] + + // 跳过冷却中的账号 + if cooldown, ok := p.cooldowns[acc.ID]; ok && now.Before(cooldown) { + continue + } + + // 跳过即将过期的 Token + if acc.ExpiresAt > 0 && time.Now().Unix() > acc.ExpiresAt-300 { + continue + } + + return acc + } + + // 无可用账号,返回冷却时间最短的 + var best *config.Account + var earliest time.Time + for i := range p.accounts { + acc := &p.accounts[i] + if cooldown, ok := p.cooldowns[acc.ID]; ok { + if best == nil || cooldown.Before(earliest) { + best = acc + earliest = cooldown + } + } else { + return acc + } + } + return best +} + +// GetByID 根据 ID 获取账号 +func (p *AccountPool) GetByID(id string) *config.Account { + p.mu.RLock() + defer p.mu.RUnlock() + for i := range p.accounts { + if p.accounts[i].ID == id { + return &p.accounts[i] + } + } + return nil +} + +// RecordSuccess 记录请求成功,清除冷却 +func (p *AccountPool) RecordSuccess(id string) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.cooldowns, id) + p.errorCounts[id] = 0 +} + +// RecordError 记录请求错误,设置冷却 +func (p *AccountPool) RecordError(id string, isQuotaError bool) { + p.mu.Lock() + defer p.mu.Unlock() + + p.errorCounts[id]++ + + if isQuotaError { + // 配额错误,冷却 1 小时 + p.cooldowns[id] = time.Now().Add(time.Hour) + } else if p.errorCounts[id] >= 3 { + // 连续 3 次错误,冷却 1 分钟 + p.cooldowns[id] = time.Now().Add(time.Minute) + } +} + +// UpdateToken 更新账号 Token +func (p *AccountPool) UpdateToken(id, accessToken, refreshToken string, expiresAt int64) { + p.mu.Lock() + defer p.mu.Unlock() + for i := range p.accounts { + if p.accounts[i].ID == id { + p.accounts[i].AccessToken = accessToken + if refreshToken != "" { + p.accounts[i].RefreshToken = refreshToken + } + p.accounts[i].ExpiresAt = expiresAt + break + } + } +} + +// Count 返回账号总数 +func (p *AccountPool) Count() int { + p.mu.RLock() + defer p.mu.RUnlock() + return len(p.accounts) +} + +// AvailableCount 返回可用账号数 +func (p *AccountPool) AvailableCount() int { + p.mu.RLock() + defer p.mu.RUnlock() + now := time.Now() + count := 0 + for _, acc := range p.accounts { + if cooldown, ok := p.cooldowns[acc.ID]; ok && now.Before(cooldown) { + continue + } + count++ + } + return count +} + +// UpdateStats 更新账号统计 +func (p *AccountPool) UpdateStats(id string, tokens int, credits float64) { + p.mu.Lock() + defer p.mu.Unlock() + for i := range p.accounts { + if p.accounts[i].ID == id { + p.accounts[i].RequestCount++ + p.accounts[i].TotalTokens += tokens + p.accounts[i].TotalCredits += credits + p.accounts[i].LastUsed = time.Now().Unix() + go config.UpdateAccountStats(id, p.accounts[i].RequestCount, p.accounts[i].ErrorCount, p.accounts[i].TotalTokens, p.accounts[i].TotalCredits, p.accounts[i].LastUsed) + break + } + } +} + +// GetAllAccounts 获取所有账号副本 +func (p *AccountPool) GetAllAccounts() []config.Account { + p.mu.RLock() + defer p.mu.RUnlock() + result := make([]config.Account, len(p.accounts)) + copy(result, p.accounts) + return result +} diff --git a/proxy/handler.go b/proxy/handler.go new file mode 100644 index 0000000..0c7f9ae --- /dev/null +++ b/proxy/handler.go @@ -0,0 +1,1392 @@ +package proxy + +import ( + "encoding/json" + "fmt" + "io" + "kiro-api-proxy/auth" + "kiro-api-proxy/config" + "kiro-api-proxy/pool" + "net/http" + "strings" + "time" + + "github.com/google/uuid" +) + +// Handler HTTP 处理器 +type Handler struct { + pool *pool.AccountPool + // 运行时统计 + totalRequests int + successRequests int + failedRequests int + totalTokens int + totalCredits float64 + startTime int64 + stopRefresh chan struct{} +} + +func NewHandler() *Handler { + totalReq, successReq, failedReq, totalTokens, totalCredits := config.GetStats() + h := &Handler{ + pool: pool.GetPool(), + totalRequests: totalReq, + successRequests: successReq, + failedRequests: failedReq, + totalTokens: totalTokens, + totalCredits: totalCredits, + startTime: time.Now().Unix(), + stopRefresh: make(chan struct{}), + } + // 启动后台刷新 + go h.backgroundRefresh() + return h +} + +// backgroundRefresh 后台定时刷新账户信息 +func (h *Handler) backgroundRefresh() { + ticker := time.NewTicker(30 * time.Minute) // 每 30 分钟刷新一次 + defer ticker.Stop() + + // 启动时延迟 10 秒后执行一次 + time.Sleep(10 * time.Second) + h.refreshAllAccounts() + + for { + select { + case <-ticker.C: + h.refreshAllAccounts() + case <-h.stopRefresh: + return + } + } +} + +// refreshAllAccounts 刷新所有账户信息 +func (h *Handler) refreshAllAccounts() { + accounts := config.GetAccounts() + for i := range accounts { + account := &accounts[i] + if !account.Enabled || account.AccessToken == "" { + continue + } + + // 检查 token 是否需要刷新 + if account.ExpiresAt > 0 && time.Now().Unix() > account.ExpiresAt-300 { + newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(account) + if err != nil { + fmt.Printf("[BackgroundRefresh] Token refresh failed for %s: %v\n", account.Email, err) + continue + } + account.AccessToken = newAccessToken + if newRefreshToken != "" { + account.RefreshToken = newRefreshToken + } + account.ExpiresAt = newExpiresAt + config.UpdateAccountToken(account.ID, newAccessToken, newRefreshToken, newExpiresAt) + h.pool.UpdateToken(account.ID, newAccessToken, newRefreshToken, newExpiresAt) + } + + // 刷新账户信息 + info, err := RefreshAccountInfo(account) + if err != nil { + fmt.Printf("[BackgroundRefresh] Failed to refresh %s: %v\n", account.Email, err) + continue + } + + config.UpdateAccountInfo(account.ID, *info) + fmt.Printf("[BackgroundRefresh] Refreshed %s: %s %.1f/%.1f\n", account.Email, info.SubscriptionType, info.UsageCurrent, info.UsageLimit) + } + h.pool.Reload() +} + +// validateApiKey 验证 API Key +func (h *Handler) validateApiKey(r *http.Request) bool { + if !config.IsApiKeyRequired() { + return true + } + + expectedKey := config.GetApiKey() + if expectedKey == "" { + return true + } + + // 从 Authorization 头或 X-Api-Key 头获取 + authHeader := r.Header.Get("Authorization") + apiKeyHeader := r.Header.Get("X-Api-Key") + + var providedKey string + if strings.HasPrefix(authHeader, "Bearer ") { + providedKey = strings.TrimPrefix(authHeader, "Bearer ") + } else if apiKeyHeader != "" { + providedKey = apiKeyHeader + } + + return providedKey == expectedKey +} + +// ServeHTTP 路由分发 +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + + // CORS - 完整的头部支持 + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Api-Key, anthropic-version, anthropic-beta, x-api-key, x-stainless-os, x-stainless-lang, x-stainless-package-version, x-stainless-runtime, x-stainless-runtime-version, x-stainless-arch") + w.Header().Set("Access-Control-Expose-Headers", "x-request-id, x-ratelimit-limit-requests, x-ratelimit-limit-tokens, x-ratelimit-remaining-requests, x-ratelimit-remaining-tokens, x-ratelimit-reset-requests, x-ratelimit-reset-tokens") + + if r.Method == "OPTIONS" { + w.WriteHeader(204) + return + } + + // 路由 + switch { + // API 端点(需要验证 API Key) + case path == "/v1/messages" || path == "/messages" || path == "/anthropic/v1/messages": + if !h.validateApiKey(r) { + h.sendClaudeError(w, 401, "authentication_error", "Invalid or missing API key") + return + } + h.handleClaudeMessages(w, r) + case path == "/v1/messages/count_tokens" || path == "/messages/count_tokens": + if !h.validateApiKey(r) { + h.sendClaudeError(w, 401, "authentication_error", "Invalid or missing API key") + return + } + h.handleCountTokens(w, r) + case path == "/v1/chat/completions" || path == "/chat/completions": + if !h.validateApiKey(r) { + h.sendOpenAIError(w, 401, "authentication_error", "Invalid or missing API key") + return + } + h.handleOpenAIChat(w, r) + case path == "/v1/models" || path == "/models": + h.handleModels(w, r) + case path == "/api/event_logging/batch": + // Claude Code 遥测端点 - 直接返回 200 OK + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write([]byte(`{"status":"ok"}`)) + + // 管理端点 + case path == "/admin" || path == "/admin/": + h.serveAdminPage(w, r) + case strings.HasPrefix(path, "/admin/api/"): + h.handleAdminAPI(w, r) + case strings.HasPrefix(path, "/admin/"): + h.serveStaticFile(w, r) + + // 健康检查 + case path == "/health" || path == "/": + h.handleHealth(w, r) + + default: + http.Error(w, "Not Found", 404) + } +} + +// handleHealth 健康检查 +func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "ok", + "accounts": h.pool.Count(), + "available": h.pool.AvailableCount(), + "totalRequests": h.totalRequests, + "successRequests": h.successRequests, + "failedRequests": h.failedRequests, + "totalTokens": h.totalTokens, + "totalCredits": h.totalCredits, + "uptime": time.Now().Unix() - h.startTime, + }) +} + +// handleModels 模型列表 +func (h *Handler) handleModels(w http.ResponseWriter, r *http.Request) { + models := []map[string]interface{}{ + {"id": "claude-sonnet-4.5", "object": "model", "owned_by": "anthropic"}, + {"id": "claude-sonnet-4", "object": "model", "owned_by": "anthropic"}, + {"id": "claude-haiku-4.5", "object": "model", "owned_by": "anthropic"}, + {"id": "claude-opus-4.5", "object": "model", "owned_by": "anthropic"}, + {"id": "auto", "object": "model", "owned_by": "kiro-api"}, + {"id": "gpt-4o", "object": "model", "owned_by": "kiro-proxy"}, + {"id": "gpt-4", "object": "model", "owned_by": "kiro-proxy"}, + } + w.Header().Set("Content-Type", "application/json; charset=utf-8") + json.NewEncoder(w).Encode(map[string]interface{}{ + "object": "list", + "data": models, + }) +} + +// handleCountTokens Token 计数(Claude Code 会调用) +func (h *Handler) handleCountTokens(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.Error(w, "Method Not Allowed", 405) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + h.sendClaudeError(w, 400, "invalid_request_error", "Failed to read request body") + return + } + + var req struct { + Messages []struct { + Role string `json:"role"` + Content interface{} `json:"content"` + } `json:"messages"` + System interface{} `json:"system"` + } + if err := json.Unmarshal(body, &req); err != nil { + h.sendClaudeError(w, 400, "invalid_request_error", "Invalid JSON") + return + } + + // 简单估算 token 数量(每 4 个字符约 1 个 token) + var totalChars int + for _, msg := range req.Messages { + switch content := msg.Content.(type) { + case string: + totalChars += len(content) + case []interface{}: + for _, part := range content { + if p, ok := part.(map[string]interface{}); ok { + if text, ok := p["text"].(string); ok { + totalChars += len(text) + } + } + } + } + } + + // 系统提示 + switch system := req.System.(type) { + case string: + totalChars += len(system) + case []interface{}: + for _, part := range system { + if p, ok := part.(map[string]interface{}); ok { + if text, ok := p["text"].(string); ok { + totalChars += len(text) + } + } + } + } + + estimatedTokens := (totalChars + 3) / 4 // 向上取整 + if estimatedTokens < 1 { + estimatedTokens = 1 + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + json.NewEncoder(w).Encode(map[string]int{"input_tokens": estimatedTokens}) +} + +// handleClaudeMessages Claude API 处理 +func (h *Handler) handleClaudeMessages(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.Error(w, "Method Not Allowed", 405) + return + } + + // 读取请求 + body, err := io.ReadAll(r.Body) + if err != nil { + h.sendClaudeError(w, 400, "invalid_request_error", "Failed to read request body") + return + } + + var req ClaudeRequest + if err := json.Unmarshal(body, &req); err != nil { + h.sendClaudeError(w, 400, "invalid_request_error", "Invalid JSON: "+err.Error()) + return + } + + // 获取账号 + account := h.pool.GetNext() + if account == nil { + h.sendClaudeError(w, 503, "api_error", "No available accounts") + return + } + + // 检查并刷新 token + if err := h.ensureValidToken(account); err != nil { + h.sendClaudeError(w, 503, "api_error", "Token refresh failed: "+err.Error()) + return + } + + // 转换请求 + kiroPayload := ClaudeToKiro(&req) + + // 流式或非流式 + if req.Stream { + h.handleClaudeStream(w, account, kiroPayload, req.Model) + } else { + h.handleClaudeNonStream(w, account, kiroPayload, req.Model) + } +} + +// handleClaudeStream Claude 流式响应 +func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string) { + w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + flusher, ok := w.(http.Flusher) + if !ok { + h.sendClaudeError(w, 500, "api_error", "Streaming not supported") + return + } + + msgID := "msg_" + uuid.New().String() + var contentStarted bool + var toolUseIndex int + var inputTokens, outputTokens int + var credits float64 + var toolUses []KiroToolUse + + // 发送 message_start + h.sendSSE(w, flusher, "message_start", map[string]interface{}{ + "type": "message_start", + "message": map[string]interface{}{ + "id": msgID, + "type": "message", + "role": "assistant", + "content": []interface{}{}, + "model": model, + }, + }) + + callback := &KiroStreamCallback{ + OnText: func(text string, isThinking bool) { + if text == "" { + return + } + // 确保 content_block 已开始 + if !contentStarted { + h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{ + "type": "content_block_start", + "index": 0, + "content_block": map[string]string{"type": "text", "text": ""}, + }) + contentStarted = true + } + // 直接转发文本,不缓冲 + outputText := text + if isThinking { + outputText = "" + text + "" + } + h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{ + "type": "content_block_delta", + "index": 0, + "delta": map[string]string{"type": "text_delta", "text": outputText}, + }) + }, + OnToolUse: func(tu KiroToolUse) { + toolUses = append(toolUses, tu) + + // 关闭文本块 + if contentStarted && toolUseIndex == 0 { + h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{ + "type": "content_block_stop", + "index": 0, + }) + } + + idx := toolUseIndex + if contentStarted { + idx = toolUseIndex + 1 + } + toolUseIndex++ + + h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{ + "type": "content_block_start", + "index": idx, + "content_block": map[string]interface{}{ + "type": "tool_use", + "id": tu.ToolUseID, + "name": tu.Name, + "input": map[string]interface{}{}, + }, + }) + + inputJSON, _ := json.Marshal(tu.Input) + h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{ + "type": "content_block_delta", + "index": idx, + "delta": map[string]interface{}{ + "type": "input_json_delta", + "partial_json": string(inputJSON), + }, + }) + + h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{ + "type": "content_block_stop", + "index": idx, + }) + }, + OnComplete: func(inTok, outTok int) { + inputTokens = inTok + outputTokens = outTok + }, + OnError: func(err error) { + h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "quota")) + }, + OnCredits: func(c float64) { + credits = c + }, + } + + err := CallKiroAPI(account, payload, callback) + if err != nil { + h.recordFailure() + h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "quota")) + h.sendSSE(w, flusher, "error", map[string]interface{}{ + "type": "error", + "error": map[string]string{"type": "api_error", "message": err.Error()}, + }) + return + } + + h.recordSuccess(inputTokens, outputTokens, credits) + h.pool.RecordSuccess(account.ID) + h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits) + + // 关闭最后的内容块 + if contentStarted && toolUseIndex == 0 { + h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{ + "type": "content_block_stop", + "index": 0, + }) + } + + // 发送 message_delta + stopReason := "end_turn" + if len(toolUses) > 0 { + stopReason = "tool_use" + } + + h.sendSSE(w, flusher, "message_delta", map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{ + "stop_reason": stopReason, + }, + "usage": map[string]int{ + "input_tokens": inputTokens, + "output_tokens": outputTokens, + }, + }) + + h.sendSSE(w, flusher, "message_stop", map[string]interface{}{ + "type": "message_stop", + }) +} + +func (h *Handler) sendSSE(w http.ResponseWriter, flusher http.Flusher, event string, data interface{}) { + jsonData, _ := json.Marshal(data) + fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event, string(jsonData)) + flusher.Flush() +} + +// 统计记录 +func (h *Handler) recordSuccess(inputTokens, outputTokens int, credits float64) { + h.totalRequests++ + h.successRequests++ + h.totalTokens += inputTokens + outputTokens + h.totalCredits += credits + // 异步保存 + go config.UpdateStats(h.totalRequests, h.successRequests, h.failedRequests, h.totalTokens, h.totalCredits) +} + +func (h *Handler) recordFailure() { + h.totalRequests++ + h.failedRequests++ + go config.UpdateStats(h.totalRequests, h.successRequests, h.failedRequests, h.totalTokens, h.totalCredits) +} + +// handleClaudeNonStream Claude 非流式响应 +func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string) { + var content string + var toolUses []KiroToolUse + var inputTokens, outputTokens int + var credits float64 + + callback := &KiroStreamCallback{ + OnText: func(text string, isThinking bool) { + if isThinking { + content += "" + text + "" + } else { + content += text + } + }, + OnToolUse: func(tu KiroToolUse) { + toolUses = append(toolUses, tu) + }, + OnComplete: func(inTok, outTok int) { + inputTokens = inTok + outputTokens = outTok + }, + OnError: func(err error) { + h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) + }, + OnCredits: func(c float64) { + credits = c + }, + } + + err := CallKiroAPI(account, payload, callback) + if err != nil { + h.recordFailure() + h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) + h.sendClaudeError(w, 500, "api_error", err.Error()) + return + } + + h.recordSuccess(inputTokens, outputTokens, credits) + h.pool.RecordSuccess(account.ID) + h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits) + + resp := KiroToClaudeResponse(content, toolUses, inputTokens, outputTokens, model) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + json.NewEncoder(w).Encode(resp) +} + +func (h *Handler) sendClaudeError(w http.ResponseWriter, status int, errType, message string) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + json.NewEncoder(w).Encode(map[string]interface{}{ + "type": "error", + "error": map[string]string{ + "type": errType, + "message": message, + }, + }) +} + +// handleOpenAIChat OpenAI API 处理 +func (h *Handler) handleOpenAIChat(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.Error(w, "Method Not Allowed", 405) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + h.sendOpenAIError(w, 400, "invalid_request_error", "Failed to read request body") + return + } + + var req OpenAIRequest + if err := json.Unmarshal(body, &req); err != nil { + h.sendOpenAIError(w, 400, "invalid_request_error", "Invalid JSON") + return + } + + account := h.pool.GetNext() + if account == nil { + h.sendOpenAIError(w, 503, "server_error", "No available accounts") + return + } + + if err := h.ensureValidToken(account); err != nil { + h.sendOpenAIError(w, 503, "server_error", "Token refresh failed") + return + } + + kiroPayload := OpenAIToKiro(&req) + + if req.Stream { + h.handleOpenAIStream(w, account, kiroPayload, req.Model) + } else { + h.handleOpenAINonStream(w, account, kiroPayload, req.Model) + } +} + +// handleOpenAIStream OpenAI 流式响应 +func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string) { + w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + flusher, ok := w.(http.Flusher) + if !ok { + h.sendOpenAIError(w, 500, "server_error", "Streaming not supported") + return + } + + chatID := "chatcmpl-" + uuid.New().String() + var toolCalls []ToolCall + var toolCallIndex int + var inputTokens, outputTokens int + var credits float64 + + callback := &KiroStreamCallback{ + OnText: func(text string, isThinking bool) { + if text == "" { + return + } + // 直接转发,不缓冲 + deltaKey := "content" + if isThinking { + deltaKey = "reasoning_content" + } + chunk := map[string]interface{}{ + "id": chatID, + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{{ + "index": 0, + "delta": map[string]string{deltaKey: text}, + "finish_reason": nil, + }}, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", string(data)) + flusher.Flush() + }, + OnToolUse: func(tu KiroToolUse) { + args, _ := json.Marshal(tu.Input) + tc := ToolCall{ID: tu.ToolUseID, Type: "function"} + tc.Function.Name = tu.Name + tc.Function.Arguments = string(args) + toolCalls = append(toolCalls, tc) + + chunk := map[string]interface{}{ + "id": chatID, + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{{ + "index": 0, + "delta": map[string]interface{}{ + "tool_calls": []map[string]interface{}{{ + "index": toolCallIndex, + "id": tu.ToolUseID, + "type": "function", + "function": map[string]string{ + "name": tu.Name, + "arguments": string(args), + }, + }}, + }, + "finish_reason": nil, + }}, + } + toolCallIndex++ + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", string(data)) + flusher.Flush() + }, + OnComplete: func(inTok, outTok int) { + inputTokens = inTok + outputTokens = outTok + }, + OnError: func(err error) { + h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) + }, + OnCredits: func(c float64) { + credits = c + }, + } + + err := CallKiroAPI(account, payload, callback) + if err != nil { + h.recordFailure() + h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) + return + } + + h.recordSuccess(inputTokens, outputTokens, credits) + h.pool.RecordSuccess(account.ID) + h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits) + + // 发送结束 + finishReason := "stop" + if len(toolCalls) > 0 { + finishReason = "tool_calls" + } + + chunk := map[string]interface{}{ + "id": chatID, + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{{ + "index": 0, + "delta": map[string]interface{}{}, + "finish_reason": finishReason, + }}, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", string(data)) + fmt.Fprintf(w, "data: [DONE]\n\n") + flusher.Flush() +} + +// handleOpenAINonStream OpenAI 非流式响应 +func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string) { + var content string + var toolUses []KiroToolUse + var inputTokens, outputTokens int + var credits float64 + + callback := &KiroStreamCallback{ + OnText: func(text string, isThinking bool) { + if isThinking { + // 非流式模式下,thinking 内容可以作为单独字段或忽略 + // 这里暂时忽略 + } else { + content += text + } + }, + OnToolUse: func(tu KiroToolUse) { toolUses = append(toolUses, tu) }, + OnComplete: func(inTok, outTok int) { inputTokens = inTok; outputTokens = outTok }, + OnError: func(err error) { h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) }, + OnCredits: func(c float64) { credits = c }, + } + + err := CallKiroAPI(account, payload, callback) + if err != nil { + h.recordFailure() + h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) + h.sendOpenAIError(w, 500, "server_error", err.Error()) + return + } + + h.recordSuccess(inputTokens, outputTokens, credits) + h.pool.RecordSuccess(account.ID) + h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits) + + resp := KiroToOpenAIResponse(content, toolUses, inputTokens, outputTokens, model) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + json.NewEncoder(w).Encode(resp) +} + +func (h *Handler) sendOpenAIError(w http.ResponseWriter, status int, errType, message string) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": map[string]interface{}{ + "type": errType, + "message": message, + }, + }) +} + +// ensureValidToken 确保 token 有效 +func (h *Handler) ensureValidToken(account *config.Account) error { + if account.ExpiresAt == 0 || time.Now().Unix() < account.ExpiresAt-300 { + return nil + } + + accessToken, refreshToken, expiresAt, err := auth.RefreshToken(account) + if err != nil { + return err + } + + // 更新内存 + h.pool.UpdateToken(account.ID, accessToken, refreshToken, expiresAt) + account.AccessToken = accessToken + if refreshToken != "" { + account.RefreshToken = refreshToken + } + account.ExpiresAt = expiresAt + + // 持久化 + config.UpdateAccountToken(account.ID, accessToken, refreshToken, expiresAt) + + return nil +} + +// ==================== 管理 API ==================== + +func (h *Handler) handleAdminAPI(w http.ResponseWriter, r *http.Request) { + // 验证密码 + password := r.Header.Get("X-Admin-Password") + if password == "" { + cookie, _ := r.Cookie("admin_password") + if cookie != nil { + password = cookie.Value + } + } + + if password != config.GetPassword() { + w.WriteHeader(401) + json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"}) + return + } + + path := strings.TrimPrefix(r.URL.Path, "/admin/api") + w.Header().Set("Content-Type", "application/json; charset=utf-8") + + switch { + case path == "/accounts" && r.Method == "GET": + h.apiGetAccounts(w, r) + case path == "/accounts" && r.Method == "POST": + h.apiAddAccount(w, r) + case strings.HasPrefix(path, "/accounts/") && strings.HasSuffix(path, "/refresh") && r.Method == "POST": + id := strings.TrimSuffix(strings.TrimPrefix(path, "/accounts/"), "/refresh") + h.apiRefreshAccount(w, r, id) + case strings.HasPrefix(path, "/accounts/") && strings.HasSuffix(path, "/models") && r.Method == "GET": + id := strings.TrimSuffix(strings.TrimPrefix(path, "/accounts/"), "/models") + h.apiGetAccountModels(w, r, id) + case strings.HasPrefix(path, "/accounts/") && r.Method == "DELETE": + h.apiDeleteAccount(w, r, strings.TrimPrefix(path, "/accounts/")) + case strings.HasPrefix(path, "/accounts/") && r.Method == "PUT": + h.apiUpdateAccount(w, r, strings.TrimPrefix(path, "/accounts/")) + case path == "/auth/iam-sso/start" && r.Method == "POST": + h.apiStartIamSso(w, r) + case path == "/auth/iam-sso/complete" && r.Method == "POST": + h.apiCompleteIamSso(w, r) + case path == "/auth/sso-token" && r.Method == "POST": + h.apiImportSsoToken(w, r) + case path == "/auth/credentials" && r.Method == "POST": + h.apiImportCredentials(w, r) + case path == "/status" && r.Method == "GET": + h.apiGetStatus(w, r) + case path == "/settings" && r.Method == "GET": + h.apiGetSettings(w, r) + case path == "/settings" && r.Method == "POST": + h.apiUpdateSettings(w, r) + case path == "/stats" && r.Method == "GET": + h.apiGetStats(w, r) + case path == "/stats/reset" && r.Method == "POST": + h.apiResetStats(w, r) + case path == "/generate-machine-id" && r.Method == "GET": + h.apiGenerateMachineId(w, r) + default: + w.WriteHeader(404) + json.NewEncoder(w).Encode(map[string]string{"error": "Not Found"}) + } +} + +func (h *Handler) apiGetAccounts(w http.ResponseWriter, r *http.Request) { + accounts := config.GetAccounts() + poolAccounts := h.pool.GetAllAccounts() + + // 合并运行时统计 + statsMap := make(map[string]config.Account) + for _, a := range poolAccounts { + statsMap[a.ID] = a + } + + // 隐藏敏感信息 + result := make([]map[string]interface{}, len(accounts)) + for i, a := range accounts { + // 获取运行时统计 + stats := statsMap[a.ID] + + result[i] = map[string]interface{}{ + "id": a.ID, + "email": a.Email, + "userId": a.UserId, + "nickname": a.Nickname, + "authMethod": a.AuthMethod, + "provider": a.Provider, + "region": a.Region, + "enabled": a.Enabled, + "expiresAt": a.ExpiresAt, + "hasToken": a.AccessToken != "", + "machineId": a.MachineId, + "subscriptionType": a.SubscriptionType, + "subscriptionTitle": a.SubscriptionTitle, + "daysRemaining": a.DaysRemaining, + "usageCurrent": a.UsageCurrent, + "usageLimit": a.UsageLimit, + "usagePercent": a.UsagePercent, + "nextResetDate": a.NextResetDate, + "lastRefresh": a.LastRefresh, + "requestCount": stats.RequestCount, + "errorCount": stats.ErrorCount, + "totalTokens": stats.TotalTokens, + "totalCredits": stats.TotalCredits, + "lastUsed": stats.LastUsed, + } + } + json.NewEncoder(w).Encode(result) +} + +func (h *Handler) apiAddAccount(w http.ResponseWriter, r *http.Request) { + var account config.Account + if err := json.NewDecoder(r.Body).Decode(&account); err != nil { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) + return + } + + if account.ID == "" { + account.ID = auth.GenerateAccountID() + } + if account.Region == "" { + account.Region = "us-east-1" + } + + if err := config.AddAccount(account); err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + h.pool.Reload() + json.NewEncoder(w).Encode(map[string]interface{}{"success": true, "id": account.ID}) +} + +func (h *Handler) apiDeleteAccount(w http.ResponseWriter, r *http.Request, id string) { + if err := config.DeleteAccount(id); err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + h.pool.Reload() + json.NewEncoder(w).Encode(map[string]bool{"success": true}) +} + +func (h *Handler) apiUpdateAccount(w http.ResponseWriter, r *http.Request, id string) { + var updates map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&updates); err != nil { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) + return + } + + // 获取现有账号 + accounts := config.GetAccounts() + var existing *config.Account + for i := range accounts { + if accounts[i].ID == id { + existing = &accounts[i] + break + } + } + if existing == nil { + w.WriteHeader(404) + json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"}) + return + } + + // 只更新传入的字段 + if v, ok := updates["enabled"].(bool); ok { + existing.Enabled = v + } + if v, ok := updates["nickname"].(string); ok { + existing.Nickname = v + } + if v, ok := updates["machineId"].(string); ok { + existing.MachineId = v + } + + if err := config.UpdateAccount(id, *existing); err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + h.pool.Reload() + json.NewEncoder(w).Encode(map[string]bool{"success": true}) +} + +func (h *Handler) apiStartIamSso(w http.ResponseWriter, r *http.Request) { + var req struct { + StartUrl string `json:"startUrl"` + Region string `json:"region"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) + return + } + + if req.StartUrl == "" { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "startUrl is required"}) + return + } + + sessionID, authorizeUrl, expiresIn, err := auth.StartIamSsoLogin(req.StartUrl, req.Region) + if err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + json.NewEncoder(w).Encode(map[string]interface{}{ + "sessionId": sessionID, + "authorizeUrl": authorizeUrl, + "expiresIn": expiresIn, + }) +} + +func (h *Handler) apiCompleteIamSso(w http.ResponseWriter, r *http.Request) { + var req struct { + SessionID string `json:"sessionId"` + CallbackUrl string `json:"callbackUrl"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) + return + } + + accessToken, refreshToken, clientID, clientSecret, region, expiresIn, err := auth.CompleteIamSsoLogin(req.SessionID, req.CallbackUrl) + if err != nil { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + // 获取用户信息 + email, _, _ := auth.GetUserInfo(accessToken) + + // 创建账号 + account := config.Account{ + ID: auth.GenerateAccountID(), + Email: email, + AccessToken: accessToken, + RefreshToken: refreshToken, + ClientID: clientID, + ClientSecret: clientSecret, + AuthMethod: "idc", + Region: region, + ExpiresAt: time.Now().Unix() + int64(expiresIn), + Enabled: true, + MachineId: config.GenerateMachineId(), + } + + if err := config.AddAccount(account); err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + h.pool.Reload() + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": true, + "account": map[string]interface{}{ + "id": account.ID, + "email": account.Email, + }, + }) +} + +func (h *Handler) apiImportSsoToken(w http.ResponseWriter, r *http.Request) { + var req struct { + BearerToken string `json:"bearerToken"` + Region string `json:"region"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) + return + } + + if req.BearerToken == "" { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "bearerToken is required"}) + return + } + + accessToken, refreshToken, clientID, clientSecret, expiresIn, err := auth.ImportFromSsoToken(req.BearerToken, req.Region) + if err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + // 获取用户信息 + email, _, _ := auth.GetUserInfo(accessToken) + + // 创建账号 + account := config.Account{ + ID: auth.GenerateAccountID(), + Email: email, + AccessToken: accessToken, + RefreshToken: refreshToken, + ClientID: clientID, + ClientSecret: clientSecret, + AuthMethod: "idc", + Region: req.Region, + ExpiresAt: time.Now().Unix() + int64(expiresIn), + Enabled: true, + MachineId: config.GenerateMachineId(), + } + + if err := config.AddAccount(account); err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + h.pool.Reload() + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": true, + "account": map[string]interface{}{ + "id": account.ID, + "email": account.Email, + }, + }) +} + +func (h *Handler) apiImportCredentials(w http.ResponseWriter, r *http.Request) { + var req struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + AuthMethod string `json:"authMethod"` + Provider string `json:"provider"` + Region string `json:"region"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) + return + } + + if req.RefreshToken == "" { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "refreshToken is required"}) + return + } + + // 设置默认值 + if req.Region == "" { + req.Region = "us-east-1" + } + if req.AuthMethod == "" { + if req.ClientID != "" { + req.AuthMethod = "idc" + } else { + req.AuthMethod = "social" + } + } + + // 如果没有 accessToken,尝试刷新获取 + accessToken := req.AccessToken + var expiresAt int64 + if accessToken == "" { + tempAccount := &config.Account{ + RefreshToken: req.RefreshToken, + ClientID: req.ClientID, + ClientSecret: req.ClientSecret, + AuthMethod: req.AuthMethod, + Region: req.Region, + } + newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(tempAccount) + if err != nil { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "Token refresh failed: " + err.Error()}) + return + } + accessToken = newAccessToken + if newRefreshToken != "" { + req.RefreshToken = newRefreshToken + } + expiresAt = newExpiresAt + } else { + expiresAt = time.Now().Unix() + 3600 // 默认 1 小时 + } + + // 获取用户信息 + email, _, _ := auth.GetUserInfo(accessToken) + + // 创建账号 + account := config.Account{ + ID: auth.GenerateAccountID(), + Email: email, + AccessToken: accessToken, + RefreshToken: req.RefreshToken, + ClientID: req.ClientID, + ClientSecret: req.ClientSecret, + AuthMethod: req.AuthMethod, + Provider: req.Provider, + Region: req.Region, + ExpiresAt: expiresAt, + Enabled: true, + MachineId: config.GenerateMachineId(), + } + + if err := config.AddAccount(account); err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + h.pool.Reload() + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": true, + "account": map[string]interface{}{ + "id": account.ID, + "email": account.Email, + }, + }) +} + +func (h *Handler) apiGetStatus(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]interface{}{ + "accounts": h.pool.Count(), + "available": h.pool.AvailableCount(), + "totalRequests": h.totalRequests, + "successRequests": h.successRequests, + "failedRequests": h.failedRequests, + "totalTokens": h.totalTokens, + "totalCredits": h.totalCredits, + "uptime": time.Now().Unix() - h.startTime, + }) +} + +func (h *Handler) apiGetSettings(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]interface{}{ + "apiKey": config.GetApiKey(), + "requireApiKey": config.IsApiKeyRequired(), + "port": config.GetPort(), + "host": config.GetHost(), + }) +} + +func (h *Handler) apiUpdateSettings(w http.ResponseWriter, r *http.Request) { + var req struct { + ApiKey string `json:"apiKey"` + RequireApiKey bool `json:"requireApiKey"` + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) + return + } + + if err := config.UpdateSettings(req.ApiKey, req.RequireApiKey, req.Password); err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + json.NewEncoder(w).Encode(map[string]bool{"success": true}) +} + +func (h *Handler) apiGetStats(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]interface{}{ + "totalRequests": h.totalRequests, + "successRequests": h.successRequests, + "failedRequests": h.failedRequests, + "totalTokens": h.totalTokens, + "totalCredits": h.totalCredits, + "uptime": time.Now().Unix() - h.startTime, + }) +} + +func (h *Handler) apiResetStats(w http.ResponseWriter, r *http.Request) { + h.totalRequests = 0 + h.successRequests = 0 + h.failedRequests = 0 + h.totalTokens = 0 + h.totalCredits = 0 + config.UpdateStats(0, 0, 0, 0, 0) + json.NewEncoder(w).Encode(map[string]bool{"success": true}) +} + +// apiGenerateMachineId 生成新的机器码 +func (h *Handler) apiGenerateMachineId(w http.ResponseWriter, r *http.Request) { + machineId := config.GenerateMachineId() + json.NewEncoder(w).Encode(map[string]string{"machineId": machineId}) +} + +// apiRefreshAccount 刷新账户信息(使用量、订阅等) +func (h *Handler) apiRefreshAccount(w http.ResponseWriter, r *http.Request, id string) { + accounts := config.GetAccounts() + var account *config.Account + for i := range accounts { + if accounts[i].ID == id { + account = &accounts[i] + break + } + } + + if account == nil { + w.WriteHeader(404) + json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"}) + return + } + + // 检查 token 是否过期,需要刷新 + if account.ExpiresAt > 0 && time.Now().Unix() > account.ExpiresAt-60 { + newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(account) + if err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": "Token refresh failed: " + err.Error()}) + return + } + account.AccessToken = newAccessToken + if newRefreshToken != "" { + account.RefreshToken = newRefreshToken + } + account.ExpiresAt = newExpiresAt + config.UpdateAccountToken(id, newAccessToken, newRefreshToken, newExpiresAt) + h.pool.UpdateToken(id, newAccessToken, newRefreshToken, newExpiresAt) + } + + // 获取账户信息 + info, err := RefreshAccountInfo(account) + if err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + // 保存到配置 + if err := config.UpdateAccountInfo(id, *info); err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": true, + "info": info, + }) +} + +// apiGetAccountModels 获取账户可用模型 +func (h *Handler) apiGetAccountModels(w http.ResponseWriter, r *http.Request, id string) { + accounts := config.GetAccounts() + var account *config.Account + for i := range accounts { + if accounts[i].ID == id { + account = &accounts[i] + break + } + } + + if account == nil { + w.WriteHeader(404) + json.NewEncoder(w).Encode(map[string]string{"error": "Account not found"}) + return + } + + models, err := ListAvailableModels(account) + if err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": true, + "models": models, + }) +} + +// ==================== 静态文件服务 ==================== + +func (h *Handler) serveAdminPage(w http.ResponseWriter, r *http.Request) { + http.ServeFile(w, r, "web/index.html") +} + +func (h *Handler) serveStaticFile(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/admin/") + http.ServeFile(w, r, "web/"+path) +} diff --git a/proxy/kiro.go b/proxy/kiro.go new file mode 100644 index 0000000..ccea09f --- /dev/null +++ b/proxy/kiro.go @@ -0,0 +1,370 @@ +// Package proxy Kiro API 代理核心 +// 负责调用 Kiro API 并解析 AWS Event Stream 响应 +package proxy + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "kiro-api-proxy/config" + "net/http" + "strings" + "time" + + "github.com/google/uuid" +) + +const ( + KiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse" + KiroVersion = "0.6.18" +) + +// ==================== 请求结构 ==================== + +// KiroPayload Kiro API 请求体 +type KiroPayload struct { + ConversationState struct { + ChatTriggerType string `json:"chatTriggerType"` + ConversationID string `json:"conversationId"` + CurrentMessage struct { + UserInputMessage KiroUserInputMessage `json:"userInputMessage"` + } `json:"currentMessage"` + History []KiroHistoryMessage `json:"history,omitempty"` + } `json:"conversationState"` + ProfileArn string `json:"profileArn,omitempty"` + InferenceConfig *InferenceConfig `json:"inferenceConfig,omitempty"` +} + +type KiroUserInputMessage struct { + Content string `json:"content"` + ModelID string `json:"modelId,omitempty"` + Origin string `json:"origin"` + Images []KiroImage `json:"images,omitempty"` + UserInputMessageContext *UserInputMessageContext `json:"userInputMessageContext,omitempty"` +} + +type UserInputMessageContext struct { + Tools []KiroToolWrapper `json:"tools,omitempty"` + ToolResults []KiroToolResult `json:"toolResults,omitempty"` +} + +type KiroToolWrapper struct { + ToolSpecification struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema InputSchema `json:"inputSchema"` + } `json:"toolSpecification"` +} + +type InputSchema struct { + JSON interface{} `json:"json"` +} + +type KiroToolResult struct { + ToolUseID string `json:"toolUseId"` + Content []KiroResultContent `json:"content"` + Status string `json:"status"` +} + +type KiroResultContent struct { + Text string `json:"text"` +} + +type KiroImage struct { + Format string `json:"format"` + Source struct { + Bytes string `json:"bytes"` + } `json:"source"` +} + +type KiroHistoryMessage struct { + UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"` + AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"` +} + +type KiroAssistantResponseMessage struct { + Content string `json:"content"` + ToolUses []KiroToolUse `json:"toolUses,omitempty"` +} + +type KiroToolUse struct { + ToolUseID string `json:"toolUseId"` + Name string `json:"name"` + Input map[string]interface{} `json:"input"` +} + +type InferenceConfig struct { + MaxTokens int `json:"maxTokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` +} + +// ==================== 流式回调 ==================== + +// KiroStreamCallback 流式响应回调 +type KiroStreamCallback struct { + OnText func(text string, isThinking bool) + OnToolUse func(toolUse KiroToolUse) + OnComplete func(inputTokens, outputTokens int) + OnError func(err error) + OnCredits func(credits float64) +} + +// ==================== API 调用 ==================== + +// CallKiroAPI 调用 Kiro API(流式) +func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroStreamCallback) error { + body, err := json.Marshal(payload) + if err != nil { + return err + } + + req, err := http.NewRequest("POST", KiroEndpoint, bytes.NewReader(body)) + if err != nil { + return err + } + + // 设置请求头 + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "*/*") + req.Header.Set("X-Amz-Target", "AmazonCodeWhispererStreamingService.GenerateAssistantResponse") + + // User-Agent 包含机器码 + machineId := account.MachineId + var userAgent, amzUserAgent string + if machineId != "" { + userAgent = fmt.Sprintf("aws-sdk-js/1.0.18 ua/2.1 os/linux lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-%s-%s", KiroVersion, machineId) + amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.18 KiroIDE %s %s", KiroVersion, machineId) + } else { + userAgent = fmt.Sprintf("aws-sdk-js/1.0.18 ua/2.1 os/linux lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-%s", KiroVersion) + amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.18 KiroIDE %s", KiroVersion) + } + req.Header.Set("User-Agent", userAgent) + req.Header.Set("X-Amz-User-Agent", amzUserAgent) + req.Header.Set("x-amzn-kiro-agent-mode", "spec") + req.Header.Set("x-amzn-codewhisperer-optout", "true") + req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") + req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) + req.Header.Set("Authorization", "Bearer "+account.AccessToken) + + client := &http.Client{Timeout: 5 * time.Minute} + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + return parseEventStream(resp.Body, callback) +} + +// ==================== Event Stream 解析 ==================== + +// parseEventStream 解析 AWS Event Stream 二进制格式 +func parseEventStream(body io.Reader, callback *KiroStreamCallback) error { + reader := bufio.NewReader(body) + + var inputTokens, outputTokens int + var totalOutputChars int + var totalCredits float64 + var currentToolUse *toolUseState + + for { + // Prelude: 12 bytes (total_len + headers_len + crc) + prelude := make([]byte, 12) + _, err := io.ReadFull(reader, prelude) + if err == io.EOF { + break + } + if err != nil { + return err + } + + totalLength := int(prelude[0])<<24 | int(prelude[1])<<16 | int(prelude[2])<<8 | int(prelude[3]) + headersLength := int(prelude[4])<<24 | int(prelude[5])<<16 | int(prelude[6])<<8 | int(prelude[7]) + + if totalLength < 16 { + continue + } + + // 读取剩余部分 + remaining := totalLength - 12 + msgBuf := make([]byte, remaining) + _, err = io.ReadFull(reader, msgBuf) + if err != nil { + return err + } + + if headersLength > len(msgBuf)-4 { + continue + } + + eventType := extractEventType(msgBuf[0:headersLength]) + payloadBytes := msgBuf[headersLength : len(msgBuf)-4] + if len(payloadBytes) == 0 { + continue + } + + var event map[string]interface{} + if err := json.Unmarshal(payloadBytes, &event); err != nil { + continue + } + + // 处理事件 + switch eventType { + case "assistantResponseEvent": + if content, ok := event["content"].(string); ok && content != "" { + callback.OnText(content, false) + totalOutputChars += len(content) + } + case "reasoningContentEvent": + if text, ok := event["text"].(string); ok && text != "" { + callback.OnText(text, true) + totalOutputChars += len(text) + } + case "toolUseEvent": + currentToolUse = handleToolUseEvent(event, currentToolUse, callback) + case "messageMetadataEvent", "metadataEvent": + if tokenUsage, ok := event["tokenUsage"].(map[string]interface{}); ok { + if v, ok := tokenUsage["outputTokens"].(float64); ok { + outputTokens = int(v) + } + uncached, _ := tokenUsage["uncachedInputTokens"].(float64) + cacheRead, _ := tokenUsage["cacheReadInputTokens"].(float64) + cacheWrite, _ := tokenUsage["cacheWriteInputTokens"].(float64) + inputTokens = int(uncached + cacheRead + cacheWrite) + } + case "meteringEvent": + if usage, ok := event["usage"].(float64); ok { + totalCredits += usage + } + } + } + + // 估算 token(约 3 字符 = 1 token) + if outputTokens == 0 && totalOutputChars > 0 { + outputTokens = max(1, totalOutputChars/3) + } + + if callback.OnCredits != nil && totalCredits > 0 { + callback.OnCredits(totalCredits) + } + + callback.OnComplete(inputTokens, outputTokens) + return nil +} + +// ==================== Tool Use 处理 ==================== + +type toolUseState struct { + ToolUseID string + Name string + InputBuffer strings.Builder +} + +func handleToolUseEvent(event map[string]interface{}, current *toolUseState, callback *KiroStreamCallback) *toolUseState { + toolUseID, _ := event["toolUseId"].(string) + name, _ := event["name"].(string) + isStop, _ := event["stop"].(bool) + + if toolUseID != "" && name != "" { + if current == nil { + current = &toolUseState{ToolUseID: toolUseID, Name: name} + } else if current.ToolUseID != toolUseID { + finishToolUse(current, callback) + current = &toolUseState{ToolUseID: toolUseID, Name: name} + } + } + + if current != nil { + if input, ok := event["input"].(string); ok { + current.InputBuffer.WriteString(input) + } else if inputObj, ok := event["input"].(map[string]interface{}); ok { + data, _ := json.Marshal(inputObj) + current.InputBuffer.Reset() + current.InputBuffer.Write(data) + } + } + + if isStop && current != nil { + finishToolUse(current, callback) + return nil + } + + return current +} + +func finishToolUse(state *toolUseState, callback *KiroStreamCallback) { + var input map[string]interface{} + if state.InputBuffer.Len() > 0 { + json.Unmarshal([]byte(state.InputBuffer.String()), &input) + } + if input == nil { + input = make(map[string]interface{}) + } + callback.OnToolUse(KiroToolUse{ + ToolUseID: state.ToolUseID, + Name: state.Name, + Input: input, + }) +} + +// extractEventType 从 headers 中提取事件类型 +func extractEventType(headers []byte) string { + offset := 0 + for offset < len(headers) { + if offset >= len(headers) { + break + } + nameLen := int(headers[offset]) + offset++ + if offset+nameLen > len(headers) { + break + } + name := string(headers[offset : offset+nameLen]) + offset += nameLen + if offset >= len(headers) { + break + } + valueType := headers[offset] + offset++ + + if valueType == 7 { // String + if offset+2 > len(headers) { + break + } + valueLen := int(headers[offset])<<8 | int(headers[offset+1]) + offset += 2 + if offset+valueLen > len(headers) { + break + } + value := string(headers[offset : offset+valueLen]) + offset += valueLen + if name == ":event-type" { + return value + } + continue + } + + // 跳过其他类型 + skipSizes := map[byte]int{0: 0, 1: 0, 2: 1, 3: 2, 4: 4, 5: 8, 8: 8, 9: 16} + if valueType == 6 { + if offset+2 > len(headers) { + break + } + l := int(headers[offset])<<8 | int(headers[offset+1]) + offset += 2 + l + } else if skip, ok := skipSizes[valueType]; ok { + offset += skip + } else { + break + } + } + return "" +} diff --git a/proxy/kiro_api.go b/proxy/kiro_api.go new file mode 100644 index 0000000..fc81dc8 --- /dev/null +++ b/proxy/kiro_api.go @@ -0,0 +1,271 @@ +package proxy + +import ( + "encoding/json" + "fmt" + "io" + "kiro-api-proxy/config" + "net/http" + "strings" + "time" +) + +const ( + kiroRestAPIBase = "https://codewhisperer.us-east-1.amazonaws.com" + kiroVersion = "0.6.18" +) + +// GetUsageLimits 获取账户使用量和订阅信息 +func GetUsageLimits(account *config.Account) (*UsageLimitsResponse, error) { + url := fmt.Sprintf("%s/getUsageLimits?origin=AI_EDITOR&resourceType=AGENTIC_REQUEST&isEmailRequired=true", kiroRestAPIBase) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + + setKiroHeaders(req, account) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + var result UsageLimitsResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + return &result, nil +} + +// GetUserInfo 获取用户信息 +func GetUserInfo(account *config.Account) (*UserInfoResponse, error) { + url := fmt.Sprintf("%s/GetUserInfo", kiroRestAPIBase) + + payload := `{"origin":"KIRO_IDE"}` + req, err := http.NewRequest("POST", url, strings.NewReader(payload)) + if err != nil { + return nil, err + } + + setKiroHeaders(req, account) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + var result UserInfoResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + return &result, nil +} + +// ListAvailableModels 获取可用模型列表 +func ListAvailableModels(account *config.Account) ([]ModelInfo, error) { + url := fmt.Sprintf("%s/ListAvailableModels?origin=AI_EDITOR&maxResults=50", kiroRestAPIBase) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + + setKiroHeaders(req, account) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + var result struct { + Models []ModelInfo `json:"models"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + return result.Models, nil +} + +func setKiroHeaders(req *http.Request, account *config.Account) { + machineId := account.MachineId + var userAgent, amzUserAgent string + if machineId != "" { + userAgent = fmt.Sprintf("aws-sdk-js/1.0.18 ua/2.1 os/windows lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-%s-%s", kiroVersion, machineId) + amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.18 KiroIDE %s %s", kiroVersion, machineId) + } else { + userAgent = fmt.Sprintf("aws-sdk-js/1.0.18 ua/2.1 os/windows lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-%s", kiroVersion) + amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.18 KiroIDE-%s", kiroVersion) + } + + req.Header.Set("Authorization", "Bearer "+account.AccessToken) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", userAgent) + req.Header.Set("x-amz-user-agent", amzUserAgent) + req.Header.Set("x-amzn-codewhisperer-optout", "true") +} + +// RefreshAccountInfo 刷新账户信息(使用量、订阅等) +func RefreshAccountInfo(account *config.Account) (*config.AccountInfo, error) { + info := &config.AccountInfo{ + LastRefresh: time.Now().Unix(), + } + + // 获取使用量和订阅信息 + usage, err := GetUsageLimits(account) + if err != nil { + return nil, fmt.Errorf("GetUsageLimits: %w", err) + } + + // 解析用户信息 + if usage.UserInfo != nil { + info.Email = usage.UserInfo.Email + info.UserId = usage.UserInfo.UserId + } + + // 解析订阅信息 + if usage.SubscriptionInfo != nil { + // 优先从 SubscriptionTitle 或 SubscriptionName 解析类型 + titleOrName := usage.SubscriptionInfo.SubscriptionTitle + if titleOrName == "" { + titleOrName = usage.SubscriptionInfo.SubscriptionName + } + if titleOrName == "" { + titleOrName = usage.SubscriptionInfo.SubscriptionType + } + info.SubscriptionType = parseSubscriptionType(titleOrName) + info.SubscriptionTitle = usage.SubscriptionInfo.SubscriptionTitle + if info.SubscriptionTitle == "" { + info.SubscriptionTitle = usage.SubscriptionInfo.SubscriptionName + } + fmt.Printf("[RefreshAccountInfo] Subscription: type=%s, title=%s, name=%s, parsed=%s\n", + usage.SubscriptionInfo.SubscriptionType, + usage.SubscriptionInfo.SubscriptionTitle, + usage.SubscriptionInfo.SubscriptionName, + info.SubscriptionType) + } + + // 解析使用量 + if len(usage.UsageBreakdownList) > 0 { + breakdown := usage.UsageBreakdownList[0] + info.UsageCurrent = breakdown.CurrentUsage + info.UsageLimit = breakdown.UsageLimit + if info.UsageLimit > 0 { + info.UsagePercent = info.UsageCurrent / info.UsageLimit + } + } + + // 解析重置日期 + if usage.NextDateReset != "" { + if ts, err := usage.NextDateReset.Int64(); err == nil && ts > 0 { + info.NextResetDate = time.Unix(ts, 0).Format("2006-01-02") + } else if f, err := usage.NextDateReset.Float64(); err == nil && f > 0 { + info.NextResetDate = time.Unix(int64(f), 0).Format("2006-01-02") + } + } + + return info, nil +} + +func parseSubscriptionType(raw string) string { + upper := strings.ToUpper(raw) + if strings.Contains(upper, "PRO_PLUS") || strings.Contains(upper, "PROPLUS") { + return "PRO_PLUS" + } + if strings.Contains(upper, "POWER") { + return "POWER" + } + if strings.Contains(upper, "PRO") { + return "PRO" + } + return "FREE" +} + +// 响应结构体 +type UsageLimitsResponse struct { + UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList"` + NextDateReset json.Number `json:"nextDateReset"` + SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo"` + UserInfo *UserInfo `json:"userInfo"` +} + +type UsageBreakdown struct { + ResourceType string `json:"resourceType"` + CurrentUsage float64 `json:"currentUsage"` + UsageLimit float64 `json:"usageLimit"` + Currency string `json:"currency"` + Unit string `json:"unit"` + OverageRate float64 `json:"overageRate"` + FreeTrialInfo *FreeTrialInfo `json:"freeTrialInfo"` + Bonuses []BonusInfo `json:"bonuses"` +} + +type FreeTrialInfo struct { + CurrentUsage float64 `json:"currentUsage"` + UsageLimit float64 `json:"usageLimit"` + FreeTrialStatus string `json:"freeTrialStatus"` + FreeTrialExpiry int64 `json:"freeTrialExpiry"` +} + +type BonusInfo struct { + BonusCode string `json:"bonusCode"` + DisplayName string `json:"displayName"` + CurrentUsage float64 `json:"currentUsage"` + UsageLimit float64 `json:"usageLimit"` + ExpiresAt int64 `json:"expiresAt"` + Status string `json:"status"` +} + +type SubscriptionInfo struct { + SubscriptionName string `json:"subscriptionName"` + SubscriptionTitle string `json:"subscriptionTitle"` + SubscriptionType string `json:"subscriptionType"` + Status string `json:"status"` + UpgradeCapability string `json:"upgradeCapability"` +} + +type UserInfo struct { + Email string `json:"email"` + UserId string `json:"userId"` +} + +type UserInfoResponse struct { + Email string `json:"email"` + UserId string `json:"userId"` + Idp string `json:"idp"` + Status string `json:"status"` +} + +type ModelInfo struct { + ModelId string `json:"modelId"` + ModelName string `json:"modelName"` + Description string `json:"description"` + InputTypes []string `json:"supportedInputTypes"` + RateMultiplier float64 `json:"rateMultiplier"` + TokenLimits *struct { + MaxInputTokens int `json:"maxInputTokens"` + MaxOutputTokens int `json:"maxOutputTokens"` + } `json:"tokenLimits"` +} diff --git a/proxy/translator.go b/proxy/translator.go new file mode 100644 index 0000000..7849601 --- /dev/null +++ b/proxy/translator.go @@ -0,0 +1,811 @@ +package proxy + +import ( + "encoding/base64" + "encoding/json" + "regexp" + "strings" + "time" + + "github.com/google/uuid" +) + +// 模型映射 +var modelMap = map[string]string{ + "claude-sonnet-4-5": "claude-sonnet-4.5", + "claude-sonnet-4.5": "claude-sonnet-4.5", + "claude-haiku-4-5": "claude-haiku-4.5", + "claude-haiku-4.5": "claude-haiku-4.5", + "claude-opus-4-5": "claude-opus-4.5", + "claude-opus-4.5": "claude-opus-4.5", + "claude-sonnet-4": "claude-sonnet-4", + "claude-sonnet-4-20250514": "claude-sonnet-4", + "claude-3-5-sonnet": "claude-sonnet-4.5", + "claude-3-opus": "claude-sonnet-4.5", + "claude-3-sonnet": "claude-sonnet-4", + "claude-3-haiku": "claude-haiku-4.5", + "gpt-4": "claude-sonnet-4.5", + "gpt-4o": "claude-sonnet-4.5", + "gpt-4-turbo": "claude-sonnet-4.5", + "gpt-3.5-turbo": "claude-sonnet-4.5", +} + +func MapModel(model string) string { + lower := strings.ToLower(model) + for k, v := range modelMap { + if strings.Contains(lower, k) { + return v + } + } + // 如果已经是有效的 Kiro 模型,直接返回 + if strings.HasPrefix(lower, "claude-") { + return model + } + return "claude-sonnet-4.5" +} + +// ==================== Claude API 类型 ==================== + +type ClaudeRequest struct { + Model string `json:"model"` + Messages []ClaudeMessage `json:"messages"` + MaxTokens int `json:"max_tokens"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Stream bool `json:"stream,omitempty"` + System interface{} `json:"system,omitempty"` // string or []SystemBlock + Tools []ClaudeTool `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` +} + +type ClaudeMessage struct { + Role string `json:"role"` + Content interface{} `json:"content"` // string or []ContentBlock +} + +type ClaudeContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input interface{} `json:"input,omitempty"` + ToolUseID string `json:"tool_use_id,omitempty"` + Content interface{} `json:"content,omitempty"` // for tool_result + Source *ImageSource `json:"source,omitempty"` +} + +type ImageSource struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data string `json:"data"` +} + +type ClaudeTool struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema interface{} `json:"input_schema"` +} + +type ClaudeResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []ClaudeContentBlock `json:"content"` + Model string `json:"model"` + StopReason string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence"` + Usage ClaudeUsage `json:"usage"` +} + +type ClaudeUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +// ==================== Claude -> Kiro 转换 ==================== + +const maxToolDescLen = 10237 + +func ClaudeToKiro(req *ClaudeRequest) *KiroPayload { + modelID := MapModel(req.Model) + origin := "AI_EDITOR" + + // 提取系统提示 + systemPrompt := extractSystemPrompt(req.System) + + // 注入时间戳 + timestamp := time.Now().Format(time.RFC3339) + systemPrompt = "[Context: Current time is " + timestamp + "]\n\n" + systemPrompt + + // 注入执行导向指令(防止 AI 在探索过程中丢失目标) + executionDirective := ` + +当用户要求执行特定任务时,你必须遵循以下纪律: +1. **目标锁定**:在整个会话中始终牢记用户的原始目标,不要在代码探索过程中迷失方向 +2. **行动优先**:优先执行任务而非仅分析或总结,除非用户明确只要求分析 +3. **计划执行**:为任务创建明确的步骤计划,逐步执行并标记完成状态 +4. **禁止确认性收尾**:在任务未完成前,禁止输出"需要我继续吗?"、"需要深入分析吗?"等确认性问题 +5. **持续推进**:如果发现部分任务已完成,立即继续执行剩余未完成的任务 +6. **完整交付**:直到所有任务步骤都执行完毕才算完成 + +` + systemPrompt = systemPrompt + "\n\n" + executionDirective + + // 构建历史消息 + history := make([]KiroHistoryMessage, 0) + var currentContent string + var currentImages []KiroImage + var currentToolResults []KiroToolResult + + for i, msg := range req.Messages { + isLast := i == len(req.Messages)-1 + + if msg.Role == "user" { + content, images, toolResults := extractClaudeUserContent(msg.Content) + + if isLast { + currentContent = content + currentImages = images + currentToolResults = toolResults + } else { + userMsg := KiroUserInputMessage{ + Content: content, + ModelID: modelID, + Origin: origin, + } + if len(images) > 0 { + userMsg.Images = images + } + if len(toolResults) > 0 { + userMsg.UserInputMessageContext = &UserInputMessageContext{ + ToolResults: toolResults, + } + } + history = append(history, KiroHistoryMessage{ + UserInputMessage: &userMsg, + }) + } + } else if msg.Role == "assistant" { + content, toolUses := extractClaudeAssistantContent(msg.Content) + history = append(history, KiroHistoryMessage{ + AssistantResponseMessage: &KiroAssistantResponseMessage{ + Content: content, + ToolUses: toolUses, + }, + }) + } + } + + // 确保 history 以 user 开始 + if len(history) > 0 && history[0].AssistantResponseMessage != nil { + history = append([]KiroHistoryMessage{{ + UserInputMessage: &KiroUserInputMessage{ + Content: "Begin conversation", + ModelID: modelID, + Origin: origin, + }, + }}, history...) + } + + // 构建最终内容 + finalContent := "" + if systemPrompt != "" { + finalContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n\n" + } + if currentContent != "" { + finalContent += currentContent + } else if len(currentToolResults) > 0 { + finalContent += "Tool results provided." + } else { + finalContent += "Continue" + } + + // 转换工具 + kiroTools := convertClaudeTools(req.Tools) + + // 构建 payload + payload := &KiroPayload{} + payload.ConversationState.ChatTriggerType = "MANUAL" + payload.ConversationState.ConversationID = uuid.New().String() + payload.ConversationState.CurrentMessage.UserInputMessage = KiroUserInputMessage{ + Content: finalContent, + ModelID: modelID, + Origin: origin, + Images: currentImages, + } + + if len(kiroTools) > 0 || len(currentToolResults) > 0 { + payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext = &UserInputMessageContext{ + Tools: kiroTools, + ToolResults: currentToolResults, + } + } + + if len(history) > 0 { + payload.ConversationState.History = history + } + + if req.MaxTokens > 0 || req.Temperature > 0 || req.TopP > 0 { + payload.InferenceConfig = &InferenceConfig{ + MaxTokens: req.MaxTokens, + Temperature: req.Temperature, + TopP: req.TopP, + } + } + + return payload +} + +func extractSystemPrompt(system interface{}) string { + if system == nil { + return "" + } + if s, ok := system.(string); ok { + return s + } + if blocks, ok := system.([]interface{}); ok { + var parts []string + for _, b := range blocks { + if block, ok := b.(map[string]interface{}); ok { + if text, ok := block["text"].(string); ok { + parts = append(parts, text) + } + } + } + return strings.Join(parts, "\n") + } + return "" +} + +func extractClaudeUserContent(content interface{}) (string, []KiroImage, []KiroToolResult) { + var text string + var images []KiroImage + var toolResults []KiroToolResult + + if s, ok := content.(string); ok { + return s, nil, nil + } + + if blocks, ok := content.([]interface{}); ok { + for _, b := range blocks { + block, ok := b.(map[string]interface{}) + if !ok { + continue + } + + blockType, _ := block["type"].(string) + switch blockType { + case "text": + if t, ok := block["text"].(string); ok { + text += t + } + case "image": + if source, ok := block["source"].(map[string]interface{}); ok { + mediaType, _ := source["media_type"].(string) + data, _ := source["data"].(string) + format := strings.TrimPrefix(mediaType, "image/") + if format == "jpg" { + format = "jpeg" + } + images = append(images, KiroImage{ + Format: format, + Source: struct { + Bytes string `json:"bytes"` + }{Bytes: data}, + }) + } + case "tool_result": + toolUseID, _ := block["tool_use_id"].(string) + resultContent := extractToolResultContent(block["content"]) + toolResults = append(toolResults, KiroToolResult{ + ToolUseID: toolUseID, + Content: []KiroResultContent{{Text: resultContent}}, + Status: "success", + }) + } + } + } + + return text, images, toolResults +} + +func extractToolResultContent(content interface{}) string { + if s, ok := content.(string); ok { + return s + } + if blocks, ok := content.([]interface{}); ok { + var parts []string + for _, b := range blocks { + if block, ok := b.(map[string]interface{}); ok { + if text, ok := block["text"].(string); ok { + parts = append(parts, text) + } + } + } + return strings.Join(parts, "") + } + return "" +} + +func extractClaudeAssistantContent(content interface{}) (string, []KiroToolUse) { + var text string + var toolUses []KiroToolUse + + if s, ok := content.(string); ok { + return s, nil + } + + if blocks, ok := content.([]interface{}); ok { + for _, b := range blocks { + block, ok := b.(map[string]interface{}) + if !ok { + continue + } + + blockType, _ := block["type"].(string) + switch blockType { + case "text": + if t, ok := block["text"].(string); ok { + text += t + } + case "tool_use": + id, _ := block["id"].(string) + name, _ := block["name"].(string) + input, _ := block["input"].(map[string]interface{}) + if input == nil { + input = make(map[string]interface{}) + } + toolUses = append(toolUses, KiroToolUse{ + ToolUseID: id, + Name: name, + Input: input, + }) + } + } + } + + if text == "" && len(toolUses) > 0 { + text = "Using tools." + } + + return text, toolUses +} + +func convertClaudeTools(tools []ClaudeTool) []KiroToolWrapper { + if len(tools) == 0 { + return nil + } + + result := make([]KiroToolWrapper, len(tools)) + for i, tool := range tools { + desc := tool.Description + if len(desc) > maxToolDescLen { + desc = desc[:maxToolDescLen] + "..." + } + result[i] = KiroToolWrapper{} + result[i].ToolSpecification.Name = shortenToolName(tool.Name) + result[i].ToolSpecification.Description = desc + result[i].ToolSpecification.InputSchema = InputSchema{JSON: tool.InputSchema} + } + return result +} + +func shortenToolName(name string) string { + if len(name) <= 64 { + return name + } + // MCP tools: mcp__server__tool -> mcp__tool + if strings.HasPrefix(name, "mcp__") { + lastIdx := strings.LastIndex(name, "__") + if lastIdx > 5 { + shortened := "mcp__" + name[lastIdx+2:] + if len(shortened) <= 64 { + return shortened + } + } + } + return name[:64] +} + +// ==================== Kiro -> Claude 转换 ==================== + +func KiroToClaudeResponse(content string, toolUses []KiroToolUse, inputTokens, outputTokens int, model string) *ClaudeResponse { + blocks := make([]ClaudeContentBlock, 0) + + if content != "" { + blocks = append(blocks, ClaudeContentBlock{ + Type: "text", + Text: content, + }) + } + + for _, tu := range toolUses { + blocks = append(blocks, ClaudeContentBlock{ + Type: "tool_use", + ID: tu.ToolUseID, + Name: tu.Name, + Input: tu.Input, + }) + } + + stopReason := "end_turn" + if len(toolUses) > 0 { + stopReason = "tool_use" + } + + return &ClaudeResponse{ + ID: "msg_" + uuid.New().String(), + Type: "message", + Role: "assistant", + Content: blocks, + Model: model, + StopReason: stopReason, + Usage: ClaudeUsage{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + }, + } +} + +// ==================== OpenAI API 类型 ==================== + +type OpenAIRequest struct { + Model string `json:"model"` + Messages []OpenAIMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Stream bool `json:"stream,omitempty"` + Tools []OpenAITool `json:"tools,omitempty"` +} + +type OpenAIMessage struct { + Role string `json:"role"` + Content interface{} `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` +} + +type OpenAITool struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters interface{} `json:"parameters"` + } `json:"function"` +} + +type OpenAIResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []OpenAIChoice `json:"choices"` + Usage OpenAIUsage `json:"usage"` +} + +type OpenAIChoice struct { + Index int `json:"index"` + Message OpenAIMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type OpenAIUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// ==================== OpenAI -> Kiro 转换 ==================== + +func OpenAIToKiro(req *OpenAIRequest) *KiroPayload { + modelID := MapModel(req.Model) + origin := "AI_EDITOR" + + // 提取系统提示 + var systemPrompt string + var nonSystemMessages []OpenAIMessage + + for _, msg := range req.Messages { + if msg.Role == "system" { + if s, ok := msg.Content.(string); ok { + systemPrompt += s + "\n" + } + } else { + nonSystemMessages = append(nonSystemMessages, msg) + } + } + + // 注入时间戳 + timestamp := time.Now().Format(time.RFC3339) + systemPrompt = "[Context: Current time is " + timestamp + "]\n\n" + systemPrompt + + // 注入执行导向指令(防止 AI 在探索过程中丢失目标) + executionDirective := ` + +当用户要求执行特定任务时,你必须遵循以下纪律: +1. **目标锁定**:在整个会话中始终牢记用户的原始目标,不要在代码探索过程中迷失方向 +2. **行动优先**:优先执行任务而非仅分析或总结,除非用户明确只要求分析 +3. **计划执行**:为任务创建明确的步骤计划,逐步执行并标记完成状态 +4. **禁止确认性收尾**:在任务未完成前,禁止输出"需要我继续吗?"、"需要深入分析吗?"等确认性问题 +5. **持续推进**:如果发现部分任务已完成,立即继续执行剩余未完成的任务 +6. **完整交付**:直到所有任务步骤都执行完毕才算完成 + +` + systemPrompt = systemPrompt + "\n\n" + executionDirective + + // 构建历史消息 + history := make([]KiroHistoryMessage, 0) + var currentContent string + var currentImages []KiroImage + var currentToolResults []KiroToolResult + systemMerged := false + + for i, msg := range nonSystemMessages { + isLast := i == len(nonSystemMessages)-1 + + switch msg.Role { + case "user": + content, images := extractOpenAIUserContent(msg.Content) + + // 第一条 user 消息合并 system prompt + if !systemMerged && systemPrompt != "" { + content = systemPrompt + "\n" + content + systemMerged = true + } + + if isLast { + currentContent = content + currentImages = images + } else { + history = append(history, KiroHistoryMessage{ + UserInputMessage: &KiroUserInputMessage{ + Content: content, + ModelID: modelID, + Origin: origin, + Images: images, + }, + }) + } + + case "assistant": + content, _ := msg.Content.(string) + if content == "" && len(msg.ToolCalls) > 0 { + content = "Using tools." + } + + var toolUses []KiroToolUse + for _, tc := range msg.ToolCalls { + var input map[string]interface{} + json.Unmarshal([]byte(tc.Function.Arguments), &input) + if input == nil { + input = make(map[string]interface{}) + } + toolUses = append(toolUses, KiroToolUse{ + ToolUseID: tc.ID, + Name: tc.Function.Name, + Input: input, + }) + } + + history = append(history, KiroHistoryMessage{ + AssistantResponseMessage: &KiroAssistantResponseMessage{ + Content: content, + ToolUses: toolUses, + }, + }) + + case "tool": + content, _ := msg.Content.(string) + currentToolResults = append(currentToolResults, KiroToolResult{ + ToolUseID: msg.ToolCallID, + Content: []KiroResultContent{{Text: content}}, + Status: "success", + }) + + // 检查下一条是否还是 tool + nextIdx := i + 1 + if nextIdx >= len(nonSystemMessages) || nonSystemMessages[nextIdx].Role != "tool" { + if !isLast { + history = append(history, KiroHistoryMessage{ + UserInputMessage: &KiroUserInputMessage{ + Content: "Tool results provided.", + ModelID: modelID, + Origin: origin, + UserInputMessageContext: &UserInputMessageContext{ + ToolResults: currentToolResults, + }, + }, + }) + currentToolResults = nil + } + } + } + } + + // 构建最终内容 + finalContent := currentContent + if finalContent == "" { + if len(currentToolResults) > 0 { + finalContent = "Tool results provided." + } else { + finalContent = "Continue" + } + } + if !systemMerged && systemPrompt != "" { + finalContent = systemPrompt + "\n" + finalContent + } + + // 转换工具 + kiroTools := convertOpenAITools(req.Tools) + + // 构建 payload + payload := &KiroPayload{} + payload.ConversationState.ChatTriggerType = "MANUAL" + payload.ConversationState.ConversationID = uuid.New().String() + payload.ConversationState.CurrentMessage.UserInputMessage = KiroUserInputMessage{ + Content: finalContent, + ModelID: modelID, + Origin: origin, + Images: currentImages, + } + + if len(kiroTools) > 0 || len(currentToolResults) > 0 { + payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext = &UserInputMessageContext{ + Tools: kiroTools, + ToolResults: currentToolResults, + } + } + + if len(history) > 0 { + payload.ConversationState.History = history + } + + if req.MaxTokens > 0 || req.Temperature > 0 || req.TopP > 0 { + payload.InferenceConfig = &InferenceConfig{ + MaxTokens: req.MaxTokens, + Temperature: req.Temperature, + TopP: req.TopP, + } + } + + return payload +} + +func extractOpenAIUserContent(content interface{}) (string, []KiroImage) { + if s, ok := content.(string); ok { + return s, nil + } + + var text string + var images []KiroImage + + if parts, ok := content.([]interface{}); ok { + for _, p := range parts { + part, ok := p.(map[string]interface{}) + if !ok { + continue + } + + partType, _ := part["type"].(string) + switch partType { + case "text": + if t, ok := part["text"].(string); ok { + text += t + } + case "image_url": + if imgUrl, ok := part["image_url"].(map[string]interface{}); ok { + if url, ok := imgUrl["url"].(string); ok { + if img := parseDataURL(url); img != nil { + images = append(images, *img) + } + } + } + } + } + } + + return text, images +} + +func parseDataURL(url string) *KiroImage { + // data:image/png;base64,xxxxx + re := regexp.MustCompile(`^data:image/(\w+);base64,(.+)$`) + matches := re.FindStringSubmatch(url) + if len(matches) != 3 { + return nil + } + + format := matches[1] + if format == "jpg" { + format = "jpeg" + } + + // 验证 base64 + if _, err := base64.StdEncoding.DecodeString(matches[2]); err != nil { + return nil + } + + return &KiroImage{ + Format: format, + Source: struct { + Bytes string `json:"bytes"` + }{Bytes: matches[2]}, + } +} + +func convertOpenAITools(tools []OpenAITool) []KiroToolWrapper { + if len(tools) == 0 { + return nil + } + + result := make([]KiroToolWrapper, 0, len(tools)) + for _, tool := range tools { + if tool.Type != "function" { + continue + } + desc := tool.Function.Description + if len(desc) > maxToolDescLen { + desc = desc[:maxToolDescLen] + "..." + } + wrapper := KiroToolWrapper{} + wrapper.ToolSpecification.Name = shortenToolName(tool.Function.Name) + wrapper.ToolSpecification.Description = desc + wrapper.ToolSpecification.InputSchema = InputSchema{JSON: tool.Function.Parameters} + result = append(result, wrapper) + } + return result +} + +// ==================== Kiro -> OpenAI 转换 ==================== + +func KiroToOpenAIResponse(content string, toolUses []KiroToolUse, inputTokens, outputTokens int, model string) *OpenAIResponse { + msg := OpenAIMessage{ + Role: "assistant", + } + + finishReason := "stop" + + if len(toolUses) > 0 { + msg.Content = nil + msg.ToolCalls = make([]ToolCall, len(toolUses)) + for i, tu := range toolUses { + args, _ := json.Marshal(tu.Input) + msg.ToolCalls[i] = ToolCall{ + ID: tu.ToolUseID, + Type: "function", + } + msg.ToolCalls[i].Function.Name = tu.Name + msg.ToolCalls[i].Function.Arguments = string(args) + } + finishReason = "tool_calls" + } else { + msg.Content = content + } + + return &OpenAIResponse{ + ID: "chatcmpl-" + uuid.New().String(), + Object: "chat.completion", + Created: time.Now().Unix(), + Model: model, + Choices: []OpenAIChoice{{ + Index: 0, + Message: msg, + FinishReason: finishReason, + }}, + Usage: OpenAIUsage{ + PromptTokens: inputTokens, + CompletionTokens: outputTokens, + TotalTokens: inputTokens + outputTokens, + }, + } +} diff --git a/web/index.html b/web/index.html new file mode 100644 index 0000000..67059d1 --- /dev/null +++ b/web/index.html @@ -0,0 +1,616 @@ + + + + + + Kiro API Proxy + + + +
+ +
+ + + + + + + + + +