diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..a43909b --- /dev/null +++ b/.dockerignore @@ -0,0 +1,10 @@ +.git +.github +Kiro-Go +data +backup* +PR-*.md +docker-compose.yml +README.md +README_CN.md +*.log \ No newline at end of file diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index b1ad3d1..125fecc 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -2,10 +2,10 @@ name: Build Docker Image on: push: - branches: [main, master] + branches: [main, master, dev] tags: ['v*'] pull_request: - branches: [main, master] + branches: [main, master, dev] workflow_dispatch: env: @@ -23,8 +23,14 @@ jobs: - name: Checkout uses: actions/checkout@v4 + - name: Set lowercase image name + id: image + run: echo "name=$(echo '${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}' | tr '[:upper:]' '[:lower:]')" >> "$GITHUB_OUTPUT" + - name: Set up QEMU uses: docker/setup-qemu-action@v3 + with: + platforms: arm64 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 @@ -41,7 +47,7 @@ jobs: id: meta uses: docker/metadata-action@v5 with: - images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + images: ${{ steps.image.outputs.name }} tags: | type=raw,value=latest,enable={{is_default_branch}} type=ref,event=branch @@ -50,7 +56,7 @@ jobs: type=sha,prefix= - name: Build and push - uses: docker/build-push-action@v5 + uses: docker/build-push-action@v6 with: context: . platforms: linux/amd64,linux/arm64 @@ -59,3 +65,5 @@ jobs: labels: ${{ steps.meta.outputs.labels }} cache-from: type=gha cache-to: type=gha,mode=max + provenance: false + diff --git a/Dockerfile b/Dockerfile index 9834d80..7c6cfa4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,20 +1,27 @@ -FROM golang:1.21-alpine AS builder +# builder 阶段始终运行在构建机原生平台(amd64),用 Go 交叉编译目标平台二进制 +FROM --platform=$BUILDPLATFORM golang:1.21-alpine AS builder + +ARG TARGETOS +ARG TARGETARCH WORKDIR /app COPY go.mod go.sum ./ -RUN go mod download +RUN --mount=type=cache,target=/go/pkg/mod \ + go mod download COPY . . -RUN CGO_ENABLED=0 GOOS=linux go build -o kiro-api-proxy . +RUN --mount=type=cache,target=/go/pkg/mod \ + --mount=type=cache,target=/root/.cache/go-build \ + CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -o kiro-go . FROM alpine:latest RUN apk --no-cache add ca-certificates WORKDIR /app -COPY --from=builder /app/kiro-api-proxy . +COPY --from=builder /app/kiro-go . COPY --from=builder /app/web ./web EXPOSE 8080 VOLUME /app/data -CMD ["./kiro-api-proxy"] +CMD ["./kiro-go"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..1bf685b --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Quorinex + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index b5d9b85..17a8649 100644 --- a/README.md +++ b/README.md @@ -8,20 +8,16 @@ Convert Kiro accounts to OpenAI / Anthropic compatible API service. [English](README.md) | [中文](README_CN.md) +If this project helps you, a Star would mean a lot. + ## Features -- 🔄 **Anthropic Claude API** - Full support for `/v1/messages` endpoint -- 🤖 **OpenAI Chat API** - Compatible with `/v1/chat/completions` -- ⚖️ **Multi-Account Pool** - Round-robin load balancing -- 🔐 **Auto Token Refresh** - Seamless token management -- 📡 **Streaming** - Real-time SSE responses -- 🎛️ **Web Admin Panel** - Easy account management -- 🔑 **Multiple Auth Methods** - AWS Builder ID, IAM Identity Center (Enterprise SSO), SSO Token, Local Cache, Credentials -- 📊 **Usage Tracking** - Monitor requests, tokens, and credits -- 📦 **Account Export/Import** - Compatible with Kiro Account Manager format -- 🔄 **Dynamic Model List** - Auto-synced from Kiro API with caching -- 🔔 **Version Update Check** - Automatic new version notification -- 🌐 **i18n** - Chinese / English admin panel +- Anthropic `/v1/messages` & OpenAI `/v1/chat/completions` +- Multi-account pool with round-robin load balancing +- Auto token refresh, SSE streaming, Web admin panel +- Multiple auth: AWS Builder ID, IAM Identity Center (Enterprise SSO), SSO Token, local cache, credentials JSON +- Usage tracking, account import/export, i18n (CN / EN) +- Support configuring outbound proxy (SOCKS5 / HTTP) ## Quick Start @@ -30,19 +26,13 @@ Convert Kiro accounts to OpenAI / Anthropic compatible API service. ```bash git clone https://github.com/Quorinex/Kiro-Go.git cd Kiro-Go - -# Create data directory for persistence mkdir -p data - docker-compose up -d ``` ### Docker Run ```bash -# Create data directory -mkdir -p /path/to/data - docker run -d \ --name kiro-go \ -p 8080:8080 \ @@ -52,8 +42,6 @@ docker run -d \ ghcr.io/quorinex/kiro-go:latest ``` -> 📁 The `/app/data` volume stores `config.json` with accounts and settings. Mount it for data persistence. - ### Build from Source ```bash @@ -63,22 +51,35 @@ go build -o kiro-go . ./kiro-go ``` -## Configuration +Config is auto-created at `data/config.json`. Mount `/app/data` for persistence. The default admin password is `changeme` — override it via the `ADMIN_PASSWORD` env var or change it in the admin panel before going to production. -Config file is auto-created at `data/config.json` on first run: +## Usage -```json -{ - "password": "changeme", - "port": 8080, - "host": "127.0.0.1", - "requireApiKey": false, - "apiKey": "", - "accounts": [] -} +Open `http://localhost:8080/admin`, log in, add accounts, then call the API: + +```bash +# Claude +curl http://localhost:8080/v1/messages \ + -H "Content-Type: application/json" \ + -H "anthropic-version: 2023-06-01" \ + -d '{"model":"claude-sonnet-4.5","max_tokens":1024,"messages":[{"role":"user","content":"Hello!"}]}' + +# OpenAI +curl http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer any" \ + -d '{"model":"gpt-4o","messages":[{"role":"user","content":"Hello!"}]}' ``` -> ⚠️ **Change the default password before production use!** +## Thinking Mode + +Append a suffix (default `-thinking`) to the model name, e.g. `claude-sonnet-4.5-thinking`. Claude-compatible requests that include a top-level `thinking` config such as `{"type":"enabled","budget_tokens":2048}` or `{"type":"adaptive"}` also enable thinking mode automatically. Configure output format in the admin panel under Settings - Thinking Mode. + +## Outbound Proxy + +For users in restricted network regions, configure an outbound proxy in the admin panel under **Settings - Outbound Proxy Settings**. Supports SOCKS5 and HTTP proxies. + +The setting takes effect immediately without restarting. ## Environment Variables @@ -87,168 +88,17 @@ Config file is auto-created at `data/config.json` on first run: | `CONFIG_PATH` | Config file path | `data/config.json` | | `ADMIN_PASSWORD` | Admin panel password (overrides config) | - | -## Usage +## Contributing -### 1. Access Admin Panel +Friendly discussion is welcome. If you run into issues, try asking Claude Code, Codex, or similar tools for help first — most problems can be solved that way. PRs are even better. -Open `http://localhost:8080/admin` and login with your password. +## Friend Links -### 2. Add Accounts - -Multiple methods available: - -| Method | Description | -|--------|-------------| -| **AWS Builder ID** | Login with AWS Builder ID (personal accounts) | -| **IAM Identity Center (Enterprise SSO)** | Login with IAM Identity Center (enterprise accounts) | -| **SSO Token** | Import `x-amz-sso_authn` token from browser | -| **Kiro Local Cache** | Import from local Kiro IDE cache files | -| **Credentials JSON** | Import JSON from Kiro Account Manager | - -#### Credentials Format - -```json -{ - "refreshToken": "eyJ...", - "accessToken": "eyJ...", - "clientId": "xxx", - "clientSecret": "xxx" -} -``` - -### 3. Call API - -#### Claude API - -```bash -curl http://localhost:8080/v1/messages \ - -H "Content-Type: application/json" \ - -H "anthropic-version: 2023-06-01" \ - -d '{ - "model": "claude-sonnet-4-20250514", - "max_tokens": 1024, - "messages": [{"role": "user", "content": "Hello!"}] - }' -``` - -#### OpenAI API - -```bash -curl http://localhost:8080/v1/chat/completions \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer any" \ - -d '{ - "model": "gpt-4o", - "messages": [{"role": "user", "content": "Hello!"}] - }' -``` - -## Model Mapping - -| Request Model | Actual Model | -|---------------|--------------| -| `claude-sonnet-4-20250514` | claude-sonnet-4-20250514 | -| `claude-sonnet-4.5` | claude-sonnet-4.5 | -| `claude-haiku-4.5` | claude-haiku-4.5 | -| `claude-opus-4.5` | claude-opus-4.5 | -| `claude-opus-4.6` | claude-opus-4.6 | -| `gpt-4o`, `gpt-4` | claude-sonnet-4-20250514 | -| `gpt-3.5-turbo` | claude-sonnet-4-20250514 | - -## Thinking Mode - -Enable extended thinking by adding a suffix to the model name (default: `-thinking`). - -### Usage - -```bash -# OpenAI API with thinking -curl http://localhost:8080/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "claude-sonnet-4.5-thinking", - "messages": [{"role": "user", "content": "Solve this step by step: 15 * 23"}], - "stream": true - }' - -# Claude API with thinking -curl http://localhost:8080/v1/messages \ - -H "Content-Type: application/json" \ - -H "anthropic-version: 2023-06-01" \ - -d '{ - "model": "claude-sonnet-4.5-thinking", - "max_tokens": 4096, - "messages": [{"role": "user", "content": "Analyze this problem"}] - }' -``` - -### Configuration - -Configure thinking mode in the Admin Panel under **Settings > Thinking Mode Settings**: - -| Setting | Description | Options | -|---------|-------------|---------| -| **Trigger Suffix** | Model name suffix to enable thinking | Default: `-thinking` (customizable, e.g., `-think`, `-reason`) | -| **OpenAI Output Format** | How thinking content is returned in OpenAI API | `reasoning_content` (DeepSeek compatible), `` tag, `` tag | -| **Claude Output Format** | How thinking content is returned in Claude API | `` tag (default), `` tag, plain text | - -### Output Formats - -**OpenAI API (`/v1/chat/completions`)**: -- `reasoning_content` - Thinking in separate `reasoning_content` field (DeepSeek compatible) -- `thinking` - Thinking wrapped in `...` tags in content -- `think` - Thinking wrapped in `...` tags in content - -**Claude API (`/v1/messages`)**: -- `thinking` - Thinking wrapped in `...` tags (default) -- `think` - Thinking wrapped in `...` tags -- `reasoning_content` - Plain text output - -## API Endpoints - -| Endpoint | Description | -|----------|-------------| -| `GET /health` | Health check | -| `GET /v1/models` | List models | -| `GET /v1/stats` | Statistics | -| `POST /v1/messages` | Claude Messages API | -| `POST /v1/messages/count_tokens` | Token counting | -| `POST /v1/chat/completions` | OpenAI Chat API | -| `GET /admin` | Admin panel | - -## Project Structure - -``` -Kiro-Go/ -├── main.go # Entry point -├── version.json # Version info for update check -├── config/ # Configuration management -├── pool/ # Account pool & load balancing -├── proxy/ # API handlers & Kiro client -│ ├── handler.go # HTTP routing & admin API -│ ├── kiro.go # Kiro API client -│ ├── kiro_api.go # Kiro REST API (usage, models) -│ └── translator.go # Request/response conversion -├── auth/ # Authentication -│ ├── builderid.go # AWS Builder ID login -│ ├── iam_sso.go # IAM SSO login -│ ├── oidc.go # OIDC token refresh -│ └── sso_token.go # SSO token import -├── web/ # Admin panel frontend -├── Dockerfile -└── docker-compose.yml -``` +- [LINUX DO](https://linux.do) ## Disclaimer -This project is provided for **educational and research purposes only**. - -- This software is not affiliated with, endorsed by, or associated with Amazon, AWS, or Kiro in any way -- Users are solely responsible for ensuring their use complies with all applicable terms of service and laws -- The authors assume no liability for any misuse or violations arising from the use of this software -- Use at your own risk - -By using this software, you acknowledge that you have read and understood this disclaimer. +For educational and research purposes only. Not affiliated with Amazon, AWS, or Kiro. Users are responsible for complying with applicable terms of service and laws. Use at your own risk. ## License diff --git a/README_CN.md b/README_CN.md index 750884b..8e9fdf6 100644 --- a/README_CN.md +++ b/README_CN.md @@ -8,20 +8,16 @@ [English](README.md) | 中文 +如果这个项目帮到了你,欢迎点个 Star 支持一下。 + ## 功能特性 -- 🔄 **Anthropic Claude API** - 完整支持 `/v1/messages` 端点 -- 🤖 **OpenAI Chat API** - 兼容 `/v1/chat/completions` -- ⚖️ **多账号池** - 轮询负载均衡 -- 🔐 **自动刷新 Token** - 无缝 Token 管理 -- 📡 **流式响应** - 实时 SSE 输出 -- 🎛️ **Web 管理面板** - 便捷的账号管理 -- 🔑 **多种认证方式** - AWS Builder ID、IAM Identity Center (企业 SSO)、SSO Token、本地缓存、凭证 JSON -- 📊 **用量追踪** - 监控请求数、Token、Credits -- 📦 **账号导入导出** - 兼容 Kiro Account Manager 格式 -- 🔄 **动态模型列表** - 自动从 Kiro API 同步并缓存 -- 🔔 **版本更新检测** - 自动提醒新版本 -- 🌐 **中英双语** - 管理面板支持中文 / 英文 +- Anthropic `/v1/messages` 与 OpenAI `/v1/chat/completions` +- 多账号池轮询负载均衡 +- 自动 Token 刷新、SSE 流式输出、Web 管理面板 +- 多种认证方式:AWS Builder ID、IAM Identity Center (企业 SSO)、SSO Token、本地缓存、凭证 JSON +- 用量追踪、账号导入导出、中英双语 +- 支持设置出站代理(SOCKS5 / HTTP) ## 快速开始 @@ -30,19 +26,13 @@ ```bash git clone https://github.com/Quorinex/Kiro-Go.git cd Kiro-Go - -# 创建数据目录用于持久化 mkdir -p data - docker-compose up -d ``` ### Docker 运行 ```bash -# 创建数据目录 -mkdir -p /path/to/data - docker run -d \ --name kiro-go \ -p 8080:8080 \ @@ -52,8 +42,6 @@ docker run -d \ ghcr.io/quorinex/kiro-go:latest ``` -> 📁 `/app/data` 卷存储 `config.json`(包含账号和设置),挂载此目录以实现数据持久化。 - ### 源码编译 ```bash @@ -63,22 +51,35 @@ go build -o kiro-go . ./kiro-go ``` -## 配置 +首次运行会在 `data/config.json` 自动生成配置,挂载 `/app/data` 以持久化。默认管理密码为 `changeme`,生产环境请务必通过 `ADMIN_PASSWORD` 环境变量或在管理面板中修改。 -首次运行会自动创建 `data/config.json`: +## 使用方法 -```json -{ - "password": "changeme", - "port": 8080, - "host": "127.0.0.1", - "requireApiKey": false, - "apiKey": "", - "accounts": [] -} +访问 `http://localhost:8080/admin` 登录、添加账号,然后调用 API: + +```bash +# Claude +curl http://localhost:8080/v1/messages \ + -H "Content-Type: application/json" \ + -H "anthropic-version: 2023-06-01" \ + -d '{"model":"claude-sonnet-4.5","max_tokens":1024,"messages":[{"role":"user","content":"你好!"}]}' + +# OpenAI +curl http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer any" \ + -d '{"model":"gpt-4o","messages":[{"role":"user","content":"你好!"}]}' ``` -> ⚠️ **生产环境请务必修改默认密码!** +## 思考模式 + +在模型名后加后缀(默认 `-thinking`)即可启用,例如 `claude-sonnet-4.5-thinking`。Claude 兼容请求如果带有顶层 `thinking` 配置,例如 `{"type":"enabled","budget_tokens":2048}` 或 `{"type":"adaptive"}`,也会自动启用 thinking 模式。输出格式可在管理面板「设置 - Thinking 模式」中配置。 + +## 出站代理 + +可在管理面板「设置 - 出站代理设置」中配置代理。支持 SOCKS5 和 HTTP 代理。 + +设置保存后即时生效,无需重启服务。 ## 环境变量 @@ -87,168 +88,17 @@ go build -o kiro-go . | `CONFIG_PATH` | 配置文件路径 | `data/config.json` | | `ADMIN_PASSWORD` | 管理面板密码(覆盖配置文件) | - | -## 使用方法 +## 参与贡献 -### 1. 访问管理面板 +欢迎友好交流。遇到问题时,建议先让 Claude Code、Codex 等工具帮忙排查一下,大部分问题都能自己解决。如果能直接提个 PR 就更好了。 -打开 `http://localhost:8080/admin`,输入密码登录。 +## 友情链接 -### 2. 添加账号 - -支持多种方式: - -| 方式 | 说明 | -|------|------| -| **AWS Builder ID** | 通过 AWS Builder ID 授权登录(个人账号) | -| **IAM Identity Center (企业 SSO) 登录** | 通过 IAM Identity Center (企业 SSO) 授权登录(企业账号) | -| **SSO Token** | 通过浏览器 `x-amz-sso_authn` Token 添加账号 | -| **Kiro 本地缓存** | 通过 Kiro IDE 本地缓存文件添加账号 | -| **凭证 JSON** | 通过 Kiro Account Manager 导出的凭证添加账号 | - -#### 凭证格式 - -```json -{ - "refreshToken": "eyJ...", - "accessToken": "eyJ...", - "clientId": "xxx", - "clientSecret": "xxx" -} -``` - -### 3. 调用 API - -#### Claude API - -```bash -curl http://localhost:8080/v1/messages \ - -H "Content-Type: application/json" \ - -H "anthropic-version: 2023-06-01" \ - -d '{ - "model": "claude-sonnet-4-20250514", - "max_tokens": 1024, - "messages": [{"role": "user", "content": "你好!"}] - }' -``` - -#### OpenAI API - -```bash -curl http://localhost:8080/v1/chat/completions \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer any" \ - -d '{ - "model": "gpt-4o", - "messages": [{"role": "user", "content": "你好!"}] - }' -``` - -## 模型映射 - -| 请求模型 | 实际模型 | -|---------|---------| -| `claude-sonnet-4-20250514` | claude-sonnet-4-20250514 | -| `claude-sonnet-4.5` | claude-sonnet-4.5 | -| `claude-haiku-4.5` | claude-haiku-4.5 | -| `claude-opus-4.5` | claude-opus-4.5 | -| `claude-opus-4.6` | claude-opus-4.6 | -| `gpt-4o`, `gpt-4` | claude-sonnet-4-20250514 | -| `gpt-3.5-turbo` | claude-sonnet-4-20250514 | - -## 思考模式 - -在模型名称后添加后缀(默认:`-thinking`)即可启用扩展思考模式。 - -### 使用方法 - -```bash -# OpenAI API 启用思考 -curl http://localhost:8080/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "claude-sonnet-4.5-thinking", - "messages": [{"role": "user", "content": "一步步解决:15 * 23"}], - "stream": true - }' - -# Claude API 启用思考 -curl http://localhost:8080/v1/messages \ - -H "Content-Type: application/json" \ - -H "anthropic-version: 2023-06-01" \ - -d '{ - "model": "claude-sonnet-4.5-thinking", - "max_tokens": 4096, - "messages": [{"role": "user", "content": "分析这个问题"}] - }' -``` - -### 配置 - -在管理面板的 **设置 > Thinking 模式设置** 中配置: - -| 设置 | 说明 | 选项 | -|-----|------|------| -| **触发后缀** | 启用思考的模型名称后缀 | 默认:`-thinking`(可自定义,如 `-think`、`-sikao`) | -| **OpenAI 输出格式** | OpenAI API 中思考内容的返回方式 | `reasoning_content`(DeepSeek 兼容)、`` 标签、`` 标签 | -| **Claude 输出格式** | Claude API 中思考内容的返回方式 | `` 标签(默认)、`` 标签、纯文本 | - -### 输出格式说明 - -**OpenAI API (`/v1/chat/completions`)**: -- `reasoning_content` - 思考内容放在单独的 `reasoning_content` 字段(DeepSeek 兼容) -- `thinking` - 思考内容用 `...` 标签包裹在 content 中 -- `think` - 思考内容用 `...` 标签包裹在 content 中 - -**Claude API (`/v1/messages`)**: -- `thinking` - 思考内容用 `...` 标签包裹(默认) -- `think` - 思考内容用 `...` 标签包裹 -- `reasoning_content` - 纯文本输出 - -## API 端点 - -| 端点 | 说明 | -|-----|------| -| `GET /health` | 健康检查 | -| `GET /v1/models` | 模型列表 | -| `GET /v1/stats` | 统计数据 | -| `POST /v1/messages` | Claude Messages API | -| `POST /v1/messages/count_tokens` | Token 计数 | -| `POST /v1/chat/completions` | OpenAI Chat API | -| `GET /admin` | 管理面板 | - -## 项目结构 - -``` -Kiro-Go/ -├── main.go # 入口 -├── version.json # 版本信息(用于更新检测) -├── config/ # 配置管理 -├── pool/ # 账号池 & 负载均衡 -├── proxy/ # API 处理 & Kiro 客户端 -│ ├── handler.go # HTTP 路由 & 管理 API -│ ├── kiro.go # Kiro API 客户端 -│ ├── kiro_api.go # Kiro REST API(用量、模型) -│ └── translator.go # 请求/响应转换 -├── auth/ # 认证 -│ ├── builderid.go # AWS Builder ID 登录 -│ ├── iam_sso.go # IAM SSO 登录 -│ ├── oidc.go # OIDC Token 刷新 -│ └── sso_token.go # SSO Token 导入 -├── web/ # 管理面板前端 -├── Dockerfile -└── docker-compose.yml -``` +- [LINUX DO](https://linux.do) ## 免责声明 -本项目仅供**学习和研究目的**使用。 - -- 本软件与 Amazon、AWS 或 Kiro 没有任何关联、认可或合作关系 -- 用户需自行确保其使用行为符合所有适用的服务条款和法律法规 -- 作者不对因使用本软件而产生的任何滥用或违规行为承担责任 -- 使用风险自负 - -使用本软件即表示您已阅读并理解本免责声明。 +本项目仅供学习和研究目的使用,与 Amazon、AWS 或 Kiro 没有任何关联。用户需自行确保使用行为符合所有适用的服务条款和法律法规,使用风险自负。 ## 许可证 diff --git a/auth/builderid.go b/auth/builderid.go index 460ad6b..21a74d9 100644 --- a/auth/builderid.go +++ b/auth/builderid.go @@ -57,7 +57,7 @@ func StartBuilderIdLogin(region string) (*BuilderIdSession, error) { regReq, _ := http.NewRequest("POST", oidcBase+"/client/register", bytes.NewReader(regBody)) regReq.Header.Set("Content-Type", "application/json") - client := httpClient + client := httpClient() regResp, err := client.Do(regReq) if err != nil { return nil, fmt.Errorf("register client failed: %v", err) @@ -175,7 +175,7 @@ func PollBuilderIdAuth(sessionID string) (accessToken, refreshToken, clientID, c tokenReq, _ := http.NewRequest("POST", oidcBase+"/token", bytes.NewReader(tokenBody)) tokenReq.Header.Set("Content-Type", "application/json") - client := httpClient + client := httpClient() tokenResp, err := client.Do(tokenReq) if err != nil { return "", "", "", "", "", 0, "", fmt.Errorf("token request failed: %v", err) diff --git a/auth/http_client.go b/auth/http_client.go index 836fb7c..4604d70 100644 --- a/auth/http_client.go +++ b/auth/http_client.go @@ -3,18 +3,48 @@ package auth import ( "net/http" + "net/url" + "sync/atomic" "time" ) -// 全局 HTTP 客户端,复用连接池 -// 用于所有 auth 模块的 HTTP 请求 -var httpClient = &http.Client{ - Timeout: 30 * time.Second, - Transport: &http.Transport{ - MaxIdleConns: 50, // 最大空闲连接数 - MaxIdleConnsPerHost: 10, // 每个 Host 最大空闲连接数 - IdleConnTimeout: 90 * time.Second, // 空闲连接超时 - DisableCompression: false, // 启用压缩 - ForceAttemptHTTP2: true, // 尝试使用 HTTP/2 - }, +// 全局 HTTP 客户端存储,支持运行时代理重配置 +var httpClientStore atomic.Pointer[http.Client] + +// httpClient 返回当前全局 auth HTTP 客户端 +func httpClient() *http.Client { + return httpClientStore.Load() +} + +func init() { + InitHttpClient("") +} + +// buildAuthTransport 构建带可选代理的 Transport +func buildAuthTransport(proxyURL string) *http.Transport { + t := &http.Transport{ + MaxIdleConns: 50, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + DisableCompression: false, + ForceAttemptHTTP2: true, + } + if proxyURL != "" { + if u, err := url.Parse(proxyURL); err == nil { + t.Proxy = http.ProxyURL(u) + t.ForceAttemptHTTP2 = false + } + } else { + t.Proxy = http.ProxyFromEnvironment + } + return t +} + +// InitHttpClient 初始化(或重新初始化)auth 模块的全局 HTTP 客户端 +func InitHttpClient(proxyURL string) { + client := &http.Client{ + Timeout: 30 * time.Second, + Transport: buildAuthTransport(proxyURL), + } + httpClientStore.Store(client) } diff --git a/auth/http_client_test.go b/auth/http_client_test.go new file mode 100644 index 0000000..3f5d505 --- /dev/null +++ b/auth/http_client_test.go @@ -0,0 +1,52 @@ +package auth + +import ( + "net/http" + "net/url" + "testing" +) + +func TestBuildAuthTransportUsesExplicitProxyURL(t *testing.T) { + transport := buildAuthTransport("http://proxy.local:8080") + req := &http.Request{URL: mustParseURL(t, "https://oidc.us-east-1.amazonaws.com")} + + got, err := transport.Proxy(req) + if err != nil { + t.Fatalf("unexpected proxy error: %v", err) + } + assertProxyURL(t, got, "http://proxy.local:8080") +} + +func TestBuildAuthTransportFallsBackToEnvironmentProxy(t *testing.T) { + t.Setenv("HTTPS_PROXY", "http://env-proxy.local:2323") + t.Setenv("NO_PROXY", "") + t.Setenv("no_proxy", "") + + transport := buildAuthTransport("") + req := &http.Request{URL: mustParseURL(t, "https://oidc.us-east-1.amazonaws.com")} + + got, err := transport.Proxy(req) + if err != nil { + t.Fatalf("unexpected proxy error: %v", err) + } + assertProxyURL(t, got, "http://env-proxy.local:2323") +} + +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + parsed, err := url.Parse(raw) + if err != nil { + t.Fatalf("invalid test URL: %v", err) + } + return parsed +} + +func assertProxyURL(t *testing.T, got *url.URL, want string) { + t.Helper() + if got == nil { + t.Fatalf("expected proxy URL %q, got nil", want) + } + if got.String() != want { + t.Fatalf("expected proxy URL %q, got %q", want, got.String()) + } +} diff --git a/auth/iam_sso.go b/auth/iam_sso.go index e17e4eb..bfd4a4a 100644 --- a/auth/iam_sso.go +++ b/auth/iam_sso.go @@ -170,7 +170,7 @@ func registerOIDCClient(oidcBase, startUrl, redirectUri string) (clientID, clien req, _ := http.NewRequest("POST", oidcBase+"/client/register", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - resp, err := httpClient.Do(req) + resp, err := httpClient().Do(req) if err != nil { return "", "", err } @@ -207,7 +207,7 @@ func exchangeToken(oidcBase, clientID, clientSecret, code, codeVerifier, redirec req, _ := http.NewRequest("POST", oidcBase+"/token", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - resp, err := httpClient.Do(req) + resp, err := httpClient().Do(req) if err != nil { return "", "", 0, err } diff --git a/auth/oidc.go b/auth/oidc.go index 40d3456..7dcb494 100644 --- a/auth/oidc.go +++ b/auth/oidc.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" "io" - "kiro-api-proxy/config" + "kiro-go/config" "net/http" "time" ) @@ -40,7 +40,7 @@ func refreshOIDCToken(refreshToken, clientID, clientSecret, region string) (stri req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - resp, err := httpClient.Do(req) + resp, err := httpClient().Do(req) if err != nil { return "", "", 0, err } @@ -77,7 +77,7 @@ func refreshSocialToken(refreshToken string) (string, string, int64, error) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - resp, err := httpClient.Do(req) + resp, err := httpClient().Do(req) if err != nil { return "", "", 0, err } diff --git a/auth/sso_token.go b/auth/sso_token.go index 22da746..dee0540 100644 --- a/auth/sso_token.go +++ b/auth/sso_token.go @@ -79,7 +79,7 @@ func registerDeviceClient(oidcBase, startUrl string) (clientID, clientSecret str req, _ := http.NewRequest("POST", oidcBase+"/client/register", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - client := httpClient + client := httpClient() resp, err := client.Do(req) if err != nil { return "", "", err @@ -110,7 +110,7 @@ func startDeviceAuth(oidcBase, clientID, clientSecret, startUrl string) (deviceC req, _ := http.NewRequest("POST", oidcBase+"/device_authorization", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - client := httpClient + client := httpClient() resp, err := client.Do(req) if err != nil { return "", "", 0, err @@ -139,7 +139,7 @@ func verifyBearerToken(portalBase, bearerToken string) error { req.Header.Set("Authorization", "Bearer "+bearerToken) req.Header.Set("Accept", "application/json") - client := httpClient + client := httpClient() resp, err := client.Do(req) if err != nil { return err @@ -157,7 +157,7 @@ func getDeviceSessionToken(portalBase, bearerToken string) (string, error) { req.Header.Set("Authorization", "Bearer "+bearerToken) req.Header.Set("Content-Type", "application/json") - client := httpClient + client := httpClient() resp, err := client.Do(req) if err != nil { return "", err @@ -193,7 +193,7 @@ func acceptUserCode(oidcBase, userCode, deviceSessionToken string) (*deviceConte req.Header.Set("Content-Type", "application/json") req.Header.Set("Referer", "https://view.awsapps.com/") - client := httpClient + client := httpClient() resp, err := client.Do(req) if err != nil { return nil, err @@ -227,7 +227,7 @@ func approveAuth(oidcBase string, deviceContext *deviceContextInfo, deviceSessio req.Header.Set("Content-Type", "application/json") req.Header.Set("Referer", "https://view.awsapps.com/") - client := httpClient + client := httpClient() resp, err := client.Do(req) if err != nil { return err @@ -262,7 +262,7 @@ func pollForToken(oidcBase, clientID, clientSecret, deviceCode string, interval req, _ := http.NewRequest("POST", oidcBase+"/token", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - client := httpClient + client := httpClient() resp, err := client.Do(req) if err != nil { continue @@ -311,7 +311,7 @@ func GetUserInfo(accessToken string) (email, userID string, err error) { req.Header.Set("User-Agent", "aws-sdk-js/1.0.18 KiroAPIProxy") req.Header.Set("x-amz-user-agent", "aws-sdk-js/1.0.18 KiroAPIProxy") - client := httpClient + client := httpClient() resp, err := client.Do(req) if err != nil { return "", "", err diff --git a/config/config.go b/config/config.go index 47a14e6..9b3dae7 100644 --- a/config/config.go +++ b/config/config.go @@ -15,6 +15,7 @@ import ( "encoding/json" "fmt" "os" + "runtime" "sync" ) @@ -49,6 +50,7 @@ type Account struct { StartUrl string `json:"startUrl,omitempty"` // AWS SSO start URL ExpiresAt int64 `json:"expiresAt,omitempty"` // Token expiration timestamp (Unix seconds) MachineId string `json:"machineId,omitempty"` // UUID machine identifier for request tracking + ProfileArn string `json:"profileArn,omitempty"` // CodeWhisperer/Kiro profile ARN for generation requests // Priority weight for load balancing (higher = more requests) Weight int `json:"weight,omitempty"` // 0 or 1 = normal, 2+ = higher priority @@ -98,7 +100,10 @@ type Config struct { Host string `json:"host"` // HTTP server bind address (default: 0.0.0.0) ApiKey string `json:"apiKey,omitempty"` // API key for client authentication RequireApiKey bool `json:"requireApiKey"` // Whether to enforce API key validation - Accounts []Account `json:"accounts"` // Registered Kiro accounts + KiroVersion string `json:"kiroVersion,omitempty"` + SystemVersion string `json:"systemVersion,omitempty"` + NodeVersion string `json:"nodeVersion,omitempty"` + Accounts []Account `json:"accounts"` // Registered Kiro accounts // Thinking mode configuration for extended reasoning output ThinkingSuffix string `json:"thinkingSuffix,omitempty"` // Model suffix to trigger thinking mode (default: "-thinking") @@ -108,6 +113,12 @@ type Config struct { // Endpoint configuration: "auto", "codewhisperer", or "amazonq" PreferredEndpoint string `json:"preferredEndpoint,omitempty"` + // Proxy configuration: optional outbound proxy for Kiro API requests + // Format: "socks5://host:port", "socks5://user:pass@host:port", + // "http://host:port", "http://user:pass@host:port" + // Leave empty to connect directly. + ProxyURL string `json:"proxyURL,omitempty"` + // Global statistics (persisted across restarts) TotalRequests int `json:"totalRequests,omitempty"` // Total API requests received SuccessRequests int `json:"successRequests,omitempty"` // Successful requests count @@ -136,8 +147,8 @@ type AccountInfo struct { TrialExpiresAt int64 } -// Version 当前版本号 -const Version = "1.0.3" +// Version current version +const Version = "1.0.6" var ( cfg *Config @@ -268,6 +279,18 @@ func UpdateAccount(id string, account Account) error { return nil } +func UpdateAccountProfileArn(id, profileArn string) error { + cfgLock.Lock() + defer cfgLock.Unlock() + for i, a := range cfg.Accounts { + if a.ID == id { + cfg.Accounts[i].ProfileArn = profileArn + return Save() + } + } + return nil +} + func DeleteAccount(id string) error { cfgLock.Lock() defer cfgLock.Unlock() @@ -444,3 +467,64 @@ func UpdatePreferredEndpoint(endpoint string) error { cfg.PreferredEndpoint = endpoint return Save() } + +// GetProxyURL 获取出站代理地址 +func GetProxyURL() string { + cfgLock.RLock() + defer cfgLock.RUnlock() + return cfg.ProxyURL +} + +// UpdateProxySettings 更新出站代理配置 +func UpdateProxySettings(proxyURL string) error { + cfgLock.Lock() + defer cfgLock.Unlock() + cfg.ProxyURL = proxyURL + return Save() +} + +type KiroClientConfig struct { + KiroVersion string + SystemVersion string + NodeVersion string +} + +func GetKiroClientConfig() KiroClientConfig { + cfgLock.RLock() + defer cfgLock.RUnlock() + + kiroVersion := "0.11.107" + if cfg != nil && cfg.KiroVersion != "" { + kiroVersion = cfg.KiroVersion + } + + systemVersion := "" + if cfg != nil { + systemVersion = cfg.SystemVersion + } + if systemVersion == "" { + systemVersion = defaultSystemVersion() + } + + nodeVersion := "22.22.0" + if cfg != nil && cfg.NodeVersion != "" { + nodeVersion = cfg.NodeVersion + } + + return KiroClientConfig{ + KiroVersion: kiroVersion, + SystemVersion: systemVersion, + NodeVersion: nodeVersion, + } +} + +func defaultSystemVersion() string { + switch runtime.GOOS { + case "windows": + return "win32#10.0.22631" + case "darwin": + return "darwin#24.6.0" + default: + return "linux#6.6.87" + } +} diff --git a/go.mod b/go.mod index f1bd668..4be296b 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module kiro-api-proxy +module kiro-go go 1.21 diff --git a/main.go b/main.go index 63631ef..99de1c3 100644 --- a/main.go +++ b/main.go @@ -15,9 +15,9 @@ package main import ( "fmt" - "kiro-api-proxy/config" - "kiro-api-proxy/pool" - "kiro-api-proxy/proxy" + "kiro-go/config" + "kiro-go/pool" + "kiro-go/proxy" "log" "net/http" "os" diff --git a/pool/account.go b/pool/account.go index c2f06f3..bba2fad 100644 --- a/pool/account.go +++ b/pool/account.go @@ -3,7 +3,7 @@ package pool import ( - "kiro-api-proxy/config" + "kiro-go/config" "sync" "sync/atomic" "time" diff --git a/proxy/cache_tracker.go b/proxy/cache_tracker.go new file mode 100644 index 0000000..024e4d3 --- /dev/null +++ b/proxy/cache_tracker.go @@ -0,0 +1,621 @@ +package proxy + +import ( + "bytes" + "crypto/sha256" + "encoding/json" + "sort" + "strconv" + "strings" + "sync" + "time" +) + +const defaultPromptCacheTTL = 5 * time.Minute + +// Anthropic requires cached prefixes to reach a minimum token count before +// caching takes effect. Breakpoints below this threshold are excluded from +// matching and storage to avoid reporting unrealistic 100% cache hits on +// short requests. +const defaultMinCacheableTokens = 1024 +const opusMinCacheableTokens = 4096 + +type promptCacheUsage struct { + CacheCreationInputTokens int + CacheReadInputTokens int + CacheCreation5mInputTokens int + CacheCreation1hInputTokens int +} + +type promptCacheBreakpoint struct { + Fingerprint [32]byte + CumulativeTokens int + TTL time.Duration +} + +type promptCacheProfile struct { + Breakpoints []promptCacheBreakpoint + TotalInputTokens int + Model string +} + +func minCacheableTokensForModel(model string) int { + lower := strings.ToLower(model) + if strings.Contains(lower, "opus") { + return opusMinCacheableTokens + } + return defaultMinCacheableTokens +} + +type promptCacheEntry struct { + ExpiresAt time.Time + TTL time.Duration +} + +type promptCacheTracker struct { + mu sync.Mutex + entriesByAccount map[string]map[[32]byte]promptCacheEntry + maxSupportedTTL time.Duration +} + +func newPromptCacheTracker(maxTTL time.Duration) *promptCacheTracker { + if maxTTL <= 0 { + maxTTL = defaultPromptCacheTTL + } + return &promptCacheTracker{ + entriesByAccount: make(map[string]map[[32]byte]promptCacheEntry), + maxSupportedTTL: maxTTL, + } +} + +func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTokens int) *promptCacheProfile { + blocks := flattenClaudeCacheBlocks(req) + if len(blocks) == 0 { + return nil + } + + hasher := sha256.New() + breakpoints := make([]promptCacheBreakpoint, 0) + cumulativeTokens := 0 + var activeTTL time.Duration + + for _, block := range blocks { + canonical := canonicalizeCacheValue(block.Value) + writeHashChunk(hasher, canonical) + cumulativeTokens += block.Tokens + + // Determine whether this block acts as a cache breakpoint: + // 1) Explicit cache_control on the block itself. + // 2) Once any explicit breakpoint has been seen, every message-end + // boundary becomes an implicit breakpoint so that multi-turn + // conversations can hit earlier stored prefixes. + breakpointTTL := time.Duration(0) + if block.TTL > 0 { + breakpointTTL = block.TTL + activeTTL = block.TTL + } else if block.IsMessageEnd && activeTTL > 0 { + breakpointTTL = activeTTL + } + + if breakpointTTL <= 0 { + continue + } + + var fingerprint [32]byte + copy(fingerprint[:], hasher.Sum(nil)) + breakpoints = append(breakpoints, promptCacheBreakpoint{ + Fingerprint: fingerprint, + CumulativeTokens: cumulativeTokens, + TTL: breakpointTTL, + }) + } + + if len(breakpoints) == 0 { + return nil + } + + if totalInputTokens < cumulativeTokens { + totalInputTokens = cumulativeTokens + } + + return &promptCacheProfile{ + Breakpoints: breakpoints, + TotalInputTokens: totalInputTokens, + Model: req.Model, + } +} + +func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfile) promptCacheUsage { + if t == nil || profile == nil || len(profile.Breakpoints) == 0 || accountID == "" { + return promptCacheUsage{} + } + + minTokens := minCacheableTokensForModel(profile.Model) + last := profile.Breakpoints[len(profile.Breakpoints)-1] + lastTokens := minInt(last.CumulativeTokens, profile.TotalInputTokens) + now := time.Now() + + t.mu.Lock() + defer t.mu.Unlock() + t.pruneExpiredLocked(now) + + entries := t.entriesByAccount[accountID] + if len(entries) == 0 { + // First request for this account: report creation only if above threshold. + effectiveCreation := lastTokens + if effectiveCreation < minTokens { + effectiveCreation = 0 + } + cache5m, cache1h := computePromptCacheTTLBreakdown(profile, 0) + return promptCacheUsage{ + CacheCreationInputTokens: effectiveCreation, + CacheReadInputTokens: 0, + CacheCreation5mInputTokens: cache5m, + CacheCreation1hInputTokens: cache1h, + } + } + + // Cap cacheable tokens at 85% of total input to ensure a realistic + // uncached portion. The newest content in a request is never fully + // served from cache on the current turn. + maxCacheable := int(float64(profile.TotalInputTokens) * 0.85) + if lastTokens > maxCacheable { + lastTokens = maxCacheable + } + + matchedTokens := 0 + for i := len(profile.Breakpoints) - 1; i >= 0; i-- { + breakpoint := profile.Breakpoints[i] + // Skip breakpoints below the minimum cacheable token threshold. + if breakpoint.CumulativeTokens < minTokens { + continue + } + entry, ok := entries[breakpoint.Fingerprint] + if !ok || entry.ExpiresAt.Before(now) { + continue + } + entry.ExpiresAt = now.Add(entry.TTL) + entries[breakpoint.Fingerprint] = entry + matchedTokens = minInt(breakpoint.CumulativeTokens, profile.TotalInputTokens) + if matchedTokens > lastTokens { + matchedTokens = lastTokens + } + break + } + + creation := maxInt(lastTokens-matchedTokens, 0) + cache5m, cache1h := computePromptCacheTTLBreakdown(profile, matchedTokens) + return promptCacheUsage{ + CacheCreationInputTokens: creation, + CacheReadInputTokens: matchedTokens, + CacheCreation5mInputTokens: cache5m, + CacheCreation1hInputTokens: cache1h, + } +} + +func (t *promptCacheTracker) Update(accountID string, profile *promptCacheProfile) { + if t == nil || profile == nil || len(profile.Breakpoints) == 0 || accountID == "" { + return + } + + minTokens := minCacheableTokensForModel(profile.Model) + now := time.Now() + t.mu.Lock() + defer t.mu.Unlock() + t.pruneExpiredLocked(now) + + entries := t.entriesByAccount[accountID] + if entries == nil { + entries = make(map[[32]byte]promptCacheEntry) + t.entriesByAccount[accountID] = entries + } + + for _, breakpoint := range profile.Breakpoints { + // Skip breakpoints below the minimum cacheable token threshold. + if breakpoint.CumulativeTokens < minTokens { + continue + } + entries[breakpoint.Fingerprint] = promptCacheEntry{ + ExpiresAt: now.Add(breakpoint.TTL), + TTL: breakpoint.TTL, + } + } +} + +func (t *promptCacheTracker) pruneExpiredLocked(now time.Time) { + for accountID, entries := range t.entriesByAccount { + for fingerprint, entry := range entries { + if !entry.ExpiresAt.After(now) { + delete(entries, fingerprint) + } + } + if len(entries) == 0 { + delete(t.entriesByAccount, accountID) + } + } +} + +type cacheablePromptBlock struct { + Value interface{} + Tokens int + TTL time.Duration + IsMessageEnd bool +} + +func flattenClaudeCacheBlocks(req *ClaudeRequest) []cacheablePromptBlock { + blocks := make([]cacheablePromptBlock, 0) + blocks = append(blocks, buildCachePreludeBlock(req)) + + for toolIndex, tool := range req.Tools { + toolValue := map[string]interface{}{ + "kind": "tool", + "tool_index": toolIndex, + "name": tool.Name, + "description": tool.Description, + "input_schema": tool.InputSchema, + } + fingerprintValue := stripCachePositionKeys(toolValue) + blocks = append(blocks, cacheablePromptBlock{ + Value: fingerprintValue, + Tokens: estimateApproxTokens(canonicalizeCacheValue(fingerprintValue)), + TTL: normalizePromptCacheTTL(extractPromptCacheTTL(tool)), + }) + } + + appendSystemCacheBlocks(&blocks, req.System) + + for messageIndex, msg := range req.Messages { + appendMessageCacheBlocks(&blocks, messageIndex, msg) + } + + return blocks +} + +func buildCachePreludeBlock(req *ClaudeRequest) cacheablePromptBlock { + prelude := map[string]interface{}{ + "kind": "request_prelude", + "model": req.Model, + "tool_choice": req.ToolChoice, + } + return cacheablePromptBlock{ + Value: prelude, + Tokens: estimateApproxTokens(canonicalizeCacheValue(prelude)), + } +} + +func appendSystemCacheBlocks(blocks *[]cacheablePromptBlock, system interface{}) { + switch v := system.(type) { + case string: + appendPromptBlock(blocks, map[string]interface{}{ + "kind": "system", + "system_index": 0, + "block": map[string]interface{}{ + "type": "text", + "text": v, + }, + }, false) + case []interface{}: + for i, block := range v { + appendPromptBlock(blocks, map[string]interface{}{ + "kind": "system", + "system_index": i, + "block": block, + }, false) + } + case []string: + for i, block := range v { + appendPromptBlock(blocks, map[string]interface{}{ + "kind": "system", + "system_index": i, + "block": map[string]interface{}{ + "type": "text", + "text": block, + }, + }, false) + } + } +} + +func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int, msg ClaudeMessage) { + role := msg.Role + switch content := msg.Content.(type) { + case string: + appendPromptBlock(blocks, map[string]interface{}{ + "kind": "message", + "message_index": messageIndex, + "role": role, + "block_index": 0, + "block": map[string]interface{}{ + "type": "text", + "text": content, + }, + }, true) + case []interface{}: + lastIdx := len(content) - 1 + for blockIndex, block := range content { + appendPromptBlock(blocks, map[string]interface{}{ + "kind": "message", + "message_index": messageIndex, + "role": role, + "block_index": blockIndex, + "block": block, + }, blockIndex == lastIdx) + } + default: + if content != nil { + appendPromptBlock(blocks, map[string]interface{}{ + "kind": "message", + "message_index": messageIndex, + "role": role, + "block_index": 0, + "block": content, + }, true) + } + } +} + +func appendPromptBlock(blocks *[]cacheablePromptBlock, wrapper map[string]interface{}, isMessageEnd bool) { + blockValue := wrapper["block"] + ttl := normalizePromptCacheTTL(extractPromptCacheTTL(blockValue)) + + // Drop volatile billing metadata from the cache fingerprint. Claude Code's + // x-anthropic-billing-header can drift, appear, or disappear across + // otherwise identical requests, and it does not change model semantics. + if isAnthropicBillingHeaderBlock(blockValue) { + return + } + + fingerprintValue := stripCachePositionKeys(wrapper) + canonical := canonicalizeCacheValue(fingerprintValue) + *blocks = append(*blocks, cacheablePromptBlock{ + Value: fingerprintValue, + Tokens: estimateApproxTokens(canonical), + TTL: ttl, + IsMessageEnd: isMessageEnd, + }) +} + +func stripCachePositionKeys(value map[string]interface{}) map[string]interface{} { + cloned := make(map[string]interface{}, len(value)) + for key, item := range value { + if isCachePositionKey(key) { + continue + } + cloned[key] = item + } + return cloned +} + +func isAnthropicBillingHeaderBlock(value interface{}) bool { + blockMap, ok := value.(map[string]interface{}) + if !ok { + return false + } + + // Only normalize text blocks (or blocks without an explicit type but containing text). + if t, ok := blockMap["type"].(string); ok && t != "" && t != "text" { + return false + } + + text, ok := blockMap["text"].(string) + if !ok { + return false + } + + trimmed := strings.TrimLeft(text, " \t\r\n") + return strings.HasPrefix(strings.ToLower(trimmed), "x-anthropic-billing-header:") +} + +func extractPromptCacheTTL(value interface{}) time.Duration { + block, ok := value.(map[string]interface{}) + if !ok { + if raw, err := json.Marshal(value); err == nil { + var decoded map[string]interface{} + if json.Unmarshal(raw, &decoded) == nil { + block = decoded + ok = true + } + } + } + if !ok { + return 0 + } + + rawCache, ok := block["cache_control"] + if !ok { + return 0 + } + cacheControl, ok := rawCache.(map[string]interface{}) + if !ok { + return 0 + } + cacheType, _ := cacheControl["type"].(string) + if !strings.EqualFold(cacheType, "ephemeral") { + return 0 + } + + if ttl, ok := parsePromptCacheTTLValue(cacheControl["ttl"]); ok { + return ttl + } + return defaultPromptCacheTTL +} + +func parsePromptCacheTTLValue(value interface{}) (time.Duration, bool) { + switch v := value.(type) { + case string: + trimmed := strings.TrimSpace(strings.ToLower(v)) + if trimmed == "" { + return 0, false + } + if d, err := time.ParseDuration(trimmed); err == nil { + return d, true + } + if seconds, err := strconv.Atoi(trimmed); err == nil { + return time.Duration(seconds) * time.Second, true + } + case float64: + if v > 0 { + return time.Duration(v) * time.Second, true + } + case int: + if v > 0 { + return time.Duration(v) * time.Second, true + } + case int64: + if v > 0 { + return time.Duration(v) * time.Second, true + } + } + return 0, false +} + +func normalizePromptCacheTTL(ttl time.Duration) time.Duration { + if ttl <= 0 { + return 0 + } + if ttl > time.Hour { + return time.Hour + } + if ttl > defaultPromptCacheTTL { + return time.Hour + } + return defaultPromptCacheTTL +} + +func computePromptCacheTTLBreakdown(profile *promptCacheProfile, matchedTokens int) (int, int) { + if profile == nil || len(profile.Breakpoints) == 0 { + return 0, 0 + } + + cache5m := 0 + cache1h := 0 + previous := matchedTokens + for _, breakpoint := range profile.Breakpoints { + current := minInt(breakpoint.CumulativeTokens, profile.TotalInputTokens) + if current <= previous { + continue + } + delta := current - previous + if breakpoint.TTL >= time.Hour { + cache1h += delta + } else { + cache5m += delta + } + previous = current + } + return cache5m, cache1h +} + +func billedClaudeInputTokens(inputTokens int, usage promptCacheUsage) int { + return maxInt(inputTokens-usage.CacheCreationInputTokens-usage.CacheReadInputTokens, 0) +} + +func buildClaudeUsageMap(inputTokens, outputTokens int, usage promptCacheUsage, includeCache bool) map[string]interface{} { + result := map[string]interface{}{ + "input_tokens": billedClaudeInputTokens(inputTokens, usage), + "output_tokens": outputTokens, + } + if !includeCache { + return result + } + result["cache_creation_input_tokens"] = usage.CacheCreationInputTokens + result["cache_read_input_tokens"] = usage.CacheReadInputTokens + result["cache_creation"] = map[string]int{ + "ephemeral_5m_input_tokens": usage.CacheCreation5mInputTokens, + "ephemeral_1h_input_tokens": usage.CacheCreation1hInputTokens, + } + return result +} + +func canonicalizeCacheValue(value interface{}) string { + var buf bytes.Buffer + writeCanonicalJSON(&buf, value) + return buf.String() +} + +func writeCanonicalJSON(buf *bytes.Buffer, value interface{}) { + switch v := value.(type) { + case nil: + buf.WriteString("null") + case string: + encoded, _ := json.Marshal(v) + buf.Write(encoded) + case bool: + if v { + buf.WriteString("true") + } else { + buf.WriteString("false") + } + case float64, float32, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, json.Number: + encoded, _ := json.Marshal(v) + buf.Write(encoded) + case []interface{}: + buf.WriteByte('[') + for i, item := range v { + if i > 0 { + buf.WriteByte(',') + } + writeCanonicalJSON(buf, item) + } + buf.WriteByte(']') + case map[string]interface{}: + buf.WriteByte('{') + keys := make([]string, 0, len(v)) + for key := range v { + if key == "cache_control" { + continue + } + keys = append(keys, key) + } + sort.Strings(keys) + for i, key := range keys { + if i > 0 { + buf.WriteByte(',') + } + encoded, _ := json.Marshal(key) + buf.Write(encoded) + buf.WriteByte(':') + writeCanonicalJSON(buf, v[key]) + } + buf.WriteByte('}') + default: + encoded, _ := json.Marshal(v) + buf.Write(encoded) + } +} + +func isCachePositionKey(key string) bool { + switch key { + case "tool_index", "system_index", "message_index", "block_index": + return true + default: + return false + } +} + +func writeHashChunk(hasher hashWriter, chunk string) { + length := strconv.Itoa(len(chunk)) + hasher.Write([]byte(length)) + hasher.Write([]byte{0}) + hasher.Write([]byte(chunk)) + hasher.Write([]byte{0}) +} + +type hashWriter interface { + Write([]byte) (int, error) + Sum([]byte) []byte +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} + +func maxInt(a, b int) int { + if a > b { + return a + } + return b +} diff --git a/proxy/cache_tracker_test.go b/proxy/cache_tracker_test.go new file mode 100644 index 0000000..6b0262c --- /dev/null +++ b/proxy/cache_tracker_test.go @@ -0,0 +1,264 @@ +package proxy + +import ( + "strings" + "testing" + "time" +) + +func TestPromptCacheTrackerComputeAndUpdate(t *testing.T) { + tracker := newPromptCacheTracker(time.Hour) + longSystem := strings.Repeat("You are a helpful coding assistant with deep knowledge of Go, Rust, Python, and TypeScript. ", 80) + req := &ClaudeRequest{ + Model: "claude-sonnet-4.5", + System: []interface{}{ + map[string]interface{}{ + "type": "text", + "text": longSystem, + "cache_control": map[string]interface{}{ + "type": "ephemeral", + }, + }, + }, + Messages: []ClaudeMessage{{Role: "user", Content: "hello world"}}, + } + + profile := tracker.BuildClaudeProfile(req, 120) + if profile == nil { + t.Fatalf("expected cache profile to be built") + } + + first := tracker.Compute("acct-1", profile) + if first.CacheCreationInputTokens <= 0 { + t.Fatalf("expected first request to create cache tokens, got %+v", first) + } + if first.CacheReadInputTokens != 0 { + t.Fatalf("expected first request to have zero cache reads, got %+v", first) + } + + tracker.Update("acct-1", profile) + second := tracker.Compute("acct-1", profile) + if second.CacheReadInputTokens <= 0 { + t.Fatalf("expected repeated request to read cache tokens, got %+v", second) + } + if second.CacheCreationInputTokens != 0 { + t.Fatalf("expected repeated request to avoid cache creation, got %+v", second) + } +} + +func TestBuildClaudeUsageMapIncludesCacheFields(t *testing.T) { + usage := promptCacheUsage{ + CacheCreationInputTokens: 30, + CacheReadInputTokens: 20, + CacheCreation5mInputTokens: 10, + CacheCreation1hInputTokens: 20, + } + + m := buildClaudeUsageMap(100, 50, usage, true) + + if got := m["input_tokens"]; got != 50 { + t.Fatalf("expected billed input tokens 50, got %#v", got) + } + if got := m["cache_creation_input_tokens"]; got != 30 { + t.Fatalf("expected cache creation tokens 30, got %#v", got) + } + if got := m["cache_read_input_tokens"]; got != 20 { + t.Fatalf("expected cache read tokens 20, got %#v", got) + } + creation, ok := m["cache_creation"].(map[string]int) + if !ok { + t.Fatalf("expected typed cache creation map, got %#v", m["cache_creation"]) + } + if creation["ephemeral_5m_input_tokens"] != 10 || creation["ephemeral_1h_input_tokens"] != 20 { + t.Fatalf("unexpected ttl breakdown: %#v", creation) + } +} + +// TestPromptCacheStableAcrossBillingHeaderDrift verifies that Claude Code's +// per-request "x-anthropic-billing-header: cc_version=...; cch=...;" system +// block (whose content drifts on every request) does not break cache hits. +// The tracker should ignore that metadata when fingerprinting cached prefixes. +func TestPromptCacheStableAcrossBillingHeaderDrift(t *testing.T) { + tracker := newPromptCacheTracker(time.Hour) + mainSystem := strings.Repeat("You are a helpful coding assistant with deep knowledge of Go, Rust, Python, and TypeScript. ", 80) + + build := func(billingHdr string) *ClaudeRequest { + return &ClaudeRequest{ + Model: "claude-sonnet-4.5", + System: []interface{}{ + map[string]interface{}{ + "type": "text", + "text": billingHdr, + }, + map[string]interface{}{ + "type": "text", + "text": mainSystem, + "cache_control": map[string]interface{}{ + "type": "ephemeral", + }, + }, + }, + Messages: []ClaudeMessage{{Role: "user", Content: "hello world"}}, + } + } + + req1 := build("x-anthropic-billing-header: cc_version=2.1.87.1; cch=aaaa;") + profile1 := tracker.BuildClaudeProfile(req1, 2048) + if profile1 == nil { + t.Fatalf("profile1 should be built") + } + first := tracker.Compute("acct-1", profile1) + if first.CacheReadInputTokens != 0 { + t.Fatalf("expected no cache read on first request, got %+v", first) + } + tracker.Update("acct-1", profile1) + + req2 := build("x-anthropic-billing-header: cc_version=2.1.87.42; cch=bbbb; padding=xxyyzz;") + profile2 := tracker.BuildClaudeProfile(req2, 2048) + if profile2 == nil { + t.Fatalf("profile2 should be built") + } + second := tracker.Compute("acct-1", profile2) + if second.CacheReadInputTokens == 0 { + t.Fatalf("expected cache read after billing header drift, got %+v", second) + } +} + +func TestPromptCacheStableWhenBillingHeaderAppearsOrDisappears(t *testing.T) { + tracker := newPromptCacheTracker(time.Hour) + mainSystem := strings.Repeat("You are a helpful coding assistant with deep knowledge of Go, Rust, Python, and TypeScript. ", 80) + + build := func(includeBilling bool) *ClaudeRequest { + system := []interface{}{} + if includeBilling { + system = append(system, map[string]interface{}{ + "type": "text", + "text": "x-anthropic-billing-header: cc_version=2.1.87.1; cch=aaaa;", + }) + } + system = append(system, map[string]interface{}{ + "type": "text", + "text": mainSystem, + "cache_control": map[string]interface{}{ + "type": "ephemeral", + }, + }) + return &ClaudeRequest{ + Model: "claude-sonnet-4.5", + System: system, + Messages: []ClaudeMessage{{Role: "user", Content: "hello world"}}, + } + } + + withBilling := tracker.BuildClaudeProfile(build(true), 2048) + if withBilling == nil { + t.Fatalf("profile with billing header should be built") + } + tracker.Update("acct-1", withBilling) + + withoutBilling := tracker.BuildClaudeProfile(build(false), 2048) + if withoutBilling == nil { + t.Fatalf("profile without billing header should be built") + } + result := tracker.Compute("acct-1", withoutBilling) + if result.CacheReadInputTokens == 0 { + t.Fatalf("expected cache read when billing header disappears, got %+v", result) + } +} + +func TestCanonicalCacheValueIgnoresPositionKeys(t *testing.T) { + first := canonicalizeCacheValue(stripCachePositionKeys(map[string]interface{}{ + "kind": "system", + "system_index": 0, + "block": map[string]interface{}{ + "type": "text", + "text": "stable", + }, + })) + second := canonicalizeCacheValue(stripCachePositionKeys(map[string]interface{}{ + "kind": "system", + "system_index": 1, + "block": map[string]interface{}{ + "type": "text", + "text": "stable", + }, + })) + if first != second { + t.Fatalf("expected position keys to be ignored, got %q vs %q", first, second) + } +} + +func TestCanonicalCacheValuePreservesSemanticPositionKeys(t *testing.T) { + first := canonicalizeCacheValue(map[string]interface{}{ + "kind": "system", + "block": map[string]interface{}{ + "type": "text", + "text": "stable", + "block_index": 1, + }, + }) + second := canonicalizeCacheValue(map[string]interface{}{ + "kind": "system", + "block": map[string]interface{}{ + "type": "text", + "text": "stable", + "block_index": 2, + }, + }) + if first == second { + t.Fatalf("expected semantic block_index fields to remain fingerprinted") + } +} + +// TestPromptCacheImplicitBreakpointAtMessageEnd verifies that once any +// explicit cache_control breakpoint has been seen, subsequent message-end +// boundaries act as implicit breakpoints. This allows multi-turn conversations +// to hit earlier stored prefix fingerprints even when the newest messages +// lack explicit cache_control. +func TestPromptCacheImplicitBreakpointAtMessageEnd(t *testing.T) { + tracker := newPromptCacheTracker(time.Hour) + systemText := strings.Repeat("You are a helpful coding assistant with deep knowledge of Go, Rust, Python, and TypeScript. ", 80) + + baseSystem := []interface{}{ + map[string]interface{}{ + "type": "text", + "text": systemText, + "cache_control": map[string]interface{}{ + "type": "ephemeral", + }, + }, + } + + // Round 1: single user message. + req1 := &ClaudeRequest{ + Model: "claude-sonnet-4.5", + System: baseSystem, + Messages: []ClaudeMessage{{Role: "user", Content: "question one"}}, + } + profile1 := tracker.BuildClaudeProfile(req1, 2048) + if profile1 == nil { + t.Fatalf("profile1 should be built") + } + tracker.Update("acct-1", profile1) + + // Round 2: conversation continues with new messages. The latest user + // message has no explicit cache_control; it should still hit the stored + // prefix via the implicit message-end breakpoint. + req2 := &ClaudeRequest{ + Model: "claude-sonnet-4.5", + System: baseSystem, + Messages: []ClaudeMessage{ + {Role: "user", Content: "question one"}, + {Role: "assistant", Content: "answer one"}, + {Role: "user", Content: "follow-up question"}, + }, + } + profile2 := tracker.BuildClaudeProfile(req2, 4096) + if profile2 == nil { + t.Fatalf("profile2 should be built") + } + result := tracker.Compute("acct-1", profile2) + if result.CacheReadInputTokens == 0 { + t.Fatalf("expected cache read via implicit message-end breakpoint, got %+v", result) + } +} diff --git a/proxy/handler.go b/proxy/handler.go index a610db5..a4c1098 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -4,9 +4,9 @@ import ( "encoding/json" "fmt" "io" - "kiro-api-proxy/auth" - "kiro-api-proxy/config" - "kiro-api-proxy/pool" + "kiro-go/auth" + "kiro-go/config" + "kiro-go/pool" "net/http" "strings" "sync" @@ -33,6 +33,7 @@ type Handler struct { cachedModels []ModelInfo modelsCacheMu sync.RWMutex modelsCacheTime int64 + promptCache *promptCacheTracker } type thinkingStreamSource int @@ -61,7 +62,153 @@ func allowTagSource(source *thinkingStreamSource) bool { return *source == thinkingSourceTagBlock } +func validateClaudeRequestShape(req *ClaudeRequest) string { + if len(req.Messages) == 0 { + return "messages must not be empty" + } + if msg := validateClaudeThinkingConfig(req.Thinking, req.MaxTokens); msg != "" { + return msg + } + + hasUserContext := false + lastRole := "" + for _, msg := range req.Messages { + role := strings.TrimSpace(msg.Role) + if role == "" { + continue + } + lastRole = role + if role != "user" { + continue + } + + text, images, toolResults := extractClaudeUserContent(msg.Content) + if normalizeUserContent(text, len(images) > 0) != "" || len(toolResults) > 0 { + hasUserContext = true + } + } + + if lastRole == "assistant" { + return "assistant-prefill final message is not supported; last message must be user" + } + if !hasUserContext { + return "at least one non-empty user message is required" + } + return "" +} + +func validateClaudeThinkingConfig(thinking *ClaudeThinkingConfig, maxTokens int) string { + if thinking == nil { + return "" + } + + kind := strings.ToLower(strings.TrimSpace(thinking.Type)) + switch kind { + case "enabled": + if maxTokens == 0 { + return "thinking.type enabled cannot be used with max_tokens=0" + } + if thinking.BudgetTokens <= 0 { + return "thinking.budget_tokens is required when thinking.type is enabled" + } + if thinking.BudgetTokens < 1024 { + return "thinking.budget_tokens must be at least 1024" + } + if maxTokens > 0 && thinking.BudgetTokens >= maxTokens { + return "thinking.budget_tokens must be less than max_tokens" + } + case "adaptive": + if thinking.BudgetTokens != 0 { + return "thinking.budget_tokens is not supported when thinking.type is adaptive" + } + case "disabled": + if thinking.BudgetTokens != 0 { + return "thinking.budget_tokens is not supported when thinking.type is disabled" + } + default: + return "thinking.type must be one of: enabled, adaptive, disabled" + } + + display := strings.ToLower(strings.TrimSpace(thinking.Display)) + if display != "" && display != "summarized" && display != "omitted" { + return "thinking.display must be one of: summarized, omitted" + } + if kind == "disabled" && display != "" { + return "thinking.display is not supported when thinking.type is disabled" + } + + return "" +} + +type claudeThinkingResponseOptions struct { + Format string + OmitDisplay bool +} + +func resolveClaudeThinkingResponseOptions(thinking *ClaudeThinkingConfig, defaultFormat string) claudeThinkingResponseOptions { + opts := claudeThinkingResponseOptions{Format: defaultFormat} + if opts.Format == "" { + opts.Format = "thinking" + } + if thinking == nil { + return opts + } + + display := strings.ToLower(strings.TrimSpace(thinking.Display)) + switch display { + case "summarized": + opts.Format = "thinking" + case "omitted": + opts.Format = "thinking" + opts.OmitDisplay = true + } + + return opts +} + +func validateOpenAIRequestShape(req *OpenAIRequest) string { + if len(req.Messages) == 0 { + return "messages must not be empty" + } + + hasNonSystem := false + hasUserContext := false + lastRole := "" + for _, msg := range req.Messages { + role := strings.TrimSpace(msg.Role) + if role == "" { + continue + } + if role != "system" { + hasNonSystem = true + lastRole = role + } + + if role != "user" { + continue + } + text, images := extractOpenAIUserContent(msg.Content) + if normalizeUserContent(text, len(images) > 0) != "" { + hasUserContext = true + } + } + + if !hasNonSystem { + return "at least one non-system message is required" + } + if lastRole == "assistant" { + return "assistant-prefill final message is not supported; last message must be user or tool" + } + if !hasUserContext { + return "at least one non-empty user message is required" + } + return "" +} + func NewHandler() *Handler { + // 启动时应用代理配置 + applyProxyConfig(config.GetProxyURL()) + totalReq, successReq, failedReq, totalTokens, totalCredits := config.GetStats() h := &Handler{ pool: pool.GetPool(), @@ -73,6 +220,7 @@ func NewHandler() *Handler { startTime: time.Now().Unix(), stopRefresh: make(chan struct{}), stopStatsSaver: make(chan struct{}), + promptCache: newPromptCacheTracker(defaultPromptCacheTTL), } // 启动后台刷新 go h.backgroundRefresh() @@ -268,36 +416,20 @@ func (h *Handler) handleModels(w http.ResponseWriter, r *http.Request) { h.modelsCacheMu.RLock() cached := h.cachedModels h.modelsCacheMu.RUnlock() + if len(cached) == 0 { + h.refreshModelsCache() + h.modelsCacheMu.RLock() + cached = h.cachedModels + h.modelsCacheMu.RUnlock() + } thinkingSuffix := config.GetThinkingConfig().Suffix - var models []map[string]interface{} - if len(cached) > 0 { - for _, m := range cached { - supportsImage := modelSupportsImage(m.InputTypes) - models = append(models, buildModelInfo(m.ModelId, "anthropic", supportsImage)) - // 自动生成 thinking 变体 - models = append(models, buildModelInfo(m.ModelId+thinkingSuffix, "anthropic", supportsImage)) - } - } else { - // fallback 静态列表 - models = []map[string]interface{}{ - buildModelInfo("claude-sonnet-4.6", "anthropic", true), - buildModelInfo("claude-sonnet-4.6"+thinkingSuffix, "anthropic", true), - buildModelInfo("claude-opus-4.6", "anthropic", true), - buildModelInfo("claude-opus-4.6"+thinkingSuffix, "anthropic", true), - buildModelInfo("claude-opus-4-7", "anthropic", true), - buildModelInfo("claude-opus-4-7"+thinkingSuffix, "anthropic", true), - buildModelInfo("claude-sonnet-4.5", "anthropic", true), - buildModelInfo("claude-sonnet-4.5"+thinkingSuffix, "anthropic", true), - buildModelInfo("claude-sonnet-4", "anthropic", true), - buildModelInfo("claude-sonnet-4"+thinkingSuffix, "anthropic", true), - buildModelInfo("claude-haiku-4.5", "anthropic", true), - buildModelInfo("claude-haiku-4.5"+thinkingSuffix, "anthropic", true), - buildModelInfo("claude-opus-4.5", "anthropic", true), - buildModelInfo("claude-opus-4.5"+thinkingSuffix, "anthropic", true), - } + models := buildAnthropicModelsResponse(cached, thinkingSuffix) + if len(models) == 0 { + models = fallbackAnthropicModels(thinkingSuffix) } + // 添加别名模型 models = append(models, buildModelInfo("auto", "kiro-proxy", true), @@ -310,6 +442,43 @@ func (h *Handler) handleModels(w http.ResponseWriter, r *http.Request) { "object": "list", "data": models, }) + return +} + +func buildAnthropicModelsResponse(cached []ModelInfo, thinkingSuffix string) []map[string]interface{} { + if len(cached) == 0 { + return nil + } + + models := make([]map[string]interface{}, 0, len(cached)*2) + if len(cached) > 0 { + for _, m := range cached { + supportsImage := modelSupportsImage(m.InputTypes) + models = append(models, buildModelInfo(m.ModelId, "anthropic", supportsImage)) + // 自动生成 thinking 变体 + models = append(models, buildModelInfo(m.ModelId+thinkingSuffix, "anthropic", supportsImage)) + } + } + return models +} + +func fallbackAnthropicModels(thinkingSuffix string) []map[string]interface{} { + return []map[string]interface{}{ + buildModelInfo("claude-sonnet-4.6", "anthropic", true), + buildModelInfo("claude-sonnet-4.6"+thinkingSuffix, "anthropic", true), + buildModelInfo("claude-opus-4.6", "anthropic", true), + buildModelInfo("claude-opus-4.6"+thinkingSuffix, "anthropic", true), + buildModelInfo("claude-opus-4.7", "anthropic", true), + buildModelInfo("claude-opus-4.7"+thinkingSuffix, "anthropic", true), + buildModelInfo("claude-sonnet-4.5", "anthropic", true), + buildModelInfo("claude-sonnet-4.5"+thinkingSuffix, "anthropic", true), + buildModelInfo("claude-sonnet-4", "anthropic", true), + buildModelInfo("claude-sonnet-4"+thinkingSuffix, "anthropic", true), + buildModelInfo("claude-haiku-4.5", "anthropic", true), + buildModelInfo("claude-haiku-4.5"+thinkingSuffix, "anthropic", true), + buildModelInfo("claude-opus-4.5", "anthropic", true), + buildModelInfo("claude-opus-4.5"+thinkingSuffix, "anthropic", true), + } } func modelSupportsImage(inputTypes []string) bool { @@ -357,31 +526,106 @@ func buildModelInfo(id, ownedBy string, supportsImage bool) map[string]interface // refreshModelsCache 从 Kiro API 拉取模型列表并缓存 func (h *Handler) refreshModelsCache() { - account := h.pool.GetNext() - if account == nil { + accounts := config.GetEnabledAccounts() + if len(accounts) == 0 { return } - // 确保 token 有效 - if err := h.ensureValidToken(account); err != nil { - return + aggregated := make([]ModelInfo, 0) + for i := range accounts { + account := &accounts[i] + if err := h.ensureValidToken(account); err != nil { + fmt.Printf("[ModelsCache] Skip %s token refresh failed: %v\n", account.Email, err) + continue + } + + models, err := ListAvailableModels(account) + if err != nil { + fmt.Printf("[ModelsCache] Failed to refresh for %s: %v\n", account.Email, err) + continue + } + aggregated = mergeUniqueModels(aggregated, models) } - models, err := ListAvailableModels(account) - if err != nil { - fmt.Printf("[ModelsCache] Failed to refresh: %v\n", err) - return - } - - if len(models) > 0 { + if len(aggregated) > 0 { h.modelsCacheMu.Lock() - h.cachedModels = models + h.cachedModels = aggregated h.modelsCacheTime = time.Now().Unix() h.modelsCacheMu.Unlock() - fmt.Printf("[ModelsCache] Cached %d models\n", len(models)) + fmt.Printf("[ModelsCache] Cached %d models\n", len(aggregated)) } } +func mergeUniqueModels(existing []ModelInfo, incoming []ModelInfo) []ModelInfo { + if len(incoming) == 0 { + return existing + } + + indexByID := make(map[string]int, len(existing)) + merged := make([]ModelInfo, len(existing)) + copy(merged, existing) + for i, model := range merged { + indexByID[strings.ToLower(strings.TrimSpace(model.ModelId))] = i + } + + for _, model := range incoming { + key := strings.ToLower(strings.TrimSpace(model.ModelId)) + if key == "" { + continue + } + if idx, ok := indexByID[key]; ok { + merged[idx] = mergeModelInfo(merged[idx], model) + continue + } + indexByID[key] = len(merged) + merged = append(merged, model) + } + + return merged +} + +func mergeModelInfo(base ModelInfo, extra ModelInfo) ModelInfo { + if base.ModelName == "" { + base.ModelName = extra.ModelName + } + if base.Description == "" { + base.Description = extra.Description + } + if base.RateMultiplier == 0 { + base.RateMultiplier = extra.RateMultiplier + } + if base.TokenLimits == nil { + base.TokenLimits = extra.TokenLimits + } + base.InputTypes = mergeStringLists(base.InputTypes, extra.InputTypes) + return base +} + +func mergeStringLists(base []string, extra []string) []string { + if len(extra) == 0 { + return base + } + seen := make(map[string]bool, len(base)+len(extra)) + merged := make([]string, 0, len(base)+len(extra)) + for _, item := range base { + key := strings.ToLower(strings.TrimSpace(item)) + if key == "" || seen[key] { + continue + } + seen[key] = true + merged = append(merged, item) + } + for _, item := range extra { + key := strings.ToLower(strings.TrimSpace(item)) + if key == "" || seen[key] { + continue + } + seen[key] = true + merged = append(merged, item) + } + return merged +} + // handleCountTokens Token 计数(Claude Code 会调用) func (h *Handler) handleCountTokens(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { @@ -400,8 +644,17 @@ func (h *Handler) handleCountTokens(w http.ResponseWriter, r *http.Request) { h.sendClaudeError(w, 400, "invalid_request_error", "Invalid JSON") return } + if msg := validateClaudeThinkingConfig(req.Thinking, req.MaxTokens); msg != "" { + h.sendClaudeError(w, 400, "invalid_request_error", msg) + return + } - estimatedTokens := estimateClaudeRequestInputTokens(&req) + thinkingCfg := config.GetThinkingConfig() + actualModel, thinking := resolveClaudeThinkingMode(req.Model, req.Thinking, thinkingCfg.Suffix) + req.Model = actualModel + effectiveReq := cloneClaudeRequestForThinking(&req, thinking) + + estimatedTokens := estimateClaudeRequestInputTokens(effectiveReq) if estimatedTokens < 1 { estimatedTokens = 1 } @@ -433,6 +686,10 @@ func (h *Handler) handleClaudeMessagesInternal(w http.ResponseWriter, r *http.Re h.sendClaudeError(w, 400, "invalid_request_error", "Invalid JSON: "+err.Error()) return } + if msg := validateClaudeRequestShape(&req); msg != "" { + h.sendClaudeError(w, 400, "invalid_request_error", msg) + return + } // 获取账号 account := h.pool.GetNext() @@ -449,23 +706,27 @@ func (h *Handler) handleClaudeMessagesInternal(w http.ResponseWriter, r *http.Re // 解析模型和 thinking 模式 thinkingCfg := config.GetThinkingConfig() - actualModel, thinking := ParseModelAndThinking(req.Model, thinkingCfg.Suffix) + actualModel, thinking := resolveClaudeThinkingMode(req.Model, req.Thinking, thinkingCfg.Suffix) req.Model = actualModel - estimatedInputTokens := estimateClaudeRequestInputTokens(&req) + effectiveReq := cloneClaudeRequestForThinking(&req, thinking) + thinkingResponseOpts := resolveClaudeThinkingResponseOptions(req.Thinking, thinkingCfg.ClaudeFormat) + estimatedInputTokens := estimateClaudeRequestInputTokens(effectiveReq) + cacheProfile := h.promptCache.BuildClaudeProfile(effectiveReq, estimatedInputTokens) + cacheUsage := h.promptCache.Compute(account.ID, cacheProfile) // 转换请求 kiroPayload := ClaudeToKiro(&req, thinking) - // 流式或非流式 + // Stream or non-stream if req.Stream { - h.handleClaudeStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens) + h.handleClaudeStream(w, account, kiroPayload, req.Model, thinking, thinkingResponseOpts, estimatedInputTokens, cacheUsage, cacheProfile) } else { - h.handleClaudeNonStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens) + h.handleClaudeNonStream(w, account, kiroPayload, req.Model, thinking, thinkingResponseOpts, estimatedInputTokens, cacheUsage, cacheProfile) } } // handleClaudeStream Claude 流式响应 -func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int) { +func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, thinkingOpts claudeThinkingResponseOptions, estimatedInputTokens int, cacheUsage promptCacheUsage, cacheProfile *promptCacheProfile) { w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") @@ -477,11 +738,12 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco } // 获取 thinking 输出格式配置 - thinkingFormat := config.GetThinkingConfig().ClaudeFormat + thinkingFormat := thinkingOpts.Format msgID := "msg_" + uuid.New().String() var inputTokens, outputTokens int var credits float64 + var realInputTokens int var toolUses []KiroToolUse var nextContentIndex int var rawContentBuilder strings.Builder @@ -593,6 +855,19 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco "delta": map[string]string{"type": "text_delta", "text": text}, }) default: + if thinkingOpts.OmitDisplay { + if thinkingState == 1 { + startContentBlock("thinking") + return + } + if thinkingState == 3 { + if activeBlockType != "thinking" { + startContentBlock("thinking") + } + closeActiveBlock() + } + return + } if thinkingState == 3 && text == "" { if activeBlockType == "thinking" { closeActiveBlock() @@ -737,10 +1012,7 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco "model": model, "stop_reason": nil, "stop_sequence": nil, - "usage": map[string]int{ - "input_tokens": startInputTokens, - "output_tokens": 0, - }, + "usage": buildClaudeUsageMap(startInputTokens, 0, cacheUsage, cacheProfile != nil), }, }) @@ -806,6 +1078,9 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco OnCredits: func(c float64) { credits = c }, + OnContextUsage: func(pct float64) { + realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0) + }, } err := CallKiroAPI(account, payload, callback) @@ -827,7 +1102,11 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco } closeActiveBlock() - inputTokens = estimatedInputTokens + if realInputTokens > 0 { + inputTokens = realInputTokens + } else if inputTokens <= 0 { + inputTokens = estimatedInputTokens + } outputContent, extractedReasoning := extractThinkingFromContent(rawContentBuilder.String()) thinkingOutput := rawThinkingBuilder.String() if thinking && thinkingOutput == "" && extractedReasoning != "" { @@ -841,6 +1120,7 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco h.recordSuccess(inputTokens, outputTokens, credits) h.pool.RecordSuccess(account.ID) h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits) + h.promptCache.Update(account.ID, cacheProfile) // 发送 message_delta stopReason := "end_turn" @@ -853,10 +1133,7 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco "delta": map[string]interface{}{ "stop_reason": stopReason, }, - "usage": map[string]int{ - "input_tokens": inputTokens, - "output_tokens": outputTokens, - }, + "usage": buildClaudeUsageMap(inputTokens, outputTokens, cacheUsage, cacheProfile != nil), }) h.sendSSE(w, flusher, "message_stop", map[string]interface{}{ @@ -925,12 +1202,13 @@ func (h *Handler) recordFailure() { } // handleClaudeNonStream Claude 非流式响应 -func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int) { +func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, thinkingOpts claudeThinkingResponseOptions, estimatedInputTokens int, cacheUsage promptCacheUsage, cacheProfile *promptCacheProfile) { var content string var thinkingContent string var toolUses []KiroToolUse var inputTokens, outputTokens int var credits float64 + var realInputTokens int callback := &KiroStreamCallback{ OnText: func(text string, isThinking bool) { @@ -953,6 +1231,9 @@ func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.A OnCredits: func(c float64) { credits = c }, + OnContextUsage: func(pct float64) { + realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0) + }, } err := CallKiroAPI(account, payload, callback) @@ -964,35 +1245,56 @@ func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.A } // 合并 thinking 内容(如果有 reasoningContentEvent 的内容) - thinkingFormat := config.GetThinkingConfig().ClaudeFormat + thinkingFormat := thinkingOpts.Format finalContent, extractedReasoning := extractThinkingFromContent(content) - if thinking && thinkingContent == "" && extractedReasoning != "" { - thinkingContent = extractedReasoning + rawThinkingContent := thinkingContent + if thinking && rawThinkingContent == "" && extractedReasoning != "" { + rawThinkingContent = extractedReasoning } if !thinking { - thinkingContent = "" + rawThinkingContent = "" } - inputTokens = estimatedInputTokens - outputTokens = estimateClaudeOutputTokens(finalContent, thinkingContent, toolUses) + if realInputTokens > 0 { + inputTokens = realInputTokens + } else if inputTokens <= 0 { + inputTokens = estimatedInputTokens + } + outputTokens = estimateClaudeOutputTokens(finalContent, rawThinkingContent, toolUses) h.recordSuccess(inputTokens, outputTokens, credits) h.pool.RecordSuccess(account.ID) h.pool.UpdateStats(account.ID, inputTokens+outputTokens, credits) + h.promptCache.Update(account.ID, cacheProfile) - if thinking && thinkingContent != "" { + responseThinkingContent := rawThinkingContent + includeEmptyThinkingBlock := thinking && thinkingOpts.OmitDisplay && rawThinkingContent != "" + if includeEmptyThinkingBlock { + responseThinkingContent = "" + } + + if thinking && responseThinkingContent != "" { switch thinkingFormat { case "think": - finalContent = "" + thinkingContent + "" + finalContent - thinkingContent = "" + finalContent = "" + responseThinkingContent + "" + finalContent + responseThinkingContent = "" case "reasoning_content": - finalContent = thinkingContent + finalContent // Claude 格式不支持 reasoning_content,直接拼接 - thinkingContent = "" + finalContent = responseThinkingContent + finalContent // Claude 格式不支持 reasoning_content,直接拼接 + responseThinkingContent = "" default: } } - resp := KiroToClaudeResponse(finalContent, thinkingContent, toolUses, inputTokens, outputTokens, model) + resp := KiroToClaudeResponse(finalContent, responseThinkingContent, includeEmptyThinkingBlock, toolUses, inputTokens, outputTokens, model) + resp.Usage.InputTokens = billedClaudeInputTokens(inputTokens, cacheUsage) + resp.Usage.CacheCreationInputTokens = cacheUsage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens = cacheUsage.CacheReadInputTokens + if cacheProfile != nil { + resp.Usage.CacheCreation = &ClaudeCacheCreationUsage{ + Ephemeral5mInputTokens: cacheUsage.CacheCreation5mInputTokens, + Ephemeral1hInputTokens: cacheUsage.CacheCreation1hInputTokens, + } + } w.Header().Set("Content-Type", "application/json; charset=utf-8") json.NewEncoder(w).Encode(resp) } @@ -1027,6 +1329,10 @@ func (h *Handler) handleOpenAIChat(w http.ResponseWriter, r *http.Request) { h.sendOpenAIError(w, 400, "invalid_request_error", "Invalid JSON") return } + if msg := validateOpenAIRequestShape(&req); msg != "" { + h.sendOpenAIError(w, 400, "invalid_request_error", msg) + return + } account := h.pool.GetNext() if account == nil { @@ -1074,6 +1380,7 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco var toolCallIndex int var inputTokens, outputTokens int var credits float64 + var realInputTokens int var rawContentBuilder strings.Builder var rawReasoningBuilder strings.Builder @@ -1366,6 +1673,9 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco OnCredits: func(c float64) { credits = c }, + OnContextUsage: func(pct float64) { + realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0) + }, } err := CallKiroAPI(account, payload, callback) @@ -1382,7 +1692,11 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco eventThinkingOpen = false } - inputTokens = estimatedInputTokens + if realInputTokens > 0 { + inputTokens = realInputTokens + } else if inputTokens <= 0 { + inputTokens = estimatedInputTokens + } outputContent, extractedReasoning := extractThinkingFromContent(rawContentBuilder.String()) reasoningOutput := rawReasoningBuilder.String() if thinking && reasoningOutput == "" && extractedReasoning != "" { @@ -1436,6 +1750,7 @@ func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.A var toolUses []KiroToolUse var inputTokens, outputTokens int var credits float64 + var realInputTokens int callback := &KiroStreamCallback{ OnText: func(text string, isThinking bool) { @@ -1449,6 +1764,9 @@ func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.A OnComplete: func(inTok, outTok int) { inputTokens = inTok; outputTokens = outTok }, OnError: func(err error) { h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) }, OnCredits: func(c float64) { credits = c }, + OnContextUsage: func(pct float64) { + realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0) + }, } err := CallKiroAPI(account, payload, callback) @@ -1467,7 +1785,11 @@ func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.A reasoningContent = "" } - inputTokens = estimatedInputTokens + if realInputTokens > 0 { + inputTokens = realInputTokens + } else if inputTokens <= 0 { + inputTokens = estimatedInputTokens + } outputTokens = estimateOpenAIOutputTokens(finalContent, reasoningContent, toolUses) h.recordSuccess(inputTokens, outputTokens, credits) @@ -1589,6 +1911,10 @@ func (h *Handler) handleAdminAPI(w http.ResponseWriter, r *http.Request) { h.apiGetEndpointConfig(w, r) case path == "/endpoint" && r.Method == "POST": h.apiUpdateEndpointConfig(w, r) + case path == "/proxy" && r.Method == "GET": + h.apiGetProxy(w, r) + case path == "/proxy" && r.Method == "POST": + h.apiUpdateProxy(w, r) case path == "/version" && r.Method == "GET": h.apiGetVersion(w, r) case path == "/export" && r.Method == "POST": @@ -1822,7 +2148,7 @@ func (h *Handler) apiBatchAccounts(w http.ResponseWriter, r *http.Request) { } h.pool.Reload() json.NewEncoder(w).Encode(map[string]interface{}{ - "success": true, + "success": true, "refreshed": successCount, "failed": failCount, }) @@ -2564,6 +2890,54 @@ func (h *Handler) apiUpdateEndpointConfig(w http.ResponseWriter, r *http.Request json.NewEncoder(w).Encode(map[string]bool{"success": true}) } +// applyProxyConfig 将代理配置应用到所有出站 HTTP 客户端(Kiro API + auth 模块) +func applyProxyConfig(proxyURL string) { + InitKiroHttpClient(proxyURL) + auth.InitHttpClient(proxyURL) +} + +// apiGetProxy 获取当前代理配置 +func (h *Handler) apiGetProxy(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "proxyURL": config.GetProxyURL(), + }) +} + +// apiUpdateProxy 更新代理配置并立即生效 +func (h *Handler) apiUpdateProxy(w http.ResponseWriter, r *http.Request) { + var req struct { + ProxyURL string `json:"proxyURL"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) + return + } + + // 验证代理 URL 格式(非空时) + if req.ProxyURL != "" { + if !strings.HasPrefix(req.ProxyURL, "http://") && + !strings.HasPrefix(req.ProxyURL, "https://") && + !strings.HasPrefix(req.ProxyURL, "socks5://") && + !strings.HasPrefix(req.ProxyURL, "socks5h://") { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "proxyURL must start with http://, https://, socks5://, or socks5h://"}) + return + } + } + + if err := config.UpdateProxySettings(req.ProxyURL); err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + // 立即应用新的代理配置 + applyProxyConfig(req.ProxyURL) + + json.NewEncoder(w).Encode(map[string]bool{"success": true}) +} + // apiGetVersion 获取版本信息 func (h *Handler) apiGetVersion(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]string{ diff --git a/proxy/handler_test.go b/proxy/handler_test.go index e45b8dd..e905bf1 100644 --- a/proxy/handler_test.go +++ b/proxy/handler_test.go @@ -48,3 +48,328 @@ func TestThinkingSourceSameSourceRemainsAllowed(t *testing.T) { t.Fatalf("expected repeated reasoning source selection to stay allowed") } } + +func TestValidateOpenAIRequestShapeRejectsAssistantPrefill(t *testing.T) { + req := &OpenAIRequest{ + Messages: []OpenAIMessage{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "prefill"}, + }, + } + + if msg := validateOpenAIRequestShape(req); msg == "" { + t.Fatalf("expected assistant-prefill final message to be rejected") + } +} + +func TestValidateOpenAIRequestShapeAllowsToolResultFinalTurn(t *testing.T) { + req := &OpenAIRequest{ + Messages: []OpenAIMessage{ + {Role: "user", Content: "find weather"}, + { + Role: "assistant", + ToolCalls: []ToolCall{{ + ID: "call_1", + Type: "function", + Function: struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + }{Name: "get_weather", Arguments: "{}"}, + }}, + }, + {Role: "tool", ToolCallID: "call_1", Content: "sunny"}, + }, + } + + if msg := validateOpenAIRequestShape(req); msg != "" { + t.Fatalf("expected tool-result final turn to be valid, got %q", msg) + } +} + +func TestValidateClaudeRequestShapeRejectsAssistantPrefill(t *testing.T) { + req := &ClaudeRequest{ + Messages: []ClaudeMessage{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "prefill"}, + }, + } + + if msg := validateClaudeRequestShape(req); msg == "" { + t.Fatalf("expected assistant-prefill final message to be rejected") + } +} + +func TestResolveClaudeThinkingModeHonorsRequestThinking(t *testing.T) { + tests := []struct { + name string + model string + thinking *ClaudeThinkingConfig + wantModel string + wantThinking bool + }{ + { + name: "adaptive request enables thinking", + model: "claude-sonnet-4.6", + thinking: &ClaudeThinkingConfig{Type: "adaptive"}, + wantModel: "claude-sonnet-4.6", + wantThinking: true, + }, + { + name: "enabled request enables thinking", + model: "claude-opus-4.5", + thinking: &ClaudeThinkingConfig{Type: "enabled", BudgetTokens: 2048}, + wantModel: "claude-opus-4.5", + wantThinking: true, + }, + { + name: "disabled request keeps thinking off", + model: "claude-opus-4.7", + thinking: &ClaudeThinkingConfig{Type: "disabled"}, + wantModel: "claude-opus-4.7", + wantThinking: false, + }, + { + name: "suffix remains supported when thinking is disabled", + model: "claude-sonnet-4.5-thinking", + thinking: &ClaudeThinkingConfig{Type: "disabled"}, + wantModel: "claude-sonnet-4.5", + wantThinking: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + gotModel, gotThinking := resolveClaudeThinkingMode(tc.model, tc.thinking, "-thinking") + if gotModel != tc.wantModel { + t.Fatalf("expected model %q, got %q", tc.wantModel, gotModel) + } + if gotThinking != tc.wantThinking { + t.Fatalf("expected thinking=%v, got %v", tc.wantThinking, gotThinking) + } + }) + } +} + +func TestCloneClaudeRequestForThinkingInjectsPromptWithoutMutatingOriginal(t *testing.T) { + req := &ClaudeRequest{ + Model: "claude-sonnet-4.6", + System: "Follow the user instructions.", + } + + cloned := cloneClaudeRequestForThinking(req, true) + blocks, ok := cloned.System.([]interface{}) + if !ok { + t.Fatalf("expected cloned system prompt to be structured blocks, got %T", cloned.System) + } + if len(blocks) != 2 { + t.Fatalf("expected 2 system blocks after prepend, got %d", len(blocks)) + } + gotPrompt := extractSystemPrompt(cloned.System) + expected := ThinkingModePrompt + "\n\nFollow the user instructions." + if gotPrompt != expected { + t.Fatalf("expected injected system prompt %q, got %q", expected, gotPrompt) + } + if original, ok := req.System.(string); !ok || original != "Follow the user instructions." { + t.Fatalf("expected original request system prompt to stay unchanged, got %#v", req.System) + } +} + +func TestCloneClaudeRequestForThinkingPreservesStructuredSystemBlocks(t *testing.T) { + req := &ClaudeRequest{ + Model: "claude-sonnet-4.6", + System: []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "cached system", + "cache_control": map[string]interface{}{ + "type": "ephemeral", + "ttl": "5m", + }, + }, + }, + } + + cloned := cloneClaudeRequestForThinking(req, true) + blocks, ok := cloned.System.([]interface{}) + if !ok { + t.Fatalf("expected structured system blocks, got %T", cloned.System) + } + if len(blocks) != 2 { + t.Fatalf("expected 2 system blocks after prepend, got %d", len(blocks)) + } + first, ok := blocks[0].(map[string]interface{}) + if !ok || first["text"] != ThinkingModePrompt+"\n" { + t.Fatalf("expected first block to be thinking prompt, got %#v", blocks[0]) + } + second, ok := blocks[1].(map[string]interface{}) + if !ok { + t.Fatalf("expected original system block to remain a map, got %T", blocks[1]) + } + cacheControl, ok := second["cache_control"].(map[string]interface{}) + if !ok || cacheControl["type"] != "ephemeral" { + t.Fatalf("expected original cache_control to be preserved, got %#v", second["cache_control"]) + } +} + +func TestThinkingPromptAffectsClaudeTokenEstimate(t *testing.T) { + req := &ClaudeRequest{ + Model: "claude-sonnet-4.6", + Messages: []ClaudeMessage{{Role: "user", Content: "hello"}}, + } + + baseTokens := estimateClaudeRequestInputTokens(req) + thinkingTokens := estimateClaudeRequestInputTokens(cloneClaudeRequestForThinking(req, true)) + + if thinkingTokens <= baseTokens { + t.Fatalf("expected thinking tokens (%d) to exceed base tokens (%d)", thinkingTokens, baseTokens) + } +} + +func TestValidateClaudeThinkingConfig(t *testing.T) { + tests := []struct { + name string + thinking *ClaudeThinkingConfig + maxTokens int + expectError bool + }{ + { + name: "adaptive is valid", + thinking: &ClaudeThinkingConfig{Type: "adaptive"}, + maxTokens: 4096, + expectError: false, + }, + { + name: "enabled requires budget", + thinking: &ClaudeThinkingConfig{Type: "enabled"}, + maxTokens: 4096, + expectError: true, + }, + { + name: "enabled requires at least 1024 budget tokens", + thinking: &ClaudeThinkingConfig{Type: "enabled", BudgetTokens: 512}, + maxTokens: 4096, + expectError: true, + }, + { + name: "enabled rejects max tokens zero", + thinking: &ClaudeThinkingConfig{Type: "enabled", BudgetTokens: 2048}, + maxTokens: 0, + expectError: true, + }, + { + name: "enabled budget must stay below max tokens", + thinking: &ClaudeThinkingConfig{Type: "enabled", BudgetTokens: 4096}, + maxTokens: 4096, + expectError: true, + }, + { + name: "disabled rejects display", + thinking: &ClaudeThinkingConfig{Type: "disabled", Display: "summarized"}, + maxTokens: 4096, + expectError: true, + }, + { + name: "missing type is rejected", + thinking: &ClaudeThinkingConfig{}, + maxTokens: 4096, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + errMsg := validateClaudeThinkingConfig(tc.thinking, tc.maxTokens) + if tc.expectError && errMsg == "" { + t.Fatalf("expected validation error") + } + if !tc.expectError && errMsg != "" { + t.Fatalf("expected thinking config to be valid, got %q", errMsg) + } + }) + } +} + +func TestResolveClaudeThinkingResponseOptions(t *testing.T) { + tests := []struct { + name string + thinking *ClaudeThinkingConfig + defaultFmt string + wantFmt string + wantOmit bool + }{ + { + name: "default config is preserved when display unset", + thinking: &ClaudeThinkingConfig{Type: "enabled", BudgetTokens: 2048}, + defaultFmt: "think", + wantFmt: "think", + wantOmit: false, + }, + { + name: "summarized forces official thinking blocks", + thinking: &ClaudeThinkingConfig{Type: "adaptive", Display: "summarized"}, + defaultFmt: "reasoning_content", + wantFmt: "thinking", + wantOmit: false, + }, + { + name: "omitted forces official thinking blocks and hides content", + thinking: &ClaudeThinkingConfig{Type: "adaptive", Display: "omitted"}, + defaultFmt: "think", + wantFmt: "thinking", + wantOmit: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + opts := resolveClaudeThinkingResponseOptions(tc.thinking, tc.defaultFmt) + if opts.Format != tc.wantFmt { + t.Fatalf("expected format %q, got %q", tc.wantFmt, opts.Format) + } + if opts.OmitDisplay != tc.wantOmit { + t.Fatalf("expected omitDisplay=%v, got %v", tc.wantOmit, opts.OmitDisplay) + } + }) + } +} + +func TestMergeUniqueModelsPreservesUnionAcrossAccounts(t *testing.T) { + base := []ModelInfo{ + {ModelId: "claude-sonnet-4.5", InputTypes: []string{"TEXT"}}, + } + incoming := []ModelInfo{ + {ModelId: "claude-sonnet-4.5", InputTypes: []string{"image"}}, + {ModelId: "claude-opus-4-7", InputTypes: []string{"text"}}, + } + + merged := mergeUniqueModels(base, incoming) + if len(merged) != 2 { + t.Fatalf("expected 2 unique models, got %d", len(merged)) + } + if !modelSupportsImage(merged[0].InputTypes) { + t.Fatalf("expected merged input types to preserve image capability, got %#v", merged[0].InputTypes) + } + if merged[1].ModelId != "claude-opus-4-7" { + t.Fatalf("expected second model to be claude-opus-4-7, got %q", merged[1].ModelId) + } +} + +func TestBuildAnthropicModelsResponseGeneratesThinkingVariants(t *testing.T) { + models := buildAnthropicModelsResponse([]ModelInfo{{ + ModelId: "claude-sonnet-4.5", + InputTypes: []string{"text", "image"}, + }}, "-thinking") + + if len(models) != 2 { + t.Fatalf("expected base model and thinking variant, got %d", len(models)) + } + if models[0]["id"] != "claude-sonnet-4.5" { + t.Fatalf("unexpected base model id: %#v", models[0]["id"]) + } + if models[1]["id"] != "claude-sonnet-4.5-thinking" { + t.Fatalf("unexpected thinking model id: %#v", models[1]["id"]) + } + if supportsImage, ok := models[0]["supports_image"].(bool); !ok || !supportsImage { + t.Fatalf("expected image capability to be preserved, got %#v", models[0]["supports_image"]) + } +} diff --git a/proxy/kiro.go b/proxy/kiro.go index a58eff8..974650a 100644 --- a/proxy/kiro.go +++ b/proxy/kiro.go @@ -7,17 +7,17 @@ import ( "encoding/json" "fmt" "io" - "kiro-api-proxy/config" + "kiro-go/config" "net/http" + "net/url" "strconv" "strings" + "sync/atomic" "time" "github.com/google/uuid" ) -const KiroVersion = "0.7.45" - // 双端点配置(429 时自动 fallback) type kiroEndpoint struct { URL string @@ -41,16 +41,48 @@ var kiroEndpoints = []kiroEndpoint{ }, } -// 全局 HTTP 客户端,复用连接池 -var kiroHttpClient = &http.Client{ - Timeout: 5 * time.Minute, - Transport: &http.Transport{ - MaxIdleConns: 100, // 最大空闲连接数 - MaxIdleConnsPerHost: 20, // 每个 Host 最大空闲连接数 - IdleConnTimeout: 90 * time.Second, // 空闲连接超时 - DisableCompression: false, // 启用压缩 - ForceAttemptHTTP2: true, // 尝试使用 HTTP/2 - }, +// 全局 HTTP 客户端,支持运行时更换(代理重配置) +var kiroHttpStore atomic.Pointer[http.Client] +var kiroRestHttpStore atomic.Pointer[http.Client] + +func init() { + InitKiroHttpClient("") +} + +// buildKiroTransport 构建带可选代理的 Transport +func buildKiroTransport(proxyURL string) *http.Transport { + t := &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 20, + IdleConnTimeout: 90 * time.Second, + DisableCompression: false, + ForceAttemptHTTP2: true, + } + if proxyURL != "" { + if u, err := url.Parse(proxyURL); err == nil { + t.Proxy = http.ProxyURL(u) + // 代理不支持 HTTP/2 协议升级 + t.ForceAttemptHTTP2 = false + } + } else { + t.Proxy = http.ProxyFromEnvironment + } + return t +} + +// InitKiroHttpClient 初始化(或重新初始化)Kiro API 的 HTTP 客户端 +func InitKiroHttpClient(proxyURL string) { + client := &http.Client{ + Timeout: 5 * time.Minute, + Transport: buildKiroTransport(proxyURL), + } + kiroHttpStore.Store(client) + + restClient := &http.Client{ + Timeout: 30 * time.Second, + Transport: buildKiroTransport(proxyURL), + } + kiroRestHttpStore.Store(restClient) } // ==================== 请求结构 ==================== @@ -133,15 +165,16 @@ type InferenceConfig struct { TopP float64 `json:"topP,omitempty"` } -// ==================== 流式回调 ==================== +// ==================== Stream Callbacks ==================== -// KiroStreamCallback 流式响应回调 +// KiroStreamCallback stream response callbacks type KiroStreamCallback struct { - OnText func(text string, isThinking bool) - OnToolUse func(toolUse KiroToolUse) - OnComplete func(inputTokens, outputTokens int) - OnError func(err error) - OnCredits func(credits float64) + OnText func(text string, isThinking bool) + OnToolUse func(toolUse KiroToolUse) + OnComplete func(inputTokens, outputTokens int) + OnError func(err error) + OnCredits func(credits float64) + OnContextUsage func(percentage float64) } // ==================== API 调用 ==================== @@ -163,16 +196,16 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt if _, err := json.Marshal(payload); err != nil { return err } - - // User-Agent - machineId := account.MachineId - var userAgent, amzUserAgent string - if machineId != "" { - userAgent = fmt.Sprintf("aws-sdk-js/1.0.27 ua/2.1 os/linux lang/js md/nodejs#22.21.1 api/codewhispererstreaming#1.0.27 m/E KiroIDE-%s-%s", KiroVersion, machineId) - amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.27 KiroIDE %s %s", KiroVersion, machineId) - } else { - userAgent = fmt.Sprintf("aws-sdk-js/1.0.27 ua/2.1 os/linux lang/js md/nodejs#22.21.1 api/codewhispererstreaming#1.0.27 m/E KiroIDE-%s", KiroVersion) - amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.27 KiroIDE %s", KiroVersion) + if payload != nil && strings.TrimSpace(payload.ProfileArn) == "" { + if profileArn, err := ResolveProfileArn(account); err == nil { + payload.ProfileArn = profileArn + } else { + accountEmail := "" + if account != nil { + accountEmail = account.Email + } + fmt.Printf("[ProfileArn] Failed to resolve profile ARN for %s: %v\n", accountEmail, err) + } } // 根据配置排序端点 @@ -190,18 +223,22 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt continue } + host := "" + if parsedURL, parseErr := url.Parse(ep.URL); parseErr == nil { + host = parsedURL.Host + } + headerValues := buildStreamingHeaderValues(account, host) + req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "*/*") req.Header.Set("X-Amz-Target", ep.AmzTarget) - req.Header.Set("User-Agent", userAgent) - req.Header.Set("X-Amz-User-Agent", amzUserAgent) + applyKiroBaseHeaders(req, account, headerValues) req.Header.Set("x-amzn-kiro-agent-mode", "vibe") req.Header.Set("x-amzn-codewhisperer-optout", "true") req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) - req.Header.Set("Authorization", "Bearer "+account.AccessToken) - resp, err := kiroHttpClient.Do(req) + resp, err := kiroHttpStore.Load().Do(req) if err != nil { lastErr = err fmt.Printf("[KiroAPI] Endpoint %s failed: %v\n", ep.Name, err) @@ -314,6 +351,12 @@ func parseEventStream(body io.Reader, callback *KiroStreamCallback) error { if usage, ok := event["usage"].(float64); ok { totalCredits += usage } + case "contextUsageEvent": + if pct, ok := event["contextUsagePercentage"].(float64); ok { + if callback.OnContextUsage != nil { + callback.OnContextUsage(pct) + } + } } } @@ -378,6 +421,17 @@ func updateTokensFromEvent(event map[string]interface{}, currentInputTokens, cur return inputTokens, outputTokens } +// getContextWindowSize returns the context window size (in tokens) for a model. +func getContextWindowSize(model string) int { + m := strings.ToLower(model) + // sonnet-4.6, opus-4.6, opus-4.7 all have 1M context windows + if strings.Contains(m, "4.6") || strings.Contains(m, "4-6") || + strings.Contains(m, "4.7") || strings.Contains(m, "4-7") { + return 1_000_000 + } + return 200_000 +} + func collectUsageMaps(v interface{}, out *[]map[string]interface{}) { switch t := v.(type) { case map[string]interface{}: diff --git a/proxy/kiro_api.go b/proxy/kiro_api.go index 7252182..91c27f8 100644 --- a/proxy/kiro_api.go +++ b/proxy/kiro_api.go @@ -4,20 +4,21 @@ import ( "encoding/json" "fmt" "io" - "kiro-api-proxy/config" + "kiro-go/config" "net/http" + neturl "net/url" "strings" "time" ) const ( kiroRestAPIBase = "https://codewhisperer.us-east-1.amazonaws.com" - kiroVersion = "0.7.45" ) // GetUsageLimits 获取账户使用量和订阅信息 func GetUsageLimits(account *config.Account) (*UsageLimitsResponse, error) { url := fmt.Sprintf("%s/getUsageLimits?origin=AI_EDITOR&resourceType=AGENTIC_REQUEST&isEmailRequired=true", kiroRestAPIBase) + url = withProfileArnQuery(url, account) req, err := http.NewRequest("GET", url, nil) if err != nil { @@ -26,8 +27,7 @@ func GetUsageLimits(account *config.Account) (*UsageLimitsResponse, error) { setKiroHeaders(req, account) - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := kiroRestHttpStore.Load().Do(req) if err != nil { return nil, err } @@ -58,8 +58,7 @@ func GetUserInfo(account *config.Account) (*UserInfoResponse, error) { setKiroHeaders(req, account) req.Header.Set("Content-Type", "application/json") - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := kiroRestHttpStore.Load().Do(req) if err != nil { return nil, err } @@ -80,6 +79,7 @@ func GetUserInfo(account *config.Account) (*UserInfoResponse, error) { // ListAvailableModels 获取可用模型列表 func ListAvailableModels(account *config.Account) ([]ModelInfo, error) { url := fmt.Sprintf("%s/ListAvailableModels?origin=AI_EDITOR&maxResults=50", kiroRestAPIBase) + url = withProfileArnQuery(url, account) req, err := http.NewRequest("GET", url, nil) if err != nil { @@ -88,8 +88,7 @@ func ListAvailableModels(account *config.Account) ([]ModelInfo, error) { setKiroHeaders(req, account) - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := kiroRestHttpStore.Load().Do(req) if err != nil { return nil, err } @@ -109,22 +108,75 @@ func ListAvailableModels(account *config.Account) ([]ModelInfo, error) { return result.Models, nil } -func setKiroHeaders(req *http.Request, account *config.Account) { - machineId := account.MachineId - var userAgent, amzUserAgent string - if machineId != "" { - userAgent = fmt.Sprintf("aws-sdk-js/1.0.27 ua/2.1 os/linux lang/js md/nodejs#22.21.1 api/codewhispererstreaming#1.0.27 m/E KiroIDE-%s-%s", kiroVersion, machineId) - amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.27 KiroIDE %s %s", kiroVersion, machineId) - } else { - userAgent = fmt.Sprintf("aws-sdk-js/1.0.27 ua/2.1 os/linux lang/js md/nodejs#22.21.1 api/codewhispererstreaming#1.0.27 m/E KiroIDE-%s", kiroVersion) - amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.27 KiroIDE %s", kiroVersion) +// ResolveProfileArn returns the account profile ARN, fetching and caching it +// when it is missing. Some Kiro generation requests require this profile for +// model authorization even when model listing works without it. +func ResolveProfileArn(account *config.Account) (string, error) { + if account == nil { + return "", fmt.Errorf("account is nil") + } + if profileArn := strings.TrimSpace(account.ProfileArn); profileArn != "" { + return profileArn, nil } - req.Header.Set("Authorization", "Bearer "+account.AccessToken) + req, err := http.NewRequest("POST", fmt.Sprintf("%s/ListAvailableProfiles", kiroRestAPIBase), strings.NewReader(`{"maxResults":10}`)) + if err != nil { + return "", err + } + setKiroHeaders(req, account) + req.Header.Set("Content-Type", "application/json") + + resp, err := kiroRestHttpStore.Load().Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + var result struct { + Profiles []struct { + Arn string `json:"arn"` + } `json:"profiles"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", err + } + for _, profile := range result.Profiles { + if profileArn := strings.TrimSpace(profile.Arn); profileArn != "" { + if updateErr := config.UpdateAccountProfileArn(account.ID, profileArn); updateErr != nil { + fmt.Printf("[ProfileArn] Failed to cache profile ARN for %s: %v\n", account.Email, updateErr) + } + account.ProfileArn = profileArn + return profileArn, nil + } + } + return "", fmt.Errorf("no available Kiro profile") +} + +func withProfileArnQuery(rawURL string, account *config.Account) string { + if account == nil { + return rawURL + } + profileArn := strings.TrimSpace(account.ProfileArn) + if profileArn == "" { + return rawURL + } + return rawURL + "&profileArn=" + neturl.QueryEscape(profileArn) +} + +func setKiroHeaders(req *http.Request, account *config.Account) { + host := "" + if req.URL != nil { + host = req.URL.Host + } + headerValues := buildRuntimeHeaderValues(account, host) + req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", userAgent) - req.Header.Set("x-amz-user-agent", amzUserAgent) - req.Header.Set("x-amzn-codewhisperer-optout", "true") + applyKiroBaseHeaders(req, account, headerValues) } // RefreshAccountInfo 刷新账户信息(使用量、订阅等) @@ -156,7 +208,7 @@ func RefreshAccountInfo(account *config.Account) (*config.AccountInfo, error) { return nil, fmt.Errorf("Account suspended: %w", err) } else if strings.Contains(errMsg, "403") || strings.Contains(errMsg, "401") || - strings.Contains(errMsg, "invalid") || strings.Contains(errMsg, "expired") { + strings.Contains(errMsg, "invalid") || strings.Contains(errMsg, "expired") { // Token 相关错误,可能需要重新认证 fmt.Printf("[RefreshAccountInfo] Authentication error for %s: %v\n", account.Email, err) @@ -286,14 +338,14 @@ type UsageLimitsResponse struct { } type UsageBreakdown struct { - ResourceType string `json:"resourceType"` - CurrentUsage float64 `json:"currentUsage"` - UsageLimit float64 `json:"usageLimit"` - Currency string `json:"currency"` - Unit string `json:"unit"` - OverageRate float64 `json:"overageRate"` - FreeTrialInfo *FreeTrialInfo `json:"freeTrialInfo"` - Bonuses []BonusInfo `json:"bonuses"` + ResourceType string `json:"resourceType"` + CurrentUsage float64 `json:"currentUsage"` + UsageLimit float64 `json:"usageLimit"` + Currency string `json:"currency"` + Unit string `json:"unit"` + OverageRate float64 `json:"overageRate"` + FreeTrialInfo *FreeTrialInfo `json:"freeTrialInfo"` + Bonuses []BonusInfo `json:"bonuses"` } type FreeTrialInfo struct { diff --git a/proxy/kiro_api_test.go b/proxy/kiro_api_test.go new file mode 100644 index 0000000..4fce7cd --- /dev/null +++ b/proxy/kiro_api_test.go @@ -0,0 +1,96 @@ +package proxy + +import ( + "io" + "kiro-go/config" + "net/http" + "path/filepath" + "strings" + "testing" +) + +func TestResolveProfileArnReturnsCachedValueWithoutRequest(t *testing.T) { + kiroRestHttpStore.Store(&http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + t.Fatal("unexpected HTTP request for cached profile ARN") + return nil, nil + }), + }) + t.Cleanup(func() { InitKiroHttpClient("") }) + + account := &config.Account{ProfileArn: " arn:aws:codewhisperer:profile/test "} + got, err := ResolveProfileArn(account) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "arn:aws:codewhisperer:profile/test" { + t.Fatalf("expected trimmed cached ARN, got %q", got) + } +} + +func TestResolveProfileArnFetchesAndCachesProfile(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + if err := config.Init(configPath); err != nil { + t.Fatalf("init config: %v", err) + } + account := config.Account{ + ID: "acct-1", + Email: "user@example.com", + AccessToken: "access-token", + Region: "us-east-1", + UsageCurrent: 7, + } + if err := config.AddAccount(account); err != nil { + t.Fatalf("add account: %v", err) + } + + kiroRestHttpStore.Store(&http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + if req.Method != http.MethodPost { + t.Fatalf("expected POST, got %s", req.Method) + } + if req.URL.Path != "/ListAvailableProfiles" { + t.Fatalf("expected ListAvailableProfiles path, got %s", req.URL.Path) + } + if got := req.Header.Get("Content-Type"); got != "application/json" { + t.Fatalf("expected JSON content type, got %q", got) + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"profiles":[{"arn":" arn:aws:codewhisperer:profile/fetched "}]} `)), + Header: make(http.Header), + }, nil + }), + }) + t.Cleanup(func() { InitKiroHttpClient("") }) + + requestAccount := account + requestAccount.UsageCurrent = 0 + got, err := ResolveProfileArn(&requestAccount) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "arn:aws:codewhisperer:profile/fetched" { + t.Fatalf("expected fetched ARN, got %q", got) + } + if requestAccount.ProfileArn != got { + t.Fatalf("expected account to be updated with fetched ARN, got %q", requestAccount.ProfileArn) + } + + accounts := config.GetAccounts() + if len(accounts) != 1 { + t.Fatalf("expected one persisted account, got %d", len(accounts)) + } + if accounts[0].ProfileArn != got { + t.Fatalf("expected persisted account profile ARN %q, got %q", got, accounts[0].ProfileArn) + } + if accounts[0].UsageCurrent != 7 { + t.Fatalf("expected profile cache update to preserve usage fields, got usageCurrent=%v", accounts[0].UsageCurrent) + } +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} diff --git a/proxy/kiro_headers.go b/proxy/kiro_headers.go new file mode 100644 index 0000000..baf3fc6 --- /dev/null +++ b/proxy/kiro_headers.go @@ -0,0 +1,68 @@ +package proxy + +import ( + "fmt" + "kiro-go/config" + "net/http" +) + +const ( + kiroStreamingSDKVersion = "1.0.34" + kiroRuntimeSDKVersion = "1.0.0" +) + +type kiroHeaderValues struct { + UserAgent string + AmzUserAgent string + Host string +} + +func buildStreamingHeaderValues(account *config.Account, host string) kiroHeaderValues { + return buildKiroHeaderValues(account, host, "codewhispererstreaming", kiroStreamingSDKVersion, "m/E") +} + +func buildRuntimeHeaderValues(account *config.Account, host string) kiroHeaderValues { + return buildKiroHeaderValues(account, host, "codewhispererruntime", kiroRuntimeSDKVersion, "m/N,E") +} + +func buildKiroHeaderValues(account *config.Account, host, apiName, sdkVersion, mode string) kiroHeaderValues { + clientCfg := config.GetKiroClientConfig() + machineID := "" + if account != nil { + machineID = account.MachineId + } + + userAgent := fmt.Sprintf( + "aws-sdk-js/%s ua/2.1 os/%s lang/js md/nodejs#%s api/%s#%s %s KiroIDE-%s", + sdkVersion, + clientCfg.SystemVersion, + clientCfg.NodeVersion, + apiName, + sdkVersion, + mode, + clientCfg.KiroVersion, + ) + amzUserAgent := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s", sdkVersion, clientCfg.KiroVersion) + if machineID != "" { + userAgent += "-" + machineID + amzUserAgent += "-" + machineID + } + + return kiroHeaderValues{ + UserAgent: userAgent, + AmzUserAgent: amzUserAgent, + Host: host, + } +} + +func applyKiroBaseHeaders(req *http.Request, account *config.Account, values kiroHeaderValues) { + if account != nil && account.AccessToken != "" { + req.Header.Set("Authorization", "Bearer "+account.AccessToken) + } + req.Header.Set("User-Agent", values.UserAgent) + req.Header.Set("x-amz-user-agent", values.AmzUserAgent) + req.Header.Set("x-amzn-codewhisperer-optout", "true") + if values.Host != "" { + req.Host = values.Host + } +} diff --git a/proxy/kiro_headers_test.go b/proxy/kiro_headers_test.go new file mode 100644 index 0000000..a4b0805 --- /dev/null +++ b/proxy/kiro_headers_test.go @@ -0,0 +1,43 @@ +package proxy + +import ( + "kiro-go/config" + "strings" + "testing" +) + +func TestBuildStreamingHeaderValuesAlignsWithKiroIDEFormat(t *testing.T) { + account := &config.Account{MachineId: "machine-123"} + values := buildStreamingHeaderValues(account, "q.us-east-1.amazonaws.com") + + if values.Host != "q.us-east-1.amazonaws.com" { + t.Fatalf("expected host to be preserved, got %q", values.Host) + } + if !strings.Contains(values.UserAgent, "aws-sdk-js/1.0.34") { + t.Fatalf("expected streaming sdk version in user agent, got %q", values.UserAgent) + } + if !strings.Contains(values.UserAgent, "api/codewhispererstreaming#1.0.34") { + t.Fatalf("expected streaming API marker in user agent, got %q", values.UserAgent) + } + if !strings.Contains(values.UserAgent, "KiroIDE-0.11.107-machine-123") { + t.Fatalf("expected kiro version and machine id in user agent, got %q", values.UserAgent) + } + if !strings.Contains(values.AmzUserAgent, "aws-sdk-js/1.0.34 KiroIDE-0.11.107-machine-123") { + t.Fatalf("expected x-amz-user-agent to include version and machine id, got %q", values.AmzUserAgent) + } +} + +func TestBuildRuntimeHeaderValuesUsesRuntimeAPIFormat(t *testing.T) { + account := &config.Account{MachineId: "machine-456"} + values := buildRuntimeHeaderValues(account, "codewhisperer.us-east-1.amazonaws.com") + + if !strings.Contains(values.UserAgent, "aws-sdk-js/1.0.0") { + t.Fatalf("expected runtime sdk version in user agent, got %q", values.UserAgent) + } + if !strings.Contains(values.UserAgent, "api/codewhispererruntime#1.0.0") { + t.Fatalf("expected runtime API marker in user agent, got %q", values.UserAgent) + } + if !strings.Contains(values.UserAgent, "m/N,E") { + t.Fatalf("expected runtime mode marker in user agent, got %q", values.UserAgent) + } +} diff --git a/proxy/kiro_test.go b/proxy/kiro_test.go index f32190b..003e544 100644 --- a/proxy/kiro_test.go +++ b/proxy/kiro_test.go @@ -1,6 +1,11 @@ package proxy -import "testing" +import ( + "net/http" + "net/url" + "testing" + "time" +) func TestNormalizeChunkBasicProgression(t *testing.T) { prev := "" @@ -35,3 +40,63 @@ func TestNormalizeChunkOverlapDelta(t *testing.T) { t.Fatalf("expected overlap suffix delta, got %q", got) } } + +func TestBuildKiroTransportUsesExplicitProxyURL(t *testing.T) { + transport := buildKiroTransport("http://proxy.local:8080") + req := &http.Request{URL: mustParseURL(t, "https://q.us-east-1.amazonaws.com")} + + got, err := transport.Proxy(req) + if err != nil { + t.Fatalf("unexpected proxy error: %v", err) + } + assertProxyURL(t, got, "http://proxy.local:8080") +} + +func TestBuildKiroTransportFallsBackToEnvironmentProxy(t *testing.T) { + t.Setenv("HTTPS_PROXY", "http://env-proxy.local:2323") + t.Setenv("NO_PROXY", "") + t.Setenv("no_proxy", "") + + transport := buildKiroTransport("") + req := &http.Request{URL: mustParseURL(t, "https://q.us-east-1.amazonaws.com")} + + got, err := transport.Proxy(req) + if err != nil { + t.Fatalf("unexpected proxy error: %v", err) + } + assertProxyURL(t, got, "http://env-proxy.local:2323") +} + +func TestInitKiroHttpClientKeepsShortRestTimeout(t *testing.T) { + InitKiroHttpClient("") + t.Cleanup(func() { InitKiroHttpClient("") }) + + streamClient := kiroHttpStore.Load() + restClient := kiroRestHttpStore.Load() + + if streamClient.Timeout != 5*time.Minute { + t.Fatalf("expected streaming timeout to be 5m, got %s", streamClient.Timeout) + } + if restClient.Timeout != 30*time.Second { + t.Fatalf("expected REST timeout to stay 30s, got %s", restClient.Timeout) + } +} + +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + parsed, err := url.Parse(raw) + if err != nil { + t.Fatalf("invalid test URL: %v", err) + } + return parsed +} + +func assertProxyURL(t *testing.T, got *url.URL, want string) { + t.Helper() + if got == nil { + t.Fatalf("expected proxy URL %q, got nil", want) + } + if got.String() != want { + t.Fatalf("expected proxy URL %q, got %q", want, got.String()) + } +} diff --git a/proxy/translator.go b/proxy/translator.go index 64c4128..38b562e 100644 --- a/proxy/translator.go +++ b/proxy/translator.go @@ -22,8 +22,8 @@ var modelMapOrdered = []modelMapping{ {"claude-sonnet-4.5", "claude-sonnet-4.5"}, {"claude-sonnet-4-6", "claude-sonnet-4.6"}, {"claude-sonnet-4.6", "claude-sonnet-4.6"}, - {"claude-opus-4-7", "claude-opus-4-7"}, - {"claude-opus-4.7", "claude-opus-4-7"}, + {"claude-opus-4-7", "claude-opus-4.7"}, + {"claude-opus-4.7", "claude-opus-4.7"}, {"claude-haiku-4-5", "claude-haiku-4.5"}, {"claude-haiku-4.5", "claude-haiku-4.5"}, {"claude-opus-4-5", "claude-opus-4.5"}, @@ -46,6 +46,7 @@ const ThinkingModePrompt = `enabled 200000` const minimalFallbackUserContent = "." +const toolResultsContinuationPrefix = "Tool results:" // ParseModelAndThinking 解析模型名称,返回实际模型和是否启用 thinking func ParseModelAndThinking(model string, thinkingSuffix string) (string, bool) { @@ -72,7 +73,20 @@ func ParseModelAndThinking(model string, thinkingSuffix string) (string, bool) { return model, thinking } - return "claude-sonnet-4.5", thinking + return model, thinking +} + +func resolveClaudeThinkingMode(model string, thinkingCfg *ClaudeThinkingConfig, thinkingSuffix string) (string, bool) { + actualModel, suffixThinking := ParseModelAndThinking(model, thinkingSuffix) + return actualModel, suffixThinking || isClaudeThinkingRequested(thinkingCfg) +} + +func isClaudeThinkingRequested(thinkingCfg *ClaudeThinkingConfig) bool { + if thinkingCfg == nil { + return false + } + kind := strings.ToLower(strings.TrimSpace(thinkingCfg.Type)) + return kind == "enabled" || kind == "adaptive" } func MapModel(model string) string { @@ -83,15 +97,22 @@ func MapModel(model string) string { // ==================== Claude API 类型 ==================== type ClaudeRequest struct { - Model string `json:"model"` - Messages []ClaudeMessage `json:"messages"` - MaxTokens int `json:"max_tokens"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - Stream bool `json:"stream,omitempty"` - System interface{} `json:"system,omitempty"` // string or []SystemBlock - Tools []ClaudeTool `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` + Model string `json:"model"` + Messages []ClaudeMessage `json:"messages"` + MaxTokens int `json:"max_tokens"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Stream bool `json:"stream,omitempty"` + System interface{} `json:"system,omitempty"` // string or []SystemBlock + Thinking *ClaudeThinkingConfig `json:"thinking,omitempty"` + Tools []ClaudeTool `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` +} + +type ClaudeThinkingConfig struct { + Type string `json:"type,omitempty"` + BudgetTokens int `json:"budget_tokens,omitempty"` + Display string `json:"display,omitempty"` } type ClaudeMessage struct { @@ -103,6 +124,7 @@ type ClaudeContentBlock struct { Type string `json:"type"` Text string `json:"text,omitempty"` Thinking string `json:"thinking,omitempty"` + Signature string `json:"signature,omitempty"` ID string `json:"id,omitempty"` Name string `json:"name,omitempty"` Input interface{} `json:"input,omitempty"` @@ -134,9 +156,17 @@ type ClaudeResponse struct { Usage ClaudeUsage `json:"usage"` } +type ClaudeCacheCreationUsage struct { + Ephemeral5mInputTokens int `json:"ephemeral_5m_input_tokens,omitempty"` + Ephemeral1hInputTokens int `json:"ephemeral_1h_input_tokens,omitempty"` +} + type ClaudeUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` + CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` + CacheCreation *ClaudeCacheCreationUsage `json:"cache_creation,omitempty"` } // ==================== Claude -> Kiro 转换 ==================== @@ -148,12 +178,7 @@ func ClaudeToKiro(req *ClaudeRequest, thinking bool) *KiroPayload { origin := "AI_EDITOR" // 提取系统提示 - systemPrompt := extractSystemPrompt(req.System) - - // 如果启用 thinking 模式,注入 thinking 提示 - if thinking { - systemPrompt = ThinkingModePrompt + "\n\n" + systemPrompt - } + systemPrompt := buildClaudeSystemPrompt(req.System, thinking) // 构建历史消息 history := make([]KiroHistoryMessage, 0) @@ -175,7 +200,7 @@ func ClaudeToKiro(req *ClaudeRequest, thinking bool) *KiroPayload { } else { userMsg := KiroUserInputMessage{ Content: content, - // ModelID: modelID, + ModelID: modelID, Origin: origin, } if len(images) > 0 { @@ -201,16 +226,7 @@ func ClaudeToKiro(req *ClaudeRequest, thinking bool) *KiroPayload { } } - // 确保 history 以 user 开始 - if len(history) > 0 && history[0].AssistantResponseMessage != nil { - history = append([]KiroHistoryMessage{{ - UserInputMessage: &KiroUserInputMessage{ - Content: "Begin conversation", - // ModelID: modelID, - Origin: origin, - }, - }}, history...) - } + history = trimLeadingAssistantHistory(history) // 构建最终内容 finalContent := "" @@ -236,7 +252,7 @@ func ClaudeToKiro(req *ClaudeRequest, thinking bool) *KiroPayload { payload.ConversationState.ConversationID = buildConversationID(modelID, systemPrompt, firstClaudeConversationAnchor(req.Messages)) payload.ConversationState.CurrentMessage.UserInputMessage = KiroUserInputMessage{ Content: finalContent, - // ModelID: modelID, + ModelID: modelID, Origin: origin, Images: currentImages, } @@ -263,6 +279,88 @@ func ClaudeToKiro(req *ClaudeRequest, thinking bool) *KiroPayload { return payload } +func buildClaudeSystemPrompt(system interface{}, thinking bool) string { + systemPrompt := extractSystemPrompt(system) + if !thinking { + return systemPrompt + } + if systemPrompt == "" { + return ThinkingModePrompt + } + return ThinkingModePrompt + "\n\n" + systemPrompt +} + +func cloneClaudeRequestForThinking(req *ClaudeRequest, thinking bool) *ClaudeRequest { + if req == nil { + return nil + } + + cloned := *req + if thinking { + cloned.System = prependThinkingSystem(req.System) + } + return &cloned +} + +func prependThinkingSystem(system interface{}) interface{} { + thinkingText := ThinkingModePrompt + if hasClaudeSystemContent(system) { + thinkingText += "\n" + } + thinkingBlock := map[string]interface{}{ + "type": "text", + "text": thinkingText, + } + + switch v := system.(type) { + case nil: + return []interface{}{thinkingBlock} + case string: + if v == "" { + return []interface{}{thinkingBlock} + } + return []interface{}{ + thinkingBlock, + map[string]interface{}{ + "type": "text", + "text": v, + }, + } + case []interface{}: + blocks := make([]interface{}, 0, len(v)+1) + blocks = append(blocks, thinkingBlock) + blocks = append(blocks, v...) + return blocks + case []string: + blocks := make([]interface{}, 0, len(v)+1) + blocks = append(blocks, thinkingBlock) + for _, block := range v { + blocks = append(blocks, map[string]interface{}{ + "type": "text", + "text": block, + }) + } + return blocks + default: + return []interface{}{thinkingBlock} + } +} + +func hasClaudeSystemContent(system interface{}) bool { + switch v := system.(type) { + case nil: + return false + case string: + return v != "" + case []interface{}: + return len(v) > 0 + case []string: + return len(v) > 0 + default: + return true + } +} + func extractSystemPrompt(system interface{}) string { if system == nil { return "" @@ -459,10 +557,10 @@ func shortenToolName(name string) string { // ==================== Kiro -> Claude 转换 ==================== -func KiroToClaudeResponse(content, thinkingContent string, toolUses []KiroToolUse, inputTokens, outputTokens int, model string) *ClaudeResponse { +func KiroToClaudeResponse(content, thinkingContent string, includeEmptyThinkingBlock bool, toolUses []KiroToolUse, inputTokens, outputTokens int, model string) *ClaudeResponse { blocks := make([]ClaudeContentBlock, 0) - if thinkingContent != "" { + if thinkingContent != "" || includeEmptyThinkingBlock { blocks = append(blocks, ClaudeContentBlock{ Type: "thinking", Thinking: thinkingContent, @@ -615,7 +713,7 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload { history = append(history, KiroHistoryMessage{ UserInputMessage: &KiroUserInputMessage{ Content: content, - // ModelID: modelID, + ModelID: modelID, Origin: origin, Images: images, }, @@ -661,7 +759,7 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload { history = append(history, KiroHistoryMessage{ UserInputMessage: &KiroUserInputMessage{ Content: buildToolResultsContinuation(currentToolResults), - // ModelID: modelID, + ModelID: modelID, Origin: origin, UserInputMessageContext: &UserInputMessageContext{ ToolResults: currentToolResults, @@ -698,7 +796,7 @@ func OpenAIToKiro(req *OpenAIRequest, thinking bool) *KiroPayload { payload.ConversationState.ConversationID = buildConversationID(modelID, systemPrompt, firstOpenAIConversationAnchor(nonSystemMessages)) payload.ConversationState.CurrentMessage.UserInputMessage = KiroUserInputMessage{ Content: finalContent, - // ModelID: modelID, + ModelID: modelID, Origin: origin, Images: currentImages, } @@ -832,13 +930,27 @@ func buildToolResultsContinuation(toolResults []KiroToolResult) string { return minimalFallbackUserContent } - joined := strings.Join(parts, "\n\n") + joined := toolResultsContinuationPrefix + "\n\n" + strings.Join(parts, "\n\n") if len(joined) > 4000 { return joined[:4000] } return joined } +func trimLeadingAssistantHistory(history []KiroHistoryMessage) []KiroHistoryMessage { + idx := 0 + for idx < len(history) && history[idx].AssistantResponseMessage != nil { + idx++ + } + if idx == 0 { + return history + } + if idx >= len(history) { + return nil + } + return history[idx:] +} + func firstClaudeConversationAnchor(messages []ClaudeMessage) string { for _, msg := range messages { if msg.Role != "user" { @@ -849,15 +961,7 @@ func firstClaudeConversationAnchor(messages []ClaudeMessage) string { return strings.TrimSpace(text) } if len(toolResults) > 0 { - return buildToolResultsContinuation(toolResults) - } - } - - for _, msg := range messages { - if strings.TrimSpace(msg.Role) != "" { - if text := extractOpenAIMessageText(msg.Content); strings.TrimSpace(text) != "" { - return strings.TrimSpace(text) - } + continue } } @@ -875,25 +979,32 @@ func firstOpenAIConversationAnchor(messages []OpenAIMessage) string { } } - for _, msg := range messages { - text := extractOpenAIMessageText(msg.Content) - if strings.TrimSpace(text) != "" { - return strings.TrimSpace(text) - } - } - return "" } func buildConversationID(modelID, systemPrompt, anchor string) string { anchor = strings.TrimSpace(anchor) - if anchor == "" { + if isSyntheticConversationAnchor(anchor) { return uuid.New().String() } seed := strings.Join([]string{modelID, strings.TrimSpace(systemPrompt), anchor}, "\n") return uuid.NewSHA1(uuid.NameSpaceURL, []byte(seed)).String() } +func isSyntheticConversationAnchor(anchor string) bool { + if strings.TrimSpace(anchor) == "" { + return true + } + + normalized := strings.ToLower(strings.Join(strings.Fields(anchor), " ")) + switch normalized { + case ".", "begin conversation", "please analyze the attached image.", strings.ToLower(minimalFallbackUserContent): + return true + default: + return false + } +} + func extractOpenAITextPart(part map[string]interface{}) (string, bool) { partType, _ := part["type"].(string) switch partType { diff --git a/proxy/translator_test.go b/proxy/translator_test.go index c650081..e0f276f 100644 --- a/proxy/translator_test.go +++ b/proxy/translator_test.go @@ -76,7 +76,7 @@ func TestOpenAIToKiroPreservesStructuredAssistantAndToolContent(t *testing.T) { } cur := payload.ConversationState.CurrentMessage.UserInputMessage - if cur.Content != "tool-result-structured" { + if !strings.Contains(cur.Content, "tool-result-structured") { t.Fatalf("expected tool-result continuation content, got %q", cur.Content) } if cur.UserInputMessageContext == nil || len(cur.UserInputMessageContext.ToolResults) != 1 { @@ -196,3 +196,84 @@ func TestClaudeConversationIDStableFromAnchor(t *testing.T) { t.Fatalf("expected stable conversation ID across turns, got %q vs %q", payloadA.ConversationState.ConversationID, payloadB.ConversationState.ConversationID) } } + +func TestOpenAIConversationIDRandomForSyntheticAnchor(t *testing.T) { + req := &OpenAIRequest{ + Model: "claude-sonnet-4.5", + Messages: []OpenAIMessage{ + {Role: "assistant", Content: "prefill"}, + }, + } + + payloadA := OpenAIToKiro(req, false) + payloadB := OpenAIToKiro(req, false) + + if payloadA.ConversationState.ConversationID == payloadB.ConversationState.ConversationID { + t.Fatalf("expected synthetic anchor to generate non-deterministic conversation IDs") + } +} + +func TestClaudeToKiroDropsLeadingAssistantHistory(t *testing.T) { + req := &ClaudeRequest{ + Model: "claude-sonnet-4.5", + Messages: []ClaudeMessage{ + {Role: "assistant", Content: "prefill"}, + {Role: "user", Content: "real user message"}, + }, + } + + payload := ClaudeToKiro(req, false) + + if len(payload.ConversationState.History) != 0 { + t.Fatalf("expected leading assistant-only history to be dropped, got %d entries", len(payload.ConversationState.History)) + } + + if strings.Contains(payload.ConversationState.CurrentMessage.UserInputMessage.Content, "Begin conversation") { + t.Fatalf("unexpected synthetic Begin conversation injection in current content: %q", payload.ConversationState.CurrentMessage.UserInputMessage.Content) + } +} + +func TestKiroToClaudeResponseCanEmitEmptyThinkingBlock(t *testing.T) { + resp := KiroToClaudeResponse("final answer", "", true, nil, 10, 20, "claude-sonnet-4.6") + + if len(resp.Content) != 2 { + t.Fatalf("expected empty thinking block plus text block, got %d blocks", len(resp.Content)) + } + if resp.Content[0].Type != "thinking" { + t.Fatalf("expected first block to be thinking, got %#v", resp.Content[0]) + } + if resp.Content[0].Thinking != "" { + t.Fatalf("expected omitted thinking block to have empty content, got %#v", resp.Content[0].Thinking) + } + if resp.Content[1].Type != "text" || resp.Content[1].Text != "final answer" { + t.Fatalf("expected text block to be preserved, got %#v", resp.Content[1]) + } +} + +func TestToolResultsContinuationIncludesInstructionPrefix(t *testing.T) { + req := &OpenAIRequest{ + Model: "claude-sonnet-4.5", + Messages: []OpenAIMessage{ + {Role: "user", Content: "find data"}, + {Role: "assistant", ToolCalls: []ToolCall{{ + ID: "call_1", + Type: "function", + Function: struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + }{Name: "fetch", Arguments: "{}"}, + }}}, + {Role: "tool", ToolCallID: "call_1", Content: "result-1"}, + }, + } + + payload := OpenAIToKiro(req, false) + content := payload.ConversationState.CurrentMessage.UserInputMessage.Content + + if !strings.Contains(content, toolResultsContinuationPrefix) { + t.Fatalf("expected tool continuation prefix, got %q", content) + } + if !strings.Contains(content, "result-1") { + t.Fatalf("expected tool result text in continuation content, got %q", content) + } +} diff --git a/version.json b/version.json index e206569..14e14d7 100644 --- a/version.json +++ b/version.json @@ -1,5 +1,5 @@ { - "version": "1.0.3", - "changelog": "✅ 新增 clientID/clientSecret 校验\n⚖️ 新增账号权重字段,支持加权轮询策略\n🔄 批量账号管理(启用/禁用/刷新/详情)\n🚫 自动跳过用量耗尽的账号\n🔧 重构模型映射为有序列表,避免误匹配", + "version": "1.0.6", + "changelog": "✨ Added and fixed several improvements across the project.\n✨ 新增并修复了一些内容,包含若干功能改进与问题修复。", "download": "https://github.com/Quorinex/Kiro-Go" } diff --git a/web/index.html b/web/index.html index a417281..1acdc92 100644 --- a/web/index.html +++ b/web/index.html @@ -1017,6 +1017,34 @@ id="newPassword" data-i18n-placeholder="settings.newPasswordPlaceholder"> +
+
+
+ + +
+ + +
@@ -1147,6 +1175,16 @@ 'settings.statistics': '统计', 'settings.resetStats': '重置统计', 'settings.confirmReset': '确定重置统计?', + 'settings.proxySettings': '出站代理设置', + 'settings.proxyType': '代理类型', + 'settings.proxyNone': '直连(不使用代理)', + 'settings.proxyHost': '地址 / 端口', + 'settings.proxyAuth': '认证(可选)', + 'settings.proxyUsername': '用户名', + 'settings.proxyPassword': '密码', + 'settings.proxyHostRequired': '请填写代理地址和端口', + 'settings.saveProxy': '保存代理设置', + 'settings.proxySaved': '代理设置已保存', 'api.endpoints': 'API 端点', 'api.modelList': '模型列表', 'api.stats': '统计数据', @@ -1357,6 +1395,16 @@ 'settings.statistics': 'Statistics', 'settings.resetStats': 'Reset Statistics', 'settings.confirmReset': 'Confirm reset statistics?', + 'settings.proxySettings': 'Outbound Proxy Settings', + 'settings.proxyType': 'Proxy Type', + 'settings.proxyNone': 'Direct (no proxy)', + 'settings.proxyHost': 'Host / Port', + 'settings.proxyAuth': 'Authentication (optional)', + 'settings.proxyUsername': 'Username', + 'settings.proxyPassword': 'Password', + 'settings.proxyHostRequired': 'Please enter proxy host and port', + 'settings.saveProxy': 'Save Proxy Settings', + 'settings.proxySaved': 'Proxy settings saved', 'api.endpoints': 'API Endpoints', 'api.modelList': 'Model List', 'api.stats': 'Statistics', @@ -2021,6 +2069,7 @@ document.getElementById('apiKeyInput').value = d.apiKey || ''; loadThinkingConfig(); loadEndpointConfig(); + loadProxyConfig(); } async function loadThinkingConfig() { const res = await fetch('/admin/api/thinking', { headers: { 'X-Admin-Password': password } }); @@ -2050,6 +2099,52 @@ const d = await res.json(); if (d.success) { alert(t('settings.endpointSaved')); } else { alert(t('common.saveFailed') + ': ' + d.error); } } + async function loadProxyConfig() { + const res = await fetch('/admin/api/proxy', { headers: { 'X-Admin-Password': password } }); + const d = await res.json(); + const proxyURL = d.proxyURL || ''; + if (!proxyURL) { + document.getElementById('proxyType').value = 'none'; + document.getElementById('proxyFields').style.display = 'none'; + return; + } + try { + const u = new URL(proxyURL); + const scheme = u.protocol.replace(':', ''); + document.getElementById('proxyType').value = scheme.startsWith('socks5') ? 'socks5' : 'http'; + document.getElementById('proxyHost').value = u.hostname; + document.getElementById('proxyPort').value = u.port; + document.getElementById('proxyUsername').value = decodeURIComponent(u.username); + document.getElementById('proxyPassword').value = decodeURIComponent(u.password); + document.getElementById('proxyFields').style.display = ''; + } catch(e) { + document.getElementById('proxyType').value = 'none'; + document.getElementById('proxyFields').style.display = 'none'; + } + } + function onProxyTypeChange() { + const type = document.getElementById('proxyType').value; + document.getElementById('proxyFields').style.display = type === 'none' ? 'none' : ''; + } + async function saveProxyConfig() { + const type = document.getElementById('proxyType').value; + let proxyURL = ''; + if (type !== 'none') { + const host = document.getElementById('proxyHost').value.trim(); + const port = document.getElementById('proxyPort').value.trim(); + if (!host || !port) { alert(t('settings.proxyHostRequired')); return; } + const user = document.getElementById('proxyUsername').value.trim(); + const pass = document.getElementById('proxyPassword').value.trim(); + const auth = user ? (pass ? `${encodeURIComponent(user)}:${encodeURIComponent(pass)}@` : `${encodeURIComponent(user)}@`) : ''; + proxyURL = `${type}://${auth}${host}:${port}`; + } + const res = await fetch('/admin/api/proxy', { + method: 'POST', headers: { 'Content-Type': 'application/json', 'X-Admin-Password': password }, + body: JSON.stringify({ proxyURL }) + }); + const d = await res.json(); + if (d.success) { alert(t('settings.proxySaved')); } else { alert(t('common.saveFailed') + ': ' + d.error); } + } function generateApiKey() { const chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'; let key = 'sk-';