diff --git a/.dockerignore b/.dockerignore index ab803d44..770145cb 100644 --- a/.dockerignore +++ b/.dockerignore @@ -61,6 +61,9 @@ temp/ deploy/install.sh deploy/sub2api.service deploy/sub2api-sudoers +deploy/data/ +deploy/postgres_data/ +deploy/redis_data/ # GoReleaser .goreleaser.yaml diff --git a/.gitattributes b/.gitattributes index 3db3b83d..37e3bee2 100644 --- a/.gitattributes +++ b/.gitattributes @@ -4,6 +4,13 @@ backend/migrations/*.sql text eol=lf # Go 源代码文件 *.go text eol=lf +# 前端 源代码文件 +*.ts text eol=lf +*.tsx text eol=lf +*.js text eol=lf +*.jsx text eol=lf +*.vue text eol=lf + # Shell 脚本 *.sh text eol=lf diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5c0524c8..c51b3c07 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -271,3 +271,36 @@ jobs: parse_mode: "Markdown", disable_web_page_preview: true }')" + + sync-version-file: + needs: [release] + if: ${{ needs.release.result == 'success' }} + runs-on: ubuntu-latest + steps: + - name: Checkout default branch + uses: actions/checkout@v6 + with: + ref: ${{ github.event.repository.default_branch }} + + - name: Sync VERSION file to released tag + run: | + if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then + VERSION=${{ github.event.inputs.tag }} + VERSION=${VERSION#v} + else + VERSION=${GITHUB_REF#refs/tags/v} + fi + + CURRENT_VERSION=$(tr -d '\r\n' < backend/cmd/server/VERSION || true) + if [ "$CURRENT_VERSION" = "$VERSION" ]; then + echo "VERSION file already matches $VERSION" + exit 0 + fi + + echo "$VERSION" > backend/cmd/server/VERSION + + git config user.name "github-actions[bot]" + git config user.email "41898282+github-actions[bot]@users.noreply.github.com" + git add backend/cmd/server/VERSION + git commit -m "chore: sync VERSION to ${VERSION} [skip ci]" + git push origin HEAD:${{ github.event.repository.default_branch }} diff --git a/.goreleaser.simple.yaml b/.goreleaser.simple.yaml index 2155ed9d..14f67fd1 100644 --- a/.goreleaser.simple.yaml +++ b/.goreleaser.simple.yaml @@ -47,6 +47,8 @@ dockers: - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest" dockerfile: Dockerfile.goreleaser use: buildx + extra_files: + - deploy/docker-entrypoint.sh build_flag_templates: - "--platform=linux/amd64" - "--label=org.opencontainers.image.version={{ .Version }}" diff --git a/.goreleaser.yaml b/.goreleaser.yaml index da2f9aa5..41f9a555 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -63,6 +63,8 @@ dockers: - "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64" dockerfile: Dockerfile.goreleaser use: buildx + extra_files: + - deploy/docker-entrypoint.sh build_flag_templates: - "--platform=linux/amd64" - "--label=org.opencontainers.image.version={{ .Version }}" @@ -76,6 +78,8 @@ dockers: - "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64" dockerfile: Dockerfile.goreleaser use: buildx + extra_files: + - deploy/docker-entrypoint.sh build_flag_templates: - "--platform=linux/arm64" - "--label=org.opencontainers.image.version={{ .Version }}" @@ -89,6 +93,8 @@ dockers: - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64" dockerfile: Dockerfile.goreleaser use: buildx + extra_files: + - deploy/docker-entrypoint.sh build_flag_templates: - "--platform=linux/amd64" - "--label=org.opencontainers.image.version={{ .Version }}" @@ -102,6 +108,8 @@ dockers: - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64" dockerfile: Dockerfile.goreleaser use: buildx + extra_files: + - deploy/docker-entrypoint.sh build_flag_templates: - "--platform=linux/arm64" - "--label=org.opencontainers.image.version={{ .Version }}" diff --git a/Dockerfile b/Dockerfile index 8fd48cc2..a16eb958 100644 --- a/Dockerfile +++ b/Dockerfile @@ -92,6 +92,7 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api" RUN apk add --no-cache \ ca-certificates \ tzdata \ + su-exec \ libpq \ zstd-libs \ lz4-libs \ @@ -120,8 +121,9 @@ COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/ # Create data directory RUN mkdir -p /app/data && chown sub2api:sub2api /app/data -# Switch to non-root user -USER sub2api +# Copy entrypoint script (fixes volume permissions then drops to sub2api) +COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh +RUN chmod +x /app/docker-entrypoint.sh # Expose port (can be overridden by SERVER_PORT env var) EXPOSE 8080 @@ -130,5 +132,6 @@ EXPOSE 8080 HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \ CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1 -# Run the application -ENTRYPOINT ["/app/sub2api"] +# Run the application (entrypoint fixes /app/data ownership then execs as sub2api) +ENTRYPOINT ["/app/docker-entrypoint.sh"] +CMD ["/app/sub2api"] diff --git a/Dockerfile.goreleaser b/Dockerfile.goreleaser index 419994b9..f251d154 100644 --- a/Dockerfile.goreleaser +++ b/Dockerfile.goreleaser @@ -21,6 +21,7 @@ RUN apk add --no-cache \ ca-certificates \ tzdata \ curl \ + su-exec \ libpq \ zstd-libs \ lz4-libs \ @@ -47,11 +48,15 @@ COPY sub2api /app/sub2api # Create data directory RUN mkdir -p /app/data && chown -R sub2api:sub2api /app -USER sub2api +# Copy entrypoint script (fixes volume permissions then drops to sub2api) +COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh +RUN chmod +x /app/docker-entrypoint.sh EXPOSE 8080 HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \ CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1 -ENTRYPOINT ["/app/sub2api"] +# Run the application (entrypoint fixes /app/data ownership then execs as sub2api) +ENTRYPOINT ["/app/docker-entrypoint.sh"] +CMD ["/app/sub2api"] diff --git a/README.md b/README.md index 4a7bde8e..99753e45 100644 --- a/README.md +++ b/README.md @@ -8,27 +8,31 @@ [![Redis](https://img.shields.io/badge/Redis-7+-DC382D.svg)](https://redis.io/) [![Docker](https://img.shields.io/badge/Docker-Ready-2496ED.svg)](https://www.docker.com/) +Wei-Shaw%2Fsub2api | Trendshift + **AI API Gateway Platform for Subscription Quota Distribution** -English | [中文](README_CN.md) +English | [中文](README_CN.md) | [日本語](README_JA.md) +> **Sub2API officially uses only the domains `sub2api.org` and `pincc.ai`. Other websites using the Sub2API name may be third-party deployments or services and are not affiliated with this project. Please verify and exercise your own judgment.** + --- ## Demo -Try Sub2API online: **https://demo.sub2api.org/** +Try Sub2API online: **[https://demo.sub2api.org/](https://demo.sub2api.org/)** Demo credentials (shared demo environment; **not** created automatically for self-hosted installs): | Email | Password | |-------|----------| -| admin@sub2api.com | admin123 | +| admin@sub2api.org | admin123 | ## Overview -Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding. +Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions. Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding. ## Features @@ -41,6 +45,19 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot - **Admin Dashboard** - Web interface for monitoring and management - **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard +## Don't Want to Self-Host? + + + + + + + + + + +
pinccPinCC is the official relay service built on Sub2API, offering stable access to Claude Code, Codex, Gemini and other popular models — ready to use, no deployment or maintenance required.
PackyCodeThanks to PackyCode for sponsoring this project! PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. PackyCode provides special discounts for our software users: register using this link and enter the "sub2api" promo code during first recharge to get 10% off.
+ ## Ecosystem Community projects that extend or integrate with Sub2API: @@ -61,10 +78,15 @@ Community projects that extend or integrate with Sub2API: --- -## Documentation +## Nginx Reverse Proxy Note -- Dependency Security: `docs/dependency-security.md` -- Admin Payment Integration API: `docs/ADMIN_PAYMENT_INTEGRATION_API.md` +When using Nginx as a reverse proxy for Sub2API (or CRS) with Codex CLI, add the following to the `http` block in your Nginx configuration: + +```nginx +underscores_in_headers on; +``` + +Nginx drops headers containing underscores by default (e.g. `session_id`), which breaks sticky session routing in multi-account setups. --- @@ -160,10 +182,10 @@ mkdir -p sub2api-deploy && cd sub2api-deploy curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash # Start services -docker-compose up -d +docker compose up -d # View logs -docker-compose logs -f sub2api +docker compose logs -f sub2api ``` **What the script does:** @@ -227,16 +249,16 @@ mkdir -p data postgres_data redis_data # 5. Start all services # Option A: Local directory version (recommended - easy migration) -docker-compose -f docker-compose.local.yml up -d +docker compose -f docker-compose.local.yml up -d # Option B: Named volumes version (simple setup) -docker-compose up -d +docker compose up -d # 6. Check status -docker-compose -f docker-compose.local.yml ps +docker compose -f docker-compose.local.yml ps # 7. View logs -docker-compose -f docker-compose.local.yml logs -f sub2api +docker compose -f docker-compose.local.yml logs -f sub2api ``` #### Deployment Versions @@ -254,15 +276,15 @@ Open `http://YOUR_SERVER_IP:8080` in your browser. If admin password was auto-generated, find it in logs: ```bash -docker-compose -f docker-compose.local.yml logs sub2api | grep "admin password" +docker compose -f docker-compose.local.yml logs sub2api | grep "admin password" ``` #### Upgrade ```bash # Pull latest image and recreate container -docker-compose -f docker-compose.local.yml pull -docker-compose -f docker-compose.local.yml up -d +docker compose -f docker-compose.local.yml pull +docker compose -f docker-compose.local.yml up -d ``` #### Easy Migration (Local Directory Version) @@ -271,7 +293,7 @@ When using `docker-compose.local.yml`, migrate to a new server easily: ```bash # On source server -docker-compose -f docker-compose.local.yml down +docker compose -f docker-compose.local.yml down cd .. tar czf sub2api-complete.tar.gz sub2api-deploy/ @@ -281,23 +303,23 @@ scp sub2api-complete.tar.gz user@new-server:/path/ # On new server tar xzf sub2api-complete.tar.gz cd sub2api-deploy/ -docker-compose -f docker-compose.local.yml up -d +docker compose -f docker-compose.local.yml up -d ``` #### Useful Commands ```bash # Stop all services -docker-compose -f docker-compose.local.yml down +docker compose -f docker-compose.local.yml down # Restart -docker-compose -f docker-compose.local.yml restart +docker compose -f docker-compose.local.yml restart # View all logs -docker-compose -f docker-compose.local.yml logs -f +docker compose -f docker-compose.local.yml logs -f # Remove all data (caution!) -docker-compose -f docker-compose.local.yml down +docker compose -f docker-compose.local.yml down rm -rf data/ postgres_data/ redis_data/ ``` diff --git a/README_CN.md b/README_CN.md index eee89b07..8b6feaba 100644 --- a/README_CN.md +++ b/README_CN.md @@ -8,27 +8,30 @@ [![Redis](https://img.shields.io/badge/Redis-7+-DC382D.svg)](https://redis.io/) [![Docker](https://img.shields.io/badge/Docker-Ready-2496ED.svg)](https://www.docker.com/) +Wei-Shaw%2Fsub2api | Trendshift + **AI API 网关平台 - 订阅配额分发管理** -[English](README.md) | 中文 +[English](README.md) | 中文 | [日本語](README_JA.md) +> **Sub2API 官方仅使用 `sub2api.org` 与 `pincc.ai` 两个域名。其他使用 Sub2API 名义的网站可能为第三方部署或服务,与本项目无关,请自行甄别。** --- ## 在线体验 -体验地址:**https://v2.pincc.ai/** +体验地址:**[https://demo.sub2api.org/](https://demo.sub2api.org/)** 演示账号(共享演示环境;自建部署不会自动创建该账号): | 邮箱 | 密码 | |------|------| -| admin@sub2api.com | admin123 | +| admin@sub2api.org | admin123 | ## 项目概述 -Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。 +Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。 ## 核心功能 @@ -41,6 +44,19 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅( - **管理后台** - Web 界面进行监控和管理 - **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能 +## 不想自建?试试官方中转 + + + + + + + + + + +
pinccPinCC 是基于 Sub2API 搭建的官方中转服务,提供 Claude Code、Codex、Gemini 等主流模型的稳定中转,开箱即用,免去自建部署与运维烦恼。
PackyCode感谢 PackyCode 赞助了本项目!PackyCode 是一家稳定、高效的API中转服务商,提供 Claude Code、Codex、Gemini 等多种中转服务。PackyCode 为本软件的用户提供了特别优惠,使用此链接注册并在充值时填写"sub2api"优惠码,首次充值可以享受9折优惠!
+ ## 生态项目 围绕 Sub2API 的社区扩展与集成项目: @@ -61,17 +77,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅( --- -## 文档 +## Nginx 反向代理注意事项 -- 依赖安全:`docs/dependency-security.md` +通过 Nginx 反向代理 Sub2API(或 CRS 服务)并搭配 Codex CLI 使用时,需要在 Nginx 配置的 `http` 块中添加: + +```nginx +underscores_in_headers on; +``` + +Nginx 默认会丢弃名称中含下划线的请求头(如 `session_id`),这会导致多账号环境下的粘性会话功能失效。 --- -## OpenAI Responses 兼容注意事项 - -- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。 -- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。 - ## 部署方式 ### 方式一:脚本安装(推荐) @@ -164,10 +181,10 @@ mkdir -p sub2api-deploy && cd sub2api-deploy curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash # 启动服务 -docker-compose up -d +docker compose up -d # 查看日志 -docker-compose logs -f sub2api +docker compose logs -f sub2api ``` **脚本功能:** @@ -231,16 +248,16 @@ mkdir -p data postgres_data redis_data # 5. 启动所有服务 # 选项 A:本地目录版(推荐 - 易于迁移) -docker-compose -f docker-compose.local.yml up -d +docker compose -f docker-compose.local.yml up -d # 选项 B:命名卷版(简单设置) -docker-compose up -d +docker compose up -d # 6. 查看状态 -docker-compose -f docker-compose.local.yml ps +docker compose -f docker-compose.local.yml ps # 7. 查看日志 -docker-compose -f docker-compose.local.yml logs -f sub2api +docker compose -f docker-compose.local.yml logs -f sub2api ``` #### 部署版本对比 @@ -270,15 +287,15 @@ docker-compose -f docker-compose.local.yml logs -f sub2api 如果管理员密码是自动生成的,在日志中查找: ```bash -docker-compose -f docker-compose.local.yml logs sub2api | grep "admin password" +docker compose -f docker-compose.local.yml logs sub2api | grep "admin password" ``` #### 升级 ```bash # 拉取最新镜像并重建容器 -docker-compose -f docker-compose.local.yml pull -docker-compose -f docker-compose.local.yml up -d +docker compose -f docker-compose.local.yml pull +docker compose -f docker-compose.local.yml up -d ``` #### 轻松迁移(本地目录版) @@ -287,7 +304,7 @@ docker-compose -f docker-compose.local.yml up -d ```bash # 源服务器 -docker-compose -f docker-compose.local.yml down +docker compose -f docker-compose.local.yml down cd .. tar czf sub2api-complete.tar.gz sub2api-deploy/ @@ -297,23 +314,23 @@ scp sub2api-complete.tar.gz user@new-server:/path/ # 新服务器 tar xzf sub2api-complete.tar.gz cd sub2api-deploy/ -docker-compose -f docker-compose.local.yml up -d +docker compose -f docker-compose.local.yml up -d ``` #### 常用命令 ```bash # 停止所有服务 -docker-compose -f docker-compose.local.yml down +docker compose -f docker-compose.local.yml down # 重启 -docker-compose -f docker-compose.local.yml restart +docker compose -f docker-compose.local.yml restart # 查看所有日志 -docker-compose -f docker-compose.local.yml logs -f +docker compose -f docker-compose.local.yml logs -f # 删除所有数据(谨慎!) -docker-compose -f docker-compose.local.yml down +docker compose -f docker-compose.local.yml down rm -rf data/ postgres_data/ redis_data/ ``` diff --git a/README_JA.md b/README_JA.md new file mode 100644 index 00000000..1266bd84 --- /dev/null +++ b/README_JA.md @@ -0,0 +1,589 @@ +# Sub2API + +
+ +[![Go](https://img.shields.io/badge/Go-1.25.7-00ADD8.svg)](https://golang.org/) +[![Vue](https://img.shields.io/badge/Vue-3.4+-4FC08D.svg)](https://vuejs.org/) +[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-15+-336791.svg)](https://www.postgresql.org/) +[![Redis](https://img.shields.io/badge/Redis-7+-DC382D.svg)](https://redis.io/) +[![Docker](https://img.shields.io/badge/Docker-Ready-2496ED.svg)](https://www.docker.com/) + +Wei-Shaw%2Fsub2api | Trendshift + +**サブスクリプションクォータ配分のための AI API ゲートウェイプラットフォーム** + +[English](README.md) | [中文](README_CN.md) | 日本語 + +
+ +> **Sub2API が公式に使用しているドメインは `sub2api.org` と `pincc.ai` のみです。Sub2API の名称を使用している他のウェブサイトは、サードパーティによるデプロイやサービスであり、本プロジェクトとは一切関係がありません。ご利用の際はご自身で確認・判断をお願いします。** + +--- + +## デモ + +Sub2API をオンラインでお試しください: **[https://demo.sub2api.org/](https://demo.sub2api.org/)** + +デモ用認証情報(共有デモ環境です。セルフホスト環境では**自動作成されません**): + +| メールアドレス | パスワード | +|-------|----------| +| admin@sub2api.org | admin123 | + +## 概要 + +Sub2API は、AI 製品のサブスクリプションから API クォータを配分・管理するために設計された AI API ゲートウェイプラットフォームです。ユーザーはプラットフォームが生成した API キーを通じて上流の AI サービスにアクセスでき、プラットフォームは認証、課金、負荷分散、リクエスト転送を処理します。 + +## 機能 + +- **マルチアカウント管理** - 複数の上流アカウントタイプ(OAuth、APIキー)をサポート +- **APIキー配布** - ユーザー向けの APIキーの生成と管理 +- **精密な課金** - トークンレベルの使用量追跡とコスト計算 +- **スマートスケジューリング** - スティッキーセッション付きのインテリジェントなアカウント選択 +- **同時実行制御** - ユーザーごと・アカウントごとの同時実行数制限 +- **レート制限** - 設定可能なリクエスト数およびトークンレート制限 +- **管理ダッシュボード** - 監視・管理のための Web インターフェース +- **外部システム連携** - 外部システム(決済、チケット管理など)を iframe 経由で管理ダッシュボードに埋め込み可能 + +## セルフホストが不要な方へ + + + + + + + + + + +
pinccPinCC は Sub2API 上に構築された公式リレーサービスで、Claude Code、Codex、Gemini などの人気モデルへの安定したアクセスを提供します。デプロイやメンテナンスは不要で、すぐにご利用いただけます。
PackyCodePackyCode のご支援に感謝します!PackyCode は Claude Code、Codex、Gemini などのリレーサービスを提供する信頼性の高い API 中継プラットフォームです。本ソフト利用者向けに特別割引があります:このリンクで登録し、チャージ時に「sub2api」クーポンを入力すると 10% オフになります。
+ +## エコシステム + +Sub2API を拡張・統合するコミュニティプロジェクト: + +| プロジェクト | 説明 | 機能 | +|---------|-------------|----------| +| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | セルフサービス決済システム | セルフサービスによるチャージおよびサブスクリプション購入。YiPay プロトコル、WeChat Pay、Alipay、Stripe 対応。iframe での埋め込み可能 | +| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | モバイル管理コンソール | ユーザー管理、アカウント管理、監視ダッシュボード、マルチバックエンド切り替えが可能なクロスプラットフォームアプリ(iOS/Android/Web)。Expo + React Native で構築 | + +## 技術スタック + +| コンポーネント | 技術 | +|-----------|------------| +| バックエンド | Go 1.25.7, Gin, Ent | +| フロントエンド | Vue 3.4+, Vite 5+, TailwindCSS | +| データベース | PostgreSQL 15+ | +| キャッシュ/キュー | Redis 7+ | + +--- + +## Nginx リバースプロキシに関する注意 + +Sub2API(または CRS)を Nginx でリバースプロキシし、Codex CLI と組み合わせて使用する場合、Nginx の `http` ブロックに以下の設定を追加してください: + +```nginx +underscores_in_headers on; +``` + +Nginx はデフォルトでアンダースコアを含むヘッダー(例: `session_id`)を破棄するため、マルチアカウント構成でのスティッキーセッションルーティングに支障をきたします。 + +--- + +## デプロイ + +### 方法1: スクリプトによるインストール(推奨) + +GitHub Releases からビルド済みバイナリをダウンロードするワンクリックインストールスクリプトです。 + +#### 前提条件 + +- Linux サーバー(amd64 または arm64) +- PostgreSQL 15+(インストール済みかつ稼働中) +- Redis 7+(インストール済みかつ稼働中) +- root 権限 + +#### インストール手順 + +```bash +curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash +``` + +スクリプトは以下を実行します: +1. システムアーキテクチャの検出 +2. 最新リリースのダウンロード +3. バイナリを `/opt/sub2api` にインストール +4. systemd サービスの作成 +5. システムユーザーと権限の設定 + +#### インストール後の作業 + +```bash +# 1. サービスを起動 +sudo systemctl start sub2api + +# 2. 起動時の自動起動を有効化 +sudo systemctl enable sub2api + +# 3. ブラウザでセットアップウィザードを開く +# http://YOUR_SERVER_IP:8080 +``` + +セットアップウィザードでは以下の設定を行います: +- データベース設定 +- Redis 設定 +- 管理者アカウントの作成 + +#### アップグレード + +**管理ダッシュボード**の左上にある**アップデートを確認**ボタンをクリックすることで、ダッシュボードから直接アップグレードできます。 + +Web インターフェースでは以下が可能です: +- 新しいバージョンの自動確認 +- ワンクリックでのアップデートのダウンロードと適用 +- 必要に応じたロールバック + +#### よく使うコマンド + +```bash +# ステータスを確認 +sudo systemctl status sub2api + +# ログを表示 +sudo journalctl -u sub2api -f + +# サービスを再起動 +sudo systemctl restart sub2api + +# アンインストール +curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash -s -- uninstall -y +``` + +--- + +### 方法2: Docker Compose(推奨) + +PostgreSQL と Redis のコンテナを含む Docker Compose でデプロイします。 + +#### 前提条件 + +- Docker 20.10+ +- Docker Compose v2+ + +#### クイックスタート(ワンクリックデプロイ) + +自動デプロイスクリプトを使用して簡単にセットアップできます: + +```bash +# デプロイ用ディレクトリを作成 +mkdir -p sub2api-deploy && cd sub2api-deploy + +# デプロイ準備スクリプトをダウンロードして実行 +curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/docker-deploy.sh | bash + +# サービスを起動 +docker compose up -d + +# ログを表示 +docker compose logs -f sub2api +``` + +**スクリプトの動作内容:** +- `docker-compose.local.yml`(`docker-compose.yml` として保存)と `.env.example` をダウンロード +- セキュアな認証情報(JWT_SECRET、TOTP_ENCRYPTION_KEY、POSTGRES_PASSWORD)を自動生成 +- 自動生成されたシークレットで `.env` ファイルを作成 +- データディレクトリを作成(バックアップ・移行が容易なローカルディレクトリを使用) +- 生成された認証情報を参照用に表示 + +#### 手動デプロイ + +手動でセットアップする場合: + +```bash +# 1. リポジトリをクローン +git clone https://github.com/Wei-Shaw/sub2api.git +cd sub2api/deploy + +# 2. 環境設定ファイルをコピー +cp .env.example .env + +# 3. 設定を編集(セキュアなパスワードを生成) +nano .env +``` + +**`.env` の必須設定:** + +```bash +# PostgreSQL パスワード(必須) +POSTGRES_PASSWORD=your_secure_password_here + +# JWT シークレット(推奨 - 再起動後もユーザーのログイン状態を保持) +JWT_SECRET=your_jwt_secret_here + +# TOTP 暗号化キー(推奨 - 再起動後も二要素認証を維持) +TOTP_ENCRYPTION_KEY=your_totp_key_here + +# オプション: 管理者アカウント +ADMIN_EMAIL=admin@example.com +ADMIN_PASSWORD=your_admin_password + +# オプション: カスタムポート +SERVER_PORT=8080 +``` + +**セキュアなシークレットの生成方法:** +```bash +# JWT_SECRET を生成 +openssl rand -hex 32 + +# TOTP_ENCRYPTION_KEY を生成 +openssl rand -hex 32 + +# POSTGRES_PASSWORD を生成 +openssl rand -hex 32 +``` + +```bash +# 4. データディレクトリを作成(ローカルバージョンの場合) +mkdir -p data postgres_data redis_data + +# 5. すべてのサービスを起動 +# オプション A: ローカルディレクトリバージョン(推奨 - 移行が容易) +docker compose -f docker-compose.local.yml up -d + +# オプション B: 名前付きボリュームバージョン(シンプルなセットアップ) +docker compose up -d + +# 6. ステータスを確認 +docker compose -f docker-compose.local.yml ps + +# 7. ログを表示 +docker compose -f docker-compose.local.yml logs -f sub2api +``` + +#### デプロイバージョン + +| バージョン | データストレージ | 移行 | 推奨用途 | +|---------|-------------|-----------|----------| +| **docker-compose.local.yml** | ローカルディレクトリ | ✅ 容易(ディレクトリ全体を tar) | 本番環境、頻繁なバックアップ | +| **docker-compose.yml** | 名前付きボリューム | ⚠️ docker コマンドが必要 | シンプルなセットアップ | + +**推奨:** データ管理が容易な `docker-compose.local.yml`(スクリプトによるデプロイ)を使用してください。 + +#### アクセス + +ブラウザで `http://YOUR_SERVER_IP:8080` を開いてください。 + +管理者パスワードが自動生成された場合は、ログで確認できます: +```bash +docker compose -f docker-compose.local.yml logs sub2api | grep "admin password" +``` + +#### アップグレード + +```bash +# 最新イメージをプルしてコンテナを再作成 +docker compose -f docker-compose.local.yml pull +docker compose -f docker-compose.local.yml up -d +``` + +#### 簡単な移行(ローカルディレクトリバージョン) + +`docker-compose.local.yml` を使用している場合、新しいサーバーへの移行が簡単です: + +```bash +# 移行元サーバーにて +docker compose -f docker-compose.local.yml down +cd .. +tar czf sub2api-complete.tar.gz sub2api-deploy/ + +# 新しいサーバーに転送 +scp sub2api-complete.tar.gz user@new-server:/path/ + +# 移行先サーバーにて +tar xzf sub2api-complete.tar.gz +cd sub2api-deploy/ +docker compose -f docker-compose.local.yml up -d +``` + +#### よく使うコマンド + +```bash +# すべてのサービスを停止 +docker compose -f docker-compose.local.yml down + +# 再起動 +docker compose -f docker-compose.local.yml restart + +# すべてのログを表示 +docker compose -f docker-compose.local.yml logs -f + +# すべてのデータを削除(注意!) +docker compose -f docker-compose.local.yml down +rm -rf data/ postgres_data/ redis_data/ +``` + +--- + +### 方法3: ソースからビルド + +開発やカスタマイズのためにソースコードからビルドして実行します。 + +#### 前提条件 + +- Go 1.21+ +- Node.js 18+ +- PostgreSQL 15+ +- Redis 7+ + +#### ビルド手順 + +```bash +# 1. リポジトリをクローン +git clone https://github.com/Wei-Shaw/sub2api.git +cd sub2api + +# 2. pnpm をインストール(未インストールの場合) +npm install -g pnpm + +# 3. フロントエンドをビルド +cd frontend +pnpm install +pnpm run build +# 出力先: ../backend/internal/web/dist/ + +# 4. フロントエンドを組み込んだバックエンドをビルド +cd ../backend +go build -tags embed -o sub2api ./cmd/server + +# 5. 設定ファイルを作成 +cp ../deploy/config.example.yaml ./config.yaml + +# 6. 設定を編集 +nano config.yaml +``` + +> **注意:** `-tags embed` フラグはフロントエンドをバイナリに組み込みます。このフラグがない場合、バイナリはフロントエンド UI を提供しません。 + +**`config.yaml` の主要設定:** + +```yaml +server: + host: "0.0.0.0" + port: 8080 + mode: "release" + +database: + host: "localhost" + port: 5432 + user: "postgres" + password: "your_password" + dbname: "sub2api" + +redis: + host: "localhost" + port: 6379 + password: "" + +jwt: + secret: "change-this-to-a-secure-random-string" + expire_hour: 24 + +default: + user_concurrency: 5 + user_balance: 0 + api_key_prefix: "sk-" + rate_multiplier: 1.0 +``` + +### Sora ステータス(一時的に利用不可) + +> ⚠️ Sora 関連の機能は、上流統合およびメディア配信の技術的問題により一時的に利用できません。 +> 現時点では本番環境で Sora に依存しないでください。 +> 既存の `gateway.sora_*` 設定キーは予約されていますが、これらの問題が解決されるまで有効にならない場合があります。 + +`config.yaml` では追加のセキュリティ関連オプションも利用できます: + +- `cors.allowed_origins` - CORS 許可リスト +- `security.url_allowlist` - 上流/価格/CRS ホストの許可リスト +- `security.url_allowlist.enabled` - URL バリデーションの無効化(注意して使用) +- `security.url_allowlist.allow_insecure_http` - バリデーション無効時に HTTP URL を許可 +- `security.url_allowlist.allow_private_hosts` - プライベート/ローカル IP アドレスを許可 +- `security.response_headers.enabled` - 設定可能なレスポンスヘッダーフィルタリングを有効化(無効時はデフォルトの許可リストを使用) +- `security.csp` - Content-Security-Policy ヘッダーの制御 +- `billing.circuit_breaker` - 課金エラー時にフェイルクローズ +- `server.trusted_proxies` - X-Forwarded-For パースの有効化 +- `turnstile.required` - リリースモードでの Turnstile 必須化 + +**⚠️ セキュリティ警告: HTTP URL 設定** + +`security.url_allowlist.enabled=false` の場合、システムはデフォルトで最小限の URL バリデーションを行い、**HTTP URL を拒否**して HTTPS のみを許可します。HTTP URL を許可するには(開発環境や内部テスト用など)、以下を明示的に設定する必要があります: + +```yaml +security: + url_allowlist: + enabled: false # 許可リストチェックを無効化 + allow_insecure_http: true # HTTP URL を許可(⚠️ セキュリティリスクあり) +``` + +**または環境変数で設定:** + +```bash +SECURITY_URL_ALLOWLIST_ENABLED=false +SECURITY_URL_ALLOWLIST_ALLOW_INSECURE_HTTP=true +``` + +**HTTP を許可するリスク:** +- API キーとデータが**平文**で送信される(傍受の危険性) +- **中間者攻撃(MITM)**を受けやすい +- **本番環境には不適切** + +**HTTP を使用すべき場面:** +- ✅ ローカルサーバーでの開発・テスト(http://localhost) +- ✅ 信頼できるエンドポイントを持つ内部ネットワーク +- ✅ HTTPS 取得前のアカウント接続テスト +- ❌ 本番環境(HTTPS のみを使用) + +**この設定なしで表示されるエラー例:** +``` +Invalid base URL: invalid url scheme: http +``` + +URL バリデーションまたはレスポンスヘッダーフィルタリングを無効にする場合は、ネットワーク層を強化してください: +- 上流ドメイン/IP のエグレス許可リストを適用 +- プライベート/ループバック/リンクローカル範囲をブロック +- TLS のみのアウトバウンドトラフィックを強制 +- プロキシで機密性の高い上流レスポンスヘッダーを除去 + +```bash +# 6. アプリケーションを実行 +./sub2api +``` + +#### 開発モード + +```bash +# バックエンド(ホットリロード付き) +cd backend +go run ./cmd/server + +# フロントエンド(ホットリロード付き) +cd frontend +pnpm run dev +``` + +#### コード生成 + +`backend/ent/schema` を編集した場合、Ent + Wire を再生成してください: + +```bash +cd backend +go generate ./ent +go generate ./cmd/server +``` + +--- + +## シンプルモード + +シンプルモードは、フル SaaS 機能を必要とせず、素早くアクセスしたい個人開発者や社内チーム向けに設計されています。 + +- 有効化: 環境変数 `RUN_MODE=simple` を設定 +- 違い: SaaS 関連機能を非表示にし、課金プロセスをスキップ +- セキュリティに関する注意: 本番環境では `SIMPLE_MODE_CONFIRM=true` も設定する必要があります + +--- + +## Antigravity サポート + +Sub2API は [Antigravity](https://antigravity.so/) アカウントをサポートしています。認証後、Claude および Gemini モデル用の専用エンドポイントが利用可能になります。 + +### 専用エンドポイント + +| エンドポイント | モデル | +|----------|-------| +| `/antigravity/v1/messages` | Claude モデル | +| `/antigravity/v1beta/` | Gemini モデル | + +### Claude Code の設定 + +```bash +export ANTHROPIC_BASE_URL="http://localhost:8080/antigravity" +export ANTHROPIC_AUTH_TOKEN="sk-xxx" +``` + +### ハイブリッドスケジューリングモード + +Antigravity アカウントはオプションの**ハイブリッドスケジューリング**をサポートしています。有効にすると、汎用エンドポイント `/v1/messages` および `/v1beta/` も Antigravity アカウントにリクエストをルーティングします。 + +> **⚠️ 警告**: Anthropic Claude と Antigravity Claude は**同じ会話コンテキスト内で混在させることはできません**。グループを使用して適切に分離してください。 + +### 既知の問題 + +Claude Code では、Plan Mode を自動的に終了できません。(通常、ネイティブの Claude API を使用する場合、計画が完了すると Claude Code はユーザーに計画を承認または拒否するオプションをポップアップ表示します。) + +**回避策**: `Shift + Tab` を押して手動で Plan Mode を終了し、計画を承認または拒否するためのレスポンスを入力してください。 + +--- + +## プロジェクト構成 + +``` +sub2api/ +├── backend/ # Go バックエンドサービス +│ ├── cmd/server/ # アプリケーションエントリ +│ ├── internal/ # 内部モジュール +│ │ ├── config/ # 設定 +│ │ ├── model/ # データモデル +│ │ ├── service/ # ビジネスロジック +│ │ ├── handler/ # HTTP ハンドラー +│ │ └── gateway/ # API ゲートウェイコア +│ └── resources/ # 静的リソース +│ +├── frontend/ # Vue 3 フロントエンド +│ └── src/ +│ ├── api/ # API 呼び出し +│ ├── stores/ # 状態管理 +│ ├── views/ # ページコンポーネント +│ └── components/ # 再利用可能なコンポーネント +│ +└── deploy/ # デプロイファイル + ├── docker-compose.yml # Docker Compose 設定 + ├── .env.example # Docker Compose 用環境変数 + ├── config.example.yaml # バイナリデプロイ用フル設定ファイル + └── install.sh # ワンクリックインストールスクリプト +``` + +## 免責事項 + +> **本プロジェクトをご利用の前に、以下をよくお読みください:** +> +> :rotating_light: **利用規約違反のリスク**: 本プロジェクトの使用は Anthropic の利用規約に違反する可能性があります。使用前に Anthropic のユーザー契約をよくお読みください。本プロジェクトの使用に起因するすべてのリスクは、ユーザー自身が負うものとします。 +> +> :book: **免責事項**: 本プロジェクトは技術的な学習および研究目的のみで提供されています。作者は、本プロジェクトの使用によるアカウント停止、サービス中断、その他の損失について一切の責任を負いません。 + +--- + +## スター履歴 + + + + + + Star History Chart + + + +--- + +## ライセンス + +MIT License + +--- + +
+ +**このプロジェクトが役に立ったら、ぜひスターをお願いします!** + +
diff --git a/assets/partners/logos/packycode.png b/assets/partners/logos/packycode.png new file mode 100644 index 00000000..4fc7eecc Binary files /dev/null and b/assets/partners/logos/packycode.png differ diff --git a/assets/partners/logos/pincc-logo.png b/assets/partners/logos/pincc-logo.png new file mode 100644 index 00000000..081b6c84 Binary files /dev/null and b/assets/partners/logos/pincc-logo.png differ diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 32844913..9e3db2aa 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.88 \ No newline at end of file +0.1.106 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index f632bff3..ce898a4a 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -110,11 +110,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) - groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) openAIOAuthClient := repository.NewOpenAIOAuthClient() openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) + openAIOAuthService.SetPrivacyClientFactory(privacyClientFactory) geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() driveClient := repository.NewGeminiDriveClient() @@ -132,17 +132,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) usageCache := service.NewUsageCache() identityCache := repository.NewIdentityCache(redisClient) - accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache) geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI) gatewayCache := repository.NewGatewayCache(redisClient) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) - antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI) - antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService) - accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) + antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache) + internal500CounterCache := repository.NewInternal500CounterCache(redisClient) + antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) + tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client) + tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient) + tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) + accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) + accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) rpmCache := repository.NewRPMCache(redisClient) + groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache) + groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) dataManagementService := service.NewDataManagementService() @@ -169,7 +175,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) @@ -201,12 +207,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient) errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache) errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService) + tlsFingerprintProfileHandler := admin.NewTLSFingerprintProfileHandler(tlsFingerprintProfileService) adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService) scheduledTestPlanRepository := repository.NewScheduledTestPlanRepository(db) scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db) scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) diff --git a/backend/ent/client.go b/backend/ent/client.go index 7ebbaa32..4129d6c5 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -29,6 +29,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" @@ -73,6 +74,8 @@ type Client struct { SecuritySecret *SecuritySecretClient // Setting is the client for interacting with the Setting builders. Setting *SettingClient + // TLSFingerprintProfile is the client for interacting with the TLSFingerprintProfile builders. + TLSFingerprintProfile *TLSFingerprintProfileClient // UsageCleanupTask is the client for interacting with the UsageCleanupTask builders. UsageCleanupTask *UsageCleanupTaskClient // UsageLog is the client for interacting with the UsageLog builders. @@ -112,6 +115,7 @@ func (c *Client) init() { c.RedeemCode = NewRedeemCodeClient(c.config) c.SecuritySecret = NewSecuritySecretClient(c.config) c.Setting = NewSettingClient(c.config) + c.TLSFingerprintProfile = NewTLSFingerprintProfileClient(c.config) c.UsageCleanupTask = NewUsageCleanupTaskClient(c.config) c.UsageLog = NewUsageLogClient(c.config) c.User = NewUserClient(c.config) @@ -225,6 +229,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { RedeemCode: NewRedeemCodeClient(cfg), SecuritySecret: NewSecuritySecretClient(cfg), Setting: NewSettingClient(cfg), + TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), UsageCleanupTask: NewUsageCleanupTaskClient(cfg), UsageLog: NewUsageLogClient(cfg), User: NewUserClient(cfg), @@ -265,6 +270,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) RedeemCode: NewRedeemCodeClient(cfg), SecuritySecret: NewSecuritySecretClient(cfg), Setting: NewSettingClient(cfg), + TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), UsageCleanupTask: NewUsageCleanupTaskClient(cfg), UsageLog: NewUsageLogClient(cfg), User: NewUserClient(cfg), @@ -304,8 +310,9 @@ func (c *Client) Use(hooks ...Hook) { c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, - c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, - c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, + c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, + c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.UserSubscription, } { n.Use(hooks...) } @@ -318,8 +325,9 @@ func (c *Client) Intercept(interceptors ...Interceptor) { c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, - c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, - c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, + c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, + c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.UserSubscription, } { n.Intercept(interceptors...) } @@ -356,6 +364,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.SecuritySecret.mutate(ctx, m) case *SettingMutation: return c.Setting.mutate(ctx, m) + case *TLSFingerprintProfileMutation: + return c.TLSFingerprintProfile.mutate(ctx, m) case *UsageCleanupTaskMutation: return c.UsageCleanupTask.mutate(ctx, m) case *UsageLogMutation: @@ -2612,6 +2622,139 @@ func (c *SettingClient) mutate(ctx context.Context, m *SettingMutation) (Value, } } +// TLSFingerprintProfileClient is a client for the TLSFingerprintProfile schema. +type TLSFingerprintProfileClient struct { + config +} + +// NewTLSFingerprintProfileClient returns a client for the TLSFingerprintProfile from the given config. +func NewTLSFingerprintProfileClient(c config) *TLSFingerprintProfileClient { + return &TLSFingerprintProfileClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `tlsfingerprintprofile.Hooks(f(g(h())))`. +func (c *TLSFingerprintProfileClient) Use(hooks ...Hook) { + c.hooks.TLSFingerprintProfile = append(c.hooks.TLSFingerprintProfile, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `tlsfingerprintprofile.Intercept(f(g(h())))`. +func (c *TLSFingerprintProfileClient) Intercept(interceptors ...Interceptor) { + c.inters.TLSFingerprintProfile = append(c.inters.TLSFingerprintProfile, interceptors...) +} + +// Create returns a builder for creating a TLSFingerprintProfile entity. +func (c *TLSFingerprintProfileClient) Create() *TLSFingerprintProfileCreate { + mutation := newTLSFingerprintProfileMutation(c.config, OpCreate) + return &TLSFingerprintProfileCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of TLSFingerprintProfile entities. +func (c *TLSFingerprintProfileClient) CreateBulk(builders ...*TLSFingerprintProfileCreate) *TLSFingerprintProfileCreateBulk { + return &TLSFingerprintProfileCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *TLSFingerprintProfileClient) MapCreateBulk(slice any, setFunc func(*TLSFingerprintProfileCreate, int)) *TLSFingerprintProfileCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &TLSFingerprintProfileCreateBulk{err: fmt.Errorf("calling to TLSFingerprintProfileClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*TLSFingerprintProfileCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &TLSFingerprintProfileCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for TLSFingerprintProfile. +func (c *TLSFingerprintProfileClient) Update() *TLSFingerprintProfileUpdate { + mutation := newTLSFingerprintProfileMutation(c.config, OpUpdate) + return &TLSFingerprintProfileUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *TLSFingerprintProfileClient) UpdateOne(_m *TLSFingerprintProfile) *TLSFingerprintProfileUpdateOne { + mutation := newTLSFingerprintProfileMutation(c.config, OpUpdateOne, withTLSFingerprintProfile(_m)) + return &TLSFingerprintProfileUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *TLSFingerprintProfileClient) UpdateOneID(id int64) *TLSFingerprintProfileUpdateOne { + mutation := newTLSFingerprintProfileMutation(c.config, OpUpdateOne, withTLSFingerprintProfileID(id)) + return &TLSFingerprintProfileUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for TLSFingerprintProfile. +func (c *TLSFingerprintProfileClient) Delete() *TLSFingerprintProfileDelete { + mutation := newTLSFingerprintProfileMutation(c.config, OpDelete) + return &TLSFingerprintProfileDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *TLSFingerprintProfileClient) DeleteOne(_m *TLSFingerprintProfile) *TLSFingerprintProfileDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *TLSFingerprintProfileClient) DeleteOneID(id int64) *TLSFingerprintProfileDeleteOne { + builder := c.Delete().Where(tlsfingerprintprofile.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &TLSFingerprintProfileDeleteOne{builder} +} + +// Query returns a query builder for TLSFingerprintProfile. +func (c *TLSFingerprintProfileClient) Query() *TLSFingerprintProfileQuery { + return &TLSFingerprintProfileQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeTLSFingerprintProfile}, + inters: c.Interceptors(), + } +} + +// Get returns a TLSFingerprintProfile entity by its id. +func (c *TLSFingerprintProfileClient) Get(ctx context.Context, id int64) (*TLSFingerprintProfile, error) { + return c.Query().Where(tlsfingerprintprofile.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *TLSFingerprintProfileClient) GetX(ctx context.Context, id int64) *TLSFingerprintProfile { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *TLSFingerprintProfileClient) Hooks() []Hook { + return c.hooks.TLSFingerprintProfile +} + +// Interceptors returns the client interceptors. +func (c *TLSFingerprintProfileClient) Interceptors() []Interceptor { + return c.inters.TLSFingerprintProfile +} + +func (c *TLSFingerprintProfileClient) mutate(ctx context.Context, m *TLSFingerprintProfileMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&TLSFingerprintProfileCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&TLSFingerprintProfileUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&TLSFingerprintProfileUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&TLSFingerprintProfileDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown TLSFingerprintProfile mutation op: %q", m.Op()) + } +} + // UsageCleanupTaskClient is a client for the UsageCleanupTask schema. type UsageCleanupTaskClient struct { config @@ -3889,16 +4032,16 @@ type ( hooks struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage, - Proxy, RedeemCode, SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, - UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, - UserSubscription []ent.Hook + Proxy, RedeemCode, SecuritySecret, Setting, TLSFingerprintProfile, + UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, + UserAttributeValue, UserSubscription []ent.Hook } inters struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage, - Proxy, RedeemCode, SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, - UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, - UserSubscription []ent.Interceptor + Proxy, RedeemCode, SecuritySecret, Setting, TLSFingerprintProfile, + UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, + UserAttributeValue, UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/ent.go b/backend/ent/ent.go index 5197e4d8..bdeaed8a 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -26,6 +26,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" @@ -107,6 +108,7 @@ func checkColumn(t, c string) error { redeemcode.Table: redeemcode.ValidColumn, securitysecret.Table: securitysecret.ValidColumn, setting.Table: setting.ValidColumn, + tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn, usagecleanuptask.Table: usagecleanuptask.ValidColumn, usagelog.Table: usagelog.ValidColumn, user.Table: user.ValidColumn, diff --git a/backend/ent/group.go b/backend/ent/group.go index 3db54a64..fc691a9b 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -80,6 +80,10 @@ type Group struct { SortOrder int `json:"sort_order,omitempty"` // 是否允许 /v1/messages 调度到此 OpenAI 分组 AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"` + // 仅允许非 apikey 类型账号关联到此分组 + RequireOauthOnly bool `json:"require_oauth_only,omitempty"` + // 调度时仅允许 privacy 已成功设置的账号 + RequirePrivacySet bool `json:"require_privacy_set,omitempty"` // 默认映射模型 ID,当账号级映射找不到时使用此值 DefaultMappedModel string `json:"default_mapped_model,omitempty"` // Edges holds the relations/edges for other nodes in the graph. @@ -190,7 +194,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { switch columns[i] { case group.FieldModelRouting, group.FieldSupportedModelScopes: values[i] = new([]byte) - case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch: + case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd: values[i] = new(sql.NullFloat64) @@ -425,6 +429,18 @@ func (_m *Group) assignValues(columns []string, values []any) error { } else if value.Valid { _m.AllowMessagesDispatch = value.Bool } + case group.FieldRequireOauthOnly: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field require_oauth_only", values[i]) + } else if value.Valid { + _m.RequireOauthOnly = value.Bool + } + case group.FieldRequirePrivacySet: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field require_privacy_set", values[i]) + } else if value.Valid { + _m.RequirePrivacySet = value.Bool + } case group.FieldDefaultMappedModel: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field default_mapped_model", values[i]) @@ -628,6 +644,12 @@ func (_m *Group) String() string { builder.WriteString("allow_messages_dispatch=") builder.WriteString(fmt.Sprintf("%v", _m.AllowMessagesDispatch)) builder.WriteString(", ") + builder.WriteString("require_oauth_only=") + builder.WriteString(fmt.Sprintf("%v", _m.RequireOauthOnly)) + builder.WriteString(", ") + builder.WriteString("require_privacy_set=") + builder.WriteString(fmt.Sprintf("%v", _m.RequirePrivacySet)) + builder.WriteString(", ") builder.WriteString("default_mapped_model=") builder.WriteString(_m.DefaultMappedModel) builder.WriteByte(')') diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 2612b6cf..35222127 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -77,6 +77,10 @@ const ( FieldSortOrder = "sort_order" // FieldAllowMessagesDispatch holds the string denoting the allow_messages_dispatch field in the database. FieldAllowMessagesDispatch = "allow_messages_dispatch" + // FieldRequireOauthOnly holds the string denoting the require_oauth_only field in the database. + FieldRequireOauthOnly = "require_oauth_only" + // FieldRequirePrivacySet holds the string denoting the require_privacy_set field in the database. + FieldRequirePrivacySet = "require_privacy_set" // FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database. FieldDefaultMappedModel = "default_mapped_model" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. @@ -185,6 +189,8 @@ var Columns = []string{ FieldSupportedModelScopes, FieldSortOrder, FieldAllowMessagesDispatch, + FieldRequireOauthOnly, + FieldRequirePrivacySet, FieldDefaultMappedModel, } @@ -255,6 +261,10 @@ var ( DefaultSortOrder int // DefaultAllowMessagesDispatch holds the default value on creation for the "allow_messages_dispatch" field. DefaultAllowMessagesDispatch bool + // DefaultRequireOauthOnly holds the default value on creation for the "require_oauth_only" field. + DefaultRequireOauthOnly bool + // DefaultRequirePrivacySet holds the default value on creation for the "require_privacy_set" field. + DefaultRequirePrivacySet bool // DefaultDefaultMappedModel holds the default value on creation for the "default_mapped_model" field. DefaultDefaultMappedModel string // DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. @@ -414,6 +424,16 @@ func ByAllowMessagesDispatch(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldAllowMessagesDispatch, opts...).ToFunc() } +// ByRequireOauthOnly orders the results by the require_oauth_only field. +func ByRequireOauthOnly(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRequireOauthOnly, opts...).ToFunc() +} + +// ByRequirePrivacySet orders the results by the require_privacy_set field. +func ByRequirePrivacySet(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRequirePrivacySet, opts...).ToFunc() +} + // ByDefaultMappedModel orders the results by the default_mapped_model field. func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc() diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 5dd8759e..41bd575a 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -200,6 +200,16 @@ func AllowMessagesDispatch(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v)) } +// RequireOauthOnly applies equality check predicate on the "require_oauth_only" field. It's identical to RequireOauthOnlyEQ. +func RequireOauthOnly(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRequireOauthOnly, v)) +} + +// RequirePrivacySet applies equality check predicate on the "require_privacy_set" field. It's identical to RequirePrivacySetEQ. +func RequirePrivacySet(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRequirePrivacySet, v)) +} + // DefaultMappedModel applies equality check predicate on the "default_mapped_model" field. It's identical to DefaultMappedModelEQ. func DefaultMappedModel(v string) predicate.Group { return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v)) @@ -1490,6 +1500,26 @@ func AllowMessagesDispatchNEQ(v bool) predicate.Group { return predicate.Group(sql.FieldNEQ(FieldAllowMessagesDispatch, v)) } +// RequireOauthOnlyEQ applies the EQ predicate on the "require_oauth_only" field. +func RequireOauthOnlyEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRequireOauthOnly, v)) +} + +// RequireOauthOnlyNEQ applies the NEQ predicate on the "require_oauth_only" field. +func RequireOauthOnlyNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldRequireOauthOnly, v)) +} + +// RequirePrivacySetEQ applies the EQ predicate on the "require_privacy_set" field. +func RequirePrivacySetEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRequirePrivacySet, v)) +} + +// RequirePrivacySetNEQ applies the NEQ predicate on the "require_privacy_set" field. +func RequirePrivacySetNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldRequirePrivacySet, v)) +} + // DefaultMappedModelEQ applies the EQ predicate on the "default_mapped_model" field. func DefaultMappedModelEQ(v string) predicate.Group { return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v)) diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 6db5b974..a635dfd9 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -438,6 +438,34 @@ func (_c *GroupCreate) SetNillableAllowMessagesDispatch(v *bool) *GroupCreate { return _c } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (_c *GroupCreate) SetRequireOauthOnly(v bool) *GroupCreate { + _c.mutation.SetRequireOauthOnly(v) + return _c +} + +// SetNillableRequireOauthOnly sets the "require_oauth_only" field if the given value is not nil. +func (_c *GroupCreate) SetNillableRequireOauthOnly(v *bool) *GroupCreate { + if v != nil { + _c.SetRequireOauthOnly(*v) + } + return _c +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (_c *GroupCreate) SetRequirePrivacySet(v bool) *GroupCreate { + _c.mutation.SetRequirePrivacySet(v) + return _c +} + +// SetNillableRequirePrivacySet sets the "require_privacy_set" field if the given value is not nil. +func (_c *GroupCreate) SetNillableRequirePrivacySet(v *bool) *GroupCreate { + if v != nil { + _c.SetRequirePrivacySet(*v) + } + return _c +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (_c *GroupCreate) SetDefaultMappedModel(v string) *GroupCreate { _c.mutation.SetDefaultMappedModel(v) @@ -645,6 +673,14 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultAllowMessagesDispatch _c.mutation.SetAllowMessagesDispatch(v) } + if _, ok := _c.mutation.RequireOauthOnly(); !ok { + v := group.DefaultRequireOauthOnly + _c.mutation.SetRequireOauthOnly(v) + } + if _, ok := _c.mutation.RequirePrivacySet(); !ok { + v := group.DefaultRequirePrivacySet + _c.mutation.SetRequirePrivacySet(v) + } if _, ok := _c.mutation.DefaultMappedModel(); !ok { v := group.DefaultDefaultMappedModel _c.mutation.SetDefaultMappedModel(v) @@ -722,6 +758,12 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.AllowMessagesDispatch(); !ok { return &ValidationError{Name: "allow_messages_dispatch", err: errors.New(`ent: missing required field "Group.allow_messages_dispatch"`)} } + if _, ok := _c.mutation.RequireOauthOnly(); !ok { + return &ValidationError{Name: "require_oauth_only", err: errors.New(`ent: missing required field "Group.require_oauth_only"`)} + } + if _, ok := _c.mutation.RequirePrivacySet(); !ok { + return &ValidationError{Name: "require_privacy_set", err: errors.New(`ent: missing required field "Group.require_privacy_set"`)} + } if _, ok := _c.mutation.DefaultMappedModel(); !ok { return &ValidationError{Name: "default_mapped_model", err: errors.New(`ent: missing required field "Group.default_mapped_model"`)} } @@ -881,6 +923,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value) _node.AllowMessagesDispatch = value } + if value, ok := _c.mutation.RequireOauthOnly(); ok { + _spec.SetField(group.FieldRequireOauthOnly, field.TypeBool, value) + _node.RequireOauthOnly = value + } + if value, ok := _c.mutation.RequirePrivacySet(); ok { + _spec.SetField(group.FieldRequirePrivacySet, field.TypeBool, value) + _node.RequirePrivacySet = value + } if value, ok := _c.mutation.DefaultMappedModel(); ok { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) _node.DefaultMappedModel = value @@ -1587,6 +1637,30 @@ func (u *GroupUpsert) UpdateAllowMessagesDispatch() *GroupUpsert { return u } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (u *GroupUpsert) SetRequireOauthOnly(v bool) *GroupUpsert { + u.Set(group.FieldRequireOauthOnly, v) + return u +} + +// UpdateRequireOauthOnly sets the "require_oauth_only" field to the value that was provided on create. +func (u *GroupUpsert) UpdateRequireOauthOnly() *GroupUpsert { + u.SetExcluded(group.FieldRequireOauthOnly) + return u +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (u *GroupUpsert) SetRequirePrivacySet(v bool) *GroupUpsert { + u.Set(group.FieldRequirePrivacySet, v) + return u +} + +// UpdateRequirePrivacySet sets the "require_privacy_set" field to the value that was provided on create. +func (u *GroupUpsert) UpdateRequirePrivacySet() *GroupUpsert { + u.SetExcluded(group.FieldRequirePrivacySet) + return u +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (u *GroupUpsert) SetDefaultMappedModel(v string) *GroupUpsert { u.Set(group.FieldDefaultMappedModel, v) @@ -2281,6 +2355,34 @@ func (u *GroupUpsertOne) UpdateAllowMessagesDispatch() *GroupUpsertOne { }) } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (u *GroupUpsertOne) SetRequireOauthOnly(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetRequireOauthOnly(v) + }) +} + +// UpdateRequireOauthOnly sets the "require_oauth_only" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateRequireOauthOnly() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateRequireOauthOnly() + }) +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (u *GroupUpsertOne) SetRequirePrivacySet(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetRequirePrivacySet(v) + }) +} + +// UpdateRequirePrivacySet sets the "require_privacy_set" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateRequirePrivacySet() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateRequirePrivacySet() + }) +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (u *GroupUpsertOne) SetDefaultMappedModel(v string) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -3143,6 +3245,34 @@ func (u *GroupUpsertBulk) UpdateAllowMessagesDispatch() *GroupUpsertBulk { }) } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (u *GroupUpsertBulk) SetRequireOauthOnly(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetRequireOauthOnly(v) + }) +} + +// UpdateRequireOauthOnly sets the "require_oauth_only" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateRequireOauthOnly() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateRequireOauthOnly() + }) +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (u *GroupUpsertBulk) SetRequirePrivacySet(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetRequirePrivacySet(v) + }) +} + +// UpdateRequirePrivacySet sets the "require_privacy_set" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateRequirePrivacySet() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateRequirePrivacySet() + }) +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (u *GroupUpsertBulk) SetDefaultMappedModel(v string) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index b3698596..a9a4b9da 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -639,6 +639,34 @@ func (_u *GroupUpdate) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdate { return _u } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (_u *GroupUpdate) SetRequireOauthOnly(v bool) *GroupUpdate { + _u.mutation.SetRequireOauthOnly(v) + return _u +} + +// SetNillableRequireOauthOnly sets the "require_oauth_only" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableRequireOauthOnly(v *bool) *GroupUpdate { + if v != nil { + _u.SetRequireOauthOnly(*v) + } + return _u +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (_u *GroupUpdate) SetRequirePrivacySet(v bool) *GroupUpdate { + _u.mutation.SetRequirePrivacySet(v) + return _u +} + +// SetNillableRequirePrivacySet sets the "require_privacy_set" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableRequirePrivacySet(v *bool) *GroupUpdate { + if v != nil { + _u.SetRequirePrivacySet(*v) + } + return _u +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (_u *GroupUpdate) SetDefaultMappedModel(v string) *GroupUpdate { _u.mutation.SetDefaultMappedModel(v) @@ -1146,6 +1174,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.AllowMessagesDispatch(); ok { _spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value) } + if value, ok := _u.mutation.RequireOauthOnly(); ok { + _spec.SetField(group.FieldRequireOauthOnly, field.TypeBool, value) + } + if value, ok := _u.mutation.RequirePrivacySet(); ok { + _spec.SetField(group.FieldRequirePrivacySet, field.TypeBool, value) + } if value, ok := _u.mutation.DefaultMappedModel(); ok { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) } @@ -2067,6 +2101,34 @@ func (_u *GroupUpdateOne) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdate return _u } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (_u *GroupUpdateOne) SetRequireOauthOnly(v bool) *GroupUpdateOne { + _u.mutation.SetRequireOauthOnly(v) + return _u +} + +// SetNillableRequireOauthOnly sets the "require_oauth_only" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableRequireOauthOnly(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetRequireOauthOnly(*v) + } + return _u +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (_u *GroupUpdateOne) SetRequirePrivacySet(v bool) *GroupUpdateOne { + _u.mutation.SetRequirePrivacySet(v) + return _u +} + +// SetNillableRequirePrivacySet sets the "require_privacy_set" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableRequirePrivacySet(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetRequirePrivacySet(*v) + } + return _u +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (_u *GroupUpdateOne) SetDefaultMappedModel(v string) *GroupUpdateOne { _u.mutation.SetDefaultMappedModel(v) @@ -2604,6 +2666,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.AllowMessagesDispatch(); ok { _spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value) } + if value, ok := _u.mutation.RequireOauthOnly(); ok { + _spec.SetField(group.FieldRequireOauthOnly, field.TypeBool, value) + } + if value, ok := _u.mutation.RequirePrivacySet(); ok { + _spec.SetField(group.FieldRequirePrivacySet, field.TypeBool, value) + } if value, ok := _u.mutation.DefaultMappedModel(); ok { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) } diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 49d7f3c5..f6f7b4e9 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -177,6 +177,18 @@ func (f SettingFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, err return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SettingMutation", m) } +// The TLSFingerprintProfileFunc type is an adapter to allow the use of ordinary +// function as TLSFingerprintProfile mutator. +type TLSFingerprintProfileFunc func(context.Context, *ent.TLSFingerprintProfileMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f TLSFingerprintProfileFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.TLSFingerprintProfileMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.TLSFingerprintProfileMutation", m) +} + // The UsageCleanupTaskFunc type is an adapter to allow the use of ordinary // function as UsageCleanupTask mutator. type UsageCleanupTaskFunc func(context.Context, *ent.UsageCleanupTaskMutation) (ent.Value, error) diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index e7746402..13169ca7 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -23,6 +23,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" @@ -466,6 +467,33 @@ func (f TraverseSetting) Traverse(ctx context.Context, q ent.Query) error { return fmt.Errorf("unexpected query type %T. expect *ent.SettingQuery", q) } +// The TLSFingerprintProfileFunc type is an adapter to allow the use of ordinary function as a Querier. +type TLSFingerprintProfileFunc func(context.Context, *ent.TLSFingerprintProfileQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f TLSFingerprintProfileFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.TLSFingerprintProfileQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.TLSFingerprintProfileQuery", q) +} + +// The TraverseTLSFingerprintProfile type is an adapter to allow the use of ordinary function as Traverser. +type TraverseTLSFingerprintProfile func(context.Context, *ent.TLSFingerprintProfileQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseTLSFingerprintProfile) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseTLSFingerprintProfile) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.TLSFingerprintProfileQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.TLSFingerprintProfileQuery", q) +} + // The UsageCleanupTaskFunc type is an adapter to allow the use of ordinary function as a Querier. type UsageCleanupTaskFunc func(context.Context, *ent.UsageCleanupTaskQuery) (ent.Value, error) @@ -686,6 +714,8 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.SecuritySecretQuery, predicate.SecuritySecret, securitysecret.OrderOption]{typ: ent.TypeSecuritySecret, tq: q}, nil case *ent.SettingQuery: return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil + case *ent.TLSFingerprintProfileQuery: + return &query[*ent.TLSFingerprintProfileQuery, predicate.TLSFingerprintProfile, tlsfingerprintprofile.OrderOption]{typ: ent.TypeTLSFingerprintProfile, tq: q}, nil case *ent.UsageCleanupTaskQuery: return &query[*ent.UsageCleanupTaskQuery, predicate.UsageCleanupTask, usagecleanuptask.OrderOption]{typ: ent.TypeUsageCleanupTask, tq: q}, nil case *ent.UsageLogQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index ff1c1b88..6c56f2d0 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -409,6 +409,8 @@ var ( {Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "sort_order", Type: field.TypeInt, Default: 0}, {Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false}, + {Name: "require_oauth_only", Type: field.TypeBool, Default: false}, + {Name: "require_privacy_set", Type: field.TypeBool, Default: false}, {Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""}, } // GroupsTable holds the schema information for the "groups" table. @@ -673,6 +675,30 @@ var ( Columns: SettingsColumns, PrimaryKey: []*schema.Column{SettingsColumns[0]}, } + // TLSFingerprintProfilesColumns holds the columns for the "tls_fingerprint_profiles" table. + TLSFingerprintProfilesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "name", Type: field.TypeString, Unique: true, Size: 100}, + {Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647}, + {Name: "enable_grease", Type: field.TypeBool, Default: false}, + {Name: "cipher_suites", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "curves", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "point_formats", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "signature_algorithms", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "alpn_protocols", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "supported_versions", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "key_share_groups", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "psk_modes", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "extensions", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + } + // TLSFingerprintProfilesTable holds the schema information for the "tls_fingerprint_profiles" table. + TLSFingerprintProfilesTable = &schema.Table{ + Name: "tls_fingerprint_profiles", + Columns: TLSFingerprintProfilesColumns, + PrimaryKey: []*schema.Column{TLSFingerprintProfilesColumns[0]}, + } // UsageCleanupTasksColumns holds the columns for the "usage_cleanup_tasks" table. UsageCleanupTasksColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -716,6 +742,8 @@ var ( {Name: "id", Type: field.TypeInt64, Increment: true}, {Name: "request_id", Type: field.TypeString, Size: 64}, {Name: "model", Type: field.TypeString, Size: 100}, + {Name: "requested_model", Type: field.TypeString, Nullable: true, Size: 100}, + {Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100}, {Name: "input_tokens", Type: field.TypeInt, Default: 0}, {Name: "output_tokens", Type: field.TypeInt, Default: 0}, {Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0}, @@ -755,31 +783,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[33]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[34]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -788,38 +816,43 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[33]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[34]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, }, { Name: "usagelog_model", Unique: false, Columns: []*schema.Column{UsageLogsColumns[2]}, }, + { + Name: "usagelog_requested_model", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[3]}, + }, { Name: "usagelog_request_id", Unique: false, @@ -828,17 +861,17 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[29]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[29]}, }, { Name: "usagelog_group_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[29]}, }, }, } @@ -1104,6 +1137,7 @@ var ( RedeemCodesTable, SecuritySecretsTable, SettingsTable, + TLSFingerprintProfilesTable, UsageCleanupTasksTable, UsageLogsTable, UsersTable, @@ -1168,6 +1202,9 @@ func init() { SettingsTable.Annotation = &entsql.Annotation{ Table: "settings", } + TLSFingerprintProfilesTable.Annotation = &entsql.Annotation{ + Table: "tls_fingerprint_profiles", + } UsageCleanupTasksTable.Annotation = &entsql.Annotation{ Table: "usage_cleanup_tasks", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 652adcac..a862209d 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -27,6 +27,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" @@ -60,6 +61,7 @@ const ( TypeRedeemCode = "RedeemCode" TypeSecuritySecret = "SecuritySecret" TypeSetting = "Setting" + TypeTLSFingerprintProfile = "TLSFingerprintProfile" TypeUsageCleanupTask = "UsageCleanupTask" TypeUsageLog = "UsageLog" TypeUser = "User" @@ -8251,6 +8253,8 @@ type GroupMutation struct { sort_order *int addsort_order *int allow_messages_dispatch *bool + require_oauth_only *bool + require_privacy_set *bool default_mapped_model *string clearedFields map[string]struct{} api_keys map[int64]struct{} @@ -10032,6 +10036,78 @@ func (m *GroupMutation) ResetAllowMessagesDispatch() { m.allow_messages_dispatch = nil } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (m *GroupMutation) SetRequireOauthOnly(b bool) { + m.require_oauth_only = &b +} + +// RequireOauthOnly returns the value of the "require_oauth_only" field in the mutation. +func (m *GroupMutation) RequireOauthOnly() (r bool, exists bool) { + v := m.require_oauth_only + if v == nil { + return + } + return *v, true +} + +// OldRequireOauthOnly returns the old "require_oauth_only" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldRequireOauthOnly(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequireOauthOnly is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequireOauthOnly requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequireOauthOnly: %w", err) + } + return oldValue.RequireOauthOnly, nil +} + +// ResetRequireOauthOnly resets all changes to the "require_oauth_only" field. +func (m *GroupMutation) ResetRequireOauthOnly() { + m.require_oauth_only = nil +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (m *GroupMutation) SetRequirePrivacySet(b bool) { + m.require_privacy_set = &b +} + +// RequirePrivacySet returns the value of the "require_privacy_set" field in the mutation. +func (m *GroupMutation) RequirePrivacySet() (r bool, exists bool) { + v := m.require_privacy_set + if v == nil { + return + } + return *v, true +} + +// OldRequirePrivacySet returns the old "require_privacy_set" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldRequirePrivacySet(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequirePrivacySet is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequirePrivacySet requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequirePrivacySet: %w", err) + } + return oldValue.RequirePrivacySet, nil +} + +// ResetRequirePrivacySet resets all changes to the "require_privacy_set" field. +func (m *GroupMutation) ResetRequirePrivacySet() { + m.require_privacy_set = nil +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (m *GroupMutation) SetDefaultMappedModel(s string) { m.default_mapped_model = &s @@ -10426,7 +10502,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 32) + fields := make([]string, 0, 34) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -10520,6 +10596,12 @@ func (m *GroupMutation) Fields() []string { if m.allow_messages_dispatch != nil { fields = append(fields, group.FieldAllowMessagesDispatch) } + if m.require_oauth_only != nil { + fields = append(fields, group.FieldRequireOauthOnly) + } + if m.require_privacy_set != nil { + fields = append(fields, group.FieldRequirePrivacySet) + } if m.default_mapped_model != nil { fields = append(fields, group.FieldDefaultMappedModel) } @@ -10593,6 +10675,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.SortOrder() case group.FieldAllowMessagesDispatch: return m.AllowMessagesDispatch() + case group.FieldRequireOauthOnly: + return m.RequireOauthOnly() + case group.FieldRequirePrivacySet: + return m.RequirePrivacySet() case group.FieldDefaultMappedModel: return m.DefaultMappedModel() } @@ -10666,6 +10752,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldSortOrder(ctx) case group.FieldAllowMessagesDispatch: return m.OldAllowMessagesDispatch(ctx) + case group.FieldRequireOauthOnly: + return m.OldRequireOauthOnly(ctx) + case group.FieldRequirePrivacySet: + return m.OldRequirePrivacySet(ctx) case group.FieldDefaultMappedModel: return m.OldDefaultMappedModel(ctx) } @@ -10894,6 +10984,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetAllowMessagesDispatch(v) return nil + case group.FieldRequireOauthOnly: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequireOauthOnly(v) + return nil + case group.FieldRequirePrivacySet: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequirePrivacySet(v) + return nil case group.FieldDefaultMappedModel: v, ok := value.(string) if !ok { @@ -11331,6 +11435,12 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldAllowMessagesDispatch: m.ResetAllowMessagesDispatch() return nil + case group.FieldRequireOauthOnly: + m.ResetRequireOauthOnly() + return nil + case group.FieldRequirePrivacySet: + m.ResetRequirePrivacySet() + return nil case group.FieldDefaultMappedModel: m.ResetDefaultMappedModel() return nil @@ -17148,6 +17258,1380 @@ func (m *SettingMutation) ResetEdge(name string) error { return fmt.Errorf("unknown Setting edge %s", name) } +// TLSFingerprintProfileMutation represents an operation that mutates the TLSFingerprintProfile nodes in the graph. +type TLSFingerprintProfileMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + name *string + description *string + enable_grease *bool + cipher_suites *[]uint16 + appendcipher_suites []uint16 + curves *[]uint16 + appendcurves []uint16 + point_formats *[]uint16 + appendpoint_formats []uint16 + signature_algorithms *[]uint16 + appendsignature_algorithms []uint16 + alpn_protocols *[]string + appendalpn_protocols []string + supported_versions *[]uint16 + appendsupported_versions []uint16 + key_share_groups *[]uint16 + appendkey_share_groups []uint16 + psk_modes *[]uint16 + appendpsk_modes []uint16 + extensions *[]uint16 + appendextensions []uint16 + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*TLSFingerprintProfile, error) + predicates []predicate.TLSFingerprintProfile +} + +var _ ent.Mutation = (*TLSFingerprintProfileMutation)(nil) + +// tlsfingerprintprofileOption allows management of the mutation configuration using functional options. +type tlsfingerprintprofileOption func(*TLSFingerprintProfileMutation) + +// newTLSFingerprintProfileMutation creates new mutation for the TLSFingerprintProfile entity. +func newTLSFingerprintProfileMutation(c config, op Op, opts ...tlsfingerprintprofileOption) *TLSFingerprintProfileMutation { + m := &TLSFingerprintProfileMutation{ + config: c, + op: op, + typ: TypeTLSFingerprintProfile, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withTLSFingerprintProfileID sets the ID field of the mutation. +func withTLSFingerprintProfileID(id int64) tlsfingerprintprofileOption { + return func(m *TLSFingerprintProfileMutation) { + var ( + err error + once sync.Once + value *TLSFingerprintProfile + ) + m.oldValue = func(ctx context.Context) (*TLSFingerprintProfile, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().TLSFingerprintProfile.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withTLSFingerprintProfile sets the old TLSFingerprintProfile of the mutation. +func withTLSFingerprintProfile(node *TLSFingerprintProfile) tlsfingerprintprofileOption { + return func(m *TLSFingerprintProfileMutation) { + m.oldValue = func(context.Context) (*TLSFingerprintProfile, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m TLSFingerprintProfileMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m TLSFingerprintProfileMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *TLSFingerprintProfileMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *TLSFingerprintProfileMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().TLSFingerprintProfile.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *TLSFingerprintProfileMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *TLSFingerprintProfileMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the TLSFingerprintProfile entity. +// If the TLSFingerprintProfile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TLSFingerprintProfileMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *TLSFingerprintProfileMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *TLSFingerprintProfileMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *TLSFingerprintProfileMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the TLSFingerprintProfile entity. +// If the TLSFingerprintProfile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TLSFingerprintProfileMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *TLSFingerprintProfileMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetName sets the "name" field. +func (m *TLSFingerprintProfileMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *TLSFingerprintProfileMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the TLSFingerprintProfile entity. +// If the TLSFingerprintProfile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TLSFingerprintProfileMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *TLSFingerprintProfileMutation) ResetName() { + m.name = nil +} + +// SetDescription sets the "description" field. +func (m *TLSFingerprintProfileMutation) SetDescription(s string) { + m.description = &s +} + +// Description returns the value of the "description" field in the mutation. +func (m *TLSFingerprintProfileMutation) Description() (r string, exists bool) { + v := m.description + if v == nil { + return + } + return *v, true +} + +// OldDescription returns the old "description" field's value of the TLSFingerprintProfile entity. +// If the TLSFingerprintProfile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TLSFingerprintProfileMutation) OldDescription(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDescription is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDescription requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDescription: %w", err) + } + return oldValue.Description, nil +} + +// ClearDescription clears the value of the "description" field. +func (m *TLSFingerprintProfileMutation) ClearDescription() { + m.description = nil + m.clearedFields[tlsfingerprintprofile.FieldDescription] = struct{}{} +} + +// DescriptionCleared returns if the "description" field was cleared in this mutation. +func (m *TLSFingerprintProfileMutation) DescriptionCleared() bool { + _, ok := m.clearedFields[tlsfingerprintprofile.FieldDescription] + return ok +} + +// ResetDescription resets all changes to the "description" field. +func (m *TLSFingerprintProfileMutation) ResetDescription() { + m.description = nil + delete(m.clearedFields, tlsfingerprintprofile.FieldDescription) +} + +// SetEnableGrease sets the "enable_grease" field. +func (m *TLSFingerprintProfileMutation) SetEnableGrease(b bool) { + m.enable_grease = &b +} + +// EnableGrease returns the value of the "enable_grease" field in the mutation. +func (m *TLSFingerprintProfileMutation) EnableGrease() (r bool, exists bool) { + v := m.enable_grease + if v == nil { + return + } + return *v, true +} + +// OldEnableGrease returns the old "enable_grease" field's value of the TLSFingerprintProfile entity. +// If the TLSFingerprintProfile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TLSFingerprintProfileMutation) OldEnableGrease(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldEnableGrease is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldEnableGrease requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldEnableGrease: %w", err) + } + return oldValue.EnableGrease, nil +} + +// ResetEnableGrease resets all changes to the "enable_grease" field. +func (m *TLSFingerprintProfileMutation) ResetEnableGrease() { + m.enable_grease = nil +} + +// SetCipherSuites sets the "cipher_suites" field. +func (m *TLSFingerprintProfileMutation) SetCipherSuites(u []uint16) { + m.cipher_suites = &u + m.appendcipher_suites = nil +} + +// CipherSuites returns the value of the "cipher_suites" field in the mutation. +func (m *TLSFingerprintProfileMutation) CipherSuites() (r []uint16, exists bool) { + v := m.cipher_suites + if v == nil { + return + } + return *v, true +} + +// OldCipherSuites returns the old "cipher_suites" field's value of the TLSFingerprintProfile entity. +// If the TLSFingerprintProfile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TLSFingerprintProfileMutation) OldCipherSuites(ctx context.Context) (v []uint16, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCipherSuites is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCipherSuites requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCipherSuites: %w", err) + } + return oldValue.CipherSuites, nil +} + +// AppendCipherSuites adds u to the "cipher_suites" field. +func (m *TLSFingerprintProfileMutation) AppendCipherSuites(u []uint16) { + m.appendcipher_suites = append(m.appendcipher_suites, u...) +} + +// AppendedCipherSuites returns the list of values that were appended to the "cipher_suites" field in this mutation. +func (m *TLSFingerprintProfileMutation) AppendedCipherSuites() ([]uint16, bool) { + if len(m.appendcipher_suites) == 0 { + return nil, false + } + return m.appendcipher_suites, true +} + +// ClearCipherSuites clears the value of the "cipher_suites" field. +func (m *TLSFingerprintProfileMutation) ClearCipherSuites() { + m.cipher_suites = nil + m.appendcipher_suites = nil + m.clearedFields[tlsfingerprintprofile.FieldCipherSuites] = struct{}{} +} + +// CipherSuitesCleared returns if the "cipher_suites" field was cleared in this mutation. +func (m *TLSFingerprintProfileMutation) CipherSuitesCleared() bool { + _, ok := m.clearedFields[tlsfingerprintprofile.FieldCipherSuites] + return ok +} + +// ResetCipherSuites resets all changes to the "cipher_suites" field. +func (m *TLSFingerprintProfileMutation) ResetCipherSuites() { + m.cipher_suites = nil + m.appendcipher_suites = nil + delete(m.clearedFields, tlsfingerprintprofile.FieldCipherSuites) +} + +// SetCurves sets the "curves" field. +func (m *TLSFingerprintProfileMutation) SetCurves(u []uint16) { + m.curves = &u + m.appendcurves = nil +} + +// Curves returns the value of the "curves" field in the mutation. +func (m *TLSFingerprintProfileMutation) Curves() (r []uint16, exists bool) { + v := m.curves + if v == nil { + return + } + return *v, true +} + +// OldCurves returns the old "curves" field's value of the TLSFingerprintProfile entity. +// If the TLSFingerprintProfile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TLSFingerprintProfileMutation) OldCurves(ctx context.Context) (v []uint16, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCurves is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCurves requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCurves: %w", err) + } + return oldValue.Curves, nil +} + +// AppendCurves adds u to the "curves" field. +func (m *TLSFingerprintProfileMutation) AppendCurves(u []uint16) { + m.appendcurves = append(m.appendcurves, u...) +} + +// AppendedCurves returns the list of values that were appended to the "curves" field in this mutation. +func (m *TLSFingerprintProfileMutation) AppendedCurves() ([]uint16, bool) { + if len(m.appendcurves) == 0 { + return nil, false + } + return m.appendcurves, true +} + +// ClearCurves clears the value of the "curves" field. +func (m *TLSFingerprintProfileMutation) ClearCurves() { + m.curves = nil + m.appendcurves = nil + m.clearedFields[tlsfingerprintprofile.FieldCurves] = struct{}{} +} + +// CurvesCleared returns if the "curves" field was cleared in this mutation. +func (m *TLSFingerprintProfileMutation) CurvesCleared() bool { + _, ok := m.clearedFields[tlsfingerprintprofile.FieldCurves] + return ok +} + +// ResetCurves resets all changes to the "curves" field. +func (m *TLSFingerprintProfileMutation) ResetCurves() { + m.curves = nil + m.appendcurves = nil + delete(m.clearedFields, tlsfingerprintprofile.FieldCurves) +} + +// SetPointFormats sets the "point_formats" field. +func (m *TLSFingerprintProfileMutation) SetPointFormats(u []uint16) { + m.point_formats = &u + m.appendpoint_formats = nil +} + +// PointFormats returns the value of the "point_formats" field in the mutation. +func (m *TLSFingerprintProfileMutation) PointFormats() (r []uint16, exists bool) { + v := m.point_formats + if v == nil { + return + } + return *v, true +} + +// OldPointFormats returns the old "point_formats" field's value of the TLSFingerprintProfile entity. +// If the TLSFingerprintProfile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TLSFingerprintProfileMutation) OldPointFormats(ctx context.Context) (v []uint16, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPointFormats is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPointFormats requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPointFormats: %w", err) + } + return oldValue.PointFormats, nil +} + +// AppendPointFormats adds u to the "point_formats" field. +func (m *TLSFingerprintProfileMutation) AppendPointFormats(u []uint16) { + m.appendpoint_formats = append(m.appendpoint_formats, u...) +} + +// AppendedPointFormats returns the list of values that were appended to the "point_formats" field in this mutation. +func (m *TLSFingerprintProfileMutation) AppendedPointFormats() ([]uint16, bool) { + if len(m.appendpoint_formats) == 0 { + return nil, false + } + return m.appendpoint_formats, true +} + +// ClearPointFormats clears the value of the "point_formats" field. +func (m *TLSFingerprintProfileMutation) ClearPointFormats() { + m.point_formats = nil + m.appendpoint_formats = nil + m.clearedFields[tlsfingerprintprofile.FieldPointFormats] = struct{}{} +} + +// PointFormatsCleared returns if the "point_formats" field was cleared in this mutation. +func (m *TLSFingerprintProfileMutation) PointFormatsCleared() bool { + _, ok := m.clearedFields[tlsfingerprintprofile.FieldPointFormats] + return ok +} + +// ResetPointFormats resets all changes to the "point_formats" field. +func (m *TLSFingerprintProfileMutation) ResetPointFormats() { + m.point_formats = nil + m.appendpoint_formats = nil + delete(m.clearedFields, tlsfingerprintprofile.FieldPointFormats) +} + +// SetSignatureAlgorithms sets the "signature_algorithms" field. +func (m *TLSFingerprintProfileMutation) SetSignatureAlgorithms(u []uint16) { + m.signature_algorithms = &u + m.appendsignature_algorithms = nil +} + +// SignatureAlgorithms returns the value of the "signature_algorithms" field in the mutation. +func (m *TLSFingerprintProfileMutation) SignatureAlgorithms() (r []uint16, exists bool) { + v := m.signature_algorithms + if v == nil { + return + } + return *v, true +} + +// OldSignatureAlgorithms returns the old "signature_algorithms" field's value of the TLSFingerprintProfile entity. +// If the TLSFingerprintProfile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TLSFingerprintProfileMutation) OldSignatureAlgorithms(ctx context.Context) (v []uint16, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSignatureAlgorithms is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSignatureAlgorithms requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSignatureAlgorithms: %w", err) + } + return oldValue.SignatureAlgorithms, nil +} + +// AppendSignatureAlgorithms adds u to the "signature_algorithms" field. +func (m *TLSFingerprintProfileMutation) AppendSignatureAlgorithms(u []uint16) { + m.appendsignature_algorithms = append(m.appendsignature_algorithms, u...) +} + +// AppendedSignatureAlgorithms returns the list of values that were appended to the "signature_algorithms" field in this mutation. +func (m *TLSFingerprintProfileMutation) AppendedSignatureAlgorithms() ([]uint16, bool) { + if len(m.appendsignature_algorithms) == 0 { + return nil, false + } + return m.appendsignature_algorithms, true +} + +// ClearSignatureAlgorithms clears the value of the "signature_algorithms" field. +func (m *TLSFingerprintProfileMutation) ClearSignatureAlgorithms() { + m.signature_algorithms = nil + m.appendsignature_algorithms = nil + m.clearedFields[tlsfingerprintprofile.FieldSignatureAlgorithms] = struct{}{} +} + +// SignatureAlgorithmsCleared returns if the "signature_algorithms" field was cleared in this mutation. +func (m *TLSFingerprintProfileMutation) SignatureAlgorithmsCleared() bool { + _, ok := m.clearedFields[tlsfingerprintprofile.FieldSignatureAlgorithms] + return ok +} + +// ResetSignatureAlgorithms resets all changes to the "signature_algorithms" field. +func (m *TLSFingerprintProfileMutation) ResetSignatureAlgorithms() { + m.signature_algorithms = nil + m.appendsignature_algorithms = nil + delete(m.clearedFields, tlsfingerprintprofile.FieldSignatureAlgorithms) +} + +// SetAlpnProtocols sets the "alpn_protocols" field. +func (m *TLSFingerprintProfileMutation) SetAlpnProtocols(s []string) { + m.alpn_protocols = &s + m.appendalpn_protocols = nil +} + +// AlpnProtocols returns the value of the "alpn_protocols" field in the mutation. +func (m *TLSFingerprintProfileMutation) AlpnProtocols() (r []string, exists bool) { + v := m.alpn_protocols + if v == nil { + return + } + return *v, true +} + +// OldAlpnProtocols returns the old "alpn_protocols" field's value of the TLSFingerprintProfile entity. +// If the TLSFingerprintProfile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TLSFingerprintProfileMutation) OldAlpnProtocols(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAlpnProtocols is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAlpnProtocols requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAlpnProtocols: %w", err) + } + return oldValue.AlpnProtocols, nil +} + +// AppendAlpnProtocols adds s to the "alpn_protocols" field. +func (m *TLSFingerprintProfileMutation) AppendAlpnProtocols(s []string) { + m.appendalpn_protocols = append(m.appendalpn_protocols, s...) +} + +// AppendedAlpnProtocols returns the list of values that were appended to the "alpn_protocols" field in this mutation. +func (m *TLSFingerprintProfileMutation) AppendedAlpnProtocols() ([]string, bool) { + if len(m.appendalpn_protocols) == 0 { + return nil, false + } + return m.appendalpn_protocols, true +} + +// ClearAlpnProtocols clears the value of the "alpn_protocols" field. +func (m *TLSFingerprintProfileMutation) ClearAlpnProtocols() { + m.alpn_protocols = nil + m.appendalpn_protocols = nil + m.clearedFields[tlsfingerprintprofile.FieldAlpnProtocols] = struct{}{} +} + +// AlpnProtocolsCleared returns if the "alpn_protocols" field was cleared in this mutation. +func (m *TLSFingerprintProfileMutation) AlpnProtocolsCleared() bool { + _, ok := m.clearedFields[tlsfingerprintprofile.FieldAlpnProtocols] + return ok +} + +// ResetAlpnProtocols resets all changes to the "alpn_protocols" field. +func (m *TLSFingerprintProfileMutation) ResetAlpnProtocols() { + m.alpn_protocols = nil + m.appendalpn_protocols = nil + delete(m.clearedFields, tlsfingerprintprofile.FieldAlpnProtocols) +} + +// SetSupportedVersions sets the "supported_versions" field. +func (m *TLSFingerprintProfileMutation) SetSupportedVersions(u []uint16) { + m.supported_versions = &u + m.appendsupported_versions = nil +} + +// SupportedVersions returns the value of the "supported_versions" field in the mutation. +func (m *TLSFingerprintProfileMutation) SupportedVersions() (r []uint16, exists bool) { + v := m.supported_versions + if v == nil { + return + } + return *v, true +} + +// OldSupportedVersions returns the old "supported_versions" field's value of the TLSFingerprintProfile entity. +// If the TLSFingerprintProfile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TLSFingerprintProfileMutation) OldSupportedVersions(ctx context.Context) (v []uint16, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSupportedVersions is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSupportedVersions requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSupportedVersions: %w", err) + } + return oldValue.SupportedVersions, nil +} + +// AppendSupportedVersions adds u to the "supported_versions" field. +func (m *TLSFingerprintProfileMutation) AppendSupportedVersions(u []uint16) { + m.appendsupported_versions = append(m.appendsupported_versions, u...) +} + +// AppendedSupportedVersions returns the list of values that were appended to the "supported_versions" field in this mutation. +func (m *TLSFingerprintProfileMutation) AppendedSupportedVersions() ([]uint16, bool) { + if len(m.appendsupported_versions) == 0 { + return nil, false + } + return m.appendsupported_versions, true +} + +// ClearSupportedVersions clears the value of the "supported_versions" field. +func (m *TLSFingerprintProfileMutation) ClearSupportedVersions() { + m.supported_versions = nil + m.appendsupported_versions = nil + m.clearedFields[tlsfingerprintprofile.FieldSupportedVersions] = struct{}{} +} + +// SupportedVersionsCleared returns if the "supported_versions" field was cleared in this mutation. +func (m *TLSFingerprintProfileMutation) SupportedVersionsCleared() bool { + _, ok := m.clearedFields[tlsfingerprintprofile.FieldSupportedVersions] + return ok +} + +// ResetSupportedVersions resets all changes to the "supported_versions" field. +func (m *TLSFingerprintProfileMutation) ResetSupportedVersions() { + m.supported_versions = nil + m.appendsupported_versions = nil + delete(m.clearedFields, tlsfingerprintprofile.FieldSupportedVersions) +} + +// SetKeyShareGroups sets the "key_share_groups" field. +func (m *TLSFingerprintProfileMutation) SetKeyShareGroups(u []uint16) { + m.key_share_groups = &u + m.appendkey_share_groups = nil +} + +// KeyShareGroups returns the value of the "key_share_groups" field in the mutation. +func (m *TLSFingerprintProfileMutation) KeyShareGroups() (r []uint16, exists bool) { + v := m.key_share_groups + if v == nil { + return + } + return *v, true +} + +// OldKeyShareGroups returns the old "key_share_groups" field's value of the TLSFingerprintProfile entity. +// If the TLSFingerprintProfile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TLSFingerprintProfileMutation) OldKeyShareGroups(ctx context.Context) (v []uint16, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldKeyShareGroups is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldKeyShareGroups requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldKeyShareGroups: %w", err) + } + return oldValue.KeyShareGroups, nil +} + +// AppendKeyShareGroups adds u to the "key_share_groups" field. +func (m *TLSFingerprintProfileMutation) AppendKeyShareGroups(u []uint16) { + m.appendkey_share_groups = append(m.appendkey_share_groups, u...) +} + +// AppendedKeyShareGroups returns the list of values that were appended to the "key_share_groups" field in this mutation. +func (m *TLSFingerprintProfileMutation) AppendedKeyShareGroups() ([]uint16, bool) { + if len(m.appendkey_share_groups) == 0 { + return nil, false + } + return m.appendkey_share_groups, true +} + +// ClearKeyShareGroups clears the value of the "key_share_groups" field. +func (m *TLSFingerprintProfileMutation) ClearKeyShareGroups() { + m.key_share_groups = nil + m.appendkey_share_groups = nil + m.clearedFields[tlsfingerprintprofile.FieldKeyShareGroups] = struct{}{} +} + +// KeyShareGroupsCleared returns if the "key_share_groups" field was cleared in this mutation. +func (m *TLSFingerprintProfileMutation) KeyShareGroupsCleared() bool { + _, ok := m.clearedFields[tlsfingerprintprofile.FieldKeyShareGroups] + return ok +} + +// ResetKeyShareGroups resets all changes to the "key_share_groups" field. +func (m *TLSFingerprintProfileMutation) ResetKeyShareGroups() { + m.key_share_groups = nil + m.appendkey_share_groups = nil + delete(m.clearedFields, tlsfingerprintprofile.FieldKeyShareGroups) +} + +// SetPskModes sets the "psk_modes" field. +func (m *TLSFingerprintProfileMutation) SetPskModes(u []uint16) { + m.psk_modes = &u + m.appendpsk_modes = nil +} + +// PskModes returns the value of the "psk_modes" field in the mutation. +func (m *TLSFingerprintProfileMutation) PskModes() (r []uint16, exists bool) { + v := m.psk_modes + if v == nil { + return + } + return *v, true +} + +// OldPskModes returns the old "psk_modes" field's value of the TLSFingerprintProfile entity. +// If the TLSFingerprintProfile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TLSFingerprintProfileMutation) OldPskModes(ctx context.Context) (v []uint16, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPskModes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPskModes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPskModes: %w", err) + } + return oldValue.PskModes, nil +} + +// AppendPskModes adds u to the "psk_modes" field. +func (m *TLSFingerprintProfileMutation) AppendPskModes(u []uint16) { + m.appendpsk_modes = append(m.appendpsk_modes, u...) +} + +// AppendedPskModes returns the list of values that were appended to the "psk_modes" field in this mutation. +func (m *TLSFingerprintProfileMutation) AppendedPskModes() ([]uint16, bool) { + if len(m.appendpsk_modes) == 0 { + return nil, false + } + return m.appendpsk_modes, true +} + +// ClearPskModes clears the value of the "psk_modes" field. +func (m *TLSFingerprintProfileMutation) ClearPskModes() { + m.psk_modes = nil + m.appendpsk_modes = nil + m.clearedFields[tlsfingerprintprofile.FieldPskModes] = struct{}{} +} + +// PskModesCleared returns if the "psk_modes" field was cleared in this mutation. +func (m *TLSFingerprintProfileMutation) PskModesCleared() bool { + _, ok := m.clearedFields[tlsfingerprintprofile.FieldPskModes] + return ok +} + +// ResetPskModes resets all changes to the "psk_modes" field. +func (m *TLSFingerprintProfileMutation) ResetPskModes() { + m.psk_modes = nil + m.appendpsk_modes = nil + delete(m.clearedFields, tlsfingerprintprofile.FieldPskModes) +} + +// SetExtensions sets the "extensions" field. +func (m *TLSFingerprintProfileMutation) SetExtensions(u []uint16) { + m.extensions = &u + m.appendextensions = nil +} + +// Extensions returns the value of the "extensions" field in the mutation. +func (m *TLSFingerprintProfileMutation) Extensions() (r []uint16, exists bool) { + v := m.extensions + if v == nil { + return + } + return *v, true +} + +// OldExtensions returns the old "extensions" field's value of the TLSFingerprintProfile entity. +// If the TLSFingerprintProfile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TLSFingerprintProfileMutation) OldExtensions(ctx context.Context) (v []uint16, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExtensions is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExtensions requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExtensions: %w", err) + } + return oldValue.Extensions, nil +} + +// AppendExtensions adds u to the "extensions" field. +func (m *TLSFingerprintProfileMutation) AppendExtensions(u []uint16) { + m.appendextensions = append(m.appendextensions, u...) +} + +// AppendedExtensions returns the list of values that were appended to the "extensions" field in this mutation. +func (m *TLSFingerprintProfileMutation) AppendedExtensions() ([]uint16, bool) { + if len(m.appendextensions) == 0 { + return nil, false + } + return m.appendextensions, true +} + +// ClearExtensions clears the value of the "extensions" field. +func (m *TLSFingerprintProfileMutation) ClearExtensions() { + m.extensions = nil + m.appendextensions = nil + m.clearedFields[tlsfingerprintprofile.FieldExtensions] = struct{}{} +} + +// ExtensionsCleared returns if the "extensions" field was cleared in this mutation. +func (m *TLSFingerprintProfileMutation) ExtensionsCleared() bool { + _, ok := m.clearedFields[tlsfingerprintprofile.FieldExtensions] + return ok +} + +// ResetExtensions resets all changes to the "extensions" field. +func (m *TLSFingerprintProfileMutation) ResetExtensions() { + m.extensions = nil + m.appendextensions = nil + delete(m.clearedFields, tlsfingerprintprofile.FieldExtensions) +} + +// Where appends a list predicates to the TLSFingerprintProfileMutation builder. +func (m *TLSFingerprintProfileMutation) Where(ps ...predicate.TLSFingerprintProfile) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the TLSFingerprintProfileMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *TLSFingerprintProfileMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.TLSFingerprintProfile, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *TLSFingerprintProfileMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *TLSFingerprintProfileMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (TLSFingerprintProfile). +func (m *TLSFingerprintProfileMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *TLSFingerprintProfileMutation) Fields() []string { + fields := make([]string, 0, 14) + if m.created_at != nil { + fields = append(fields, tlsfingerprintprofile.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, tlsfingerprintprofile.FieldUpdatedAt) + } + if m.name != nil { + fields = append(fields, tlsfingerprintprofile.FieldName) + } + if m.description != nil { + fields = append(fields, tlsfingerprintprofile.FieldDescription) + } + if m.enable_grease != nil { + fields = append(fields, tlsfingerprintprofile.FieldEnableGrease) + } + if m.cipher_suites != nil { + fields = append(fields, tlsfingerprintprofile.FieldCipherSuites) + } + if m.curves != nil { + fields = append(fields, tlsfingerprintprofile.FieldCurves) + } + if m.point_formats != nil { + fields = append(fields, tlsfingerprintprofile.FieldPointFormats) + } + if m.signature_algorithms != nil { + fields = append(fields, tlsfingerprintprofile.FieldSignatureAlgorithms) + } + if m.alpn_protocols != nil { + fields = append(fields, tlsfingerprintprofile.FieldAlpnProtocols) + } + if m.supported_versions != nil { + fields = append(fields, tlsfingerprintprofile.FieldSupportedVersions) + } + if m.key_share_groups != nil { + fields = append(fields, tlsfingerprintprofile.FieldKeyShareGroups) + } + if m.psk_modes != nil { + fields = append(fields, tlsfingerprintprofile.FieldPskModes) + } + if m.extensions != nil { + fields = append(fields, tlsfingerprintprofile.FieldExtensions) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *TLSFingerprintProfileMutation) Field(name string) (ent.Value, bool) { + switch name { + case tlsfingerprintprofile.FieldCreatedAt: + return m.CreatedAt() + case tlsfingerprintprofile.FieldUpdatedAt: + return m.UpdatedAt() + case tlsfingerprintprofile.FieldName: + return m.Name() + case tlsfingerprintprofile.FieldDescription: + return m.Description() + case tlsfingerprintprofile.FieldEnableGrease: + return m.EnableGrease() + case tlsfingerprintprofile.FieldCipherSuites: + return m.CipherSuites() + case tlsfingerprintprofile.FieldCurves: + return m.Curves() + case tlsfingerprintprofile.FieldPointFormats: + return m.PointFormats() + case tlsfingerprintprofile.FieldSignatureAlgorithms: + return m.SignatureAlgorithms() + case tlsfingerprintprofile.FieldAlpnProtocols: + return m.AlpnProtocols() + case tlsfingerprintprofile.FieldSupportedVersions: + return m.SupportedVersions() + case tlsfingerprintprofile.FieldKeyShareGroups: + return m.KeyShareGroups() + case tlsfingerprintprofile.FieldPskModes: + return m.PskModes() + case tlsfingerprintprofile.FieldExtensions: + return m.Extensions() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *TLSFingerprintProfileMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case tlsfingerprintprofile.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case tlsfingerprintprofile.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case tlsfingerprintprofile.FieldName: + return m.OldName(ctx) + case tlsfingerprintprofile.FieldDescription: + return m.OldDescription(ctx) + case tlsfingerprintprofile.FieldEnableGrease: + return m.OldEnableGrease(ctx) + case tlsfingerprintprofile.FieldCipherSuites: + return m.OldCipherSuites(ctx) + case tlsfingerprintprofile.FieldCurves: + return m.OldCurves(ctx) + case tlsfingerprintprofile.FieldPointFormats: + return m.OldPointFormats(ctx) + case tlsfingerprintprofile.FieldSignatureAlgorithms: + return m.OldSignatureAlgorithms(ctx) + case tlsfingerprintprofile.FieldAlpnProtocols: + return m.OldAlpnProtocols(ctx) + case tlsfingerprintprofile.FieldSupportedVersions: + return m.OldSupportedVersions(ctx) + case tlsfingerprintprofile.FieldKeyShareGroups: + return m.OldKeyShareGroups(ctx) + case tlsfingerprintprofile.FieldPskModes: + return m.OldPskModes(ctx) + case tlsfingerprintprofile.FieldExtensions: + return m.OldExtensions(ctx) + } + return nil, fmt.Errorf("unknown TLSFingerprintProfile field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *TLSFingerprintProfileMutation) SetField(name string, value ent.Value) error { + switch name { + case tlsfingerprintprofile.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case tlsfingerprintprofile.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case tlsfingerprintprofile.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case tlsfingerprintprofile.FieldDescription: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDescription(v) + return nil + case tlsfingerprintprofile.FieldEnableGrease: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetEnableGrease(v) + return nil + case tlsfingerprintprofile.FieldCipherSuites: + v, ok := value.([]uint16) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCipherSuites(v) + return nil + case tlsfingerprintprofile.FieldCurves: + v, ok := value.([]uint16) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCurves(v) + return nil + case tlsfingerprintprofile.FieldPointFormats: + v, ok := value.([]uint16) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPointFormats(v) + return nil + case tlsfingerprintprofile.FieldSignatureAlgorithms: + v, ok := value.([]uint16) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSignatureAlgorithms(v) + return nil + case tlsfingerprintprofile.FieldAlpnProtocols: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAlpnProtocols(v) + return nil + case tlsfingerprintprofile.FieldSupportedVersions: + v, ok := value.([]uint16) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSupportedVersions(v) + return nil + case tlsfingerprintprofile.FieldKeyShareGroups: + v, ok := value.([]uint16) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetKeyShareGroups(v) + return nil + case tlsfingerprintprofile.FieldPskModes: + v, ok := value.([]uint16) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPskModes(v) + return nil + case tlsfingerprintprofile.FieldExtensions: + v, ok := value.([]uint16) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExtensions(v) + return nil + } + return fmt.Errorf("unknown TLSFingerprintProfile field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *TLSFingerprintProfileMutation) AddedFields() []string { + return nil +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *TLSFingerprintProfileMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *TLSFingerprintProfileMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown TLSFingerprintProfile numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *TLSFingerprintProfileMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(tlsfingerprintprofile.FieldDescription) { + fields = append(fields, tlsfingerprintprofile.FieldDescription) + } + if m.FieldCleared(tlsfingerprintprofile.FieldCipherSuites) { + fields = append(fields, tlsfingerprintprofile.FieldCipherSuites) + } + if m.FieldCleared(tlsfingerprintprofile.FieldCurves) { + fields = append(fields, tlsfingerprintprofile.FieldCurves) + } + if m.FieldCleared(tlsfingerprintprofile.FieldPointFormats) { + fields = append(fields, tlsfingerprintprofile.FieldPointFormats) + } + if m.FieldCleared(tlsfingerprintprofile.FieldSignatureAlgorithms) { + fields = append(fields, tlsfingerprintprofile.FieldSignatureAlgorithms) + } + if m.FieldCleared(tlsfingerprintprofile.FieldAlpnProtocols) { + fields = append(fields, tlsfingerprintprofile.FieldAlpnProtocols) + } + if m.FieldCleared(tlsfingerprintprofile.FieldSupportedVersions) { + fields = append(fields, tlsfingerprintprofile.FieldSupportedVersions) + } + if m.FieldCleared(tlsfingerprintprofile.FieldKeyShareGroups) { + fields = append(fields, tlsfingerprintprofile.FieldKeyShareGroups) + } + if m.FieldCleared(tlsfingerprintprofile.FieldPskModes) { + fields = append(fields, tlsfingerprintprofile.FieldPskModes) + } + if m.FieldCleared(tlsfingerprintprofile.FieldExtensions) { + fields = append(fields, tlsfingerprintprofile.FieldExtensions) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *TLSFingerprintProfileMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *TLSFingerprintProfileMutation) ClearField(name string) error { + switch name { + case tlsfingerprintprofile.FieldDescription: + m.ClearDescription() + return nil + case tlsfingerprintprofile.FieldCipherSuites: + m.ClearCipherSuites() + return nil + case tlsfingerprintprofile.FieldCurves: + m.ClearCurves() + return nil + case tlsfingerprintprofile.FieldPointFormats: + m.ClearPointFormats() + return nil + case tlsfingerprintprofile.FieldSignatureAlgorithms: + m.ClearSignatureAlgorithms() + return nil + case tlsfingerprintprofile.FieldAlpnProtocols: + m.ClearAlpnProtocols() + return nil + case tlsfingerprintprofile.FieldSupportedVersions: + m.ClearSupportedVersions() + return nil + case tlsfingerprintprofile.FieldKeyShareGroups: + m.ClearKeyShareGroups() + return nil + case tlsfingerprintprofile.FieldPskModes: + m.ClearPskModes() + return nil + case tlsfingerprintprofile.FieldExtensions: + m.ClearExtensions() + return nil + } + return fmt.Errorf("unknown TLSFingerprintProfile nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *TLSFingerprintProfileMutation) ResetField(name string) error { + switch name { + case tlsfingerprintprofile.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case tlsfingerprintprofile.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case tlsfingerprintprofile.FieldName: + m.ResetName() + return nil + case tlsfingerprintprofile.FieldDescription: + m.ResetDescription() + return nil + case tlsfingerprintprofile.FieldEnableGrease: + m.ResetEnableGrease() + return nil + case tlsfingerprintprofile.FieldCipherSuites: + m.ResetCipherSuites() + return nil + case tlsfingerprintprofile.FieldCurves: + m.ResetCurves() + return nil + case tlsfingerprintprofile.FieldPointFormats: + m.ResetPointFormats() + return nil + case tlsfingerprintprofile.FieldSignatureAlgorithms: + m.ResetSignatureAlgorithms() + return nil + case tlsfingerprintprofile.FieldAlpnProtocols: + m.ResetAlpnProtocols() + return nil + case tlsfingerprintprofile.FieldSupportedVersions: + m.ResetSupportedVersions() + return nil + case tlsfingerprintprofile.FieldKeyShareGroups: + m.ResetKeyShareGroups() + return nil + case tlsfingerprintprofile.FieldPskModes: + m.ResetPskModes() + return nil + case tlsfingerprintprofile.FieldExtensions: + m.ResetExtensions() + return nil + } + return fmt.Errorf("unknown TLSFingerprintProfile field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *TLSFingerprintProfileMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *TLSFingerprintProfileMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *TLSFingerprintProfileMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *TLSFingerprintProfileMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *TLSFingerprintProfileMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *TLSFingerprintProfileMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *TLSFingerprintProfileMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown TLSFingerprintProfile unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *TLSFingerprintProfileMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown TLSFingerprintProfile edge %s", name) +} + // UsageCleanupTaskMutation represents an operation that mutates the UsageCleanupTask nodes in the graph. type UsageCleanupTaskMutation struct { config @@ -18239,6 +19723,8 @@ type UsageLogMutation struct { id *int64 request_id *string model *string + requested_model *string + upstream_model *string input_tokens *int addinput_tokens *int output_tokens *int @@ -18576,6 +20062,104 @@ func (m *UsageLogMutation) ResetModel() { m.model = nil } +// SetRequestedModel sets the "requested_model" field. +func (m *UsageLogMutation) SetRequestedModel(s string) { + m.requested_model = &s +} + +// RequestedModel returns the value of the "requested_model" field in the mutation. +func (m *UsageLogMutation) RequestedModel() (r string, exists bool) { + v := m.requested_model + if v == nil { + return + } + return *v, true +} + +// OldRequestedModel returns the old "requested_model" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldRequestedModel(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequestedModel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequestedModel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequestedModel: %w", err) + } + return oldValue.RequestedModel, nil +} + +// ClearRequestedModel clears the value of the "requested_model" field. +func (m *UsageLogMutation) ClearRequestedModel() { + m.requested_model = nil + m.clearedFields[usagelog.FieldRequestedModel] = struct{}{} +} + +// RequestedModelCleared returns if the "requested_model" field was cleared in this mutation. +func (m *UsageLogMutation) RequestedModelCleared() bool { + _, ok := m.clearedFields[usagelog.FieldRequestedModel] + return ok +} + +// ResetRequestedModel resets all changes to the "requested_model" field. +func (m *UsageLogMutation) ResetRequestedModel() { + m.requested_model = nil + delete(m.clearedFields, usagelog.FieldRequestedModel) +} + +// SetUpstreamModel sets the "upstream_model" field. +func (m *UsageLogMutation) SetUpstreamModel(s string) { + m.upstream_model = &s +} + +// UpstreamModel returns the value of the "upstream_model" field in the mutation. +func (m *UsageLogMutation) UpstreamModel() (r string, exists bool) { + v := m.upstream_model + if v == nil { + return + } + return *v, true +} + +// OldUpstreamModel returns the old "upstream_model" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldUpstreamModel(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpstreamModel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpstreamModel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpstreamModel: %w", err) + } + return oldValue.UpstreamModel, nil +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (m *UsageLogMutation) ClearUpstreamModel() { + m.upstream_model = nil + m.clearedFields[usagelog.FieldUpstreamModel] = struct{}{} +} + +// UpstreamModelCleared returns if the "upstream_model" field was cleared in this mutation. +func (m *UsageLogMutation) UpstreamModelCleared() bool { + _, ok := m.clearedFields[usagelog.FieldUpstreamModel] + return ok +} + +// ResetUpstreamModel resets all changes to the "upstream_model" field. +func (m *UsageLogMutation) ResetUpstreamModel() { + m.upstream_model = nil + delete(m.clearedFields, usagelog.FieldUpstreamModel) +} + // SetGroupID sets the "group_id" field. func (m *UsageLogMutation) SetGroupID(i int64) { m.group = &i @@ -20197,7 +21781,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 32) + fields := make([]string, 0, 34) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -20213,6 +21797,12 @@ func (m *UsageLogMutation) Fields() []string { if m.model != nil { fields = append(fields, usagelog.FieldModel) } + if m.requested_model != nil { + fields = append(fields, usagelog.FieldRequestedModel) + } + if m.upstream_model != nil { + fields = append(fields, usagelog.FieldUpstreamModel) + } if m.group != nil { fields = append(fields, usagelog.FieldGroupID) } @@ -20312,6 +21902,10 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.RequestID() case usagelog.FieldModel: return m.Model() + case usagelog.FieldRequestedModel: + return m.RequestedModel() + case usagelog.FieldUpstreamModel: + return m.UpstreamModel() case usagelog.FieldGroupID: return m.GroupID() case usagelog.FieldSubscriptionID: @@ -20385,6 +21979,10 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldRequestID(ctx) case usagelog.FieldModel: return m.OldModel(ctx) + case usagelog.FieldRequestedModel: + return m.OldRequestedModel(ctx) + case usagelog.FieldUpstreamModel: + return m.OldUpstreamModel(ctx) case usagelog.FieldGroupID: return m.OldGroupID(ctx) case usagelog.FieldSubscriptionID: @@ -20483,6 +22081,20 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetModel(v) return nil + case usagelog.FieldRequestedModel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequestedModel(v) + return nil + case usagelog.FieldUpstreamModel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpstreamModel(v) + return nil case usagelog.FieldGroupID: v, ok := value.(int64) if !ok { @@ -20921,6 +22533,12 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error { // mutation. func (m *UsageLogMutation) ClearedFields() []string { var fields []string + if m.FieldCleared(usagelog.FieldRequestedModel) { + fields = append(fields, usagelog.FieldRequestedModel) + } + if m.FieldCleared(usagelog.FieldUpstreamModel) { + fields = append(fields, usagelog.FieldUpstreamModel) + } if m.FieldCleared(usagelog.FieldGroupID) { fields = append(fields, usagelog.FieldGroupID) } @@ -20962,6 +22580,12 @@ func (m *UsageLogMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *UsageLogMutation) ClearField(name string) error { switch name { + case usagelog.FieldRequestedModel: + m.ClearRequestedModel() + return nil + case usagelog.FieldUpstreamModel: + m.ClearUpstreamModel() + return nil case usagelog.FieldGroupID: m.ClearGroupID() return nil @@ -21012,6 +22636,12 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldModel: m.ResetModel() return nil + case usagelog.FieldRequestedModel: + m.ResetRequestedModel() + return nil + case usagelog.FieldUpstreamModel: + m.ResetUpstreamModel() + return nil case usagelog.FieldGroupID: m.ResetGroupID() return nil diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index 89d933fc..a652ab3f 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -48,6 +48,9 @@ type SecuritySecret func(*sql.Selector) // Setting is the predicate function for setting builders. type Setting func(*sql.Selector) +// TLSFingerprintProfile is the predicate function for tlsfingerprintprofile builders. +type TLSFingerprintProfile func(*sql.Selector) + // UsageCleanupTask is the predicate function for usagecleanuptask builders. type UsageCleanupTask func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index b8facf36..fd6be291 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -20,6 +20,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/schema" "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" @@ -457,8 +458,16 @@ func init() { groupDescAllowMessagesDispatch := groupFields[27].Descriptor() // group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field. group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool) + // groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field. + groupDescRequireOauthOnly := groupFields[28].Descriptor() + // group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field. + group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool) + // groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field. + groupDescRequirePrivacySet := groupFields[29].Descriptor() + // group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field. + group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool) // groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field. - groupDescDefaultMappedModel := groupFields[28].Descriptor() + groupDescDefaultMappedModel := groupFields[30].Descriptor() // group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field. group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string) // group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. @@ -746,6 +755,43 @@ func init() { setting.DefaultUpdatedAt = settingDescUpdatedAt.Default.(func() time.Time) // setting.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. setting.UpdateDefaultUpdatedAt = settingDescUpdatedAt.UpdateDefault.(func() time.Time) + tlsfingerprintprofileMixin := schema.TLSFingerprintProfile{}.Mixin() + tlsfingerprintprofileMixinFields0 := tlsfingerprintprofileMixin[0].Fields() + _ = tlsfingerprintprofileMixinFields0 + tlsfingerprintprofileFields := schema.TLSFingerprintProfile{}.Fields() + _ = tlsfingerprintprofileFields + // tlsfingerprintprofileDescCreatedAt is the schema descriptor for created_at field. + tlsfingerprintprofileDescCreatedAt := tlsfingerprintprofileMixinFields0[0].Descriptor() + // tlsfingerprintprofile.DefaultCreatedAt holds the default value on creation for the created_at field. + tlsfingerprintprofile.DefaultCreatedAt = tlsfingerprintprofileDescCreatedAt.Default.(func() time.Time) + // tlsfingerprintprofileDescUpdatedAt is the schema descriptor for updated_at field. + tlsfingerprintprofileDescUpdatedAt := tlsfingerprintprofileMixinFields0[1].Descriptor() + // tlsfingerprintprofile.DefaultUpdatedAt holds the default value on creation for the updated_at field. + tlsfingerprintprofile.DefaultUpdatedAt = tlsfingerprintprofileDescUpdatedAt.Default.(func() time.Time) + // tlsfingerprintprofile.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + tlsfingerprintprofile.UpdateDefaultUpdatedAt = tlsfingerprintprofileDescUpdatedAt.UpdateDefault.(func() time.Time) + // tlsfingerprintprofileDescName is the schema descriptor for name field. + tlsfingerprintprofileDescName := tlsfingerprintprofileFields[0].Descriptor() + // tlsfingerprintprofile.NameValidator is a validator for the "name" field. It is called by the builders before save. + tlsfingerprintprofile.NameValidator = func() func(string) error { + validators := tlsfingerprintprofileDescName.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(name string) error { + for _, fn := range fns { + if err := fn(name); err != nil { + return err + } + } + return nil + } + }() + // tlsfingerprintprofileDescEnableGrease is the schema descriptor for enable_grease field. + tlsfingerprintprofileDescEnableGrease := tlsfingerprintprofileFields[2].Descriptor() + // tlsfingerprintprofile.DefaultEnableGrease holds the default value on creation for the enable_grease field. + tlsfingerprintprofile.DefaultEnableGrease = tlsfingerprintprofileDescEnableGrease.Default.(bool) usagecleanuptaskMixin := schema.UsageCleanupTask{}.Mixin() usagecleanuptaskMixinFields0 := usagecleanuptaskMixin[0].Fields() _ = usagecleanuptaskMixinFields0 @@ -821,92 +867,100 @@ func init() { return nil } }() + // usagelogDescRequestedModel is the schema descriptor for requested_model field. + usagelogDescRequestedModel := usagelogFields[5].Descriptor() + // usagelog.RequestedModelValidator is a validator for the "requested_model" field. It is called by the builders before save. + usagelog.RequestedModelValidator = usagelogDescRequestedModel.Validators[0].(func(string) error) + // usagelogDescUpstreamModel is the schema descriptor for upstream_model field. + usagelogDescUpstreamModel := usagelogFields[6].Descriptor() + // usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. + usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error) // usagelogDescInputTokens is the schema descriptor for input_tokens field. - usagelogDescInputTokens := usagelogFields[7].Descriptor() + usagelogDescInputTokens := usagelogFields[9].Descriptor() // usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field. usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int) // usagelogDescOutputTokens is the schema descriptor for output_tokens field. - usagelogDescOutputTokens := usagelogFields[8].Descriptor() + usagelogDescOutputTokens := usagelogFields[10].Descriptor() // usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field. usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int) // usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field. - usagelogDescCacheCreationTokens := usagelogFields[9].Descriptor() + usagelogDescCacheCreationTokens := usagelogFields[11].Descriptor() // usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field. usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int) // usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field. - usagelogDescCacheReadTokens := usagelogFields[10].Descriptor() + usagelogDescCacheReadTokens := usagelogFields[12].Descriptor() // usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field. usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int) // usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field. - usagelogDescCacheCreation5mTokens := usagelogFields[11].Descriptor() + usagelogDescCacheCreation5mTokens := usagelogFields[13].Descriptor() // usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field. usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int) // usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field. - usagelogDescCacheCreation1hTokens := usagelogFields[12].Descriptor() + usagelogDescCacheCreation1hTokens := usagelogFields[14].Descriptor() // usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field. usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int) // usagelogDescInputCost is the schema descriptor for input_cost field. - usagelogDescInputCost := usagelogFields[13].Descriptor() + usagelogDescInputCost := usagelogFields[15].Descriptor() // usagelog.DefaultInputCost holds the default value on creation for the input_cost field. usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64) // usagelogDescOutputCost is the schema descriptor for output_cost field. - usagelogDescOutputCost := usagelogFields[14].Descriptor() + usagelogDescOutputCost := usagelogFields[16].Descriptor() // usagelog.DefaultOutputCost holds the default value on creation for the output_cost field. usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64) // usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field. - usagelogDescCacheCreationCost := usagelogFields[15].Descriptor() + usagelogDescCacheCreationCost := usagelogFields[17].Descriptor() // usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field. usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64) // usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field. - usagelogDescCacheReadCost := usagelogFields[16].Descriptor() + usagelogDescCacheReadCost := usagelogFields[18].Descriptor() // usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field. usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64) // usagelogDescTotalCost is the schema descriptor for total_cost field. - usagelogDescTotalCost := usagelogFields[17].Descriptor() + usagelogDescTotalCost := usagelogFields[19].Descriptor() // usagelog.DefaultTotalCost holds the default value on creation for the total_cost field. usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64) // usagelogDescActualCost is the schema descriptor for actual_cost field. - usagelogDescActualCost := usagelogFields[18].Descriptor() + usagelogDescActualCost := usagelogFields[20].Descriptor() // usagelog.DefaultActualCost holds the default value on creation for the actual_cost field. usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64) // usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field. - usagelogDescRateMultiplier := usagelogFields[19].Descriptor() + usagelogDescRateMultiplier := usagelogFields[21].Descriptor() // usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field. usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64) // usagelogDescBillingType is the schema descriptor for billing_type field. - usagelogDescBillingType := usagelogFields[21].Descriptor() + usagelogDescBillingType := usagelogFields[23].Descriptor() // usagelog.DefaultBillingType holds the default value on creation for the billing_type field. usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8) // usagelogDescStream is the schema descriptor for stream field. - usagelogDescStream := usagelogFields[22].Descriptor() + usagelogDescStream := usagelogFields[24].Descriptor() // usagelog.DefaultStream holds the default value on creation for the stream field. usagelog.DefaultStream = usagelogDescStream.Default.(bool) // usagelogDescUserAgent is the schema descriptor for user_agent field. - usagelogDescUserAgent := usagelogFields[25].Descriptor() + usagelogDescUserAgent := usagelogFields[27].Descriptor() // usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error) // usagelogDescIPAddress is the schema descriptor for ip_address field. - usagelogDescIPAddress := usagelogFields[26].Descriptor() + usagelogDescIPAddress := usagelogFields[28].Descriptor() // usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save. usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error) // usagelogDescImageCount is the schema descriptor for image_count field. - usagelogDescImageCount := usagelogFields[27].Descriptor() + usagelogDescImageCount := usagelogFields[29].Descriptor() // usagelog.DefaultImageCount holds the default value on creation for the image_count field. usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int) // usagelogDescImageSize is the schema descriptor for image_size field. - usagelogDescImageSize := usagelogFields[28].Descriptor() + usagelogDescImageSize := usagelogFields[30].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) // usagelogDescMediaType is the schema descriptor for media_type field. - usagelogDescMediaType := usagelogFields[29].Descriptor() + usagelogDescMediaType := usagelogFields[31].Descriptor() // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. - usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor() + usagelogDescCacheTTLOverridden := usagelogFields[32].Descriptor() // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[31].Descriptor() + usagelogDescCreatedAt := usagelogFields[33].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 0f5a7b14..fd83bf26 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -153,6 +153,12 @@ func (Group) Fields() []ent.Field { field.Bool("allow_messages_dispatch"). Default(false). Comment("是否允许 /v1/messages 调度到此 OpenAI 分组"), + field.Bool("require_oauth_only"). + Default(false). + Comment("仅允许非 apikey 类型账号关联到此分组"), + field.Bool("require_privacy_set"). + Default(false). + Comment("调度时仅允许 privacy 已成功设置的账号"), field.String("default_mapped_model"). MaxLen(100). Default(""). diff --git a/backend/ent/schema/tls_fingerprint_profile.go b/backend/ent/schema/tls_fingerprint_profile.go new file mode 100644 index 00000000..86856d05 --- /dev/null +++ b/backend/ent/schema/tls_fingerprint_profile.go @@ -0,0 +1,100 @@ +// Package schema 定义 Ent ORM 的数据库 schema。 +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" +) + +// TLSFingerprintProfile 定义 TLS 指纹配置模板的 schema。 +// +// TLS 指纹模板用于模拟特定客户端(如 Claude Code / Node.js)的 TLS 握手特征。 +// 每个模板包含完整的 ClientHello 参数:加密套件、曲线、扩展等。 +// 通过 Account.Extra.tls_fingerprint_profile_id 绑定到具体账号。 +type TLSFingerprintProfile struct { + ent.Schema +} + +// Annotations 返回 schema 的注解配置。 +func (TLSFingerprintProfile) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "tls_fingerprint_profiles"}, + } +} + +// Mixin 返回该 schema 使用的混入组件。 +func (TLSFingerprintProfile) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +// Fields 定义 TLS 指纹模板实体的所有字段。 +func (TLSFingerprintProfile) Fields() []ent.Field { + return []ent.Field{ + // name: 模板名称,唯一标识 + field.String("name"). + MaxLen(100). + NotEmpty(). + Unique(), + + // description: 模板描述 + field.Text("description"). + Optional(). + Nillable(), + + // enable_grease: 是否启用 GREASE 扩展(Chrome 使用,Node.js 不使用) + field.Bool("enable_grease"). + Default(false), + + // cipher_suites: TLS 加密套件列表(顺序敏感,影响 JA3) + field.JSON("cipher_suites", []uint16{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // curves: 椭圆曲线/支持的组列表 + field.JSON("curves", []uint16{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // point_formats: EC 点格式列表 + field.JSON("point_formats", []uint16{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // signature_algorithms: 签名算法列表 + field.JSON("signature_algorithms", []uint16{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // alpn_protocols: ALPN 协议列表(如 ["http/1.1"]) + field.JSON("alpn_protocols", []string{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // supported_versions: 支持的 TLS 版本列表(如 [0x0304, 0x0303]) + field.JSON("supported_versions", []uint16{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // key_share_groups: Key Share 中发送的曲线组(如 [29] 即 X25519) + field.JSON("key_share_groups", []uint16{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // psk_modes: PSK 密钥交换模式(如 [1] 即 psk_dhe_ke) + field.JSON("psk_modes", []uint16{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // extensions: TLS 扩展类型 ID 列表,按发送顺序排列 + field.JSON("extensions", []uint16{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + } +} diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index dcca1a0a..32c39e25 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -41,6 +41,18 @@ func (UsageLog) Fields() []ent.Field { field.String("model"). MaxLen(100). NotEmpty(), + // RequestedModel stores the client-requested model name for stable display and analytics. + // NULL means historical rows written before requested_model dual-write was introduced. + field.String("requested_model"). + MaxLen(100). + Optional(). + Nillable(), + // UpstreamModel stores the actual upstream model name when model mapping + // is applied. NULL means no mapping — the requested model was used as-is. + field.String("upstream_model"). + MaxLen(100). + Optional(). + Nillable(), field.Int64("group_id"). Optional(). Nillable(), @@ -175,6 +187,7 @@ func (UsageLog) Indexes() []ent.Index { index.Fields("subscription_id"), index.Fields("created_at"), index.Fields("model"), + index.Fields("requested_model"), index.Fields("request_id"), // 复合索引用于时间范围查询 index.Fields("user_id", "created_at"), diff --git a/backend/ent/tlsfingerprintprofile.go b/backend/ent/tlsfingerprintprofile.go new file mode 100644 index 00000000..c9455609 --- /dev/null +++ b/backend/ent/tlsfingerprintprofile.go @@ -0,0 +1,275 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile" +) + +// TLSFingerprintProfile is the model entity for the TLSFingerprintProfile schema. +type TLSFingerprintProfile struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + // Description holds the value of the "description" field. + Description *string `json:"description,omitempty"` + // EnableGrease holds the value of the "enable_grease" field. + EnableGrease bool `json:"enable_grease,omitempty"` + // CipherSuites holds the value of the "cipher_suites" field. + CipherSuites []uint16 `json:"cipher_suites,omitempty"` + // Curves holds the value of the "curves" field. + Curves []uint16 `json:"curves,omitempty"` + // PointFormats holds the value of the "point_formats" field. + PointFormats []uint16 `json:"point_formats,omitempty"` + // SignatureAlgorithms holds the value of the "signature_algorithms" field. + SignatureAlgorithms []uint16 `json:"signature_algorithms,omitempty"` + // AlpnProtocols holds the value of the "alpn_protocols" field. + AlpnProtocols []string `json:"alpn_protocols,omitempty"` + // SupportedVersions holds the value of the "supported_versions" field. + SupportedVersions []uint16 `json:"supported_versions,omitempty"` + // KeyShareGroups holds the value of the "key_share_groups" field. + KeyShareGroups []uint16 `json:"key_share_groups,omitempty"` + // PskModes holds the value of the "psk_modes" field. + PskModes []uint16 `json:"psk_modes,omitempty"` + // Extensions holds the value of the "extensions" field. + Extensions []uint16 `json:"extensions,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*TLSFingerprintProfile) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case tlsfingerprintprofile.FieldCipherSuites, tlsfingerprintprofile.FieldCurves, tlsfingerprintprofile.FieldPointFormats, tlsfingerprintprofile.FieldSignatureAlgorithms, tlsfingerprintprofile.FieldAlpnProtocols, tlsfingerprintprofile.FieldSupportedVersions, tlsfingerprintprofile.FieldKeyShareGroups, tlsfingerprintprofile.FieldPskModes, tlsfingerprintprofile.FieldExtensions: + values[i] = new([]byte) + case tlsfingerprintprofile.FieldEnableGrease: + values[i] = new(sql.NullBool) + case tlsfingerprintprofile.FieldID: + values[i] = new(sql.NullInt64) + case tlsfingerprintprofile.FieldName, tlsfingerprintprofile.FieldDescription: + values[i] = new(sql.NullString) + case tlsfingerprintprofile.FieldCreatedAt, tlsfingerprintprofile.FieldUpdatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the TLSFingerprintProfile fields. +func (_m *TLSFingerprintProfile) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case tlsfingerprintprofile.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case tlsfingerprintprofile.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case tlsfingerprintprofile.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case tlsfingerprintprofile.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + _m.Name = value.String + } + case tlsfingerprintprofile.FieldDescription: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field description", values[i]) + } else if value.Valid { + _m.Description = new(string) + *_m.Description = value.String + } + case tlsfingerprintprofile.FieldEnableGrease: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field enable_grease", values[i]) + } else if value.Valid { + _m.EnableGrease = value.Bool + } + case tlsfingerprintprofile.FieldCipherSuites: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field cipher_suites", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.CipherSuites); err != nil { + return fmt.Errorf("unmarshal field cipher_suites: %w", err) + } + } + case tlsfingerprintprofile.FieldCurves: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field curves", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Curves); err != nil { + return fmt.Errorf("unmarshal field curves: %w", err) + } + } + case tlsfingerprintprofile.FieldPointFormats: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field point_formats", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.PointFormats); err != nil { + return fmt.Errorf("unmarshal field point_formats: %w", err) + } + } + case tlsfingerprintprofile.FieldSignatureAlgorithms: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field signature_algorithms", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.SignatureAlgorithms); err != nil { + return fmt.Errorf("unmarshal field signature_algorithms: %w", err) + } + } + case tlsfingerprintprofile.FieldAlpnProtocols: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field alpn_protocols", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.AlpnProtocols); err != nil { + return fmt.Errorf("unmarshal field alpn_protocols: %w", err) + } + } + case tlsfingerprintprofile.FieldSupportedVersions: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field supported_versions", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.SupportedVersions); err != nil { + return fmt.Errorf("unmarshal field supported_versions: %w", err) + } + } + case tlsfingerprintprofile.FieldKeyShareGroups: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field key_share_groups", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.KeyShareGroups); err != nil { + return fmt.Errorf("unmarshal field key_share_groups: %w", err) + } + } + case tlsfingerprintprofile.FieldPskModes: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field psk_modes", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.PskModes); err != nil { + return fmt.Errorf("unmarshal field psk_modes: %w", err) + } + } + case tlsfingerprintprofile.FieldExtensions: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field extensions", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Extensions); err != nil { + return fmt.Errorf("unmarshal field extensions: %w", err) + } + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the TLSFingerprintProfile. +// This includes values selected through modifiers, order, etc. +func (_m *TLSFingerprintProfile) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this TLSFingerprintProfile. +// Note that you need to call TLSFingerprintProfile.Unwrap() before calling this method if this TLSFingerprintProfile +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *TLSFingerprintProfile) Update() *TLSFingerprintProfileUpdateOne { + return NewTLSFingerprintProfileClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the TLSFingerprintProfile entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *TLSFingerprintProfile) Unwrap() *TLSFingerprintProfile { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: TLSFingerprintProfile is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *TLSFingerprintProfile) String() string { + var builder strings.Builder + builder.WriteString("TLSFingerprintProfile(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("name=") + builder.WriteString(_m.Name) + builder.WriteString(", ") + if v := _m.Description; v != nil { + builder.WriteString("description=") + builder.WriteString(*v) + } + builder.WriteString(", ") + builder.WriteString("enable_grease=") + builder.WriteString(fmt.Sprintf("%v", _m.EnableGrease)) + builder.WriteString(", ") + builder.WriteString("cipher_suites=") + builder.WriteString(fmt.Sprintf("%v", _m.CipherSuites)) + builder.WriteString(", ") + builder.WriteString("curves=") + builder.WriteString(fmt.Sprintf("%v", _m.Curves)) + builder.WriteString(", ") + builder.WriteString("point_formats=") + builder.WriteString(fmt.Sprintf("%v", _m.PointFormats)) + builder.WriteString(", ") + builder.WriteString("signature_algorithms=") + builder.WriteString(fmt.Sprintf("%v", _m.SignatureAlgorithms)) + builder.WriteString(", ") + builder.WriteString("alpn_protocols=") + builder.WriteString(fmt.Sprintf("%v", _m.AlpnProtocols)) + builder.WriteString(", ") + builder.WriteString("supported_versions=") + builder.WriteString(fmt.Sprintf("%v", _m.SupportedVersions)) + builder.WriteString(", ") + builder.WriteString("key_share_groups=") + builder.WriteString(fmt.Sprintf("%v", _m.KeyShareGroups)) + builder.WriteString(", ") + builder.WriteString("psk_modes=") + builder.WriteString(fmt.Sprintf("%v", _m.PskModes)) + builder.WriteString(", ") + builder.WriteString("extensions=") + builder.WriteString(fmt.Sprintf("%v", _m.Extensions)) + builder.WriteByte(')') + return builder.String() +} + +// TLSFingerprintProfiles is a parsable slice of TLSFingerprintProfile. +type TLSFingerprintProfiles []*TLSFingerprintProfile diff --git a/backend/ent/tlsfingerprintprofile/tlsfingerprintprofile.go b/backend/ent/tlsfingerprintprofile/tlsfingerprintprofile.go new file mode 100644 index 00000000..49426d36 --- /dev/null +++ b/backend/ent/tlsfingerprintprofile/tlsfingerprintprofile.go @@ -0,0 +1,121 @@ +// Code generated by ent, DO NOT EDIT. + +package tlsfingerprintprofile + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the tlsfingerprintprofile type in the database. + Label = "tls_fingerprint_profile" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldDescription holds the string denoting the description field in the database. + FieldDescription = "description" + // FieldEnableGrease holds the string denoting the enable_grease field in the database. + FieldEnableGrease = "enable_grease" + // FieldCipherSuites holds the string denoting the cipher_suites field in the database. + FieldCipherSuites = "cipher_suites" + // FieldCurves holds the string denoting the curves field in the database. + FieldCurves = "curves" + // FieldPointFormats holds the string denoting the point_formats field in the database. + FieldPointFormats = "point_formats" + // FieldSignatureAlgorithms holds the string denoting the signature_algorithms field in the database. + FieldSignatureAlgorithms = "signature_algorithms" + // FieldAlpnProtocols holds the string denoting the alpn_protocols field in the database. + FieldAlpnProtocols = "alpn_protocols" + // FieldSupportedVersions holds the string denoting the supported_versions field in the database. + FieldSupportedVersions = "supported_versions" + // FieldKeyShareGroups holds the string denoting the key_share_groups field in the database. + FieldKeyShareGroups = "key_share_groups" + // FieldPskModes holds the string denoting the psk_modes field in the database. + FieldPskModes = "psk_modes" + // FieldExtensions holds the string denoting the extensions field in the database. + FieldExtensions = "extensions" + // Table holds the table name of the tlsfingerprintprofile in the database. + Table = "tls_fingerprint_profiles" +) + +// Columns holds all SQL columns for tlsfingerprintprofile fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldName, + FieldDescription, + FieldEnableGrease, + FieldCipherSuites, + FieldCurves, + FieldPointFormats, + FieldSignatureAlgorithms, + FieldAlpnProtocols, + FieldSupportedVersions, + FieldKeyShareGroups, + FieldPskModes, + FieldExtensions, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // NameValidator is a validator for the "name" field. It is called by the builders before save. + NameValidator func(string) error + // DefaultEnableGrease holds the default value on creation for the "enable_grease" field. + DefaultEnableGrease bool +) + +// OrderOption defines the ordering options for the TLSFingerprintProfile queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByDescription orders the results by the description field. +func ByDescription(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDescription, opts...).ToFunc() +} + +// ByEnableGrease orders the results by the enable_grease field. +func ByEnableGrease(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEnableGrease, opts...).ToFunc() +} diff --git a/backend/ent/tlsfingerprintprofile/where.go b/backend/ent/tlsfingerprintprofile/where.go new file mode 100644 index 00000000..f7d1ba27 --- /dev/null +++ b/backend/ent/tlsfingerprintprofile/where.go @@ -0,0 +1,415 @@ +// Code generated by ent, DO NOT EDIT. + +package tlsfingerprintprofile + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldName, v)) +} + +// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ. +func Description(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldDescription, v)) +} + +// EnableGrease applies equality check predicate on the "enable_grease" field. It's identical to EnableGreaseEQ. +func EnableGrease(v bool) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldEnableGrease, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldContainsFold(FieldName, v)) +} + +// DescriptionEQ applies the EQ predicate on the "description" field. +func DescriptionEQ(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldDescription, v)) +} + +// DescriptionNEQ applies the NEQ predicate on the "description" field. +func DescriptionNEQ(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNEQ(FieldDescription, v)) +} + +// DescriptionIn applies the In predicate on the "description" field. +func DescriptionIn(vs ...string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldIn(FieldDescription, vs...)) +} + +// DescriptionNotIn applies the NotIn predicate on the "description" field. +func DescriptionNotIn(vs ...string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNotIn(FieldDescription, vs...)) +} + +// DescriptionGT applies the GT predicate on the "description" field. +func DescriptionGT(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldGT(FieldDescription, v)) +} + +// DescriptionGTE applies the GTE predicate on the "description" field. +func DescriptionGTE(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldGTE(FieldDescription, v)) +} + +// DescriptionLT applies the LT predicate on the "description" field. +func DescriptionLT(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldLT(FieldDescription, v)) +} + +// DescriptionLTE applies the LTE predicate on the "description" field. +func DescriptionLTE(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldLTE(FieldDescription, v)) +} + +// DescriptionContains applies the Contains predicate on the "description" field. +func DescriptionContains(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldContains(FieldDescription, v)) +} + +// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field. +func DescriptionHasPrefix(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldHasPrefix(FieldDescription, v)) +} + +// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field. +func DescriptionHasSuffix(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldHasSuffix(FieldDescription, v)) +} + +// DescriptionIsNil applies the IsNil predicate on the "description" field. +func DescriptionIsNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldDescription)) +} + +// DescriptionNotNil applies the NotNil predicate on the "description" field. +func DescriptionNotNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldDescription)) +} + +// DescriptionEqualFold applies the EqualFold predicate on the "description" field. +func DescriptionEqualFold(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldEqualFold(FieldDescription, v)) +} + +// DescriptionContainsFold applies the ContainsFold predicate on the "description" field. +func DescriptionContainsFold(v string) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldContainsFold(FieldDescription, v)) +} + +// EnableGreaseEQ applies the EQ predicate on the "enable_grease" field. +func EnableGreaseEQ(v bool) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldEQ(FieldEnableGrease, v)) +} + +// EnableGreaseNEQ applies the NEQ predicate on the "enable_grease" field. +func EnableGreaseNEQ(v bool) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNEQ(FieldEnableGrease, v)) +} + +// CipherSuitesIsNil applies the IsNil predicate on the "cipher_suites" field. +func CipherSuitesIsNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldCipherSuites)) +} + +// CipherSuitesNotNil applies the NotNil predicate on the "cipher_suites" field. +func CipherSuitesNotNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldCipherSuites)) +} + +// CurvesIsNil applies the IsNil predicate on the "curves" field. +func CurvesIsNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldCurves)) +} + +// CurvesNotNil applies the NotNil predicate on the "curves" field. +func CurvesNotNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldCurves)) +} + +// PointFormatsIsNil applies the IsNil predicate on the "point_formats" field. +func PointFormatsIsNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldPointFormats)) +} + +// PointFormatsNotNil applies the NotNil predicate on the "point_formats" field. +func PointFormatsNotNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldPointFormats)) +} + +// SignatureAlgorithmsIsNil applies the IsNil predicate on the "signature_algorithms" field. +func SignatureAlgorithmsIsNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldSignatureAlgorithms)) +} + +// SignatureAlgorithmsNotNil applies the NotNil predicate on the "signature_algorithms" field. +func SignatureAlgorithmsNotNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldSignatureAlgorithms)) +} + +// AlpnProtocolsIsNil applies the IsNil predicate on the "alpn_protocols" field. +func AlpnProtocolsIsNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldAlpnProtocols)) +} + +// AlpnProtocolsNotNil applies the NotNil predicate on the "alpn_protocols" field. +func AlpnProtocolsNotNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldAlpnProtocols)) +} + +// SupportedVersionsIsNil applies the IsNil predicate on the "supported_versions" field. +func SupportedVersionsIsNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldSupportedVersions)) +} + +// SupportedVersionsNotNil applies the NotNil predicate on the "supported_versions" field. +func SupportedVersionsNotNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldSupportedVersions)) +} + +// KeyShareGroupsIsNil applies the IsNil predicate on the "key_share_groups" field. +func KeyShareGroupsIsNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldKeyShareGroups)) +} + +// KeyShareGroupsNotNil applies the NotNil predicate on the "key_share_groups" field. +func KeyShareGroupsNotNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldKeyShareGroups)) +} + +// PskModesIsNil applies the IsNil predicate on the "psk_modes" field. +func PskModesIsNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldPskModes)) +} + +// PskModesNotNil applies the NotNil predicate on the "psk_modes" field. +func PskModesNotNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldPskModes)) +} + +// ExtensionsIsNil applies the IsNil predicate on the "extensions" field. +func ExtensionsIsNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldIsNull(FieldExtensions)) +} + +// ExtensionsNotNil applies the NotNil predicate on the "extensions" field. +func ExtensionsNotNil() predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.FieldNotNull(FieldExtensions)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.TLSFingerprintProfile) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.TLSFingerprintProfile) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.TLSFingerprintProfile) predicate.TLSFingerprintProfile { + return predicate.TLSFingerprintProfile(sql.NotPredicates(p)) +} diff --git a/backend/ent/tlsfingerprintprofile_create.go b/backend/ent/tlsfingerprintprofile_create.go new file mode 100644 index 00000000..70a5e6be --- /dev/null +++ b/backend/ent/tlsfingerprintprofile_create.go @@ -0,0 +1,1341 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile" +) + +// TLSFingerprintProfileCreate is the builder for creating a TLSFingerprintProfile entity. +type TLSFingerprintProfileCreate struct { + config + mutation *TLSFingerprintProfileMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *TLSFingerprintProfileCreate) SetCreatedAt(v time.Time) *TLSFingerprintProfileCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *TLSFingerprintProfileCreate) SetNillableCreatedAt(v *time.Time) *TLSFingerprintProfileCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *TLSFingerprintProfileCreate) SetUpdatedAt(v time.Time) *TLSFingerprintProfileCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *TLSFingerprintProfileCreate) SetNillableUpdatedAt(v *time.Time) *TLSFingerprintProfileCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetName sets the "name" field. +func (_c *TLSFingerprintProfileCreate) SetName(v string) *TLSFingerprintProfileCreate { + _c.mutation.SetName(v) + return _c +} + +// SetDescription sets the "description" field. +func (_c *TLSFingerprintProfileCreate) SetDescription(v string) *TLSFingerprintProfileCreate { + _c.mutation.SetDescription(v) + return _c +} + +// SetNillableDescription sets the "description" field if the given value is not nil. +func (_c *TLSFingerprintProfileCreate) SetNillableDescription(v *string) *TLSFingerprintProfileCreate { + if v != nil { + _c.SetDescription(*v) + } + return _c +} + +// SetEnableGrease sets the "enable_grease" field. +func (_c *TLSFingerprintProfileCreate) SetEnableGrease(v bool) *TLSFingerprintProfileCreate { + _c.mutation.SetEnableGrease(v) + return _c +} + +// SetNillableEnableGrease sets the "enable_grease" field if the given value is not nil. +func (_c *TLSFingerprintProfileCreate) SetNillableEnableGrease(v *bool) *TLSFingerprintProfileCreate { + if v != nil { + _c.SetEnableGrease(*v) + } + return _c +} + +// SetCipherSuites sets the "cipher_suites" field. +func (_c *TLSFingerprintProfileCreate) SetCipherSuites(v []uint16) *TLSFingerprintProfileCreate { + _c.mutation.SetCipherSuites(v) + return _c +} + +// SetCurves sets the "curves" field. +func (_c *TLSFingerprintProfileCreate) SetCurves(v []uint16) *TLSFingerprintProfileCreate { + _c.mutation.SetCurves(v) + return _c +} + +// SetPointFormats sets the "point_formats" field. +func (_c *TLSFingerprintProfileCreate) SetPointFormats(v []uint16) *TLSFingerprintProfileCreate { + _c.mutation.SetPointFormats(v) + return _c +} + +// SetSignatureAlgorithms sets the "signature_algorithms" field. +func (_c *TLSFingerprintProfileCreate) SetSignatureAlgorithms(v []uint16) *TLSFingerprintProfileCreate { + _c.mutation.SetSignatureAlgorithms(v) + return _c +} + +// SetAlpnProtocols sets the "alpn_protocols" field. +func (_c *TLSFingerprintProfileCreate) SetAlpnProtocols(v []string) *TLSFingerprintProfileCreate { + _c.mutation.SetAlpnProtocols(v) + return _c +} + +// SetSupportedVersions sets the "supported_versions" field. +func (_c *TLSFingerprintProfileCreate) SetSupportedVersions(v []uint16) *TLSFingerprintProfileCreate { + _c.mutation.SetSupportedVersions(v) + return _c +} + +// SetKeyShareGroups sets the "key_share_groups" field. +func (_c *TLSFingerprintProfileCreate) SetKeyShareGroups(v []uint16) *TLSFingerprintProfileCreate { + _c.mutation.SetKeyShareGroups(v) + return _c +} + +// SetPskModes sets the "psk_modes" field. +func (_c *TLSFingerprintProfileCreate) SetPskModes(v []uint16) *TLSFingerprintProfileCreate { + _c.mutation.SetPskModes(v) + return _c +} + +// SetExtensions sets the "extensions" field. +func (_c *TLSFingerprintProfileCreate) SetExtensions(v []uint16) *TLSFingerprintProfileCreate { + _c.mutation.SetExtensions(v) + return _c +} + +// Mutation returns the TLSFingerprintProfileMutation object of the builder. +func (_c *TLSFingerprintProfileCreate) Mutation() *TLSFingerprintProfileMutation { + return _c.mutation +} + +// Save creates the TLSFingerprintProfile in the database. +func (_c *TLSFingerprintProfileCreate) Save(ctx context.Context) (*TLSFingerprintProfile, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *TLSFingerprintProfileCreate) SaveX(ctx context.Context) *TLSFingerprintProfile { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *TLSFingerprintProfileCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *TLSFingerprintProfileCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *TLSFingerprintProfileCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := tlsfingerprintprofile.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := tlsfingerprintprofile.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.EnableGrease(); !ok { + v := tlsfingerprintprofile.DefaultEnableGrease + _c.mutation.SetEnableGrease(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *TLSFingerprintProfileCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "TLSFingerprintProfile.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "TLSFingerprintProfile.updated_at"`)} + } + if _, ok := _c.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "TLSFingerprintProfile.name"`)} + } + if v, ok := _c.mutation.Name(); ok { + if err := tlsfingerprintprofile.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "TLSFingerprintProfile.name": %w`, err)} + } + } + if _, ok := _c.mutation.EnableGrease(); !ok { + return &ValidationError{Name: "enable_grease", err: errors.New(`ent: missing required field "TLSFingerprintProfile.enable_grease"`)} + } + return nil +} + +func (_c *TLSFingerprintProfileCreate) sqlSave(ctx context.Context) (*TLSFingerprintProfile, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *TLSFingerprintProfileCreate) createSpec() (*TLSFingerprintProfile, *sqlgraph.CreateSpec) { + var ( + _node = &TLSFingerprintProfile{config: _c.config} + _spec = sqlgraph.NewCreateSpec(tlsfingerprintprofile.Table, sqlgraph.NewFieldSpec(tlsfingerprintprofile.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(tlsfingerprintprofile.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(tlsfingerprintprofile.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.Name(); ok { + _spec.SetField(tlsfingerprintprofile.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := _c.mutation.Description(); ok { + _spec.SetField(tlsfingerprintprofile.FieldDescription, field.TypeString, value) + _node.Description = &value + } + if value, ok := _c.mutation.EnableGrease(); ok { + _spec.SetField(tlsfingerprintprofile.FieldEnableGrease, field.TypeBool, value) + _node.EnableGrease = value + } + if value, ok := _c.mutation.CipherSuites(); ok { + _spec.SetField(tlsfingerprintprofile.FieldCipherSuites, field.TypeJSON, value) + _node.CipherSuites = value + } + if value, ok := _c.mutation.Curves(); ok { + _spec.SetField(tlsfingerprintprofile.FieldCurves, field.TypeJSON, value) + _node.Curves = value + } + if value, ok := _c.mutation.PointFormats(); ok { + _spec.SetField(tlsfingerprintprofile.FieldPointFormats, field.TypeJSON, value) + _node.PointFormats = value + } + if value, ok := _c.mutation.SignatureAlgorithms(); ok { + _spec.SetField(tlsfingerprintprofile.FieldSignatureAlgorithms, field.TypeJSON, value) + _node.SignatureAlgorithms = value + } + if value, ok := _c.mutation.AlpnProtocols(); ok { + _spec.SetField(tlsfingerprintprofile.FieldAlpnProtocols, field.TypeJSON, value) + _node.AlpnProtocols = value + } + if value, ok := _c.mutation.SupportedVersions(); ok { + _spec.SetField(tlsfingerprintprofile.FieldSupportedVersions, field.TypeJSON, value) + _node.SupportedVersions = value + } + if value, ok := _c.mutation.KeyShareGroups(); ok { + _spec.SetField(tlsfingerprintprofile.FieldKeyShareGroups, field.TypeJSON, value) + _node.KeyShareGroups = value + } + if value, ok := _c.mutation.PskModes(); ok { + _spec.SetField(tlsfingerprintprofile.FieldPskModes, field.TypeJSON, value) + _node.PskModes = value + } + if value, ok := _c.mutation.Extensions(); ok { + _spec.SetField(tlsfingerprintprofile.FieldExtensions, field.TypeJSON, value) + _node.Extensions = value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.TLSFingerprintProfile.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.TLSFingerprintProfileUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *TLSFingerprintProfileCreate) OnConflict(opts ...sql.ConflictOption) *TLSFingerprintProfileUpsertOne { + _c.conflict = opts + return &TLSFingerprintProfileUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.TLSFingerprintProfile.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *TLSFingerprintProfileCreate) OnConflictColumns(columns ...string) *TLSFingerprintProfileUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &TLSFingerprintProfileUpsertOne{ + create: _c, + } +} + +type ( + // TLSFingerprintProfileUpsertOne is the builder for "upsert"-ing + // one TLSFingerprintProfile node. + TLSFingerprintProfileUpsertOne struct { + create *TLSFingerprintProfileCreate + } + + // TLSFingerprintProfileUpsert is the "OnConflict" setter. + TLSFingerprintProfileUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *TLSFingerprintProfileUpsert) SetUpdatedAt(v time.Time) *TLSFingerprintProfileUpsert { + u.Set(tlsfingerprintprofile.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsert) UpdateUpdatedAt() *TLSFingerprintProfileUpsert { + u.SetExcluded(tlsfingerprintprofile.FieldUpdatedAt) + return u +} + +// SetName sets the "name" field. +func (u *TLSFingerprintProfileUpsert) SetName(v string) *TLSFingerprintProfileUpsert { + u.Set(tlsfingerprintprofile.FieldName, v) + return u +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsert) UpdateName() *TLSFingerprintProfileUpsert { + u.SetExcluded(tlsfingerprintprofile.FieldName) + return u +} + +// SetDescription sets the "description" field. +func (u *TLSFingerprintProfileUpsert) SetDescription(v string) *TLSFingerprintProfileUpsert { + u.Set(tlsfingerprintprofile.FieldDescription, v) + return u +} + +// UpdateDescription sets the "description" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsert) UpdateDescription() *TLSFingerprintProfileUpsert { + u.SetExcluded(tlsfingerprintprofile.FieldDescription) + return u +} + +// ClearDescription clears the value of the "description" field. +func (u *TLSFingerprintProfileUpsert) ClearDescription() *TLSFingerprintProfileUpsert { + u.SetNull(tlsfingerprintprofile.FieldDescription) + return u +} + +// SetEnableGrease sets the "enable_grease" field. +func (u *TLSFingerprintProfileUpsert) SetEnableGrease(v bool) *TLSFingerprintProfileUpsert { + u.Set(tlsfingerprintprofile.FieldEnableGrease, v) + return u +} + +// UpdateEnableGrease sets the "enable_grease" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsert) UpdateEnableGrease() *TLSFingerprintProfileUpsert { + u.SetExcluded(tlsfingerprintprofile.FieldEnableGrease) + return u +} + +// SetCipherSuites sets the "cipher_suites" field. +func (u *TLSFingerprintProfileUpsert) SetCipherSuites(v []uint16) *TLSFingerprintProfileUpsert { + u.Set(tlsfingerprintprofile.FieldCipherSuites, v) + return u +} + +// UpdateCipherSuites sets the "cipher_suites" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsert) UpdateCipherSuites() *TLSFingerprintProfileUpsert { + u.SetExcluded(tlsfingerprintprofile.FieldCipherSuites) + return u +} + +// ClearCipherSuites clears the value of the "cipher_suites" field. +func (u *TLSFingerprintProfileUpsert) ClearCipherSuites() *TLSFingerprintProfileUpsert { + u.SetNull(tlsfingerprintprofile.FieldCipherSuites) + return u +} + +// SetCurves sets the "curves" field. +func (u *TLSFingerprintProfileUpsert) SetCurves(v []uint16) *TLSFingerprintProfileUpsert { + u.Set(tlsfingerprintprofile.FieldCurves, v) + return u +} + +// UpdateCurves sets the "curves" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsert) UpdateCurves() *TLSFingerprintProfileUpsert { + u.SetExcluded(tlsfingerprintprofile.FieldCurves) + return u +} + +// ClearCurves clears the value of the "curves" field. +func (u *TLSFingerprintProfileUpsert) ClearCurves() *TLSFingerprintProfileUpsert { + u.SetNull(tlsfingerprintprofile.FieldCurves) + return u +} + +// SetPointFormats sets the "point_formats" field. +func (u *TLSFingerprintProfileUpsert) SetPointFormats(v []uint16) *TLSFingerprintProfileUpsert { + u.Set(tlsfingerprintprofile.FieldPointFormats, v) + return u +} + +// UpdatePointFormats sets the "point_formats" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsert) UpdatePointFormats() *TLSFingerprintProfileUpsert { + u.SetExcluded(tlsfingerprintprofile.FieldPointFormats) + return u +} + +// ClearPointFormats clears the value of the "point_formats" field. +func (u *TLSFingerprintProfileUpsert) ClearPointFormats() *TLSFingerprintProfileUpsert { + u.SetNull(tlsfingerprintprofile.FieldPointFormats) + return u +} + +// SetSignatureAlgorithms sets the "signature_algorithms" field. +func (u *TLSFingerprintProfileUpsert) SetSignatureAlgorithms(v []uint16) *TLSFingerprintProfileUpsert { + u.Set(tlsfingerprintprofile.FieldSignatureAlgorithms, v) + return u +} + +// UpdateSignatureAlgorithms sets the "signature_algorithms" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsert) UpdateSignatureAlgorithms() *TLSFingerprintProfileUpsert { + u.SetExcluded(tlsfingerprintprofile.FieldSignatureAlgorithms) + return u +} + +// ClearSignatureAlgorithms clears the value of the "signature_algorithms" field. +func (u *TLSFingerprintProfileUpsert) ClearSignatureAlgorithms() *TLSFingerprintProfileUpsert { + u.SetNull(tlsfingerprintprofile.FieldSignatureAlgorithms) + return u +} + +// SetAlpnProtocols sets the "alpn_protocols" field. +func (u *TLSFingerprintProfileUpsert) SetAlpnProtocols(v []string) *TLSFingerprintProfileUpsert { + u.Set(tlsfingerprintprofile.FieldAlpnProtocols, v) + return u +} + +// UpdateAlpnProtocols sets the "alpn_protocols" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsert) UpdateAlpnProtocols() *TLSFingerprintProfileUpsert { + u.SetExcluded(tlsfingerprintprofile.FieldAlpnProtocols) + return u +} + +// ClearAlpnProtocols clears the value of the "alpn_protocols" field. +func (u *TLSFingerprintProfileUpsert) ClearAlpnProtocols() *TLSFingerprintProfileUpsert { + u.SetNull(tlsfingerprintprofile.FieldAlpnProtocols) + return u +} + +// SetSupportedVersions sets the "supported_versions" field. +func (u *TLSFingerprintProfileUpsert) SetSupportedVersions(v []uint16) *TLSFingerprintProfileUpsert { + u.Set(tlsfingerprintprofile.FieldSupportedVersions, v) + return u +} + +// UpdateSupportedVersions sets the "supported_versions" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsert) UpdateSupportedVersions() *TLSFingerprintProfileUpsert { + u.SetExcluded(tlsfingerprintprofile.FieldSupportedVersions) + return u +} + +// ClearSupportedVersions clears the value of the "supported_versions" field. +func (u *TLSFingerprintProfileUpsert) ClearSupportedVersions() *TLSFingerprintProfileUpsert { + u.SetNull(tlsfingerprintprofile.FieldSupportedVersions) + return u +} + +// SetKeyShareGroups sets the "key_share_groups" field. +func (u *TLSFingerprintProfileUpsert) SetKeyShareGroups(v []uint16) *TLSFingerprintProfileUpsert { + u.Set(tlsfingerprintprofile.FieldKeyShareGroups, v) + return u +} + +// UpdateKeyShareGroups sets the "key_share_groups" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsert) UpdateKeyShareGroups() *TLSFingerprintProfileUpsert { + u.SetExcluded(tlsfingerprintprofile.FieldKeyShareGroups) + return u +} + +// ClearKeyShareGroups clears the value of the "key_share_groups" field. +func (u *TLSFingerprintProfileUpsert) ClearKeyShareGroups() *TLSFingerprintProfileUpsert { + u.SetNull(tlsfingerprintprofile.FieldKeyShareGroups) + return u +} + +// SetPskModes sets the "psk_modes" field. +func (u *TLSFingerprintProfileUpsert) SetPskModes(v []uint16) *TLSFingerprintProfileUpsert { + u.Set(tlsfingerprintprofile.FieldPskModes, v) + return u +} + +// UpdatePskModes sets the "psk_modes" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsert) UpdatePskModes() *TLSFingerprintProfileUpsert { + u.SetExcluded(tlsfingerprintprofile.FieldPskModes) + return u +} + +// ClearPskModes clears the value of the "psk_modes" field. +func (u *TLSFingerprintProfileUpsert) ClearPskModes() *TLSFingerprintProfileUpsert { + u.SetNull(tlsfingerprintprofile.FieldPskModes) + return u +} + +// SetExtensions sets the "extensions" field. +func (u *TLSFingerprintProfileUpsert) SetExtensions(v []uint16) *TLSFingerprintProfileUpsert { + u.Set(tlsfingerprintprofile.FieldExtensions, v) + return u +} + +// UpdateExtensions sets the "extensions" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsert) UpdateExtensions() *TLSFingerprintProfileUpsert { + u.SetExcluded(tlsfingerprintprofile.FieldExtensions) + return u +} + +// ClearExtensions clears the value of the "extensions" field. +func (u *TLSFingerprintProfileUpsert) ClearExtensions() *TLSFingerprintProfileUpsert { + u.SetNull(tlsfingerprintprofile.FieldExtensions) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.TLSFingerprintProfile.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *TLSFingerprintProfileUpsertOne) UpdateNewValues() *TLSFingerprintProfileUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(tlsfingerprintprofile.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.TLSFingerprintProfile.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *TLSFingerprintProfileUpsertOne) Ignore() *TLSFingerprintProfileUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *TLSFingerprintProfileUpsertOne) DoNothing() *TLSFingerprintProfileUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the TLSFingerprintProfileCreate.OnConflict +// documentation for more info. +func (u *TLSFingerprintProfileUpsertOne) Update(set func(*TLSFingerprintProfileUpsert)) *TLSFingerprintProfileUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&TLSFingerprintProfileUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *TLSFingerprintProfileUpsertOne) SetUpdatedAt(v time.Time) *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertOne) UpdateUpdatedAt() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetName sets the "name" field. +func (u *TLSFingerprintProfileUpsertOne) SetName(v string) *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertOne) UpdateName() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateName() + }) +} + +// SetDescription sets the "description" field. +func (u *TLSFingerprintProfileUpsertOne) SetDescription(v string) *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetDescription(v) + }) +} + +// UpdateDescription sets the "description" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertOne) UpdateDescription() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateDescription() + }) +} + +// ClearDescription clears the value of the "description" field. +func (u *TLSFingerprintProfileUpsertOne) ClearDescription() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearDescription() + }) +} + +// SetEnableGrease sets the "enable_grease" field. +func (u *TLSFingerprintProfileUpsertOne) SetEnableGrease(v bool) *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetEnableGrease(v) + }) +} + +// UpdateEnableGrease sets the "enable_grease" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertOne) UpdateEnableGrease() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateEnableGrease() + }) +} + +// SetCipherSuites sets the "cipher_suites" field. +func (u *TLSFingerprintProfileUpsertOne) SetCipherSuites(v []uint16) *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetCipherSuites(v) + }) +} + +// UpdateCipherSuites sets the "cipher_suites" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertOne) UpdateCipherSuites() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateCipherSuites() + }) +} + +// ClearCipherSuites clears the value of the "cipher_suites" field. +func (u *TLSFingerprintProfileUpsertOne) ClearCipherSuites() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearCipherSuites() + }) +} + +// SetCurves sets the "curves" field. +func (u *TLSFingerprintProfileUpsertOne) SetCurves(v []uint16) *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetCurves(v) + }) +} + +// UpdateCurves sets the "curves" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertOne) UpdateCurves() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateCurves() + }) +} + +// ClearCurves clears the value of the "curves" field. +func (u *TLSFingerprintProfileUpsertOne) ClearCurves() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearCurves() + }) +} + +// SetPointFormats sets the "point_formats" field. +func (u *TLSFingerprintProfileUpsertOne) SetPointFormats(v []uint16) *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetPointFormats(v) + }) +} + +// UpdatePointFormats sets the "point_formats" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertOne) UpdatePointFormats() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdatePointFormats() + }) +} + +// ClearPointFormats clears the value of the "point_formats" field. +func (u *TLSFingerprintProfileUpsertOne) ClearPointFormats() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearPointFormats() + }) +} + +// SetSignatureAlgorithms sets the "signature_algorithms" field. +func (u *TLSFingerprintProfileUpsertOne) SetSignatureAlgorithms(v []uint16) *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetSignatureAlgorithms(v) + }) +} + +// UpdateSignatureAlgorithms sets the "signature_algorithms" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertOne) UpdateSignatureAlgorithms() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateSignatureAlgorithms() + }) +} + +// ClearSignatureAlgorithms clears the value of the "signature_algorithms" field. +func (u *TLSFingerprintProfileUpsertOne) ClearSignatureAlgorithms() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearSignatureAlgorithms() + }) +} + +// SetAlpnProtocols sets the "alpn_protocols" field. +func (u *TLSFingerprintProfileUpsertOne) SetAlpnProtocols(v []string) *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetAlpnProtocols(v) + }) +} + +// UpdateAlpnProtocols sets the "alpn_protocols" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertOne) UpdateAlpnProtocols() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateAlpnProtocols() + }) +} + +// ClearAlpnProtocols clears the value of the "alpn_protocols" field. +func (u *TLSFingerprintProfileUpsertOne) ClearAlpnProtocols() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearAlpnProtocols() + }) +} + +// SetSupportedVersions sets the "supported_versions" field. +func (u *TLSFingerprintProfileUpsertOne) SetSupportedVersions(v []uint16) *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetSupportedVersions(v) + }) +} + +// UpdateSupportedVersions sets the "supported_versions" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertOne) UpdateSupportedVersions() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateSupportedVersions() + }) +} + +// ClearSupportedVersions clears the value of the "supported_versions" field. +func (u *TLSFingerprintProfileUpsertOne) ClearSupportedVersions() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearSupportedVersions() + }) +} + +// SetKeyShareGroups sets the "key_share_groups" field. +func (u *TLSFingerprintProfileUpsertOne) SetKeyShareGroups(v []uint16) *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetKeyShareGroups(v) + }) +} + +// UpdateKeyShareGroups sets the "key_share_groups" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertOne) UpdateKeyShareGroups() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateKeyShareGroups() + }) +} + +// ClearKeyShareGroups clears the value of the "key_share_groups" field. +func (u *TLSFingerprintProfileUpsertOne) ClearKeyShareGroups() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearKeyShareGroups() + }) +} + +// SetPskModes sets the "psk_modes" field. +func (u *TLSFingerprintProfileUpsertOne) SetPskModes(v []uint16) *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetPskModes(v) + }) +} + +// UpdatePskModes sets the "psk_modes" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertOne) UpdatePskModes() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdatePskModes() + }) +} + +// ClearPskModes clears the value of the "psk_modes" field. +func (u *TLSFingerprintProfileUpsertOne) ClearPskModes() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearPskModes() + }) +} + +// SetExtensions sets the "extensions" field. +func (u *TLSFingerprintProfileUpsertOne) SetExtensions(v []uint16) *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetExtensions(v) + }) +} + +// UpdateExtensions sets the "extensions" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertOne) UpdateExtensions() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateExtensions() + }) +} + +// ClearExtensions clears the value of the "extensions" field. +func (u *TLSFingerprintProfileUpsertOne) ClearExtensions() *TLSFingerprintProfileUpsertOne { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearExtensions() + }) +} + +// Exec executes the query. +func (u *TLSFingerprintProfileUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for TLSFingerprintProfileCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *TLSFingerprintProfileUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *TLSFingerprintProfileUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *TLSFingerprintProfileUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// TLSFingerprintProfileCreateBulk is the builder for creating many TLSFingerprintProfile entities in bulk. +type TLSFingerprintProfileCreateBulk struct { + config + err error + builders []*TLSFingerprintProfileCreate + conflict []sql.ConflictOption +} + +// Save creates the TLSFingerprintProfile entities in the database. +func (_c *TLSFingerprintProfileCreateBulk) Save(ctx context.Context) ([]*TLSFingerprintProfile, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*TLSFingerprintProfile, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*TLSFingerprintProfileMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *TLSFingerprintProfileCreateBulk) SaveX(ctx context.Context) []*TLSFingerprintProfile { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *TLSFingerprintProfileCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *TLSFingerprintProfileCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.TLSFingerprintProfile.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.TLSFingerprintProfileUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *TLSFingerprintProfileCreateBulk) OnConflict(opts ...sql.ConflictOption) *TLSFingerprintProfileUpsertBulk { + _c.conflict = opts + return &TLSFingerprintProfileUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.TLSFingerprintProfile.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *TLSFingerprintProfileCreateBulk) OnConflictColumns(columns ...string) *TLSFingerprintProfileUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &TLSFingerprintProfileUpsertBulk{ + create: _c, + } +} + +// TLSFingerprintProfileUpsertBulk is the builder for "upsert"-ing +// a bulk of TLSFingerprintProfile nodes. +type TLSFingerprintProfileUpsertBulk struct { + create *TLSFingerprintProfileCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.TLSFingerprintProfile.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *TLSFingerprintProfileUpsertBulk) UpdateNewValues() *TLSFingerprintProfileUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(tlsfingerprintprofile.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.TLSFingerprintProfile.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *TLSFingerprintProfileUpsertBulk) Ignore() *TLSFingerprintProfileUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *TLSFingerprintProfileUpsertBulk) DoNothing() *TLSFingerprintProfileUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the TLSFingerprintProfileCreateBulk.OnConflict +// documentation for more info. +func (u *TLSFingerprintProfileUpsertBulk) Update(set func(*TLSFingerprintProfileUpsert)) *TLSFingerprintProfileUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&TLSFingerprintProfileUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *TLSFingerprintProfileUpsertBulk) SetUpdatedAt(v time.Time) *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertBulk) UpdateUpdatedAt() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetName sets the "name" field. +func (u *TLSFingerprintProfileUpsertBulk) SetName(v string) *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertBulk) UpdateName() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateName() + }) +} + +// SetDescription sets the "description" field. +func (u *TLSFingerprintProfileUpsertBulk) SetDescription(v string) *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetDescription(v) + }) +} + +// UpdateDescription sets the "description" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertBulk) UpdateDescription() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateDescription() + }) +} + +// ClearDescription clears the value of the "description" field. +func (u *TLSFingerprintProfileUpsertBulk) ClearDescription() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearDescription() + }) +} + +// SetEnableGrease sets the "enable_grease" field. +func (u *TLSFingerprintProfileUpsertBulk) SetEnableGrease(v bool) *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetEnableGrease(v) + }) +} + +// UpdateEnableGrease sets the "enable_grease" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertBulk) UpdateEnableGrease() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateEnableGrease() + }) +} + +// SetCipherSuites sets the "cipher_suites" field. +func (u *TLSFingerprintProfileUpsertBulk) SetCipherSuites(v []uint16) *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetCipherSuites(v) + }) +} + +// UpdateCipherSuites sets the "cipher_suites" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertBulk) UpdateCipherSuites() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateCipherSuites() + }) +} + +// ClearCipherSuites clears the value of the "cipher_suites" field. +func (u *TLSFingerprintProfileUpsertBulk) ClearCipherSuites() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearCipherSuites() + }) +} + +// SetCurves sets the "curves" field. +func (u *TLSFingerprintProfileUpsertBulk) SetCurves(v []uint16) *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetCurves(v) + }) +} + +// UpdateCurves sets the "curves" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertBulk) UpdateCurves() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateCurves() + }) +} + +// ClearCurves clears the value of the "curves" field. +func (u *TLSFingerprintProfileUpsertBulk) ClearCurves() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearCurves() + }) +} + +// SetPointFormats sets the "point_formats" field. +func (u *TLSFingerprintProfileUpsertBulk) SetPointFormats(v []uint16) *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetPointFormats(v) + }) +} + +// UpdatePointFormats sets the "point_formats" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertBulk) UpdatePointFormats() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdatePointFormats() + }) +} + +// ClearPointFormats clears the value of the "point_formats" field. +func (u *TLSFingerprintProfileUpsertBulk) ClearPointFormats() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearPointFormats() + }) +} + +// SetSignatureAlgorithms sets the "signature_algorithms" field. +func (u *TLSFingerprintProfileUpsertBulk) SetSignatureAlgorithms(v []uint16) *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetSignatureAlgorithms(v) + }) +} + +// UpdateSignatureAlgorithms sets the "signature_algorithms" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertBulk) UpdateSignatureAlgorithms() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateSignatureAlgorithms() + }) +} + +// ClearSignatureAlgorithms clears the value of the "signature_algorithms" field. +func (u *TLSFingerprintProfileUpsertBulk) ClearSignatureAlgorithms() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearSignatureAlgorithms() + }) +} + +// SetAlpnProtocols sets the "alpn_protocols" field. +func (u *TLSFingerprintProfileUpsertBulk) SetAlpnProtocols(v []string) *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetAlpnProtocols(v) + }) +} + +// UpdateAlpnProtocols sets the "alpn_protocols" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertBulk) UpdateAlpnProtocols() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateAlpnProtocols() + }) +} + +// ClearAlpnProtocols clears the value of the "alpn_protocols" field. +func (u *TLSFingerprintProfileUpsertBulk) ClearAlpnProtocols() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearAlpnProtocols() + }) +} + +// SetSupportedVersions sets the "supported_versions" field. +func (u *TLSFingerprintProfileUpsertBulk) SetSupportedVersions(v []uint16) *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetSupportedVersions(v) + }) +} + +// UpdateSupportedVersions sets the "supported_versions" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertBulk) UpdateSupportedVersions() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateSupportedVersions() + }) +} + +// ClearSupportedVersions clears the value of the "supported_versions" field. +func (u *TLSFingerprintProfileUpsertBulk) ClearSupportedVersions() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearSupportedVersions() + }) +} + +// SetKeyShareGroups sets the "key_share_groups" field. +func (u *TLSFingerprintProfileUpsertBulk) SetKeyShareGroups(v []uint16) *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetKeyShareGroups(v) + }) +} + +// UpdateKeyShareGroups sets the "key_share_groups" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertBulk) UpdateKeyShareGroups() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateKeyShareGroups() + }) +} + +// ClearKeyShareGroups clears the value of the "key_share_groups" field. +func (u *TLSFingerprintProfileUpsertBulk) ClearKeyShareGroups() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearKeyShareGroups() + }) +} + +// SetPskModes sets the "psk_modes" field. +func (u *TLSFingerprintProfileUpsertBulk) SetPskModes(v []uint16) *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetPskModes(v) + }) +} + +// UpdatePskModes sets the "psk_modes" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertBulk) UpdatePskModes() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdatePskModes() + }) +} + +// ClearPskModes clears the value of the "psk_modes" field. +func (u *TLSFingerprintProfileUpsertBulk) ClearPskModes() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearPskModes() + }) +} + +// SetExtensions sets the "extensions" field. +func (u *TLSFingerprintProfileUpsertBulk) SetExtensions(v []uint16) *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.SetExtensions(v) + }) +} + +// UpdateExtensions sets the "extensions" field to the value that was provided on create. +func (u *TLSFingerprintProfileUpsertBulk) UpdateExtensions() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.UpdateExtensions() + }) +} + +// ClearExtensions clears the value of the "extensions" field. +func (u *TLSFingerprintProfileUpsertBulk) ClearExtensions() *TLSFingerprintProfileUpsertBulk { + return u.Update(func(s *TLSFingerprintProfileUpsert) { + s.ClearExtensions() + }) +} + +// Exec executes the query. +func (u *TLSFingerprintProfileUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the TLSFingerprintProfileCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for TLSFingerprintProfileCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *TLSFingerprintProfileUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/tlsfingerprintprofile_delete.go b/backend/ent/tlsfingerprintprofile_delete.go new file mode 100644 index 00000000..2f6dea2e --- /dev/null +++ b/backend/ent/tlsfingerprintprofile_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile" +) + +// TLSFingerprintProfileDelete is the builder for deleting a TLSFingerprintProfile entity. +type TLSFingerprintProfileDelete struct { + config + hooks []Hook + mutation *TLSFingerprintProfileMutation +} + +// Where appends a list predicates to the TLSFingerprintProfileDelete builder. +func (_d *TLSFingerprintProfileDelete) Where(ps ...predicate.TLSFingerprintProfile) *TLSFingerprintProfileDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *TLSFingerprintProfileDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *TLSFingerprintProfileDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *TLSFingerprintProfileDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(tlsfingerprintprofile.Table, sqlgraph.NewFieldSpec(tlsfingerprintprofile.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// TLSFingerprintProfileDeleteOne is the builder for deleting a single TLSFingerprintProfile entity. +type TLSFingerprintProfileDeleteOne struct { + _d *TLSFingerprintProfileDelete +} + +// Where appends a list predicates to the TLSFingerprintProfileDelete builder. +func (_d *TLSFingerprintProfileDeleteOne) Where(ps ...predicate.TLSFingerprintProfile) *TLSFingerprintProfileDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *TLSFingerprintProfileDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{tlsfingerprintprofile.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *TLSFingerprintProfileDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/tlsfingerprintprofile_query.go b/backend/ent/tlsfingerprintprofile_query.go new file mode 100644 index 00000000..d1ef4f1d --- /dev/null +++ b/backend/ent/tlsfingerprintprofile_query.go @@ -0,0 +1,564 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile" +) + +// TLSFingerprintProfileQuery is the builder for querying TLSFingerprintProfile entities. +type TLSFingerprintProfileQuery struct { + config + ctx *QueryContext + order []tlsfingerprintprofile.OrderOption + inters []Interceptor + predicates []predicate.TLSFingerprintProfile + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the TLSFingerprintProfileQuery builder. +func (_q *TLSFingerprintProfileQuery) Where(ps ...predicate.TLSFingerprintProfile) *TLSFingerprintProfileQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *TLSFingerprintProfileQuery) Limit(limit int) *TLSFingerprintProfileQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *TLSFingerprintProfileQuery) Offset(offset int) *TLSFingerprintProfileQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *TLSFingerprintProfileQuery) Unique(unique bool) *TLSFingerprintProfileQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *TLSFingerprintProfileQuery) Order(o ...tlsfingerprintprofile.OrderOption) *TLSFingerprintProfileQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first TLSFingerprintProfile entity from the query. +// Returns a *NotFoundError when no TLSFingerprintProfile was found. +func (_q *TLSFingerprintProfileQuery) First(ctx context.Context) (*TLSFingerprintProfile, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{tlsfingerprintprofile.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *TLSFingerprintProfileQuery) FirstX(ctx context.Context) *TLSFingerprintProfile { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first TLSFingerprintProfile ID from the query. +// Returns a *NotFoundError when no TLSFingerprintProfile ID was found. +func (_q *TLSFingerprintProfileQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{tlsfingerprintprofile.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *TLSFingerprintProfileQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single TLSFingerprintProfile entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one TLSFingerprintProfile entity is found. +// Returns a *NotFoundError when no TLSFingerprintProfile entities are found. +func (_q *TLSFingerprintProfileQuery) Only(ctx context.Context) (*TLSFingerprintProfile, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{tlsfingerprintprofile.Label} + default: + return nil, &NotSingularError{tlsfingerprintprofile.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *TLSFingerprintProfileQuery) OnlyX(ctx context.Context) *TLSFingerprintProfile { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only TLSFingerprintProfile ID in the query. +// Returns a *NotSingularError when more than one TLSFingerprintProfile ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *TLSFingerprintProfileQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{tlsfingerprintprofile.Label} + default: + err = &NotSingularError{tlsfingerprintprofile.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *TLSFingerprintProfileQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of TLSFingerprintProfiles. +func (_q *TLSFingerprintProfileQuery) All(ctx context.Context) ([]*TLSFingerprintProfile, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*TLSFingerprintProfile, *TLSFingerprintProfileQuery]() + return withInterceptors[[]*TLSFingerprintProfile](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *TLSFingerprintProfileQuery) AllX(ctx context.Context) []*TLSFingerprintProfile { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of TLSFingerprintProfile IDs. +func (_q *TLSFingerprintProfileQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(tlsfingerprintprofile.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *TLSFingerprintProfileQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *TLSFingerprintProfileQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*TLSFingerprintProfileQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *TLSFingerprintProfileQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *TLSFingerprintProfileQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *TLSFingerprintProfileQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the TLSFingerprintProfileQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *TLSFingerprintProfileQuery) Clone() *TLSFingerprintProfileQuery { + if _q == nil { + return nil + } + return &TLSFingerprintProfileQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]tlsfingerprintprofile.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.TLSFingerprintProfile{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.TLSFingerprintProfile.Query(). +// GroupBy(tlsfingerprintprofile.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *TLSFingerprintProfileQuery) GroupBy(field string, fields ...string) *TLSFingerprintProfileGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &TLSFingerprintProfileGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = tlsfingerprintprofile.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.TLSFingerprintProfile.Query(). +// Select(tlsfingerprintprofile.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *TLSFingerprintProfileQuery) Select(fields ...string) *TLSFingerprintProfileSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &TLSFingerprintProfileSelect{TLSFingerprintProfileQuery: _q} + sbuild.label = tlsfingerprintprofile.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a TLSFingerprintProfileSelect configured with the given aggregations. +func (_q *TLSFingerprintProfileQuery) Aggregate(fns ...AggregateFunc) *TLSFingerprintProfileSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *TLSFingerprintProfileQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !tlsfingerprintprofile.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *TLSFingerprintProfileQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*TLSFingerprintProfile, error) { + var ( + nodes = []*TLSFingerprintProfile{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*TLSFingerprintProfile).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &TLSFingerprintProfile{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *TLSFingerprintProfileQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *TLSFingerprintProfileQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(tlsfingerprintprofile.Table, tlsfingerprintprofile.Columns, sqlgraph.NewFieldSpec(tlsfingerprintprofile.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, tlsfingerprintprofile.FieldID) + for i := range fields { + if fields[i] != tlsfingerprintprofile.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *TLSFingerprintProfileQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(tlsfingerprintprofile.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = tlsfingerprintprofile.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *TLSFingerprintProfileQuery) ForUpdate(opts ...sql.LockOption) *TLSFingerprintProfileQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *TLSFingerprintProfileQuery) ForShare(opts ...sql.LockOption) *TLSFingerprintProfileQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// TLSFingerprintProfileGroupBy is the group-by builder for TLSFingerprintProfile entities. +type TLSFingerprintProfileGroupBy struct { + selector + build *TLSFingerprintProfileQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *TLSFingerprintProfileGroupBy) Aggregate(fns ...AggregateFunc) *TLSFingerprintProfileGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *TLSFingerprintProfileGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*TLSFingerprintProfileQuery, *TLSFingerprintProfileGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *TLSFingerprintProfileGroupBy) sqlScan(ctx context.Context, root *TLSFingerprintProfileQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// TLSFingerprintProfileSelect is the builder for selecting fields of TLSFingerprintProfile entities. +type TLSFingerprintProfileSelect struct { + *TLSFingerprintProfileQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *TLSFingerprintProfileSelect) Aggregate(fns ...AggregateFunc) *TLSFingerprintProfileSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *TLSFingerprintProfileSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*TLSFingerprintProfileQuery, *TLSFingerprintProfileSelect](ctx, _s.TLSFingerprintProfileQuery, _s, _s.inters, v) +} + +func (_s *TLSFingerprintProfileSelect) sqlScan(ctx context.Context, root *TLSFingerprintProfileQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/tlsfingerprintprofile_update.go b/backend/ent/tlsfingerprintprofile_update.go new file mode 100644 index 00000000..3b12508c --- /dev/null +++ b/backend/ent/tlsfingerprintprofile_update.go @@ -0,0 +1,881 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/dialect/sql/sqljson" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile" +) + +// TLSFingerprintProfileUpdate is the builder for updating TLSFingerprintProfile entities. +type TLSFingerprintProfileUpdate struct { + config + hooks []Hook + mutation *TLSFingerprintProfileMutation +} + +// Where appends a list predicates to the TLSFingerprintProfileUpdate builder. +func (_u *TLSFingerprintProfileUpdate) Where(ps ...predicate.TLSFingerprintProfile) *TLSFingerprintProfileUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *TLSFingerprintProfileUpdate) SetUpdatedAt(v time.Time) *TLSFingerprintProfileUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetName sets the "name" field. +func (_u *TLSFingerprintProfileUpdate) SetName(v string) *TLSFingerprintProfileUpdate { + _u.mutation.SetName(v) + return _u +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (_u *TLSFingerprintProfileUpdate) SetNillableName(v *string) *TLSFingerprintProfileUpdate { + if v != nil { + _u.SetName(*v) + } + return _u +} + +// SetDescription sets the "description" field. +func (_u *TLSFingerprintProfileUpdate) SetDescription(v string) *TLSFingerprintProfileUpdate { + _u.mutation.SetDescription(v) + return _u +} + +// SetNillableDescription sets the "description" field if the given value is not nil. +func (_u *TLSFingerprintProfileUpdate) SetNillableDescription(v *string) *TLSFingerprintProfileUpdate { + if v != nil { + _u.SetDescription(*v) + } + return _u +} + +// ClearDescription clears the value of the "description" field. +func (_u *TLSFingerprintProfileUpdate) ClearDescription() *TLSFingerprintProfileUpdate { + _u.mutation.ClearDescription() + return _u +} + +// SetEnableGrease sets the "enable_grease" field. +func (_u *TLSFingerprintProfileUpdate) SetEnableGrease(v bool) *TLSFingerprintProfileUpdate { + _u.mutation.SetEnableGrease(v) + return _u +} + +// SetNillableEnableGrease sets the "enable_grease" field if the given value is not nil. +func (_u *TLSFingerprintProfileUpdate) SetNillableEnableGrease(v *bool) *TLSFingerprintProfileUpdate { + if v != nil { + _u.SetEnableGrease(*v) + } + return _u +} + +// SetCipherSuites sets the "cipher_suites" field. +func (_u *TLSFingerprintProfileUpdate) SetCipherSuites(v []uint16) *TLSFingerprintProfileUpdate { + _u.mutation.SetCipherSuites(v) + return _u +} + +// AppendCipherSuites appends value to the "cipher_suites" field. +func (_u *TLSFingerprintProfileUpdate) AppendCipherSuites(v []uint16) *TLSFingerprintProfileUpdate { + _u.mutation.AppendCipherSuites(v) + return _u +} + +// ClearCipherSuites clears the value of the "cipher_suites" field. +func (_u *TLSFingerprintProfileUpdate) ClearCipherSuites() *TLSFingerprintProfileUpdate { + _u.mutation.ClearCipherSuites() + return _u +} + +// SetCurves sets the "curves" field. +func (_u *TLSFingerprintProfileUpdate) SetCurves(v []uint16) *TLSFingerprintProfileUpdate { + _u.mutation.SetCurves(v) + return _u +} + +// AppendCurves appends value to the "curves" field. +func (_u *TLSFingerprintProfileUpdate) AppendCurves(v []uint16) *TLSFingerprintProfileUpdate { + _u.mutation.AppendCurves(v) + return _u +} + +// ClearCurves clears the value of the "curves" field. +func (_u *TLSFingerprintProfileUpdate) ClearCurves() *TLSFingerprintProfileUpdate { + _u.mutation.ClearCurves() + return _u +} + +// SetPointFormats sets the "point_formats" field. +func (_u *TLSFingerprintProfileUpdate) SetPointFormats(v []uint16) *TLSFingerprintProfileUpdate { + _u.mutation.SetPointFormats(v) + return _u +} + +// AppendPointFormats appends value to the "point_formats" field. +func (_u *TLSFingerprintProfileUpdate) AppendPointFormats(v []uint16) *TLSFingerprintProfileUpdate { + _u.mutation.AppendPointFormats(v) + return _u +} + +// ClearPointFormats clears the value of the "point_formats" field. +func (_u *TLSFingerprintProfileUpdate) ClearPointFormats() *TLSFingerprintProfileUpdate { + _u.mutation.ClearPointFormats() + return _u +} + +// SetSignatureAlgorithms sets the "signature_algorithms" field. +func (_u *TLSFingerprintProfileUpdate) SetSignatureAlgorithms(v []uint16) *TLSFingerprintProfileUpdate { + _u.mutation.SetSignatureAlgorithms(v) + return _u +} + +// AppendSignatureAlgorithms appends value to the "signature_algorithms" field. +func (_u *TLSFingerprintProfileUpdate) AppendSignatureAlgorithms(v []uint16) *TLSFingerprintProfileUpdate { + _u.mutation.AppendSignatureAlgorithms(v) + return _u +} + +// ClearSignatureAlgorithms clears the value of the "signature_algorithms" field. +func (_u *TLSFingerprintProfileUpdate) ClearSignatureAlgorithms() *TLSFingerprintProfileUpdate { + _u.mutation.ClearSignatureAlgorithms() + return _u +} + +// SetAlpnProtocols sets the "alpn_protocols" field. +func (_u *TLSFingerprintProfileUpdate) SetAlpnProtocols(v []string) *TLSFingerprintProfileUpdate { + _u.mutation.SetAlpnProtocols(v) + return _u +} + +// AppendAlpnProtocols appends value to the "alpn_protocols" field. +func (_u *TLSFingerprintProfileUpdate) AppendAlpnProtocols(v []string) *TLSFingerprintProfileUpdate { + _u.mutation.AppendAlpnProtocols(v) + return _u +} + +// ClearAlpnProtocols clears the value of the "alpn_protocols" field. +func (_u *TLSFingerprintProfileUpdate) ClearAlpnProtocols() *TLSFingerprintProfileUpdate { + _u.mutation.ClearAlpnProtocols() + return _u +} + +// SetSupportedVersions sets the "supported_versions" field. +func (_u *TLSFingerprintProfileUpdate) SetSupportedVersions(v []uint16) *TLSFingerprintProfileUpdate { + _u.mutation.SetSupportedVersions(v) + return _u +} + +// AppendSupportedVersions appends value to the "supported_versions" field. +func (_u *TLSFingerprintProfileUpdate) AppendSupportedVersions(v []uint16) *TLSFingerprintProfileUpdate { + _u.mutation.AppendSupportedVersions(v) + return _u +} + +// ClearSupportedVersions clears the value of the "supported_versions" field. +func (_u *TLSFingerprintProfileUpdate) ClearSupportedVersions() *TLSFingerprintProfileUpdate { + _u.mutation.ClearSupportedVersions() + return _u +} + +// SetKeyShareGroups sets the "key_share_groups" field. +func (_u *TLSFingerprintProfileUpdate) SetKeyShareGroups(v []uint16) *TLSFingerprintProfileUpdate { + _u.mutation.SetKeyShareGroups(v) + return _u +} + +// AppendKeyShareGroups appends value to the "key_share_groups" field. +func (_u *TLSFingerprintProfileUpdate) AppendKeyShareGroups(v []uint16) *TLSFingerprintProfileUpdate { + _u.mutation.AppendKeyShareGroups(v) + return _u +} + +// ClearKeyShareGroups clears the value of the "key_share_groups" field. +func (_u *TLSFingerprintProfileUpdate) ClearKeyShareGroups() *TLSFingerprintProfileUpdate { + _u.mutation.ClearKeyShareGroups() + return _u +} + +// SetPskModes sets the "psk_modes" field. +func (_u *TLSFingerprintProfileUpdate) SetPskModes(v []uint16) *TLSFingerprintProfileUpdate { + _u.mutation.SetPskModes(v) + return _u +} + +// AppendPskModes appends value to the "psk_modes" field. +func (_u *TLSFingerprintProfileUpdate) AppendPskModes(v []uint16) *TLSFingerprintProfileUpdate { + _u.mutation.AppendPskModes(v) + return _u +} + +// ClearPskModes clears the value of the "psk_modes" field. +func (_u *TLSFingerprintProfileUpdate) ClearPskModes() *TLSFingerprintProfileUpdate { + _u.mutation.ClearPskModes() + return _u +} + +// SetExtensions sets the "extensions" field. +func (_u *TLSFingerprintProfileUpdate) SetExtensions(v []uint16) *TLSFingerprintProfileUpdate { + _u.mutation.SetExtensions(v) + return _u +} + +// AppendExtensions appends value to the "extensions" field. +func (_u *TLSFingerprintProfileUpdate) AppendExtensions(v []uint16) *TLSFingerprintProfileUpdate { + _u.mutation.AppendExtensions(v) + return _u +} + +// ClearExtensions clears the value of the "extensions" field. +func (_u *TLSFingerprintProfileUpdate) ClearExtensions() *TLSFingerprintProfileUpdate { + _u.mutation.ClearExtensions() + return _u +} + +// Mutation returns the TLSFingerprintProfileMutation object of the builder. +func (_u *TLSFingerprintProfileUpdate) Mutation() *TLSFingerprintProfileMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *TLSFingerprintProfileUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *TLSFingerprintProfileUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *TLSFingerprintProfileUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *TLSFingerprintProfileUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *TLSFingerprintProfileUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := tlsfingerprintprofile.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *TLSFingerprintProfileUpdate) check() error { + if v, ok := _u.mutation.Name(); ok { + if err := tlsfingerprintprofile.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "TLSFingerprintProfile.name": %w`, err)} + } + } + return nil +} + +func (_u *TLSFingerprintProfileUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(tlsfingerprintprofile.Table, tlsfingerprintprofile.Columns, sqlgraph.NewFieldSpec(tlsfingerprintprofile.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(tlsfingerprintprofile.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Name(); ok { + _spec.SetField(tlsfingerprintprofile.FieldName, field.TypeString, value) + } + if value, ok := _u.mutation.Description(); ok { + _spec.SetField(tlsfingerprintprofile.FieldDescription, field.TypeString, value) + } + if _u.mutation.DescriptionCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldDescription, field.TypeString) + } + if value, ok := _u.mutation.EnableGrease(); ok { + _spec.SetField(tlsfingerprintprofile.FieldEnableGrease, field.TypeBool, value) + } + if value, ok := _u.mutation.CipherSuites(); ok { + _spec.SetField(tlsfingerprintprofile.FieldCipherSuites, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedCipherSuites(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldCipherSuites, value) + }) + } + if _u.mutation.CipherSuitesCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldCipherSuites, field.TypeJSON) + } + if value, ok := _u.mutation.Curves(); ok { + _spec.SetField(tlsfingerprintprofile.FieldCurves, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedCurves(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldCurves, value) + }) + } + if _u.mutation.CurvesCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldCurves, field.TypeJSON) + } + if value, ok := _u.mutation.PointFormats(); ok { + _spec.SetField(tlsfingerprintprofile.FieldPointFormats, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedPointFormats(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldPointFormats, value) + }) + } + if _u.mutation.PointFormatsCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldPointFormats, field.TypeJSON) + } + if value, ok := _u.mutation.SignatureAlgorithms(); ok { + _spec.SetField(tlsfingerprintprofile.FieldSignatureAlgorithms, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedSignatureAlgorithms(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldSignatureAlgorithms, value) + }) + } + if _u.mutation.SignatureAlgorithmsCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldSignatureAlgorithms, field.TypeJSON) + } + if value, ok := _u.mutation.AlpnProtocols(); ok { + _spec.SetField(tlsfingerprintprofile.FieldAlpnProtocols, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedAlpnProtocols(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldAlpnProtocols, value) + }) + } + if _u.mutation.AlpnProtocolsCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldAlpnProtocols, field.TypeJSON) + } + if value, ok := _u.mutation.SupportedVersions(); ok { + _spec.SetField(tlsfingerprintprofile.FieldSupportedVersions, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedSupportedVersions(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldSupportedVersions, value) + }) + } + if _u.mutation.SupportedVersionsCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldSupportedVersions, field.TypeJSON) + } + if value, ok := _u.mutation.KeyShareGroups(); ok { + _spec.SetField(tlsfingerprintprofile.FieldKeyShareGroups, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedKeyShareGroups(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldKeyShareGroups, value) + }) + } + if _u.mutation.KeyShareGroupsCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldKeyShareGroups, field.TypeJSON) + } + if value, ok := _u.mutation.PskModes(); ok { + _spec.SetField(tlsfingerprintprofile.FieldPskModes, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedPskModes(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldPskModes, value) + }) + } + if _u.mutation.PskModesCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldPskModes, field.TypeJSON) + } + if value, ok := _u.mutation.Extensions(); ok { + _spec.SetField(tlsfingerprintprofile.FieldExtensions, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedExtensions(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldExtensions, value) + }) + } + if _u.mutation.ExtensionsCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldExtensions, field.TypeJSON) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{tlsfingerprintprofile.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// TLSFingerprintProfileUpdateOne is the builder for updating a single TLSFingerprintProfile entity. +type TLSFingerprintProfileUpdateOne struct { + config + fields []string + hooks []Hook + mutation *TLSFingerprintProfileMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *TLSFingerprintProfileUpdateOne) SetUpdatedAt(v time.Time) *TLSFingerprintProfileUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetName sets the "name" field. +func (_u *TLSFingerprintProfileUpdateOne) SetName(v string) *TLSFingerprintProfileUpdateOne { + _u.mutation.SetName(v) + return _u +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (_u *TLSFingerprintProfileUpdateOne) SetNillableName(v *string) *TLSFingerprintProfileUpdateOne { + if v != nil { + _u.SetName(*v) + } + return _u +} + +// SetDescription sets the "description" field. +func (_u *TLSFingerprintProfileUpdateOne) SetDescription(v string) *TLSFingerprintProfileUpdateOne { + _u.mutation.SetDescription(v) + return _u +} + +// SetNillableDescription sets the "description" field if the given value is not nil. +func (_u *TLSFingerprintProfileUpdateOne) SetNillableDescription(v *string) *TLSFingerprintProfileUpdateOne { + if v != nil { + _u.SetDescription(*v) + } + return _u +} + +// ClearDescription clears the value of the "description" field. +func (_u *TLSFingerprintProfileUpdateOne) ClearDescription() *TLSFingerprintProfileUpdateOne { + _u.mutation.ClearDescription() + return _u +} + +// SetEnableGrease sets the "enable_grease" field. +func (_u *TLSFingerprintProfileUpdateOne) SetEnableGrease(v bool) *TLSFingerprintProfileUpdateOne { + _u.mutation.SetEnableGrease(v) + return _u +} + +// SetNillableEnableGrease sets the "enable_grease" field if the given value is not nil. +func (_u *TLSFingerprintProfileUpdateOne) SetNillableEnableGrease(v *bool) *TLSFingerprintProfileUpdateOne { + if v != nil { + _u.SetEnableGrease(*v) + } + return _u +} + +// SetCipherSuites sets the "cipher_suites" field. +func (_u *TLSFingerprintProfileUpdateOne) SetCipherSuites(v []uint16) *TLSFingerprintProfileUpdateOne { + _u.mutation.SetCipherSuites(v) + return _u +} + +// AppendCipherSuites appends value to the "cipher_suites" field. +func (_u *TLSFingerprintProfileUpdateOne) AppendCipherSuites(v []uint16) *TLSFingerprintProfileUpdateOne { + _u.mutation.AppendCipherSuites(v) + return _u +} + +// ClearCipherSuites clears the value of the "cipher_suites" field. +func (_u *TLSFingerprintProfileUpdateOne) ClearCipherSuites() *TLSFingerprintProfileUpdateOne { + _u.mutation.ClearCipherSuites() + return _u +} + +// SetCurves sets the "curves" field. +func (_u *TLSFingerprintProfileUpdateOne) SetCurves(v []uint16) *TLSFingerprintProfileUpdateOne { + _u.mutation.SetCurves(v) + return _u +} + +// AppendCurves appends value to the "curves" field. +func (_u *TLSFingerprintProfileUpdateOne) AppendCurves(v []uint16) *TLSFingerprintProfileUpdateOne { + _u.mutation.AppendCurves(v) + return _u +} + +// ClearCurves clears the value of the "curves" field. +func (_u *TLSFingerprintProfileUpdateOne) ClearCurves() *TLSFingerprintProfileUpdateOne { + _u.mutation.ClearCurves() + return _u +} + +// SetPointFormats sets the "point_formats" field. +func (_u *TLSFingerprintProfileUpdateOne) SetPointFormats(v []uint16) *TLSFingerprintProfileUpdateOne { + _u.mutation.SetPointFormats(v) + return _u +} + +// AppendPointFormats appends value to the "point_formats" field. +func (_u *TLSFingerprintProfileUpdateOne) AppendPointFormats(v []uint16) *TLSFingerprintProfileUpdateOne { + _u.mutation.AppendPointFormats(v) + return _u +} + +// ClearPointFormats clears the value of the "point_formats" field. +func (_u *TLSFingerprintProfileUpdateOne) ClearPointFormats() *TLSFingerprintProfileUpdateOne { + _u.mutation.ClearPointFormats() + return _u +} + +// SetSignatureAlgorithms sets the "signature_algorithms" field. +func (_u *TLSFingerprintProfileUpdateOne) SetSignatureAlgorithms(v []uint16) *TLSFingerprintProfileUpdateOne { + _u.mutation.SetSignatureAlgorithms(v) + return _u +} + +// AppendSignatureAlgorithms appends value to the "signature_algorithms" field. +func (_u *TLSFingerprintProfileUpdateOne) AppendSignatureAlgorithms(v []uint16) *TLSFingerprintProfileUpdateOne { + _u.mutation.AppendSignatureAlgorithms(v) + return _u +} + +// ClearSignatureAlgorithms clears the value of the "signature_algorithms" field. +func (_u *TLSFingerprintProfileUpdateOne) ClearSignatureAlgorithms() *TLSFingerprintProfileUpdateOne { + _u.mutation.ClearSignatureAlgorithms() + return _u +} + +// SetAlpnProtocols sets the "alpn_protocols" field. +func (_u *TLSFingerprintProfileUpdateOne) SetAlpnProtocols(v []string) *TLSFingerprintProfileUpdateOne { + _u.mutation.SetAlpnProtocols(v) + return _u +} + +// AppendAlpnProtocols appends value to the "alpn_protocols" field. +func (_u *TLSFingerprintProfileUpdateOne) AppendAlpnProtocols(v []string) *TLSFingerprintProfileUpdateOne { + _u.mutation.AppendAlpnProtocols(v) + return _u +} + +// ClearAlpnProtocols clears the value of the "alpn_protocols" field. +func (_u *TLSFingerprintProfileUpdateOne) ClearAlpnProtocols() *TLSFingerprintProfileUpdateOne { + _u.mutation.ClearAlpnProtocols() + return _u +} + +// SetSupportedVersions sets the "supported_versions" field. +func (_u *TLSFingerprintProfileUpdateOne) SetSupportedVersions(v []uint16) *TLSFingerprintProfileUpdateOne { + _u.mutation.SetSupportedVersions(v) + return _u +} + +// AppendSupportedVersions appends value to the "supported_versions" field. +func (_u *TLSFingerprintProfileUpdateOne) AppendSupportedVersions(v []uint16) *TLSFingerprintProfileUpdateOne { + _u.mutation.AppendSupportedVersions(v) + return _u +} + +// ClearSupportedVersions clears the value of the "supported_versions" field. +func (_u *TLSFingerprintProfileUpdateOne) ClearSupportedVersions() *TLSFingerprintProfileUpdateOne { + _u.mutation.ClearSupportedVersions() + return _u +} + +// SetKeyShareGroups sets the "key_share_groups" field. +func (_u *TLSFingerprintProfileUpdateOne) SetKeyShareGroups(v []uint16) *TLSFingerprintProfileUpdateOne { + _u.mutation.SetKeyShareGroups(v) + return _u +} + +// AppendKeyShareGroups appends value to the "key_share_groups" field. +func (_u *TLSFingerprintProfileUpdateOne) AppendKeyShareGroups(v []uint16) *TLSFingerprintProfileUpdateOne { + _u.mutation.AppendKeyShareGroups(v) + return _u +} + +// ClearKeyShareGroups clears the value of the "key_share_groups" field. +func (_u *TLSFingerprintProfileUpdateOne) ClearKeyShareGroups() *TLSFingerprintProfileUpdateOne { + _u.mutation.ClearKeyShareGroups() + return _u +} + +// SetPskModes sets the "psk_modes" field. +func (_u *TLSFingerprintProfileUpdateOne) SetPskModes(v []uint16) *TLSFingerprintProfileUpdateOne { + _u.mutation.SetPskModes(v) + return _u +} + +// AppendPskModes appends value to the "psk_modes" field. +func (_u *TLSFingerprintProfileUpdateOne) AppendPskModes(v []uint16) *TLSFingerprintProfileUpdateOne { + _u.mutation.AppendPskModes(v) + return _u +} + +// ClearPskModes clears the value of the "psk_modes" field. +func (_u *TLSFingerprintProfileUpdateOne) ClearPskModes() *TLSFingerprintProfileUpdateOne { + _u.mutation.ClearPskModes() + return _u +} + +// SetExtensions sets the "extensions" field. +func (_u *TLSFingerprintProfileUpdateOne) SetExtensions(v []uint16) *TLSFingerprintProfileUpdateOne { + _u.mutation.SetExtensions(v) + return _u +} + +// AppendExtensions appends value to the "extensions" field. +func (_u *TLSFingerprintProfileUpdateOne) AppendExtensions(v []uint16) *TLSFingerprintProfileUpdateOne { + _u.mutation.AppendExtensions(v) + return _u +} + +// ClearExtensions clears the value of the "extensions" field. +func (_u *TLSFingerprintProfileUpdateOne) ClearExtensions() *TLSFingerprintProfileUpdateOne { + _u.mutation.ClearExtensions() + return _u +} + +// Mutation returns the TLSFingerprintProfileMutation object of the builder. +func (_u *TLSFingerprintProfileUpdateOne) Mutation() *TLSFingerprintProfileMutation { + return _u.mutation +} + +// Where appends a list predicates to the TLSFingerprintProfileUpdate builder. +func (_u *TLSFingerprintProfileUpdateOne) Where(ps ...predicate.TLSFingerprintProfile) *TLSFingerprintProfileUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *TLSFingerprintProfileUpdateOne) Select(field string, fields ...string) *TLSFingerprintProfileUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated TLSFingerprintProfile entity. +func (_u *TLSFingerprintProfileUpdateOne) Save(ctx context.Context) (*TLSFingerprintProfile, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *TLSFingerprintProfileUpdateOne) SaveX(ctx context.Context) *TLSFingerprintProfile { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *TLSFingerprintProfileUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *TLSFingerprintProfileUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *TLSFingerprintProfileUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := tlsfingerprintprofile.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *TLSFingerprintProfileUpdateOne) check() error { + if v, ok := _u.mutation.Name(); ok { + if err := tlsfingerprintprofile.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "TLSFingerprintProfile.name": %w`, err)} + } + } + return nil +} + +func (_u *TLSFingerprintProfileUpdateOne) sqlSave(ctx context.Context) (_node *TLSFingerprintProfile, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(tlsfingerprintprofile.Table, tlsfingerprintprofile.Columns, sqlgraph.NewFieldSpec(tlsfingerprintprofile.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "TLSFingerprintProfile.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, tlsfingerprintprofile.FieldID) + for _, f := range fields { + if !tlsfingerprintprofile.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != tlsfingerprintprofile.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(tlsfingerprintprofile.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Name(); ok { + _spec.SetField(tlsfingerprintprofile.FieldName, field.TypeString, value) + } + if value, ok := _u.mutation.Description(); ok { + _spec.SetField(tlsfingerprintprofile.FieldDescription, field.TypeString, value) + } + if _u.mutation.DescriptionCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldDescription, field.TypeString) + } + if value, ok := _u.mutation.EnableGrease(); ok { + _spec.SetField(tlsfingerprintprofile.FieldEnableGrease, field.TypeBool, value) + } + if value, ok := _u.mutation.CipherSuites(); ok { + _spec.SetField(tlsfingerprintprofile.FieldCipherSuites, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedCipherSuites(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldCipherSuites, value) + }) + } + if _u.mutation.CipherSuitesCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldCipherSuites, field.TypeJSON) + } + if value, ok := _u.mutation.Curves(); ok { + _spec.SetField(tlsfingerprintprofile.FieldCurves, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedCurves(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldCurves, value) + }) + } + if _u.mutation.CurvesCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldCurves, field.TypeJSON) + } + if value, ok := _u.mutation.PointFormats(); ok { + _spec.SetField(tlsfingerprintprofile.FieldPointFormats, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedPointFormats(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldPointFormats, value) + }) + } + if _u.mutation.PointFormatsCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldPointFormats, field.TypeJSON) + } + if value, ok := _u.mutation.SignatureAlgorithms(); ok { + _spec.SetField(tlsfingerprintprofile.FieldSignatureAlgorithms, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedSignatureAlgorithms(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldSignatureAlgorithms, value) + }) + } + if _u.mutation.SignatureAlgorithmsCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldSignatureAlgorithms, field.TypeJSON) + } + if value, ok := _u.mutation.AlpnProtocols(); ok { + _spec.SetField(tlsfingerprintprofile.FieldAlpnProtocols, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedAlpnProtocols(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldAlpnProtocols, value) + }) + } + if _u.mutation.AlpnProtocolsCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldAlpnProtocols, field.TypeJSON) + } + if value, ok := _u.mutation.SupportedVersions(); ok { + _spec.SetField(tlsfingerprintprofile.FieldSupportedVersions, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedSupportedVersions(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldSupportedVersions, value) + }) + } + if _u.mutation.SupportedVersionsCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldSupportedVersions, field.TypeJSON) + } + if value, ok := _u.mutation.KeyShareGroups(); ok { + _spec.SetField(tlsfingerprintprofile.FieldKeyShareGroups, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedKeyShareGroups(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldKeyShareGroups, value) + }) + } + if _u.mutation.KeyShareGroupsCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldKeyShareGroups, field.TypeJSON) + } + if value, ok := _u.mutation.PskModes(); ok { + _spec.SetField(tlsfingerprintprofile.FieldPskModes, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedPskModes(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldPskModes, value) + }) + } + if _u.mutation.PskModesCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldPskModes, field.TypeJSON) + } + if value, ok := _u.mutation.Extensions(); ok { + _spec.SetField(tlsfingerprintprofile.FieldExtensions, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedExtensions(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, tlsfingerprintprofile.FieldExtensions, value) + }) + } + if _u.mutation.ExtensionsCleared() { + _spec.ClearField(tlsfingerprintprofile.FieldExtensions, field.TypeJSON) + } + _node = &TLSFingerprintProfile{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{tlsfingerprintprofile.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/tx.go b/backend/ent/tx.go index cd3b2296..b5aea447 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -42,6 +42,8 @@ type Tx struct { SecuritySecret *SecuritySecretClient // Setting is the client for interacting with the Setting builders. Setting *SettingClient + // TLSFingerprintProfile is the client for interacting with the TLSFingerprintProfile builders. + TLSFingerprintProfile *TLSFingerprintProfileClient // UsageCleanupTask is the client for interacting with the UsageCleanupTask builders. UsageCleanupTask *UsageCleanupTaskClient // UsageLog is the client for interacting with the UsageLog builders. @@ -201,6 +203,7 @@ func (tx *Tx) init() { tx.RedeemCode = NewRedeemCodeClient(tx.config) tx.SecuritySecret = NewSecuritySecretClient(tx.config) tx.Setting = NewSettingClient(tx.config) + tx.TLSFingerprintProfile = NewTLSFingerprintProfileClient(tx.config) tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config) tx.UsageLog = NewUsageLogClient(tx.config) tx.User = NewUserClient(tx.config) diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index f6968d0d..fb4ee1c5 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -32,6 +32,10 @@ type UsageLog struct { RequestID string `json:"request_id,omitempty"` // Model holds the value of the "model" field. Model string `json:"model,omitempty"` + // RequestedModel holds the value of the "requested_model" field. + RequestedModel *string `json:"requested_model,omitempty"` + // UpstreamModel holds the value of the "upstream_model" field. + UpstreamModel *string `json:"upstream_model,omitempty"` // GroupID holds the value of the "group_id" field. GroupID *int64 `json:"group_id,omitempty"` // SubscriptionID holds the value of the "subscription_id" field. @@ -175,7 +179,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: values[i] = new(sql.NullInt64) - case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: + case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: values[i] = new(sql.NullString) case usagelog.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -230,6 +234,20 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { } else if value.Valid { _m.Model = value.String } + case usagelog.FieldRequestedModel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field requested_model", values[i]) + } else if value.Valid { + _m.RequestedModel = new(string) + *_m.RequestedModel = value.String + } + case usagelog.FieldUpstreamModel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field upstream_model", values[i]) + } else if value.Valid { + _m.UpstreamModel = new(string) + *_m.UpstreamModel = value.String + } case usagelog.FieldGroupID: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field group_id", values[i]) @@ -477,6 +495,16 @@ func (_m *UsageLog) String() string { builder.WriteString("model=") builder.WriteString(_m.Model) builder.WriteString(", ") + if v := _m.RequestedModel; v != nil { + builder.WriteString("requested_model=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.UpstreamModel; v != nil { + builder.WriteString("upstream_model=") + builder.WriteString(*v) + } + builder.WriteString(", ") if v := _m.GroupID; v != nil { builder.WriteString("group_id=") builder.WriteString(fmt.Sprintf("%v", *v)) diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index ba97b843..b534f193 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -24,6 +24,10 @@ const ( FieldRequestID = "request_id" // FieldModel holds the string denoting the model field in the database. FieldModel = "model" + // FieldRequestedModel holds the string denoting the requested_model field in the database. + FieldRequestedModel = "requested_model" + // FieldUpstreamModel holds the string denoting the upstream_model field in the database. + FieldUpstreamModel = "upstream_model" // FieldGroupID holds the string denoting the group_id field in the database. FieldGroupID = "group_id" // FieldSubscriptionID holds the string denoting the subscription_id field in the database. @@ -135,6 +139,8 @@ var Columns = []string{ FieldAccountID, FieldRequestID, FieldModel, + FieldRequestedModel, + FieldUpstreamModel, FieldGroupID, FieldSubscriptionID, FieldInputTokens, @@ -179,6 +185,10 @@ var ( RequestIDValidator func(string) error // ModelValidator is a validator for the "model" field. It is called by the builders before save. ModelValidator func(string) error + // RequestedModelValidator is a validator for the "requested_model" field. It is called by the builders before save. + RequestedModelValidator func(string) error + // UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. + UpstreamModelValidator func(string) error // DefaultInputTokens holds the default value on creation for the "input_tokens" field. DefaultInputTokens int // DefaultOutputTokens holds the default value on creation for the "output_tokens" field. @@ -258,6 +268,16 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldModel, opts...).ToFunc() } +// ByRequestedModel orders the results by the requested_model field. +func ByRequestedModel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRequestedModel, opts...).ToFunc() +} + +// ByUpstreamModel orders the results by the upstream_model field. +func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc() +} + // ByGroupID orders the results by the group_id field. func ByGroupID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldGroupID, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index af960335..f95bceb7 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -80,6 +80,16 @@ func Model(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldModel, v)) } +// RequestedModel applies equality check predicate on the "requested_model" field. It's identical to RequestedModelEQ. +func RequestedModel(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldRequestedModel, v)) +} + +// UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ. +func UpstreamModel(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) +} + // GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ. func GroupID(v int64) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) @@ -405,6 +415,156 @@ func ModelContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v)) } +// RequestedModelEQ applies the EQ predicate on the "requested_model" field. +func RequestedModelEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldRequestedModel, v)) +} + +// RequestedModelNEQ applies the NEQ predicate on the "requested_model" field. +func RequestedModelNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldRequestedModel, v)) +} + +// RequestedModelIn applies the In predicate on the "requested_model" field. +func RequestedModelIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldRequestedModel, vs...)) +} + +// RequestedModelNotIn applies the NotIn predicate on the "requested_model" field. +func RequestedModelNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldRequestedModel, vs...)) +} + +// RequestedModelGT applies the GT predicate on the "requested_model" field. +func RequestedModelGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldRequestedModel, v)) +} + +// RequestedModelGTE applies the GTE predicate on the "requested_model" field. +func RequestedModelGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldRequestedModel, v)) +} + +// RequestedModelLT applies the LT predicate on the "requested_model" field. +func RequestedModelLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldRequestedModel, v)) +} + +// RequestedModelLTE applies the LTE predicate on the "requested_model" field. +func RequestedModelLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldRequestedModel, v)) +} + +// RequestedModelContains applies the Contains predicate on the "requested_model" field. +func RequestedModelContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldRequestedModel, v)) +} + +// RequestedModelHasPrefix applies the HasPrefix predicate on the "requested_model" field. +func RequestedModelHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldRequestedModel, v)) +} + +// RequestedModelHasSuffix applies the HasSuffix predicate on the "requested_model" field. +func RequestedModelHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldRequestedModel, v)) +} + +// RequestedModelIsNil applies the IsNil predicate on the "requested_model" field. +func RequestedModelIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldRequestedModel)) +} + +// RequestedModelNotNil applies the NotNil predicate on the "requested_model" field. +func RequestedModelNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldRequestedModel)) +} + +// RequestedModelEqualFold applies the EqualFold predicate on the "requested_model" field. +func RequestedModelEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldRequestedModel, v)) +} + +// RequestedModelContainsFold applies the ContainsFold predicate on the "requested_model" field. +func RequestedModelContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldRequestedModel, v)) +} + +// UpstreamModelEQ applies the EQ predicate on the "upstream_model" field. +func UpstreamModelEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) +} + +// UpstreamModelNEQ applies the NEQ predicate on the "upstream_model" field. +func UpstreamModelNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldUpstreamModel, v)) +} + +// UpstreamModelIn applies the In predicate on the "upstream_model" field. +func UpstreamModelIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldUpstreamModel, vs...)) +} + +// UpstreamModelNotIn applies the NotIn predicate on the "upstream_model" field. +func UpstreamModelNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldUpstreamModel, vs...)) +} + +// UpstreamModelGT applies the GT predicate on the "upstream_model" field. +func UpstreamModelGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldUpstreamModel, v)) +} + +// UpstreamModelGTE applies the GTE predicate on the "upstream_model" field. +func UpstreamModelGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldUpstreamModel, v)) +} + +// UpstreamModelLT applies the LT predicate on the "upstream_model" field. +func UpstreamModelLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldUpstreamModel, v)) +} + +// UpstreamModelLTE applies the LTE predicate on the "upstream_model" field. +func UpstreamModelLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldUpstreamModel, v)) +} + +// UpstreamModelContains applies the Contains predicate on the "upstream_model" field. +func UpstreamModelContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldUpstreamModel, v)) +} + +// UpstreamModelHasPrefix applies the HasPrefix predicate on the "upstream_model" field. +func UpstreamModelHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldUpstreamModel, v)) +} + +// UpstreamModelHasSuffix applies the HasSuffix predicate on the "upstream_model" field. +func UpstreamModelHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldUpstreamModel, v)) +} + +// UpstreamModelIsNil applies the IsNil predicate on the "upstream_model" field. +func UpstreamModelIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldUpstreamModel)) +} + +// UpstreamModelNotNil applies the NotNil predicate on the "upstream_model" field. +func UpstreamModelNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldUpstreamModel)) +} + +// UpstreamModelEqualFold applies the EqualFold predicate on the "upstream_model" field. +func UpstreamModelEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldUpstreamModel, v)) +} + +// UpstreamModelContainsFold applies the ContainsFold predicate on the "upstream_model" field. +func UpstreamModelContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v)) +} + // GroupIDEQ applies the EQ predicate on the "group_id" field. func GroupIDEQ(v int64) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index e0285a5e..6ae0bf59 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -57,6 +57,34 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate { return _c } +// SetRequestedModel sets the "requested_model" field. +func (_c *UsageLogCreate) SetRequestedModel(v string) *UsageLogCreate { + _c.mutation.SetRequestedModel(v) + return _c +} + +// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableRequestedModel(v *string) *UsageLogCreate { + if v != nil { + _c.SetRequestedModel(*v) + } + return _c +} + +// SetUpstreamModel sets the "upstream_model" field. +func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate { + _c.mutation.SetUpstreamModel(v) + return _c +} + +// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate { + if v != nil { + _c.SetUpstreamModel(*v) + } + return _c +} + // SetGroupID sets the "group_id" field. func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate { _c.mutation.SetGroupID(v) @@ -596,6 +624,16 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} } } + if v, ok := _c.mutation.RequestedModel(); ok { + if err := usagelog.RequestedModelValidator(v); err != nil { + return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)} + } + } + if v, ok := _c.mutation.UpstreamModel(); ok { + if err := usagelog.UpstreamModelValidator(v); err != nil { + return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} + } + } if _, ok := _c.mutation.InputTokens(); !ok { return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)} } @@ -714,6 +752,14 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldModel, field.TypeString, value) _node.Model = value } + if value, ok := _c.mutation.RequestedModel(); ok { + _spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value) + _node.RequestedModel = &value + } + if value, ok := _c.mutation.UpstreamModel(); ok { + _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) + _node.UpstreamModel = &value + } if value, ok := _c.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) _node.InputTokens = value @@ -1011,6 +1057,42 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert { return u } +// SetRequestedModel sets the "requested_model" field. +func (u *UsageLogUpsert) SetRequestedModel(v string) *UsageLogUpsert { + u.Set(usagelog.FieldRequestedModel, v) + return u +} + +// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateRequestedModel() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldRequestedModel) + return u +} + +// ClearRequestedModel clears the value of the "requested_model" field. +func (u *UsageLogUpsert) ClearRequestedModel() *UsageLogUpsert { + u.SetNull(usagelog.FieldRequestedModel) + return u +} + +// SetUpstreamModel sets the "upstream_model" field. +func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert { + u.Set(usagelog.FieldUpstreamModel, v) + return u +} + +// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateUpstreamModel() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldUpstreamModel) + return u +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert { + u.SetNull(usagelog.FieldUpstreamModel) + return u +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert { u.Set(usagelog.FieldGroupID, v) @@ -1600,6 +1682,48 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne { }) } +// SetRequestedModel sets the "requested_model" field. +func (u *UsageLogUpsertOne) SetRequestedModel(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetRequestedModel(v) + }) +} + +// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateRequestedModel() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateRequestedModel() + }) +} + +// ClearRequestedModel clears the value of the "requested_model" field. +func (u *UsageLogUpsertOne) ClearRequestedModel() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearRequestedModel() + }) +} + +// SetUpstreamModel sets the "upstream_model" field. +func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetUpstreamModel(v) + }) +} + +// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateUpstreamModel() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateUpstreamModel() + }) +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearUpstreamModel() + }) +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne { return u.Update(func(s *UsageLogUpsert) { @@ -2434,6 +2558,48 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk { }) } +// SetRequestedModel sets the "requested_model" field. +func (u *UsageLogUpsertBulk) SetRequestedModel(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetRequestedModel(v) + }) +} + +// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateRequestedModel() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateRequestedModel() + }) +} + +// ClearRequestedModel clears the value of the "requested_model" field. +func (u *UsageLogUpsertBulk) ClearRequestedModel() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearRequestedModel() + }) +} + +// SetUpstreamModel sets the "upstream_model" field. +func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetUpstreamModel(v) + }) +} + +// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateUpstreamModel() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateUpstreamModel() + }) +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearUpstreamModel() + }) +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk { return u.Update(func(s *UsageLogUpsert) { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index b46e5b56..516407b9 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -102,6 +102,46 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate { return _u } +// SetRequestedModel sets the "requested_model" field. +func (_u *UsageLogUpdate) SetRequestedModel(v string) *UsageLogUpdate { + _u.mutation.SetRequestedModel(v) + return _u +} + +// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableRequestedModel(v *string) *UsageLogUpdate { + if v != nil { + _u.SetRequestedModel(*v) + } + return _u +} + +// ClearRequestedModel clears the value of the "requested_model" field. +func (_u *UsageLogUpdate) ClearRequestedModel() *UsageLogUpdate { + _u.mutation.ClearRequestedModel() + return _u +} + +// SetUpstreamModel sets the "upstream_model" field. +func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate { + _u.mutation.SetUpstreamModel(v) + return _u +} + +// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableUpstreamModel(v *string) *UsageLogUpdate { + if v != nil { + _u.SetUpstreamModel(*v) + } + return _u +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate { + _u.mutation.ClearUpstreamModel() + return _u +} + // SetGroupID sets the "group_id" field. func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate { _u.mutation.SetGroupID(v) @@ -745,6 +785,16 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} } } + if v, ok := _u.mutation.RequestedModel(); ok { + if err := usagelog.RequestedModelValidator(v); err != nil { + return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)} + } + } + if v, ok := _u.mutation.UpstreamModel(); ok { + if err := usagelog.UpstreamModelValidator(v); err != nil { + return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} + } + } if v, ok := _u.mutation.UserAgent(); ok { if err := usagelog.UserAgentValidator(v); err != nil { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} @@ -795,6 +845,18 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.Model(); ok { _spec.SetField(usagelog.FieldModel, field.TypeString, value) } + if value, ok := _u.mutation.RequestedModel(); ok { + _spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value) + } + if _u.mutation.RequestedModelCleared() { + _spec.ClearField(usagelog.FieldRequestedModel, field.TypeString) + } + if value, ok := _u.mutation.UpstreamModel(); ok { + _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) + } + if _u.mutation.UpstreamModelCleared() { + _spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString) + } if value, ok := _u.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) } @@ -1177,6 +1239,46 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne { return _u } +// SetRequestedModel sets the "requested_model" field. +func (_u *UsageLogUpdateOne) SetRequestedModel(v string) *UsageLogUpdateOne { + _u.mutation.SetRequestedModel(v) + return _u +} + +// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableRequestedModel(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetRequestedModel(*v) + } + return _u +} + +// ClearRequestedModel clears the value of the "requested_model" field. +func (_u *UsageLogUpdateOne) ClearRequestedModel() *UsageLogUpdateOne { + _u.mutation.ClearRequestedModel() + return _u +} + +// SetUpstreamModel sets the "upstream_model" field. +func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne { + _u.mutation.SetUpstreamModel(v) + return _u +} + +// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableUpstreamModel(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetUpstreamModel(*v) + } + return _u +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne { + _u.mutation.ClearUpstreamModel() + return _u +} + // SetGroupID sets the "group_id" field. func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne { _u.mutation.SetGroupID(v) @@ -1833,6 +1935,16 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} } } + if v, ok := _u.mutation.RequestedModel(); ok { + if err := usagelog.RequestedModelValidator(v); err != nil { + return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)} + } + } + if v, ok := _u.mutation.UpstreamModel(); ok { + if err := usagelog.UpstreamModelValidator(v); err != nil { + return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} + } + } if v, ok := _u.mutation.UserAgent(); ok { if err := usagelog.UserAgentValidator(v); err != nil { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} @@ -1900,6 +2012,18 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if value, ok := _u.mutation.Model(); ok { _spec.SetField(usagelog.FieldModel, field.TypeString, value) } + if value, ok := _u.mutation.RequestedModel(); ok { + _spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value) + } + if _u.mutation.RequestedModelCleared() { + _spec.ClearField(usagelog.FieldRequestedModel, field.TypeString) + } + if value, ok := _u.mutation.UpstreamModel(); ok { + _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) + } + if _u.mutation.UpstreamModelCleared() { + _spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString) + } if value, ok := _u.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) } diff --git a/backend/go.sum b/backend/go.sum index 324fe652..f5b7968f 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -22,8 +22,6 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= -github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls= -github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4= github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q= @@ -60,8 +58,6 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWA github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU= github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c= github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs= -github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0= -github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q= @@ -203,6 +199,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index e90e56af..3ee5d6cd 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -656,17 +656,33 @@ type TLSFingerprintConfig struct { } // TLSProfileConfig 单个TLS指纹模板的配置 +// 所有列表字段为空时使用内置默认值(Claude CLI 2.x / Node.js 20.x) +// 建议通过 TLS 指纹采集工具 (tests/tls-fingerprint-web) 获取完整配置 type TLSProfileConfig struct { // Name: 模板显示名称 Name string `mapstructure:"name"` // EnableGREASE: 是否启用GREASE扩展(Chrome使用,Node.js不使用) EnableGREASE bool `mapstructure:"enable_grease"` - // CipherSuites: TLS加密套件列表(空则使用内置默认值) + // CipherSuites: TLS加密套件列表 CipherSuites []uint16 `mapstructure:"cipher_suites"` - // Curves: 椭圆曲线列表(空则使用内置默认值) + // Curves: 椭圆曲线列表 Curves []uint16 `mapstructure:"curves"` - // PointFormats: 点格式列表(空则使用内置默认值) - PointFormats []uint8 `mapstructure:"point_formats"` + // PointFormats: 点格式列表 + PointFormats []uint16 `mapstructure:"point_formats"` + // SignatureAlgorithms: 签名算法列表 + SignatureAlgorithms []uint16 `mapstructure:"signature_algorithms"` + // ALPNProtocols: ALPN协议列表(如 ["h2", "http/1.1"]) + ALPNProtocols []string `mapstructure:"alpn_protocols"` + // SupportedVersions: 支持的TLS版本列表(如 [0x0304, 0x0303] 即 TLS1.3, TLS1.2) + SupportedVersions []uint16 `mapstructure:"supported_versions"` + // KeyShareGroups: Key Share中发送的曲线组(如 [29] 即 X25519) + KeyShareGroups []uint16 `mapstructure:"key_share_groups"` + // PSKModes: PSK密钥交换模式(如 [1] 即 psk_dhe_ke) + PSKModes []uint16 `mapstructure:"psk_modes"` + // Extensions: TLS扩展类型ID列表,按发送顺序排列 + // 空则使用内置默认顺序 [0,11,10,35,16,22,23,13,43,45,51] + // GREASE值(如0x0a0a)会自动插入GREASE扩展 + Extensions []uint16 `mapstructure:"extensions"` } // GatewaySchedulingConfig accounts scheduling configuration. @@ -1265,8 +1281,8 @@ func setDefaults() { viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10) // Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移) - viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json") - viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256") + viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.json") + viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.sha256") viper.SetDefault("pricing.data_dir", "./data") viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json") viper.SetDefault("pricing.update_interval_hours", 24) diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index c51046a2..4e69ca02 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -82,8 +82,8 @@ var DefaultAntigravityModelMapping = map[string]string{ "claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型 "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", // Claude Haiku → Sonnet(无 Haiku 支持) - "claude-haiku-4-5": "claude-sonnet-4-5", - "claude-haiku-4-5-20251001": "claude-sonnet-4-5", + "claude-haiku-4-5": "claude-sonnet-4-6", + "claude-haiku-4-5-20251001": "claude-sonnet-4-6", // Gemini 2.5 白名单 "gemini-2.5-flash": "gemini-2.5-flash", "gemini-2.5-flash-image": "gemini-2.5-flash-image", diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index fbac73d3..12139b51 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -267,6 +267,9 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) } } + // 收集需要异步设置隐私的 Antigravity OAuth 账号 + var privacyAccounts []*service.Account + for i := range dataPayload.Accounts { item := dataPayload.Accounts[i] if err := validateDataAccount(item); err != nil { @@ -314,7 +317,8 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) SkipDefaultGroupBind: skipDefaultGroupBind, } - if _, err := h.adminService.CreateAccount(ctx, accountInput); err != nil { + created, err := h.adminService.CreateAccount(ctx, accountInput) + if err != nil { result.AccountFailed++ result.Errors = append(result.Errors, DataImportError{ Kind: "account", @@ -323,9 +327,30 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) }) continue } + // 收集 Antigravity OAuth 账号,稍后异步设置隐私 + if created.Platform == service.PlatformAntigravity && created.Type == service.AccountTypeOAuth { + privacyAccounts = append(privacyAccounts, created) + } result.AccountCreated++ } + // 异步设置 Antigravity 隐私,避免大量导入时阻塞请求 + if len(privacyAccounts) > 0 { + adminSvc := h.adminService + go func() { + defer func() { + if r := recover(); r != nil { + slog.Error("import_antigravity_privacy_panic", "recover", r) + } + }() + bgCtx := context.Background() + for _, acc := range privacyAccounts { + adminSvc.ForceAntigravityPrivacy(bgCtx, acc) + } + slog.Info("import_antigravity_privacy_done", "count", len(privacyAccounts)) + }() + } + return result, nil } @@ -352,7 +377,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc pageSize := dataPageCap var out []service.Account for { - items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0) + items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0, "") if err != nil { return nil, err } diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 3ef213e1..ce5cffe4 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "log" + "log/slog" "net/http" "strconv" "strings" @@ -165,6 +166,8 @@ type AccountWithConcurrency struct { CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数 } +const accountListGroupUngroupedQueryValue = "ungrouped" + func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency { item := AccountWithConcurrency{ Account: dto.AccountFromService(account), @@ -217,6 +220,7 @@ func (h *AccountHandler) List(c *gin.Context) { accountType := c.Query("type") status := c.Query("status") search := c.Query("search") + privacyMode := strings.TrimSpace(c.Query("privacy_mode")) // 标准化和验证 search 参数 search = strings.TrimSpace(search) if len(search) > 100 { @@ -226,10 +230,23 @@ func (h *AccountHandler) List(c *gin.Context) { var groupID int64 if groupIDStr := c.Query("group"); groupIDStr != "" { - groupID, _ = strconv.ParseInt(groupIDStr, 10, 64) + if groupIDStr == accountListGroupUngroupedQueryValue { + groupID = service.AccountListGroupUngrouped + } else { + parsedGroupID, parseErr := strconv.ParseInt(groupIDStr, 10, 64) + if parseErr != nil { + response.ErrorFrom(c, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter")) + return + } + if parsedGroupID < 0 { + response.ErrorFrom(c, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter")) + return + } + groupID = parsedGroupID + } } - accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID) + accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode) if err != nil { response.ErrorFrom(c, err) return @@ -520,6 +537,10 @@ func (h *AccountHandler) Create(c *gin.Context) { if execErr != nil { return nil, execErr } + // Antigravity OAuth: 新账号直接设置隐私 + h.adminService.ForceAntigravityPrivacy(ctx, account) + // OpenAI OAuth: 新账号直接设置隐私 + h.adminService.ForceOpenAIPrivacy(ctx, account) return h.buildAccountResponseWithRuntime(ctx, account), nil }) if err != nil { @@ -766,6 +787,8 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv if account.IsOpenAI() { tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(ctx, account) if err != nil { + // 刷新失败但 access_token 可能仍有效,尝试设置隐私 + h.adminService.EnsureOpenAIPrivacy(ctx, account) return nil, "", err } @@ -867,6 +890,8 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv // OpenAI OAuth: 刷新成功后检查并设置 privacy_mode h.adminService.EnsureOpenAIPrivacy(ctx, updatedAccount) + // Antigravity OAuth: 刷新成功后检查并设置 privacy_mode + h.adminService.EnsureAntigravityPrivacy(ctx, updatedAccount) return updatedAccount, "", nil } @@ -1138,6 +1163,9 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { success := 0 failed := 0 results := make([]gin.H, 0, len(req.Accounts)) + // 收集需要异步设置隐私的 OAuth 账号 + var antigravityPrivacyAccounts []*service.Account + var openaiPrivacyAccounts []*service.Account for _, item := range req.Accounts { if item.RateMultiplier != nil && *item.RateMultiplier < 0 { @@ -1180,6 +1208,15 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { }) continue } + // 收集需要异步设置隐私的 OAuth 账号 + if account.Type == service.AccountTypeOAuth { + switch account.Platform { + case service.PlatformAntigravity: + antigravityPrivacyAccounts = append(antigravityPrivacyAccounts, account) + case service.PlatformOpenAI: + openaiPrivacyAccounts = append(openaiPrivacyAccounts, account) + } + } success++ results = append(results, gin.H{ "name": item.Name, @@ -1188,6 +1225,37 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { }) } + // 异步设置隐私,避免批量创建时阻塞请求 + adminSvc := h.adminService + if len(antigravityPrivacyAccounts) > 0 { + accounts := antigravityPrivacyAccounts + go func() { + defer func() { + if r := recover(); r != nil { + slog.Error("batch_create_antigravity_privacy_panic", "recover", r) + } + }() + bgCtx := context.Background() + for _, acc := range accounts { + adminSvc.ForceAntigravityPrivacy(bgCtx, acc) + } + }() + } + if len(openaiPrivacyAccounts) > 0 { + accounts := openaiPrivacyAccounts + go func() { + defer func() { + if r := recover(); r != nil { + slog.Error("batch_create_openai_privacy_panic", "recover", r) + } + }() + bgCtx := context.Background() + for _, acc := range accounts { + adminSvc.ForceOpenAIPrivacy(bgCtx, acc) + } + }() + } + return gin.H{ "success": success, "failed": failed, @@ -1496,7 +1564,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) { } // GetUsage handles getting account usage information -// GET /api/v1/admin/accounts/:id/usage +// GET /api/v1/admin/accounts/:id/usage?source=passive|active func (h *AccountHandler) GetUsage(c *gin.Context) { accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { @@ -1504,7 +1572,14 @@ func (h *AccountHandler) GetUsage(c *gin.Context) { return } - usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID) + source := c.DefaultQuery("source", "active") + + var usage *service.UsageInfo + if source == "passive" { + usage, err = h.accountUsageService.GetPassiveUsage(c.Request.Context(), accountID) + } else { + usage, err = h.accountUsageService.GetUsage(c.Request.Context(), accountID) + } if err != nil { response.ErrorFrom(c, err) return @@ -1846,6 +1921,51 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { response.Success(c, models) } +// SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account +// POST /api/v1/admin/accounts/:id/set-privacy +func (h *AccountHandler) SetPrivacy(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account ID") + return + } + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.NotFound(c, "Account not found") + return + } + if account.Type != service.AccountTypeOAuth { + response.BadRequest(c, "Only OAuth accounts support privacy setting") + return + } + var mode string + switch account.Platform { + case service.PlatformOpenAI: + mode = h.adminService.ForceOpenAIPrivacy(c.Request.Context(), account) + case service.PlatformAntigravity: + mode = h.adminService.ForceAntigravityPrivacy(c.Request.Context(), account) + default: + response.BadRequest(c, "Only OpenAI and Antigravity OAuth accounts support privacy setting") + return + } + if mode == "" { + response.BadRequest(c, "Cannot set privacy: missing access_token") + return + } + // 从 DB 重新读取以确保返回最新状态 + updated, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + // 隐私已设置成功但读取失败,回退到内存更新 + if account.Extra == nil { + account.Extra = make(map[string]any) + } + account.Extra["privacy_mode"] = mode + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) + return + } + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updated)) +} + // RefreshTier handles refreshing Google One tier for a single account // POST /api/v1/admin/accounts/:id/refresh-tier func (h *AccountHandler) RefreshTier(c *gin.Context) { @@ -1914,7 +2034,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { accounts := make([]*service.Account, 0) if len(req.AccountIDs) == 0 { - allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0) + allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "") if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index 4de10d3e..cba3ae21 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -17,7 +17,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { adminSvc := newStubAdminService() userHandler := NewUserHandler(adminSvc, nil) - groupHandler := NewGroupHandler(adminSvc) + groupHandler := NewGroupHandler(adminSvc, nil, nil) proxyHandler := NewProxyHandler(adminSvc) redeemHandler := NewRedeemHandler(adminSvc, nil) diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 37a72cb4..9759cef5 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -187,7 +187,7 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int return nil } -func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) { +func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, int64, error) { return s.accounts, int64(len(s.accounts)), nil } @@ -445,5 +445,21 @@ func (s *stubAdminService) EnsureOpenAIPrivacy(ctx context.Context, account *ser return "" } +func (s *stubAdminService) EnsureAntigravityPrivacy(ctx context.Context, account *service.Account) string { + return "" +} + +func (s *stubAdminService) ForceOpenAIPrivacy(ctx context.Context, account *service.Account) string { + return "" +} + +func (s *stubAdminService) ForceAntigravityPrivacy(ctx context.Context, account *service.Account) string { + return "" +} + +func (s *stubAdminService) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*service.ReplaceUserGroupResult, error) { + return &service.ReplaceUserGroupResult{MigratedKeys: 0}, nil +} + // Ensure stub implements interface. var _ service.AdminService = (*stubAdminService)(nil) diff --git a/backend/internal/handler/admin/backup_handler.go b/backend/internal/handler/admin/backup_handler.go index d19713ee..2f528322 100644 --- a/backend/internal/handler/admin/backup_handler.go +++ b/backend/internal/handler/admin/backup_handler.go @@ -98,12 +98,12 @@ func (h *BackupHandler) CreateBackup(c *gin.Context) { expireDays = *req.ExpireDays } - record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays) + record, err := h.backupService.StartBackup(c.Request.Context(), "manual", expireDays) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, record) + response.Accepted(c, record) } func (h *BackupHandler) ListBackups(c *gin.Context) { @@ -196,9 +196,10 @@ func (h *BackupHandler) RestoreBackup(c *gin.Context) { return } - if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil { + record, err := h.backupService.StartRestore(c.Request.Context(), backupID) + if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, gin.H{"restored": true}) + response.Accepted(c, record) } diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index f415b48f..2a214471 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -9,6 +9,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -272,6 +273,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { // Parse optional filter params var userID, apiKeyID, accountID, groupID int64 + modelSource := usagestats.ModelSourceRequested var requestType *int16 var stream *bool var billingType *int8 @@ -296,6 +298,13 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { groupID = id } } + if rawModelSource := strings.TrimSpace(c.Query("model_source")); rawModelSource != "" { + if !usagestats.IsValidModelSource(rawModelSource) { + response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping") + return + } + modelSource = rawModelSource + } if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { parsed, err := service.ParseUsageRequestType(requestTypeStr) if err != nil { @@ -322,7 +331,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { } } - stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, modelSource, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get model statistics") return @@ -604,3 +613,47 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) { c.Header("X-Snapshot-Cache", "miss") response.Success(c, payload) } + +// GetUserBreakdown handles getting per-user usage breakdown within a dimension. +// GET /api/v1/admin/dashboard/user-breakdown +// Query params: start_date, end_date, group_id, model, endpoint, endpoint_type, limit +func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) { + startTime, endTime := parseTimeRange(c) + + dim := usagestats.UserBreakdownDimension{} + if v := c.Query("group_id"); v != "" { + if id, err := strconv.ParseInt(v, 10, 64); err == nil { + dim.GroupID = id + } + } + dim.Model = c.Query("model") + rawModelSource := strings.TrimSpace(c.DefaultQuery("model_source", usagestats.ModelSourceRequested)) + if !usagestats.IsValidModelSource(rawModelSource) { + response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping") + return + } + dim.ModelType = rawModelSource + dim.Endpoint = c.Query("endpoint") + dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound") + + limit := 50 + if v := c.Query("limit"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 { + limit = n + } + } + + stats, err := h.dashboardService.GetUserBreakdownStats( + c.Request.Context(), startTime, endTime, dim, limit, + ) + if err != nil { + response.Error(c, 500, "Failed to get user breakdown stats") + return + } + + response.Success(c, gin.H{ + "users": stats, + "start_date": startTime.Format("2006-01-02"), + "end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"), + }) +} diff --git a/backend/internal/handler/admin/dashboard_handler_request_type_test.go b/backend/internal/handler/admin/dashboard_handler_request_type_test.go index 9aec61d4..6056f725 100644 --- a/backend/internal/handler/admin/dashboard_handler_request_type_test.go +++ b/backend/internal/handler/admin/dashboard_handler_request_type_test.go @@ -149,6 +149,28 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) { require.Equal(t, http.StatusBadRequest, rec.Code) } +func TestDashboardModelStatsInvalidModelSource(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=invalid", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDashboardModelStatsValidModelSource(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=upstream", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) +} + func TestDashboardUsersRankingLimitAndCache(t *testing.T) { dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute) repo := &dashboardUsageRepoCapture{ diff --git a/backend/internal/handler/admin/dashboard_handler_user_breakdown_test.go b/backend/internal/handler/admin/dashboard_handler_user_breakdown_test.go new file mode 100644 index 00000000..b3a05111 --- /dev/null +++ b/backend/internal/handler/admin/dashboard_handler_user_breakdown_test.go @@ -0,0 +1,229 @@ +package admin + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// --- mock repo --- + +type userBreakdownRepoCapture struct { + service.UsageLogRepository + capturedDim usagestats.UserBreakdownDimension + capturedLimit int + result []usagestats.UserBreakdownItem +} + +func (r *userBreakdownRepoCapture) GetUserBreakdownStats( + _ context.Context, _, _ time.Time, + dim usagestats.UserBreakdownDimension, limit int, +) ([]usagestats.UserBreakdownItem, error) { + r.capturedDim = dim + r.capturedLimit = limit + if r.result != nil { + return r.result, nil + } + return []usagestats.UserBreakdownItem{}, nil +} + +func newUserBreakdownRouter(repo *userBreakdownRepoCapture) *gin.Engine { + gin.SetMode(gin.TestMode) + svc := service.NewDashboardService(repo, nil, nil, nil) + h := NewDashboardHandler(svc, nil) + router := gin.New() + router.GET("/admin/dashboard/user-breakdown", h.GetUserBreakdown) + return router +} + +// --- tests --- + +func TestGetUserBreakdown_GroupIDFilter(t *testing.T) { + repo := &userBreakdownRepoCapture{} + router := newUserBreakdownRouter(repo) + + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=42", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, int64(42), repo.capturedDim.GroupID) + require.Empty(t, repo.capturedDim.Model) + require.Empty(t, repo.capturedDim.Endpoint) + require.Equal(t, 50, repo.capturedLimit) // default limit +} + +func TestGetUserBreakdown_ModelFilter(t *testing.T) { + repo := &userBreakdownRepoCapture{} + router := newUserBreakdownRouter(repo) + + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model) + require.Equal(t, usagestats.ModelSourceRequested, repo.capturedDim.ModelType) + require.Equal(t, int64(0), repo.capturedDim.GroupID) +} + +func TestGetUserBreakdown_ModelSourceFilter(t *testing.T) { + repo := &userBreakdownRepoCapture{} + router := newUserBreakdownRouter(repo) + + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6&model_source=upstream", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, usagestats.ModelSourceUpstream, repo.capturedDim.ModelType) +} + +func TestGetUserBreakdown_InvalidModelSource(t *testing.T) { + repo := &userBreakdownRepoCapture{} + router := newUserBreakdownRouter(repo) + + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model_source=foobar", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestGetUserBreakdown_EndpointFilter(t *testing.T) { + repo := &userBreakdownRepoCapture{} + router := newUserBreakdownRouter(repo) + + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/v1/messages&endpoint_type=upstream", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, "/v1/messages", repo.capturedDim.Endpoint) + require.Equal(t, "upstream", repo.capturedDim.EndpointType) +} + +func TestGetUserBreakdown_DefaultEndpointType(t *testing.T) { + repo := &userBreakdownRepoCapture{} + router := newUserBreakdownRouter(repo) + + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/chat", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, "inbound", repo.capturedDim.EndpointType) +} + +func TestGetUserBreakdown_CustomLimit(t *testing.T) { + repo := &userBreakdownRepoCapture{} + router := newUserBreakdownRouter(repo) + + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=100", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, 100, repo.capturedLimit) +} + +func TestGetUserBreakdown_LimitClamped(t *testing.T) { + repo := &userBreakdownRepoCapture{} + router := newUserBreakdownRouter(repo) + + // limit > 200 should fall back to default 50 + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=999", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, 50, repo.capturedLimit) +} + +func TestGetUserBreakdown_ResponseFormat(t *testing.T) { + repo := &userBreakdownRepoCapture{ + result: []usagestats.UserBreakdownItem{ + {UserID: 1, Email: "alice@test.com", Requests: 100, TotalTokens: 50000, Cost: 1.5, ActualCost: 1.2}, + {UserID: 2, Email: "bob@test.com", Requests: 50, TotalTokens: 25000, Cost: 0.8, ActualCost: 0.6}, + }, + } + router := newUserBreakdownRouter(repo) + + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=1", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + Users []usagestats.UserBreakdownItem `json:"users"` + StartDate string `json:"start_date"` + EndDate string `json:"end_date"` + } `json:"data"` + } + err := json.Unmarshal(w.Body.Bytes(), &resp) + require.NoError(t, err) + require.Equal(t, 0, resp.Code) + require.Len(t, resp.Data.Users, 2) + require.Equal(t, int64(1), resp.Data.Users[0].UserID) + require.Equal(t, "alice@test.com", resp.Data.Users[0].Email) + require.Equal(t, int64(100), resp.Data.Users[0].Requests) + require.InDelta(t, 1.2, resp.Data.Users[0].ActualCost, 0.001) + require.Equal(t, "2026-03-01", resp.Data.StartDate) + require.Equal(t, "2026-03-16", resp.Data.EndDate) +} + +func TestGetUserBreakdown_EmptyResult(t *testing.T) { + repo := &userBreakdownRepoCapture{} + router := newUserBreakdownRouter(repo) + + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=999", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + + var resp struct { + Data struct { + Users []usagestats.UserBreakdownItem `json:"users"` + } `json:"data"` + } + err := json.Unmarshal(w.Body.Bytes(), &resp) + require.NoError(t, err) + require.Empty(t, resp.Data.Users) +} + +func TestGetUserBreakdown_NoFilters(t *testing.T) { + repo := &userBreakdownRepoCapture{} + router := newUserBreakdownRouter(repo) + + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, int64(0), repo.capturedDim.GroupID) + require.Empty(t, repo.capturedDim.Model) + require.Empty(t, repo.capturedDim.Endpoint) +} diff --git a/backend/internal/handler/admin/dashboard_query_cache.go b/backend/internal/handler/admin/dashboard_query_cache.go index 47af5117..815c5161 100644 --- a/backend/internal/handler/admin/dashboard_query_cache.go +++ b/backend/internal/handler/admin/dashboard_query_cache.go @@ -38,6 +38,7 @@ type dashboardModelGroupCacheKey struct { APIKeyID int64 `json:"api_key_id"` AccountID int64 `json:"account_id"` GroupID int64 `json:"group_id"` + ModelSource string `json:"model_source,omitempty"` RequestType *int16 `json:"request_type"` Stream *bool `json:"stream"` BillingType *int8 `json:"billing_type"` @@ -111,6 +112,7 @@ func (h *DashboardHandler) getModelStatsCached( ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, + modelSource string, requestType *int16, stream *bool, billingType *int8, @@ -122,12 +124,13 @@ func (h *DashboardHandler) getModelStatsCached( APIKeyID: apiKeyID, AccountID: accountID, GroupID: groupID, + ModelSource: usagestats.NormalizeModelSource(modelSource), RequestType: requestType, Stream: stream, BillingType: billingType, }) entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) { - return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + return h.dashboardService.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, modelSource) }) if err != nil { return nil, hit, err diff --git a/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go b/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go index 16e10339..517ae7bd 100644 --- a/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go +++ b/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go @@ -200,6 +200,7 @@ func (h *DashboardHandler) buildSnapshotV2Response( filters.APIKeyID, filters.AccountID, filters.GroupID, + usagestats.ModelSourceRequested, filters.RequestType, filters.Stream, filters.BillingType, diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 4ffe64ee..caa27bc3 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -9,6 +9,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -16,7 +17,9 @@ import ( // GroupHandler handles admin group management type GroupHandler struct { - adminService service.AdminService + adminService service.AdminService + dashboardService *service.DashboardService + groupCapacityService *service.GroupCapacityService } type optionalLimitField struct { @@ -69,9 +72,11 @@ func (f optionalLimitField) ToServiceInput() *float64 { } // NewGroupHandler creates a new admin group handler -func NewGroupHandler(adminService service.AdminService) *GroupHandler { +func NewGroupHandler(adminService service.AdminService, dashboardService *service.DashboardService, groupCapacityService *service.GroupCapacityService) *GroupHandler { return &GroupHandler{ - adminService: adminService, + adminService: adminService, + dashboardService: dashboardService, + groupCapacityService: groupCapacityService, } } @@ -107,6 +112,8 @@ type CreateGroupRequest struct { SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + RequireOAuthOnly bool `json:"require_oauth_only"` + RequirePrivacySet bool `json:"require_privacy_set"` DefaultMappedModel string `json:"default_mapped_model"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` @@ -145,6 +152,8 @@ type UpdateGroupRequest struct { SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch *bool `json:"allow_messages_dispatch"` + RequireOAuthOnly *bool `json:"require_oauth_only"` + RequirePrivacySet *bool `json:"require_privacy_set"` DefaultMappedModel *string `json:"default_mapped_model"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` @@ -262,6 +271,8 @@ func (h *GroupHandler) Create(c *gin.Context) { SupportedModelScopes: req.SupportedModelScopes, SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, AllowMessagesDispatch: req.AllowMessagesDispatch, + RequireOAuthOnly: req.RequireOAuthOnly, + RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) @@ -315,6 +326,8 @@ func (h *GroupHandler) Update(c *gin.Context) { SupportedModelScopes: req.SupportedModelScopes, SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, AllowMessagesDispatch: req.AllowMessagesDispatch, + RequireOAuthOnly: req.RequireOAuthOnly, + RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) @@ -363,6 +376,33 @@ func (h *GroupHandler) GetStats(c *gin.Context) { _ = groupID // TODO: implement actual stats } +// GetUsageSummary returns today's and cumulative cost for all groups. +// GET /api/v1/admin/groups/usage-summary?timezone=Asia/Shanghai +func (h *GroupHandler) GetUsageSummary(c *gin.Context) { + userTZ := c.Query("timezone") + now := timezone.NowInUserLocation(userTZ) + todayStart := timezone.StartOfDayInUserLocation(now, userTZ) + + results, err := h.dashboardService.GetGroupUsageSummary(c.Request.Context(), todayStart) + if err != nil { + response.Error(c, 500, "Failed to get group usage summary") + return + } + + response.Success(c, results) +} + +// GetCapacitySummary returns aggregated capacity (concurrency/sessions/RPM) for all active groups. +// GET /api/v1/admin/groups/capacity-summary +func (h *GroupHandler) GetCapacitySummary(c *gin.Context) { + results, err := h.groupCapacityService.GetAllGroupCapacity(c.Request.Context()) + if err != nil { + response.Error(c, 500, "Failed to get group capacity summary") + return + } + response.Success(c, results) +} + // GetGroupAPIKeys handles getting API keys in a group // GET /api/v1/admin/groups/:id/api-keys func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) { diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index c966cb7d..397526a7 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -110,6 +110,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, SoraClientEnabled: settings.SoraClientEnabled, CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), + CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, DefaultSubscriptions: defaultSubscriptions, @@ -125,8 +126,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { OpsQueryModeDefault: settings.OpsQueryModeDefault, OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds, MinClaudeCodeVersion: settings.MinClaudeCodeVersion, + MaxClaudeCodeVersion: settings.MaxClaudeCodeVersion, AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling, BackendModeEnabled: settings.BackendModeEnabled, + EnableFingerprintUnification: settings.EnableFingerprintUnification, + EnableMetadataPassthrough: settings.EnableMetadataPassthrough, }) } @@ -175,6 +179,7 @@ type UpdateSettingsRequest struct { PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"` + CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"` // 默认配置 DefaultConcurrency int `json:"default_concurrency"` @@ -199,12 +204,17 @@ type UpdateSettingsRequest struct { OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"` MinClaudeCodeVersion string `json:"min_claude_code_version"` + MaxClaudeCodeVersion string `json:"max_claude_code_version"` // 分组隔离 AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"` // Backend Mode BackendModeEnabled bool `json:"backend_mode_enabled"` + + // Gateway forwarding behavior + EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"` + EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"` } // UpdateSettings 更新系统设置 @@ -229,11 +239,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { if req.DefaultBalance < 0 { req.DefaultBalance = 0 } + req.SMTPHost = strings.TrimSpace(req.SMTPHost) + req.SMTPUsername = strings.TrimSpace(req.SMTPUsername) + req.SMTPPassword = strings.TrimSpace(req.SMTPPassword) + req.SMTPFrom = strings.TrimSpace(req.SMTPFrom) + req.SMTPFromName = strings.TrimSpace(req.SMTPFromName) if req.SMTPPort <= 0 { req.SMTPPort = 587 } req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions) + // SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置 + // 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置 + if req.SMTPHost == "" && previousSettings.SMTPHost != "" { + req.SMTPHost = previousSettings.SMTPHost + req.SMTPPort = previousSettings.SMTPPort + req.SMTPUsername = previousSettings.SMTPUsername + req.SMTPFrom = previousSettings.SMTPFrom + req.SMTPFromName = previousSettings.SMTPFromName + req.SMTPUseTLS = previousSettings.SMTPUseTLS + } + // Turnstile 参数验证 if req.TurnstileEnabled { // 检查必填字段 @@ -415,6 +441,55 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { customMenuJSON = string(menuBytes) } + // 自定义端点验证 + const ( + maxCustomEndpoints = 10 + maxEndpointNameLen = 50 + maxEndpointURLLen = 2048 + maxEndpointDescriptionLen = 200 + ) + + customEndpointsJSON := previousSettings.CustomEndpoints + if req.CustomEndpoints != nil { + endpoints := *req.CustomEndpoints + if len(endpoints) > maxCustomEndpoints { + response.BadRequest(c, "Too many custom endpoints (max 10)") + return + } + for _, ep := range endpoints { + if strings.TrimSpace(ep.Name) == "" { + response.BadRequest(c, "Custom endpoint name is required") + return + } + if len(ep.Name) > maxEndpointNameLen { + response.BadRequest(c, "Custom endpoint name is too long (max 50 characters)") + return + } + if strings.TrimSpace(ep.Endpoint) == "" { + response.BadRequest(c, "Custom endpoint URL is required") + return + } + if len(ep.Endpoint) > maxEndpointURLLen { + response.BadRequest(c, "Custom endpoint URL is too long (max 2048 characters)") + return + } + if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(ep.Endpoint)); err != nil { + response.BadRequest(c, "Custom endpoint URL must be an absolute http(s) URL") + return + } + if len(ep.Description) > maxEndpointDescriptionLen { + response.BadRequest(c, "Custom endpoint description is too long (max 200 characters)") + return + } + } + endpointBytes, err := json.Marshal(endpoints) + if err != nil { + response.BadRequest(c, "Failed to serialize custom endpoints") + return + } + customEndpointsJSON = string(endpointBytes) + } + // Ops metrics collector interval validation (seconds). if req.OpsMetricsIntervalSeconds != nil { v := *req.OpsMetricsIntervalSeconds @@ -442,6 +517,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } + // 验证最高版本号格式(空字符串=禁用,或合法 semver) + if req.MaxClaudeCodeVersion != "" { + if !semverPattern.MatchString(req.MaxClaudeCodeVersion) { + response.Error(c, http.StatusBadRequest, "max_claude_code_version must be empty or a valid semver (e.g. 3.0.0)") + return + } + } + + // 交叉验证:如果同时设置了最低和最高版本号,最高版本号必须 >= 最低版本号 + if req.MinClaudeCodeVersion != "" && req.MaxClaudeCodeVersion != "" { + if service.CompareVersions(req.MaxClaudeCodeVersion, req.MinClaudeCodeVersion) < 0 { + response.Error(c, http.StatusBadRequest, "max_claude_code_version must be greater than or equal to min_claude_code_version") + return + } + } + settings := &service.SystemSettings{ RegistrationEnabled: req.RegistrationEnabled, EmailVerifyEnabled: req.EmailVerifyEnabled, @@ -477,6 +568,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PurchaseSubscriptionURL: purchaseURL, SoraClientEnabled: req.SoraClientEnabled, CustomMenuItems: customMenuJSON, + CustomEndpoints: customEndpointsJSON, DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, DefaultSubscriptions: defaultSubscriptions, @@ -488,6 +580,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { EnableIdentityPatch: req.EnableIdentityPatch, IdentityPatchPrompt: req.IdentityPatchPrompt, MinClaudeCodeVersion: req.MinClaudeCodeVersion, + MaxClaudeCodeVersion: req.MaxClaudeCodeVersion, AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling, BackendModeEnabled: req.BackendModeEnabled, OpsMonitoringEnabled: func() bool { @@ -514,6 +607,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.OpsMetricsIntervalSeconds }(), + EnableFingerprintUnification: func() bool { + if req.EnableFingerprintUnification != nil { + return *req.EnableFingerprintUnification + } + return previousSettings.EnableFingerprintUnification + }(), + EnableMetadataPassthrough: func() bool { + if req.EnableMetadataPassthrough != nil { + return *req.EnableMetadataPassthrough + } + return previousSettings.EnableMetadataPassthrough + }(), } if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { @@ -573,6 +678,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, SoraClientEnabled: updatedSettings.SoraClientEnabled, CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), + CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints), DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, DefaultSubscriptions: updatedDefaultSubscriptions, @@ -588,8 +694,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault, OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds, MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion, + MaxClaudeCodeVersion: updatedSettings.MaxClaudeCodeVersion, AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling, BackendModeEnabled: updatedSettings.BackendModeEnabled, + EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification, + EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, }) } @@ -744,6 +853,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion { changed = append(changed, "min_claude_code_version") } + if before.MaxClaudeCodeVersion != after.MaxClaudeCodeVersion { + changed = append(changed, "max_claude_code_version") + } if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling { changed = append(changed, "allow_ungrouped_key_scheduling") } @@ -759,6 +871,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.CustomMenuItems != after.CustomMenuItems { changed = append(changed, "custom_menu_items") } + if before.EnableFingerprintUnification != after.EnableFingerprintUnification { + changed = append(changed, "enable_fingerprint_unification") + } + if before.EnableMetadataPassthrough != after.EnableMetadataPassthrough { + changed = append(changed, "enable_metadata_passthrough") + } return changed } @@ -805,7 +923,7 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool { // TestSMTPRequest 测试SMTP连接请求 type TestSMTPRequest struct { - SMTPHost string `json:"smtp_host" binding:"required"` + SMTPHost string `json:"smtp_host"` SMTPPort int `json:"smtp_port"` SMTPUsername string `json:"smtp_username"` SMTPPassword string `json:"smtp_password"` @@ -821,18 +939,35 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) { return } - if req.SMTPPort <= 0 { - req.SMTPPort = 587 + req.SMTPHost = strings.TrimSpace(req.SMTPHost) + req.SMTPUsername = strings.TrimSpace(req.SMTPUsername) + + var savedConfig *service.SMTPConfig + if cfg, err := h.emailService.GetSMTPConfig(c.Request.Context()); err == nil && cfg != nil { + savedConfig = cfg } - // 如果未提供密码,从数据库获取已保存的密码 - password := req.SMTPPassword - if password == "" { - savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context()) - if err == nil && savedConfig != nil { - password = savedConfig.Password + if req.SMTPHost == "" && savedConfig != nil { + req.SMTPHost = savedConfig.Host + } + if req.SMTPPort <= 0 { + if savedConfig != nil && savedConfig.Port > 0 { + req.SMTPPort = savedConfig.Port + } else { + req.SMTPPort = 587 } } + if req.SMTPUsername == "" && savedConfig != nil { + req.SMTPUsername = savedConfig.Username + } + password := strings.TrimSpace(req.SMTPPassword) + if password == "" && savedConfig != nil { + password = savedConfig.Password + } + if req.SMTPHost == "" { + response.BadRequest(c, "SMTP host is required") + return + } config := &service.SMTPConfig{ Host: req.SMTPHost, @@ -854,7 +989,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) { // SendTestEmailRequest 发送测试邮件请求 type SendTestEmailRequest struct { Email string `json:"email" binding:"required,email"` - SMTPHost string `json:"smtp_host" binding:"required"` + SMTPHost string `json:"smtp_host"` SMTPPort int `json:"smtp_port"` SMTPUsername string `json:"smtp_username"` SMTPPassword string `json:"smtp_password"` @@ -872,18 +1007,43 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) { return } - if req.SMTPPort <= 0 { - req.SMTPPort = 587 + req.SMTPHost = strings.TrimSpace(req.SMTPHost) + req.SMTPUsername = strings.TrimSpace(req.SMTPUsername) + req.SMTPFrom = strings.TrimSpace(req.SMTPFrom) + req.SMTPFromName = strings.TrimSpace(req.SMTPFromName) + + var savedConfig *service.SMTPConfig + if cfg, err := h.emailService.GetSMTPConfig(c.Request.Context()); err == nil && cfg != nil { + savedConfig = cfg } - // 如果未提供密码,从数据库获取已保存的密码 - password := req.SMTPPassword - if password == "" { - savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context()) - if err == nil && savedConfig != nil { - password = savedConfig.Password + if req.SMTPHost == "" && savedConfig != nil { + req.SMTPHost = savedConfig.Host + } + if req.SMTPPort <= 0 { + if savedConfig != nil && savedConfig.Port > 0 { + req.SMTPPort = savedConfig.Port + } else { + req.SMTPPort = 587 } } + if req.SMTPUsername == "" && savedConfig != nil { + req.SMTPUsername = savedConfig.Username + } + password := strings.TrimSpace(req.SMTPPassword) + if password == "" && savedConfig != nil { + password = savedConfig.Password + } + if req.SMTPFrom == "" && savedConfig != nil { + req.SMTPFrom = savedConfig.From + } + if req.SMTPFromName == "" && savedConfig != nil { + req.SMTPFromName = savedConfig.FromName + } + if req.SMTPHost == "" { + response.BadRequest(c, "SMTP host is required") + return + } config := &service.SMTPConfig{ Host: req.SMTPHost, @@ -977,6 +1137,58 @@ func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) { response.Success(c, gin.H{"message": "Admin API key deleted"}) } +// GetOverloadCooldownSettings 获取529过载冷却配置 +// GET /api/v1/admin/settings/overload-cooldown +func (h *SettingHandler) GetOverloadCooldownSettings(c *gin.Context) { + settings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.OverloadCooldownSettings{ + Enabled: settings.Enabled, + CooldownMinutes: settings.CooldownMinutes, + }) +} + +// UpdateOverloadCooldownSettingsRequest 更新529过载冷却配置请求 +type UpdateOverloadCooldownSettingsRequest struct { + Enabled bool `json:"enabled"` + CooldownMinutes int `json:"cooldown_minutes"` +} + +// UpdateOverloadCooldownSettings 更新529过载冷却配置 +// PUT /api/v1/admin/settings/overload-cooldown +func (h *SettingHandler) UpdateOverloadCooldownSettings(c *gin.Context) { + var req UpdateOverloadCooldownSettingsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + settings := &service.OverloadCooldownSettings{ + Enabled: req.Enabled, + CooldownMinutes: req.CooldownMinutes, + } + + if err := h.settingService.SetOverloadCooldownSettings(c.Request.Context(), settings); err != nil { + response.BadRequest(c, err.Error()) + return + } + + updatedSettings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.OverloadCooldownSettings{ + Enabled: updatedSettings.Enabled, + CooldownMinutes: updatedSettings.CooldownMinutes, + }) +} + // GetStreamTimeoutSettings 获取流超时处理配置 // GET /api/v1/admin/settings/stream-timeout func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) { @@ -1382,18 +1594,26 @@ func (h *SettingHandler) GetRectifierSettings(c *gin.Context) { return } + patterns := settings.APIKeySignaturePatterns + if patterns == nil { + patterns = []string{} + } response.Success(c, dto.RectifierSettings{ Enabled: settings.Enabled, ThinkingSignatureEnabled: settings.ThinkingSignatureEnabled, ThinkingBudgetEnabled: settings.ThinkingBudgetEnabled, + APIKeySignatureEnabled: settings.APIKeySignatureEnabled, + APIKeySignaturePatterns: patterns, }) } // UpdateRectifierSettingsRequest 更新整流器配置请求 type UpdateRectifierSettingsRequest struct { - Enabled bool `json:"enabled"` - ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"` - ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"` + Enabled bool `json:"enabled"` + ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"` + ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"` + APIKeySignatureEnabled bool `json:"apikey_signature_enabled"` + APIKeySignaturePatterns []string `json:"apikey_signature_patterns"` } // UpdateRectifierSettings 更新请求整流器配置 @@ -1405,10 +1625,32 @@ func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) { return } + // 校验并清理自定义匹配关键词 + const maxPatterns = 50 + const maxPatternLen = 500 + if len(req.APIKeySignaturePatterns) > maxPatterns { + response.BadRequest(c, "Too many signature patterns (max 50)") + return + } + var cleanedPatterns []string + for _, p := range req.APIKeySignaturePatterns { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if len(p) > maxPatternLen { + response.BadRequest(c, "Signature pattern too long (max 500 characters)") + return + } + cleanedPatterns = append(cleanedPatterns, p) + } + settings := &service.RectifierSettings{ Enabled: req.Enabled, ThinkingSignatureEnabled: req.ThinkingSignatureEnabled, ThinkingBudgetEnabled: req.ThinkingBudgetEnabled, + APIKeySignatureEnabled: req.APIKeySignatureEnabled, + APIKeySignaturePatterns: cleanedPatterns, } if err := h.settingService.SetRectifierSettings(c.Request.Context(), settings); err != nil { @@ -1423,10 +1665,16 @@ func (h *SettingHandler) UpdateRectifierSettings(c *gin.Context) { return } + updatedPatterns := updatedSettings.APIKeySignaturePatterns + if updatedPatterns == nil { + updatedPatterns = []string{} + } response.Success(c, dto.RectifierSettings{ Enabled: updatedSettings.Enabled, ThinkingSignatureEnabled: updatedSettings.ThinkingSignatureEnabled, ThinkingBudgetEnabled: updatedSettings.ThinkingBudgetEnabled, + APIKeySignatureEnabled: updatedSettings.APIKeySignatureEnabled, + APIKeySignaturePatterns: updatedPatterns, }) } diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go index 342964b6..611666de 100644 --- a/backend/internal/handler/admin/subscription_handler.go +++ b/backend/internal/handler/admin/subscription_handler.go @@ -77,12 +77,13 @@ func (h *SubscriptionHandler) List(c *gin.Context) { } } status := c.Query("status") + platform := c.Query("platform") // Parse sorting parameters sortBy := c.DefaultQuery("sort_by", "created_at") sortOrder := c.DefaultQuery("sort_order", "desc") - subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder) + subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, platform, sortBy, sortOrder) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/tls_fingerprint_profile_handler.go b/backend/internal/handler/admin/tls_fingerprint_profile_handler.go new file mode 100644 index 00000000..38f97555 --- /dev/null +++ b/backend/internal/handler/admin/tls_fingerprint_profile_handler.go @@ -0,0 +1,234 @@ +package admin + +import ( + "strconv" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +// TLSFingerprintProfileHandler 处理 TLS 指纹模板的 HTTP 请求 +type TLSFingerprintProfileHandler struct { + service *service.TLSFingerprintProfileService +} + +// NewTLSFingerprintProfileHandler 创建 TLS 指纹模板处理器 +func NewTLSFingerprintProfileHandler(service *service.TLSFingerprintProfileService) *TLSFingerprintProfileHandler { + return &TLSFingerprintProfileHandler{service: service} +} + +// CreateTLSFingerprintProfileRequest 创建模板请求 +type CreateTLSFingerprintProfileRequest struct { + Name string `json:"name" binding:"required"` + Description *string `json:"description"` + EnableGREASE *bool `json:"enable_grease"` + CipherSuites []uint16 `json:"cipher_suites"` + Curves []uint16 `json:"curves"` + PointFormats []uint16 `json:"point_formats"` + SignatureAlgorithms []uint16 `json:"signature_algorithms"` + ALPNProtocols []string `json:"alpn_protocols"` + SupportedVersions []uint16 `json:"supported_versions"` + KeyShareGroups []uint16 `json:"key_share_groups"` + PSKModes []uint16 `json:"psk_modes"` + Extensions []uint16 `json:"extensions"` +} + +// UpdateTLSFingerprintProfileRequest 更新模板请求(部分更新) +type UpdateTLSFingerprintProfileRequest struct { + Name *string `json:"name"` + Description *string `json:"description"` + EnableGREASE *bool `json:"enable_grease"` + CipherSuites []uint16 `json:"cipher_suites"` + Curves []uint16 `json:"curves"` + PointFormats []uint16 `json:"point_formats"` + SignatureAlgorithms []uint16 `json:"signature_algorithms"` + ALPNProtocols []string `json:"alpn_protocols"` + SupportedVersions []uint16 `json:"supported_versions"` + KeyShareGroups []uint16 `json:"key_share_groups"` + PSKModes []uint16 `json:"psk_modes"` + Extensions []uint16 `json:"extensions"` +} + +// List 获取所有模板 +// GET /api/v1/admin/tls-fingerprint-profiles +func (h *TLSFingerprintProfileHandler) List(c *gin.Context) { + profiles, err := h.service.List(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profiles) +} + +// GetByID 根据 ID 获取模板 +// GET /api/v1/admin/tls-fingerprint-profiles/:id +func (h *TLSFingerprintProfileHandler) GetByID(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid profile ID") + return + } + + profile, err := h.service.GetByID(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + if profile == nil { + response.NotFound(c, "Profile not found") + return + } + + response.Success(c, profile) +} + +// Create 创建模板 +// POST /api/v1/admin/tls-fingerprint-profiles +func (h *TLSFingerprintProfileHandler) Create(c *gin.Context) { + var req CreateTLSFingerprintProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + profile := &model.TLSFingerprintProfile{ + Name: req.Name, + Description: req.Description, + CipherSuites: req.CipherSuites, + Curves: req.Curves, + PointFormats: req.PointFormats, + SignatureAlgorithms: req.SignatureAlgorithms, + ALPNProtocols: req.ALPNProtocols, + SupportedVersions: req.SupportedVersions, + KeyShareGroups: req.KeyShareGroups, + PSKModes: req.PSKModes, + Extensions: req.Extensions, + } + + if req.EnableGREASE != nil { + profile.EnableGREASE = *req.EnableGREASE + } + + created, err := h.service.Create(c.Request.Context(), profile) + if err != nil { + if _, ok := err.(*model.ValidationError); ok { + response.BadRequest(c, err.Error()) + return + } + response.ErrorFrom(c, err) + return + } + + response.Success(c, created) +} + +// Update 更新模板(支持部分更新) +// PUT /api/v1/admin/tls-fingerprint-profiles/:id +func (h *TLSFingerprintProfileHandler) Update(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid profile ID") + return + } + + var req UpdateTLSFingerprintProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + existing, err := h.service.GetByID(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + if existing == nil { + response.NotFound(c, "Profile not found") + return + } + + // 部分更新 + profile := &model.TLSFingerprintProfile{ + ID: id, + Name: existing.Name, + Description: existing.Description, + EnableGREASE: existing.EnableGREASE, + CipherSuites: existing.CipherSuites, + Curves: existing.Curves, + PointFormats: existing.PointFormats, + SignatureAlgorithms: existing.SignatureAlgorithms, + ALPNProtocols: existing.ALPNProtocols, + SupportedVersions: existing.SupportedVersions, + KeyShareGroups: existing.KeyShareGroups, + PSKModes: existing.PSKModes, + Extensions: existing.Extensions, + } + + if req.Name != nil { + profile.Name = *req.Name + } + if req.Description != nil { + profile.Description = req.Description + } + if req.EnableGREASE != nil { + profile.EnableGREASE = *req.EnableGREASE + } + if req.CipherSuites != nil { + profile.CipherSuites = req.CipherSuites + } + if req.Curves != nil { + profile.Curves = req.Curves + } + if req.PointFormats != nil { + profile.PointFormats = req.PointFormats + } + if req.SignatureAlgorithms != nil { + profile.SignatureAlgorithms = req.SignatureAlgorithms + } + if req.ALPNProtocols != nil { + profile.ALPNProtocols = req.ALPNProtocols + } + if req.SupportedVersions != nil { + profile.SupportedVersions = req.SupportedVersions + } + if req.KeyShareGroups != nil { + profile.KeyShareGroups = req.KeyShareGroups + } + if req.PSKModes != nil { + profile.PSKModes = req.PSKModes + } + if req.Extensions != nil { + profile.Extensions = req.Extensions + } + + updated, err := h.service.Update(c.Request.Context(), profile) + if err != nil { + if _, ok := err.(*model.ValidationError); ok { + response.BadRequest(c, err.Error()) + return + } + response.ErrorFrom(c, err) + return + } + + response.Success(c, updated) +} + +// Delete 删除模板 +// DELETE /api/v1/admin/tls-fingerprint-profiles/:id +func (h *TLSFingerprintProfileHandler) Delete(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid profile ID") + return + } + + if err := h.service.Delete(c.Request.Context(), id); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "Profile deleted successfully"}) +} diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 5a55ab14..998308dd 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -75,6 +75,7 @@ type UpdateBalanceRequest struct { // - role: filter by user role // - search: search in email, username // - attr[{id}]: filter by custom attribute value, e.g. attr[1]=company +// - group_name: fuzzy filter by allowed group name func (h *UserHandler) List(c *gin.Context) { page, pageSize := response.ParsePagination(c) @@ -89,6 +90,7 @@ func (h *UserHandler) List(c *gin.Context) { Status: c.Query("status"), Role: c.Query("role"), Search: search, + GroupName: strings.TrimSpace(c.Query("group_name")), Attributes: parseAttributeFilters(c), } if raw, ok := c.GetQuery("include_subscriptions"); ok { @@ -366,3 +368,35 @@ func (h *UserHandler) GetBalanceHistory(c *gin.Context) { "total_recharged": totalRecharged, }) } + +// ReplaceGroupRequest represents the request to replace a user's exclusive group +type ReplaceGroupRequest struct { + OldGroupID int64 `json:"old_group_id" binding:"required,gt=0"` + NewGroupID int64 `json:"new_group_id" binding:"required,gt=0"` +} + +// ReplaceGroup handles replacing a user's exclusive group +// POST /api/v1/admin/users/:id/replace-group +func (h *UserHandler) ReplaceGroup(c *gin.Context) { + userID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid user ID") + return + } + + var req ReplaceGroupRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + result, err := h.adminService.ReplaceUserGroup(c.Request.Context(), userID, req.OldGroupID, req.NewGroupID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{ + "migrated_keys": result.MigratedKeys, + }) +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 8e5f23e7..a8da92c0 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -135,14 +135,16 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { return nil } out := &AdminGroup{ - Group: groupFromServiceBase(g), - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - MCPXMLInject: g.MCPXMLInject, - DefaultMappedModel: g.DefaultMappedModel, - SupportedModelScopes: g.SupportedModelScopes, - AccountCount: g.AccountCount, - SortOrder: g.SortOrder, + Group: groupFromServiceBase(g), + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + MCPXMLInject: g.MCPXMLInject, + DefaultMappedModel: g.DefaultMappedModel, + SupportedModelScopes: g.SupportedModelScopes, + AccountCount: g.AccountCount, + ActiveAccountCount: g.ActiveAccountCount, + RateLimitedAccountCount: g.RateLimitedAccountCount, + SortOrder: g.SortOrder, } if len(g.AccountGroups) > 0 { out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) @@ -179,6 +181,8 @@ func groupFromServiceBase(g *service.Group) Group { FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, SoraStorageQuotaBytes: g.SoraStorageQuotaBytes, AllowMessagesDispatch: g.AllowMessagesDispatch, + RequireOAuthOnly: g.RequireOAuthOnly, + RequirePrivacySet: g.RequirePrivacySet, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } @@ -250,6 +254,10 @@ func AccountFromServiceShallow(a *service.Account) *Account { enabled := true out.EnableTLSFingerprint = &enabled } + // TLS指纹模板ID + if profileID := a.GetTLSFingerprintProfileID(); profileID > 0 { + out.TLSFingerprintProfileID = &profileID + } // 会话ID伪装开关 if a.IsSessionIDMaskingEnabled() { enabled := true @@ -262,6 +270,14 @@ func AccountFromServiceShallow(a *service.Account) *Account { target := a.GetCacheTTLOverrideTarget() out.CacheTTLOverrideTarget = &target } + // 自定义 Base URL 中继转发 + if a.IsCustomBaseURLEnabled() { + enabled := true + out.CustomBaseURLEnabled = &enabled + if customURL := a.GetCustomBaseURL(); customURL != "" { + out.CustomBaseURL = &customURL + } + } } // 提取账号配额限制(apikey / bedrock 类型有效) @@ -274,11 +290,17 @@ func AccountFromServiceShallow(a *service.Account) *Account { if limit := a.GetQuotaDailyLimit(); limit > 0 { out.QuotaDailyLimit = &limit used := a.GetQuotaDailyUsed() + if a.IsDailyQuotaPeriodExpired() { + used = 0 + } out.QuotaDailyUsed = &used } if limit := a.GetQuotaWeeklyLimit(); limit > 0 { out.QuotaWeeklyLimit = &limit used := a.GetQuotaWeeklyUsed() + if a.IsWeeklyQuotaPeriodExpired() { + used = 0 + } out.QuotaWeeklyUsed = &used } // 固定时间重置配置 @@ -514,13 +536,17 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { // 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。 requestType := l.EffectiveRequestType() stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode) + requestedModel := l.RequestedModel + if requestedModel == "" { + requestedModel = l.Model + } return UsageLog{ ID: l.ID, UserID: l.UserID, APIKeyID: l.APIKeyID, AccountID: l.AccountID, RequestID: l.RequestID, - Model: l.Model, + Model: requestedModel, ServiceTier: l.ServiceTier, ReasoningEffort: l.ReasoningEffort, InboundEndpoint: l.InboundEndpoint, @@ -577,6 +603,7 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog { } return &AdminUsageLog{ UsageLog: usageLogFromServiceUser(l), + UpstreamModel: l.UpstreamModel, AccountRateMultiplier: l.AccountRateMultiplier, IPAddress: l.IPAddress, Account: AccountSummaryFromService(l.Account), diff --git a/backend/internal/handler/dto/mappers_usage_test.go b/backend/internal/handler/dto/mappers_usage_test.go index e4031970..c2635e33 100644 --- a/backend/internal/handler/dto/mappers_usage_test.go +++ b/backend/internal/handler/dto/mappers_usage_test.go @@ -1,6 +1,7 @@ package dto import ( + "encoding/json" "testing" "github.com/Wei-Shaw/sub2api/internal/service" @@ -106,6 +107,47 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) { require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12) } +func TestUsageLogFromService_UsesRequestedModelAndKeepsUpstreamAdminOnly(t *testing.T) { + t.Parallel() + + upstreamModel := "claude-sonnet-4-20250514" + log := &service.UsageLog{ + RequestID: "req_4", + Model: upstreamModel, + RequestedModel: "claude-sonnet-4", + UpstreamModel: &upstreamModel, + } + + userDTO := UsageLogFromService(log) + adminDTO := UsageLogFromServiceAdmin(log) + + require.Equal(t, "claude-sonnet-4", userDTO.Model) + require.Equal(t, "claude-sonnet-4", adminDTO.Model) + + userJSON, err := json.Marshal(userDTO) + require.NoError(t, err) + require.NotContains(t, string(userJSON), "upstream_model") + + adminJSON, err := json.Marshal(adminDTO) + require.NoError(t, err) + require.Contains(t, string(adminJSON), `"upstream_model":"claude-sonnet-4-20250514"`) +} + +func TestUsageLogFromService_FallsBackToLegacyModelWhenRequestedModelMissing(t *testing.T) { + t.Parallel() + + log := &service.UsageLog{ + RequestID: "req_legacy", + Model: "claude-3", + } + + userDTO := UsageLogFromService(log) + adminDTO := UsageLogFromServiceAdmin(log) + + require.Equal(t, "claude-3", userDTO.Model) + require.Equal(t, "claude-3", adminDTO.Model) +} + func f64Ptr(value float64) *float64 { return &value } diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 29b00bb8..47bab091 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -15,6 +15,13 @@ type CustomMenuItem struct { SortOrder int `json:"sort_order"` } +// CustomEndpoint represents an admin-configured API endpoint for quick copy. +type CustomEndpoint struct { + Name string `json:"name"` + Endpoint string `json:"endpoint"` + Description string `json:"description"` +} + // SystemSettings represents the admin settings API response payload. type SystemSettings struct { RegistrationEnabled bool `json:"registration_enabled"` @@ -56,6 +63,7 @@ type SystemSettings struct { PurchaseSubscriptionURL string `json:"purchase_subscription_url"` SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` + CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` @@ -79,12 +87,17 @@ type SystemSettings struct { OpsMetricsIntervalSeconds int `json:"ops_metrics_interval_seconds"` MinClaudeCodeVersion string `json:"min_claude_code_version"` + MaxClaudeCodeVersion string `json:"max_claude_code_version"` // 分组隔离 AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"` // Backend Mode BackendModeEnabled bool `json:"backend_mode_enabled"` + + // Gateway forwarding behavior + EnableFingerprintUnification bool `json:"enable_fingerprint_unification"` + EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"` } type DefaultSubscriptionSetting struct { @@ -113,6 +126,7 @@ type PublicSettings struct { PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` + CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"` @@ -157,6 +171,12 @@ type ListSoraS3ProfilesResponse struct { Items []SoraS3Profile `json:"items"` } +// OverloadCooldownSettings 529过载冷却配置 DTO +type OverloadCooldownSettings struct { + Enabled bool `json:"enabled"` + CooldownMinutes int `json:"cooldown_minutes"` +} + // StreamTimeoutSettings 流超时处理配置 DTO type StreamTimeoutSettings struct { Enabled bool `json:"enabled"` @@ -168,9 +188,11 @@ type StreamTimeoutSettings struct { // RectifierSettings 请求整流器配置 DTO type RectifierSettings struct { - Enabled bool `json:"enabled"` - ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"` - ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"` + Enabled bool `json:"enabled"` + ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"` + ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"` + APIKeySignatureEnabled bool `json:"apikey_signature_enabled"` + APIKeySignaturePatterns []string `json:"apikey_signature_patterns"` } // BetaPolicyRule Beta 策略规则 DTO @@ -211,3 +233,17 @@ func ParseUserVisibleMenuItems(raw string) []CustomMenuItem { } return filtered } + +// ParseCustomEndpoints parses a JSON string into a slice of CustomEndpoint. +// Returns empty slice on empty/invalid input. +func ParseCustomEndpoints(raw string) []CustomEndpoint { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return []CustomEndpoint{} + } + var items []CustomEndpoint + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return []CustomEndpoint{} + } + return items +} diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index c52e357e..46984044 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -102,6 +102,10 @@ type Group struct { // OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程) AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + // 账号过滤控制(仅 OpenAI/Antigravity 平台有效) + RequireOAuthOnly bool `json:"require_oauth_only"` + RequirePrivacySet bool `json:"require_privacy_set"` + CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } @@ -122,9 +126,11 @@ type AdminGroup struct { DefaultMappedModel string `json:"default_mapped_model"` // 支持的模型系列(仅 antigravity 平台使用) - SupportedModelScopes []string `json:"supported_model_scopes"` - AccountGroups []AccountGroup `json:"account_groups,omitempty"` - AccountCount int64 `json:"account_count,omitempty"` + SupportedModelScopes []string `json:"supported_model_scopes"` + AccountGroups []AccountGroup `json:"account_groups,omitempty"` + AccountCount int64 `json:"account_count,omitempty"` + ActiveAccountCount int64 `json:"active_account_count,omitempty"` + RateLimitedAccountCount int64 `json:"rate_limited_account_count,omitempty"` // 分组排序 SortOrder int `json:"sort_order"` @@ -183,7 +189,8 @@ type Account struct { // TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效) // 从 extra 字段提取,方便前端显示和编辑 - EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"` + EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"` + TLSFingerprintProfileID *int64 `json:"tls_fingerprint_profile_id,omitempty"` // 会话ID伪装(仅 Anthropic OAuth/SetupToken 账号有效) // 启用后将在15分钟内固定 metadata.user_id 中的 session ID @@ -195,6 +202,10 @@ type Account struct { CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"` CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"` + // 自定义 Base URL 中继转发(仅 Anthropic OAuth/SetupToken 账号有效) + CustomBaseURLEnabled *bool `json:"custom_base_url_enabled,omitempty"` + CustomBaseURL *string `json:"custom_base_url,omitempty"` + // API Key 账号配额限制 QuotaLimit *float64 `json:"quota_limit,omitempty"` QuotaUsed *float64 `json:"quota_used,omitempty"` @@ -391,6 +402,10 @@ type UsageLog struct { type AdminUsageLog struct { UsageLog + // UpstreamModel is the actual model sent to the upstream provider after mapping. + // Omitted when no mapping was applied (requested model was used as-is). + UpstreamModel *string `json:"upstream_model,omitempty"` + // AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理) AccountRateMultiplier *float64 `json:"account_rate_multiplier"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 831029c4..a0d8b2e9 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -178,6 +178,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) // 验证 model 必填 if reqModel == "" { @@ -421,11 +422,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } } wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) - reqLog.Error("gateway.forward_failed", + forwardFailedFields := []zap.Field{ zap.Int64("account_id", account.ID), + zap.String("account_name", account.Name), + zap.String("account_platform", account.Platform), zap.Bool("fallback_error_response_written", wroteFallback), zap.Error(err), - ) + } + if account.Proxy != nil { + forwardFailedFields = append(forwardFailedFields, + zap.Int64("proxy_id", account.Proxy.ID), + zap.String("proxy_name", account.Proxy.Name), + zap.String("proxy_host", account.Proxy.Host), + zap.Int("proxy_port", account.Proxy.Port), + ) + } else if account.ProxyID != nil { + forwardFailedFields = append(forwardFailedFields, zap.Int64p("proxy_id", account.ProxyID)) + } + reqLog.Error("gateway.forward_failed", forwardFailedFields...) return } @@ -740,11 +754,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } } wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) - reqLog.Error("gateway.forward_failed", + forwardFailedFields := []zap.Field{ zap.Int64("account_id", account.ID), + zap.String("account_name", account.Name), + zap.String("account_platform", account.Platform), zap.Bool("fallback_error_response_written", wroteFallback), zap.Error(err), - ) + } + if account.Proxy != nil { + forwardFailedFields = append(forwardFailedFields, + zap.Int64("proxy_id", account.Proxy.ID), + zap.String("proxy_name", account.Proxy.Name), + zap.String("proxy_host", account.Proxy.Host), + zap.Int("proxy_port", account.Proxy.Port), + ) + } else if account.ProxyID != nil { + forwardFailedFields = append(forwardFailedFields, zap.Int64p("proxy_id", account.ProxyID)) + } + reqLog.Error("gateway.forward_failed", forwardFailedFields...) return } @@ -1219,6 +1246,10 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se } } + // 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误 + upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody) + service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "") + // 使用默认的错误映射 status, errType, errMsg := h.mapUpstreamError(statusCode) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) @@ -1227,6 +1258,7 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se // handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况 func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) { status, errType, errMsg := h.mapUpstreamError(statusCode) + service.SetOpsUpstreamError(c, statusCode, errMsg, "") h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) } @@ -1276,7 +1308,7 @@ func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarte return true } -// checkClaudeCodeVersion 检查 Claude Code 客户端版本是否满足最低要求 +// checkClaudeCodeVersion 检查 Claude Code 客户端版本是否满足版本要求 // 仅对已识别的 Claude Code 客户端执行,count_tokens 路径除外 func (h *GatewayHandler) checkClaudeCodeVersion(c *gin.Context) bool { ctx := c.Request.Context() @@ -1289,8 +1321,8 @@ func (h *GatewayHandler) checkClaudeCodeVersion(c *gin.Context) bool { return true } - minVersion := h.settingService.GetMinClaudeCodeVersion(ctx) - if minVersion == "" { + minVersion, maxVersion := h.settingService.GetClaudeCodeVersionBounds(ctx) + if minVersion == "" && maxVersion == "" { return true // 未设置,不检查 } @@ -1301,13 +1333,22 @@ func (h *GatewayHandler) checkClaudeCodeVersion(c *gin.Context) bool { return false } - if service.CompareVersions(clientVersion, minVersion) < 0 { + if minVersion != "" && service.CompareVersions(clientVersion, minVersion) < 0 { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", fmt.Sprintf("Your Claude Code version (%s) is below the minimum required version (%s). Please update: npm update -g @anthropic-ai/claude-code", clientVersion, minVersion)) return false } + if maxVersion != "" && service.CompareVersions(clientVersion, maxVersion) > 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", + fmt.Sprintf("Your Claude Code version (%s) exceeds the maximum allowed version (%s). "+ + "Please downgrade: npm install -g @anthropic-ai/claude-code@%s && "+ + "set CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1 to prevent auto-upgrade", + clientVersion, maxVersion, maxVersion)) + return false + } + return true } @@ -1382,6 +1423,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { } setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsedReq.Stream, false))) // 获取订阅信息(可能为nil) subscription, _ := middleware2.GetSubscriptionFromContext(c) diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go new file mode 100644 index 00000000..da376036 --- /dev/null +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -0,0 +1,289 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "time" + + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// ChatCompletions handles OpenAI Chat Completions API endpoint for Anthropic platform groups. +// POST /v1/chat/completions +// This converts Chat Completions requests to Anthropic format (via Responses format chain), +// forwards to Anthropic upstream, and converts responses back to Chat Completions format. +func (h *GatewayHandler) ChatCompletions(c *gin.Context) { + streamStarted := false + + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.chatCompletionsErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.chatCompletionsErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.gateway.chat_completions", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + + // Read request body + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.chatCompletionsErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + + if len(body) == 0 { + h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + setOpsRequestContext(c, "", false, body) + + // Validate JSON + if !gjson.ValidBytes(body) { + h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + // Extract model and stream + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { + h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqModel := modelResult.String() + reqStream := gjson.GetBytes(body, "stream").Bool() + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + + setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + + // Claude Code only restriction + if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly { + h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error", + "This group is restricted to Claude Code clients (/v1/messages only)") + return + } + + // Error passthrough binding + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + + // 1. Acquire user concurrency slot + maxWait := service.CalculateMaxWait(subject.Concurrency) + canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) + waitCounted := false + if err != nil { + reqLog.Warn("gateway.cc.user_wait_counter_increment_failed", zap.Error(err)) + } else if !canWait { + h.chatCompletionsErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return + } + if err == nil && canWait { + waitCounted = true + } + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + } + }() + + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) + if err != nil { + reqLog.Warn("gateway.cc.user_slot_acquire_failed", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", streamStarted) + return + } + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + waitCounted = false + } + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + // 2. Re-check billing + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.chatCompletionsErrorResponse(c, status, code, message) + return + } + + // Parse request for session hash + parsedReq, _ := service.ParseGatewayRequest(body, "chat_completions") + if parsedReq == nil { + parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body} + } + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } + sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) + + // 3. Account selection + failover loop + fs := NewFailoverState(h.maxAccountSwitches, false) + + for { + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "") + if err != nil { + if len(fs.FailedAccountIDs) == 0 { + h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) + return + } + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + continue + case FailoverCanceled: + return + default: + if fs.LastFailoverErr != nil { + h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted) + } else { + h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted") + } + return + } + } + account := selection.Account + setOpsSelectedAccount(c, account.ID, account.Platform) + + // 4. Acquire account concurrency slot + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") + return + } + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + reqLog.Warn("gateway.cc.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + } + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + + // 5. Forward request + writerSizeBeforeForward := c.Writer.Size() + result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq) + + if accountReleaseFunc != nil { + accountReleaseFunc() + } + + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + if c.Writer.Size() != writerSizeBeforeForward { + h.handleCCFailoverExhausted(c, failoverErr, true) + return + } + action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch action { + case FailoverContinue: + continue + case FailoverExhausted: + h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted) + return + case FailoverCanceled: + return + } + } + h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("gateway.cc.forward_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + ) + return + } + + // 6. Record usage + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, + }); err != nil { + reqLog.Error("gateway.cc.record_usage_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + ) + } + }) + return + } +} + +// chatCompletionsErrorResponse writes an error in OpenAI Chat Completions format. +func (h *GatewayHandler) chatCompletionsErrorResponse(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +// handleCCFailoverExhausted writes a failover-exhausted error in CC format. +func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) { + if streamStarted { + return + } + statusCode := http.StatusBadGateway + if lastErr != nil && lastErr.StatusCode > 0 { + statusCode = lastErr.StatusCode + } + h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted") +} diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go new file mode 100644 index 00000000..d146d724 --- /dev/null +++ b/backend/internal/handler/gateway_handler_responses.go @@ -0,0 +1,295 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "time" + + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// Responses handles OpenAI Responses API endpoint for Anthropic platform groups. +// POST /v1/responses +// This converts Responses API requests to Anthropic format, forwards to Anthropic +// upstream, and converts responses back to Responses format. +func (h *GatewayHandler) Responses(c *gin.Context) { + streamStarted := false + + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.responsesErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.responsesErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.gateway.responses", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + + // Read request body + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.responsesErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + + if len(body) == 0 { + h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + setOpsRequestContext(c, "", false, body) + + // Validate JSON + if !gjson.ValidBytes(body) { + h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + // Extract model and stream using gjson (like OpenAI handler) + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { + h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqModel := modelResult.String() + reqStream := gjson.GetBytes(body, "stream").Bool() + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + + setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + + // Claude Code only restriction: + // /v1/responses is never a Claude Code endpoint. + // When claude_code_only is enabled, this endpoint is rejected. + // The existing service-layer checkClaudeCodeRestriction handles degradation + // to fallback groups when the Forward path calls SelectAccountForModelWithExclusions. + // Here we just reject at handler level since /v1/responses clients can't be Claude Code. + if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly { + h.responsesErrorResponse(c, http.StatusForbidden, "permission_error", + "This group is restricted to Claude Code clients (/v1/messages only)") + return + } + + // Error passthrough binding + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + + // 1. Acquire user concurrency slot + maxWait := service.CalculateMaxWait(subject.Concurrency) + canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) + waitCounted := false + if err != nil { + reqLog.Warn("gateway.responses.user_wait_counter_increment_failed", zap.Error(err)) + } else if !canWait { + h.responsesErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return + } + if err == nil && canWait { + waitCounted = true + } + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + } + }() + + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) + if err != nil { + reqLog.Warn("gateway.responses.user_slot_acquire_failed", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", streamStarted) + return + } + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + waitCounted = false + } + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + // 2. Re-check billing + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.responsesErrorResponse(c, status, code, message) + return + } + + // Parse request for session hash + parsedReq, _ := service.ParseGatewayRequest(body, "responses") + if parsedReq == nil { + parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body} + } + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } + sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) + + // 3. Account selection + failover loop + fs := NewFailoverState(h.maxAccountSwitches, false) + + for { + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "") + if err != nil { + if len(fs.FailedAccountIDs) == 0 { + h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) + return + } + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + continue + case FailoverCanceled: + return + default: + if fs.LastFailoverErr != nil { + h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted) + } else { + h.responsesErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted") + } + return + } + } + account := selection.Account + setOpsSelectedAccount(c, account.ID, account.Platform) + + // 4. Acquire account concurrency slot + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") + return + } + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + reqLog.Warn("gateway.responses.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + } + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + + // 5. Forward request + writerSizeBeforeForward := c.Writer.Size() + result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq) + + if accountReleaseFunc != nil { + accountReleaseFunc() + } + + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + // Can't failover if streaming content already sent + if c.Writer.Size() != writerSizeBeforeForward { + h.handleResponsesFailoverExhausted(c, failoverErr, true) + return + } + action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch action { + case FailoverContinue: + continue + case FailoverExhausted: + h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted) + return + case FailoverCanceled: + return + } + } + h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("gateway.responses.forward_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + ) + return + } + + // 6. Record usage + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, + }); err != nil { + reqLog.Error("gateway.responses.record_usage_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + ) + } + }) + return + } +} + +// responsesErrorResponse writes an error in OpenAI Responses API format. +func (h *GatewayHandler) responsesErrorResponse(c *gin.Context, status int, code, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "code": code, + "message": message, + }, + }) +} + +// handleResponsesFailoverExhausted writes a failover-exhausted error in Responses format. +func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) { + if streamStarted { + return // Can't write error after stream started + } + statusCode := http.StatusBadGateway + if lastErr != nil && lastErr.StatusCode > 0 { + statusCode = lastErr.StatusCode + } + h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted") +} diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 6bcc0003..69c8d1d5 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -75,8 +75,10 @@ func (f *fakeGroupRepo) ListActive(context.Context) ([]service.Group, error) { r func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service.Group, error) { return nil, nil } -func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil } -func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil } +func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil } +func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, int64, error) { + return 0, 0, nil +} func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { return 0, nil } @@ -158,6 +160,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // rpmCache nil, // digestStore nil, // settingService + nil, // tlsFPProfileService ) // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 diff --git a/backend/internal/handler/gateway_helper_hotpath_test.go b/backend/internal/handler/gateway_helper_hotpath_test.go index 9e904107..4a677199 100644 --- a/backend/internal/handler/gateway_helper_hotpath_test.go +++ b/backend/internal/handler/gateway_helper_hotpath_test.go @@ -136,7 +136,7 @@ func validClaudeCodeBodyJSON() []byte { return []byte(`{ "model":"claude-3-5-sonnet-20241022", "system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}], - "metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"} + "metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"} }`) } @@ -190,7 +190,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing System: []any{ map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, }, - MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123", + MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", } // body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。 @@ -209,7 +209,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing "system": []any{ map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, }, - "metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}, + "metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}, }) SetClaudeCodeClientContext(c, []byte(`{invalid`), nil) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index cfe80911..524c6b6d 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { googleError(c, http.StatusBadGateway, err.Error()) return } - if shouldFallbackGeminiModels(res) { + if shouldFallbackGeminiModel(modelName, res) { c.JSON(http.StatusOK, gemini.FallbackModel(modelName)) return } @@ -182,6 +182,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } setOpsRequestContext(c, modelName, stream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false))) // Get subscription (may be nil) subscription, _ := middleware.GetSubscriptionFromContext(c) @@ -593,6 +594,10 @@ func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverE } } + // 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误 + upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody) + service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "") + // 使用默认的错误映射 status, message := mapGeminiUpstreamError(statusCode) googleError(c, status, message) @@ -669,6 +674,16 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool { return false } +func shouldFallbackGeminiModel(modelName string, res *service.UpstreamHTTPResult) bool { + if shouldFallbackGeminiModels(res) { + return true + } + if res == nil || res.StatusCode != http.StatusNotFound { + return false + } + return gemini.HasFallbackModel(modelName) +} + // extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。 // 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。 // diff --git a/backend/internal/handler/gemini_v1beta_handler_test.go b/backend/internal/handler/gemini_v1beta_handler_test.go index 82b30ee4..29d7cc41 100644 --- a/backend/internal/handler/gemini_v1beta_handler_test.go +++ b/backend/internal/handler/gemini_v1beta_handler_test.go @@ -3,6 +3,7 @@ package handler import ( + "net/http" "testing" "github.com/Wei-Shaw/sub2api/internal/service" @@ -141,3 +142,28 @@ func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) { }) } } + +func TestShouldFallbackGeminiModel_KnownFallbackOn404(t *testing.T) { + t.Parallel() + + res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound} + require.True(t, shouldFallbackGeminiModel("gemini-3.1-pro-preview-customtools", res)) +} + +func TestShouldFallbackGeminiModel_UnknownModelOn404(t *testing.T) { + t.Parallel() + + res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound} + require.False(t, shouldFallbackGeminiModel("gemini-future-model", res)) +} + +func TestShouldFallbackGeminiModel_DelegatesScopeFallback(t *testing.T) { + t.Parallel() + + res := &service.UpstreamHTTPResult{ + StatusCode: http.StatusForbidden, + Headers: http.Header{"Www-Authenticate": []string{"Bearer error=\"insufficient_scope\""}}, + Body: []byte("insufficient authentication scopes"), + } + require.True(t, shouldFallbackGeminiModel("gemini-future-model", res)) +} diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 89d556cc..b2467eac 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -6,29 +6,30 @@ import ( // AdminHandlers contains all admin-related HTTP handlers type AdminHandlers struct { - Dashboard *admin.DashboardHandler - User *admin.UserHandler - Group *admin.GroupHandler - Account *admin.AccountHandler - Announcement *admin.AnnouncementHandler - DataManagement *admin.DataManagementHandler - Backup *admin.BackupHandler - OAuth *admin.OAuthHandler - OpenAIOAuth *admin.OpenAIOAuthHandler - GeminiOAuth *admin.GeminiOAuthHandler - AntigravityOAuth *admin.AntigravityOAuthHandler - Proxy *admin.ProxyHandler - Redeem *admin.RedeemHandler - Promo *admin.PromoHandler - Setting *admin.SettingHandler - Ops *admin.OpsHandler - System *admin.SystemHandler - Subscription *admin.SubscriptionHandler - Usage *admin.UsageHandler - UserAttribute *admin.UserAttributeHandler - ErrorPassthrough *admin.ErrorPassthroughHandler - APIKey *admin.AdminAPIKeyHandler - ScheduledTest *admin.ScheduledTestHandler + Dashboard *admin.DashboardHandler + User *admin.UserHandler + Group *admin.GroupHandler + Account *admin.AccountHandler + Announcement *admin.AnnouncementHandler + DataManagement *admin.DataManagementHandler + Backup *admin.BackupHandler + OAuth *admin.OAuthHandler + OpenAIOAuth *admin.OpenAIOAuthHandler + GeminiOAuth *admin.GeminiOAuthHandler + AntigravityOAuth *admin.AntigravityOAuthHandler + Proxy *admin.ProxyHandler + Redeem *admin.RedeemHandler + Promo *admin.PromoHandler + Setting *admin.SettingHandler + Ops *admin.OpsHandler + System *admin.SystemHandler + Subscription *admin.SubscriptionHandler + Usage *admin.UsageHandler + UserAttribute *admin.UserAttributeHandler + ErrorPassthrough *admin.ErrorPassthroughHandler + TLSFingerprintProfile *admin.TLSFingerprintProfileHandler + APIKey *admin.AdminAPIKeyHandler + ScheduledTest *admin.ScheduledTestHandler } // Handlers contains all HTTP handlers diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 4db5cadd..0c94aa21 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -77,6 +77,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) @@ -181,7 +182,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() - defaultMappedModel := c.GetString("openai_chat_completions_fallback_model") + defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model")) result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) forwardDurationMs := time.Since(forwardStart).Milliseconds() diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index c681e61d..ae70cee4 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -37,6 +37,16 @@ type OpenAIGatewayHandler struct { cfg *config.Config } +func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackModel string) string { + if fallbackModel = strings.TrimSpace(fallbackModel); fallbackModel != "" { + return fallbackModel + } + if apiKey == nil || apiKey.Group == nil { + return "" + } + return strings.TrimSpace(apiKey.Group.DefaultMappedModel) +} + // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler func NewOpenAIGatewayHandler( gatewayService *service.OpenAIGatewayService, @@ -173,6 +183,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 if !h.validateFunctionCallOutputRequest(c, body, reqLog) { @@ -530,11 +541,13 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { return } reqModel := modelResult.String() + routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel) reqStream := gjson.GetBytes(body, "stream").Bool() reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 if h.errorPassthroughService != nil { @@ -594,7 +607,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { apiKey.GroupID, "", // no previous_response_id sessionHash, - reqModel, + routingModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, ) @@ -609,7 +622,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { if apiKey.Group != nil { defaultModel = apiKey.Group.DefaultMappedModel } - if defaultModel != "" && defaultModel != reqModel { + if defaultModel != "" && defaultModel != routingModel { reqLog.Info("openai_messages.fallback_to_default_model", zap.String("default_mapped_model", defaultModel), ) @@ -657,9 +670,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() - // 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时, - // 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。 - defaultMappedModel := c.GetString("openai_messages_fallback_model") + // Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的 + // Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。 + defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model")) result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) forwardDurationMs := time.Since(forwardStart).Milliseconds() @@ -1086,6 +1099,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { zap.String("previous_response_id_kind", previousResponseIDKind), ) setOpsRequestContext(c, reqModel, true, firstMessage) + setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) var currentUserRelease func() var currentAccountRelease func() @@ -1435,6 +1449,10 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE } } + // 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误 + upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody) + service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "") + // 使用默认的错误映射 status, errType, errMsg := h.mapUpstreamError(statusCode) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) @@ -1443,6 +1461,7 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE // handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况 func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) { status, errType, errMsg := h.mapUpstreamError(statusCode) + service.SetOpsUpstreamError(c, statusCode, errMsg, "") h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) } diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index a26b3a0c..7bbf94ec 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -352,6 +352,30 @@ func TestOpenAIEnsureResponsesDependencies(t *testing.T) { }) } +func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) { + t.Run("prefers_explicit_fallback_model", func(t *testing.T) { + apiKey := &service.APIKey{ + Group: &service.Group{DefaultMappedModel: "gpt-5.4"}, + } + require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 ")) + }) + + t.Run("uses_group_default_on_normal_path", func(t *testing.T) { + apiKey := &service.APIKey{ + Group: &service.Group{DefaultMappedModel: "gpt-5.4"}, + } + require.Equal(t, "gpt-5.4", resolveOpenAIForwardDefaultMappedModel(apiKey, "")) + }) + + t.Run("returns_empty_without_group_default", func(t *testing.T) { + require.Empty(t, resolveOpenAIForwardDefaultMappedModel(nil, "")) + require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{}, "")) + require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{ + Group: &service.Group{}, + }, "")) + }) +} + func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index ceb06f0e..90e90dd0 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -27,6 +27,9 @@ const ( opsRequestBodyKey = "ops_request_body" opsAccountIDKey = "ops_account_id" + opsUpstreamModelKey = "ops_upstream_model" + opsRequestTypeKey = "ops_request_type" + // 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用 opsErrContextCanceled = "context canceled" opsErrNoAvailableAccounts = "no available accounts" @@ -345,6 +348,18 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody } } +// setOpsEndpointContext stores upstream model and request type for ops error logging. +// Called by handlers after model mapping and request type determination. +func setOpsEndpointContext(c *gin.Context, upstreamModel string, requestType int16) { + if c == nil { + return + } + if upstreamModel = strings.TrimSpace(upstreamModel); upstreamModel != "" { + c.Set(opsUpstreamModelKey, upstreamModel) + } + c.Set(opsRequestTypeKey, requestType) +} + func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) { if c == nil || entry == nil { return @@ -628,7 +643,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { } return "" }(), - Stream: stream, + Stream: stream, + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, platform), + RequestedModel: modelName, + UpstreamModel: func() string { + if v, ok := c.Get(opsUpstreamModelKey); ok { + if s, ok := v.(string); ok { + return strings.TrimSpace(s) + } + } + return "" + }(), + RequestType: func() *int16 { + if v, ok := c.Get(opsRequestTypeKey); ok { + switch t := v.(type) { + case int16: + return &t + case int: + v16 := int16(t) + return &v16 + } + } + return nil + }(), UserAgent: c.GetHeader("User-Agent"), ErrorPhase: "upstream", @@ -756,7 +794,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { } return "" }(), - Stream: stream, + Stream: stream, + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, platform), + RequestedModel: modelName, + UpstreamModel: func() string { + if v, ok := c.Get(opsUpstreamModelKey); ok { + if s, ok := v.(string); ok { + return strings.TrimSpace(s) + } + } + return "" + }(), + RequestType: func() *int16 { + if v, ok := c.Get(opsRequestTypeKey); ok { + switch t := v.(type) { + case int16: + return &t + case int: + v16 := int16(t) + return &v16 + } + } + return nil + }(), UserAgent: c.GetHeader("User-Agent"), ErrorPhase: phase, diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go index 679dd4ce..6ae45110 100644 --- a/backend/internal/handler/ops_error_logger_test.go +++ b/backend/internal/handler/ops_error_logger_test.go @@ -274,3 +274,48 @@ func TestNormalizeOpsErrorType(t *testing.T) { }) } } + +func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + setOpsEndpointContext(c, "claude-3-5-sonnet-20241022", int16(2)) // stream + + v, ok := c.Get(opsUpstreamModelKey) + require.True(t, ok) + vStr, ok := v.(string) + require.True(t, ok) + require.Equal(t, "claude-3-5-sonnet-20241022", vStr) + + rt, ok := c.Get(opsRequestTypeKey) + require.True(t, ok) + rtVal, ok := rt.(int16) + require.True(t, ok) + require.Equal(t, int16(2), rtVal) +} + +func TestSetOpsEndpointContext_EmptyModelNotStored(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + setOpsEndpointContext(c, "", int16(1)) + + _, ok := c.Get(opsUpstreamModelKey) + require.False(t, ok, "empty upstream model should not be stored") + + rt, ok := c.Get(opsRequestTypeKey) + require.True(t, ok) + rtVal, ok := rt.(int16) + require.True(t, ok) + require.Equal(t, int16(1), rtVal) +} + +func TestSetOpsEndpointContext_NilContext(t *testing.T) { + require.NotPanics(t, func() { + setOpsEndpointContext(nil, "model", int16(1)) + }) +} diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 92061895..2c999cf1 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -52,6 +52,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), + CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, SoraClientEnabled: settings.SoraClientEnabled, BackendModeEnabled: settings.BackendModeEnabled, diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index dab17673..fe035b6f 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -942,6 +942,9 @@ func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, e func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { return 0, nil } +func (r *stubUserRepoForHandler) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + return nil +} func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil } func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil } func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil } @@ -1017,6 +1020,20 @@ func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { return 0, nil } +func (r *stubAPIKeyRepoForHandler) UpdateGroupIDByUserAndGroup(_ context.Context, userID, oldGroupID, newGroupID int64) (int64, error) { + var updated int64 + for id, key := range r.keys { + if key.UserID != userID || key.GroupID == nil || *key.GroupID != oldGroupID { + continue + } + clone := *key + gid := newGroupID + clone.GroupID = &gid + r.keys[id] = &clone + updated++ + } + return updated, nil +} func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) { return 0, nil } @@ -2055,7 +2072,7 @@ func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) { +func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) { @@ -2207,7 +2224,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { return service.NewGatewayService( accountRepo, nil, nil, nil, nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, ) } diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index dc301ce1..5e505409 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -159,6 +159,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { } setOpsRequestContext(c, reqModel, clientStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(clientStream, false))) platform := "" if forced, ok := middleware2.GetForcePlatformFromContext(c); ok { @@ -484,6 +485,9 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s } func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) { + upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody) + service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "") + status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) } diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 7170415d..c790a36c 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -130,7 +130,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { +func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { @@ -273,8 +273,8 @@ func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform strin func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) { return false, nil } -func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { - return 0, nil +func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { + return 0, 0, nil } func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, nil @@ -345,6 +345,12 @@ func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Conte func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { return nil, nil } +func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) { + return nil, nil +} func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { return nil, nil } @@ -458,6 +464,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { nil, // rpmCache nil, // digestStore nil, // settingService + nil, // tlsFPProfileService ) soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index f3aadcf3..02ddd030 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -30,33 +30,35 @@ func ProvideAdminHandlers( usageHandler *admin.UsageHandler, userAttributeHandler *admin.UserAttributeHandler, errorPassthroughHandler *admin.ErrorPassthroughHandler, + tlsFingerprintProfileHandler *admin.TLSFingerprintProfileHandler, apiKeyHandler *admin.AdminAPIKeyHandler, scheduledTestHandler *admin.ScheduledTestHandler, ) *AdminHandlers { return &AdminHandlers{ - Dashboard: dashboardHandler, - User: userHandler, - Group: groupHandler, - Account: accountHandler, - Announcement: announcementHandler, - DataManagement: dataManagementHandler, - Backup: backupHandler, - OAuth: oauthHandler, - OpenAIOAuth: openaiOAuthHandler, - GeminiOAuth: geminiOAuthHandler, - AntigravityOAuth: antigravityOAuthHandler, - Proxy: proxyHandler, - Redeem: redeemHandler, - Promo: promoHandler, - Setting: settingHandler, - Ops: opsHandler, - System: systemHandler, - Subscription: subscriptionHandler, - Usage: usageHandler, - UserAttribute: userAttributeHandler, - ErrorPassthrough: errorPassthroughHandler, - APIKey: apiKeyHandler, - ScheduledTest: scheduledTestHandler, + Dashboard: dashboardHandler, + User: userHandler, + Group: groupHandler, + Account: accountHandler, + Announcement: announcementHandler, + DataManagement: dataManagementHandler, + Backup: backupHandler, + OAuth: oauthHandler, + OpenAIOAuth: openaiOAuthHandler, + GeminiOAuth: geminiOAuthHandler, + AntigravityOAuth: antigravityOAuthHandler, + Proxy: proxyHandler, + Redeem: redeemHandler, + Promo: promoHandler, + Setting: settingHandler, + Ops: opsHandler, + System: systemHandler, + Subscription: subscriptionHandler, + Usage: usageHandler, + UserAttribute: userAttributeHandler, + ErrorPassthrough: errorPassthroughHandler, + TLSFingerprintProfile: tlsFingerprintProfileHandler, + APIKey: apiKeyHandler, + ScheduledTest: scheduledTestHandler, } } @@ -145,6 +147,7 @@ var ProviderSet = wire.NewSet( admin.NewUsageHandler, admin.NewUserAttributeHandler, admin.NewErrorPassthroughHandler, + admin.NewTLSFingerprintProfileHandler, admin.NewAdminAPIKeyHandler, admin.NewScheduledTestHandler, diff --git a/backend/internal/model/tls_fingerprint_profile.go b/backend/internal/model/tls_fingerprint_profile.go new file mode 100644 index 00000000..ef57af7a --- /dev/null +++ b/backend/internal/model/tls_fingerprint_profile.go @@ -0,0 +1,54 @@ +// Package model 定义服务层使用的数据模型。 +package model + +import ( + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" +) + +// TLSFingerprintProfile TLS 指纹配置模板 +// 包含完整的 ClientHello 参数,用于模拟特定客户端的 TLS 握手特征 +type TLSFingerprintProfile struct { + ID int64 `json:"id"` + Name string `json:"name"` + Description *string `json:"description"` + EnableGREASE bool `json:"enable_grease"` + CipherSuites []uint16 `json:"cipher_suites"` + Curves []uint16 `json:"curves"` + PointFormats []uint16 `json:"point_formats"` + SignatureAlgorithms []uint16 `json:"signature_algorithms"` + ALPNProtocols []string `json:"alpn_protocols"` + SupportedVersions []uint16 `json:"supported_versions"` + KeyShareGroups []uint16 `json:"key_share_groups"` + PSKModes []uint16 `json:"psk_modes"` + Extensions []uint16 `json:"extensions"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// Validate 验证模板配置的有效性 +func (p *TLSFingerprintProfile) Validate() error { + if p.Name == "" { + return &ValidationError{Field: "name", Message: "name is required"} + } + return nil +} + +// ToTLSProfile 将领域模型转换为运行时使用的 tlsfingerprint.Profile +// 空切片字段会在 dialer 中 fallback 到内置默认值 +func (p *TLSFingerprintProfile) ToTLSProfile() *tlsfingerprint.Profile { + return &tlsfingerprint.Profile{ + Name: p.Name, + EnableGREASE: p.EnableGREASE, + CipherSuites: p.CipherSuites, + Curves: p.Curves, + PointFormats: p.PointFormats, + SignatureAlgorithms: p.SignatureAlgorithms, + ALPNProtocols: p.ALPNProtocols, + SupportedVersions: p.SupportedVersions, + KeyShareGroups: p.KeyShareGroups, + PSKModes: p.PSKModes, + Extensions: p.Extensions, + } +} diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index af3a0bfc..fdd7fea1 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -78,7 +78,9 @@ type UserInfo struct { // LoadCodeAssistRequest loadCodeAssist 请求 type LoadCodeAssistRequest struct { Metadata struct { - IDEType string `json:"ideType"` + IDEType string `json:"ideType"` + IDEVersion string `json:"ideVersion"` + IDEName string `json:"ideName"` } `json:"metadata"` } @@ -223,14 +225,40 @@ func (r *LoadCodeAssistResponse) GetAvailableCredits() []AvailableCredit { return r.PaidTier.AvailableCredits } +// TierIDToPlanType 将 tier ID 映射为用户可见的套餐名。 +func TierIDToPlanType(tierID string) string { + switch strings.ToLower(strings.TrimSpace(tierID)) { + case "free-tier": + return "Free" + case "g1-pro-tier": + return "Pro" + case "g1-ultra-tier": + return "Ultra" + default: + if tierID == "" { + return "Free" + } + return tierID + } +} + // Client Antigravity API 客户端 type Client struct { httpClient *http.Client } +const ( + // proxyDialTimeout 代理 TCP 连接超时(含代理握手),代理不通时快速失败 + proxyDialTimeout = 5 * time.Second + // proxyTLSHandshakeTimeout 代理 TLS 握手超时 + proxyTLSHandshakeTimeout = 5 * time.Second + // clientTimeout 整体请求超时(含连接、发送、等待响应、读取 body) + clientTimeout = 10 * time.Second +) + func NewClient(proxyURL string) (*Client, error) { client := &http.Client{ - Timeout: 30 * time.Second, + Timeout: clientTimeout, } _, parsed, err := proxyurl.Parse(proxyURL) @@ -238,7 +266,12 @@ func NewClient(proxyURL string) (*Client, error) { return nil, err } if parsed != nil { - transport := &http.Transport{} + transport := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: proxyDialTimeout, + }).DialContext, + TLSHandshakeTimeout: proxyTLSHandshakeTimeout, + } if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { return nil, fmt.Errorf("configure proxy: %w", err) } @@ -250,8 +283,8 @@ func NewClient(proxyURL string) (*Client, error) { }, nil } -// isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) -func isConnectionError(err error) bool { +// IsConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) +func IsConnectionError(err error) bool { if err == nil { return false } @@ -276,7 +309,7 @@ func isConnectionError(err error) bool { // shouldFallbackToNextURL 判断是否应切换到下一个 URL // 与 Antigravity-Manager 保持一致:连接错误、429、408、404、5xx 触发 URL 降级 func shouldFallbackToNextURL(err error, statusCode int) bool { - if isConnectionError(err) { + if IsConnectionError(err) { return true } return statusCode == http.StatusTooManyRequests || @@ -407,6 +440,8 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) { reqBody := LoadCodeAssistRequest{} reqBody.Metadata.IDEType = "ANTIGRAVITY" + reqBody.Metadata.IDEVersion = "1.20.6" + reqBody.Metadata.IDEName = "antigravity" bodyBytes, err := json.Marshal(reqBody) if err != nil { @@ -690,3 +725,139 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI return nil, nil, lastErr } + +// ── Privacy API ────────────────────────────────────────────────────── + +// privacyBaseURL 隐私设置 API 仅使用 daily 端点(与 Antigravity 客户端行为一致) +const privacyBaseURL = antigravityDailyBaseURL + +// SetUserSettingsRequest setUserSettings 请求体 +type SetUserSettingsRequest struct { + UserSettings map[string]any `json:"user_settings"` +} + +// FetchUserInfoRequest fetchUserInfo 请求体 +type FetchUserInfoRequest struct { + Project string `json:"project"` +} + +// FetchUserInfoResponse fetchUserInfo 响应体 +type FetchUserInfoResponse struct { + UserSettings map[string]any `json:"userSettings,omitempty"` + RegionCode string `json:"regionCode,omitempty"` +} + +// IsPrivate 判断隐私是否已设置:userSettings 为空或不含 telemetryEnabled 表示已设置 +func (r *FetchUserInfoResponse) IsPrivate() bool { + if r == nil || r.UserSettings == nil { + return true + } + _, hasTelemetry := r.UserSettings["telemetryEnabled"] + return !hasTelemetry +} + +// SetUserSettingsResponse setUserSettings 响应体 +type SetUserSettingsResponse struct { + UserSettings map[string]any `json:"userSettings,omitempty"` +} + +// IsSuccess 判断 setUserSettings 是否成功:返回 {"userSettings":{}} 且无 telemetryEnabled +func (r *SetUserSettingsResponse) IsSuccess() bool { + if r == nil { + return false + } + // userSettings 为 nil 或空 map 均视为成功 + if len(r.UserSettings) == 0 { + return true + } + // 如果包含 telemetryEnabled 字段,说明未成功清除 + _, hasTelemetry := r.UserSettings["telemetryEnabled"] + return !hasTelemetry +} + +// SetUserSettings 调用 setUserSettings API 设置用户隐私,返回解析后的响应 +func (c *Client) SetUserSettings(ctx context.Context, accessToken string) (*SetUserSettingsResponse, error) { + // 发送空 user_settings 以清除隐私设置 + payload := SetUserSettingsRequest{UserSettings: map[string]any{}} + bodyBytes, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("序列化请求失败: %w", err) + } + + apiURL := privacyBaseURL + "/v1internal:setUserSettings" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "*/*") + req.Header.Set("User-Agent", GetUserAgent()) + req.Header.Set("X-Goog-Api-Client", "gl-node/22.21.1") + req.Host = "daily-cloudcode-pa.googleapis.com" + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("setUserSettings 请求失败: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("setUserSettings 失败 (HTTP %d): %s", resp.StatusCode, string(respBody)) + } + + var result SetUserSettingsResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("响应解析失败: %w", err) + } + + return &result, nil +} + +// FetchUserInfo 调用 fetchUserInfo API 获取用户隐私设置状态 +func (c *Client) FetchUserInfo(ctx context.Context, accessToken, projectID string) (*FetchUserInfoResponse, error) { + reqBody := FetchUserInfoRequest{Project: projectID} + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("序列化请求失败: %w", err) + } + + apiURL := privacyBaseURL + "/v1internal:fetchUserInfo" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "*/*") + req.Header.Set("User-Agent", GetUserAgent()) + req.Header.Set("X-Goog-Api-Client", "gl-node/22.21.1") + req.Host = "daily-cloudcode-pa.googleapis.com" + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("fetchUserInfo 请求失败: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("fetchUserInfo 失败 (HTTP %d): %s", resp.StatusCode, string(respBody)) + } + + var result FetchUserInfoResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("响应解析失败: %w", err) + } + + return &result, nil +} diff --git a/backend/internal/pkg/antigravity/client_test.go b/backend/internal/pkg/antigravity/client_test.go index 7d5bba93..b6c2e6a5 100644 --- a/backend/internal/pkg/antigravity/client_test.go +++ b/backend/internal/pkg/antigravity/client_test.go @@ -250,6 +250,27 @@ func TestGetTier_两者都为nil(t *testing.T) { } } +func TestTierIDToPlanType(t *testing.T) { + tests := []struct { + tierID string + want string + }{ + {"free-tier", "Free"}, + {"g1-pro-tier", "Pro"}, + {"g1-ultra-tier", "Ultra"}, + {"FREE-TIER", "Free"}, + {"", "Free"}, + {"unknown-tier", "unknown-tier"}, + } + for _, tt := range tests { + t.Run(tt.tierID, func(t *testing.T) { + if got := TierIDToPlanType(tt.tierID); got != tt.want { + t.Errorf("TierIDToPlanType(%q) = %q, want %q", tt.tierID, got, tt.want) + } + }) + } +} + // --------------------------------------------------------------------------- // NewClient // --------------------------------------------------------------------------- @@ -274,8 +295,8 @@ func TestNewClient_无代理(t *testing.T) { if client.httpClient == nil { t.Fatal("httpClient 为 nil") } - if client.httpClient.Timeout != 30*time.Second { - t.Errorf("Timeout 不匹配: got %v, want 30s", client.httpClient.Timeout) + if client.httpClient.Timeout != clientTimeout { + t.Errorf("Timeout 不匹配: got %v, want %v", client.httpClient.Timeout, clientTimeout) } // 无代理时 Transport 应为 nil(使用默认) if client.httpClient.Transport != nil { @@ -322,11 +343,11 @@ func TestNewClient_无效代理URL(t *testing.T) { } // --------------------------------------------------------------------------- -// isConnectionError +// IsConnectionError // --------------------------------------------------------------------------- func TestIsConnectionError_nil(t *testing.T) { - if isConnectionError(nil) { + if IsConnectionError(nil) { t.Error("nil 错误不应判定为连接错误") } } @@ -338,7 +359,7 @@ func TestIsConnectionError_超时错误(t *testing.T) { Net: "tcp", Err: &timeoutError{}, } - if !isConnectionError(err) { + if !IsConnectionError(err) { t.Error("超时错误应判定为连接错误") } } @@ -356,7 +377,7 @@ func TestIsConnectionError_netOpError(t *testing.T) { Net: "tcp", Err: fmt.Errorf("connection refused"), } - if !isConnectionError(err) { + if !IsConnectionError(err) { t.Error("net.OpError 应判定为连接错误") } } @@ -367,14 +388,14 @@ func TestIsConnectionError_urlError(t *testing.T) { URL: "https://example.com", Err: fmt.Errorf("some error"), } - if !isConnectionError(err) { + if !IsConnectionError(err) { t.Error("url.Error 应判定为连接错误") } } func TestIsConnectionError_普通错误(t *testing.T) { err := fmt.Errorf("some random error") - if isConnectionError(err) { + if IsConnectionError(err) { t.Error("普通错误不应判定为连接错误") } } @@ -386,7 +407,7 @@ func TestIsConnectionError_包装的netOpError(t *testing.T) { Err: fmt.Errorf("connection refused"), } err := fmt.Errorf("wrapping: %w", inner) - if !isConnectionError(err) { + if !IsConnectionError(err) { t.Error("被包装的 net.OpError 应判定为连接错误") } } @@ -800,6 +821,12 @@ type redirectRoundTripper struct { transport http.RoundTripper } +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + func (rt *redirectRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { originalURL := req.URL.String() for prefix, target := range rt.redirects { @@ -1271,6 +1298,12 @@ func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) { if reqBody.Metadata.IDEType != "ANTIGRAVITY" { t.Errorf("IDEType 不匹配: got %s, want ANTIGRAVITY", reqBody.Metadata.IDEType) } + if strings.TrimSpace(reqBody.Metadata.IDEVersion) == "" { + t.Errorf("IDEVersion 不应为空") + } + if reqBody.Metadata.IDEName != "antigravity" { + t.Errorf("IDEName 不匹配: got %s, want antigravity", reqBody.Metadata.IDEName) + } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index 5bda31ac..8a8bed92 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -49,8 +49,8 @@ const ( antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" ) -// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4 -var defaultUserAgentVersion = "1.20.4" +// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5 +var defaultUserAgentVersion = "1.20.5" // defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置 var defaultClientSecret = "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" diff --git a/backend/internal/pkg/antigravity/oauth_test.go b/backend/internal/pkg/antigravity/oauth_test.go index f4630b09..3a093fe6 100644 --- a/backend/internal/pkg/antigravity/oauth_test.go +++ b/backend/internal/pkg/antigravity/oauth_test.go @@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) { if RedirectURI != "http://localhost:8085/callback" { t.Errorf("RedirectURI 不匹配: got %s", RedirectURI) } - if GetUserAgent() != "antigravity/1.20.4 windows/amd64" { + if GetUserAgent() != "antigravity/1.20.5 windows/amd64" { t.Errorf("UserAgent 不匹配: got %s", GetUserAgent()) } if SessionTTL != 30*time.Minute { diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 55cdd786..1b45e507 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -275,21 +275,6 @@ func filterOpenCodePrompt(text string) string { return "" } -// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表 -var systemBlockFilterPrefixes = []string{ - "x-anthropic-billing-header", -} - -// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串 -func filterSystemBlockByPrefix(text string) string { - for _, prefix := range systemBlockFilterPrefixes { - if strings.HasPrefix(text, prefix) { - return "" - } - } - return text -} - // buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致) func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent { var parts []GeminiPart @@ -306,8 +291,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans if strings.Contains(sysStr, "You are Antigravity") { userHasAntigravityIdentity = true } - // 过滤 OpenCode 默认提示词和黑名单前缀 - filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr)) + // 过滤 OpenCode 默认提示词 + filtered := filterOpenCodePrompt(sysStr) if filtered != "" { userSystemParts = append(userSystemParts, GeminiPart{Text: filtered}) } @@ -321,8 +306,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans if strings.Contains(block.Text, "You are Antigravity") { userHasAntigravityIdentity = true } - // 过滤 OpenCode 默认提示词和黑名单前缀 - filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text)) + // 过滤 OpenCode 默认提示词 + filtered := filterOpenCodePrompt(block.Text) if filtered != "" { userSystemParts = append(userSystemParts, GeminiPart{Text: filtered}) } diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go index f267e0e1..9e46295a 100644 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -2,7 +2,10 @@ package antigravity import ( "encoding/json" + "strings" "testing" + + "github.com/stretchr/testify/require" ) // TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理 @@ -349,3 +352,51 @@ func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) { }) } } + +func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t *testing.T) { + tests := []struct { + name string + system json.RawMessage + }{ + { + name: "system array", + system: json.RawMessage(`[{"type":"text","text":"x-anthropic-billing-header keep"}]`), + }, + { + name: "system string", + system: json.RawMessage(`"x-anthropic-billing-header keep"`), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claudeReq := &ClaudeRequest{ + Model: "claude-3-5-sonnet-latest", + System: tt.system, + Messages: []ClaudeMessage{ + { + Role: "user", + Content: json.RawMessage(`[{"type":"text","text":"hello"}]`), + }, + }, + } + + body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "gemini-2.5-flash", DefaultTransformOptions()) + require.NoError(t, err) + + var req V1InternalRequest + require.NoError(t, json.Unmarshal(body, &req)) + require.NotNil(t, req.Request.SystemInstruction) + + found := false + for _, part := range req.Request.SystemInstruction.Parts { + if strings.Contains(part.Text, "x-anthropic-billing-header keep") { + found = true + break + } + } + + require.True(t, found, "转换后的 systemInstruction 应保留 x-anthropic-billing-header 内容") + }) + } +} diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go index 2db65572..095305c2 100644 --- a/backend/internal/pkg/apicompat/anthropic_responses_test.go +++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go @@ -632,8 +632,8 @@ func TestAnthropicToResponses_ThinkingEnabled(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) require.NotNil(t, resp.Reasoning) - // thinking.type is ignored for effort; default xhigh applies. - assert.Equal(t, "xhigh", resp.Reasoning.Effort) + // thinking.type is ignored for effort; default high applies. + assert.Equal(t, "high", resp.Reasoning.Effort) assert.Equal(t, "auto", resp.Reasoning.Summary) assert.Contains(t, resp.Include, "reasoning.encrypted_content") assert.NotContains(t, resp.Include, "reasoning.summary") @@ -650,8 +650,8 @@ func TestAnthropicToResponses_ThinkingAdaptive(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) require.NotNil(t, resp.Reasoning) - // thinking.type is ignored for effort; default xhigh applies. - assert.Equal(t, "xhigh", resp.Reasoning.Effort) + // thinking.type is ignored for effort; default high applies. + assert.Equal(t, "high", resp.Reasoning.Effort) assert.Equal(t, "auto", resp.Reasoning.Summary) assert.NotContains(t, resp.Include, "reasoning.summary") } @@ -666,9 +666,9 @@ func TestAnthropicToResponses_ThinkingDisabled(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) - // Default effort applies (high → xhigh) even when thinking is disabled. + // Default effort applies (high → high) even when thinking is disabled. require.NotNil(t, resp.Reasoning) - assert.Equal(t, "xhigh", resp.Reasoning.Effort) + assert.Equal(t, "high", resp.Reasoning.Effort) } func TestAnthropicToResponses_NoThinking(t *testing.T) { @@ -680,9 +680,9 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) - // Default effort applies (high → xhigh) when no thinking/output_config is set. + // Default effort applies (high → high) when no thinking/output_config is set. require.NotNil(t, resp.Reasoning) - assert.Equal(t, "xhigh", resp.Reasoning.Effort) + assert.Equal(t, "high", resp.Reasoning.Effort) } // --------------------------------------------------------------------------- @@ -690,7 +690,7 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) { // --------------------------------------------------------------------------- func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) { - // Default is xhigh, but output_config.effort="low" overrides. low→low after mapping. + // Default is high, but output_config.effort="low" overrides. low→low after mapping. req := &AnthropicRequest{ Model: "gpt-5.2", MaxTokens: 1024, @@ -708,7 +708,7 @@ func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) { func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) { // No thinking field, but output_config.effort="medium" → creates reasoning. - // medium→high after mapping. + // medium→medium after 1:1 mapping. req := &AnthropicRequest{ Model: "gpt-5.2", MaxTokens: 1024, @@ -719,12 +719,12 @@ func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) require.NotNil(t, resp.Reasoning) - assert.Equal(t, "high", resp.Reasoning.Effort) + assert.Equal(t, "medium", resp.Reasoning.Effort) assert.Equal(t, "auto", resp.Reasoning.Summary) } func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) { - // output_config.effort="high" → mapped to "xhigh". + // output_config.effort="high" → mapped to "high" (1:1, both sides' default). req := &AnthropicRequest{ Model: "gpt-5.2", MaxTokens: 1024, @@ -732,6 +732,22 @@ func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) { OutputConfig: &AnthropicOutputConfig{Effort: "high"}, } + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "high", resp.Reasoning.Effort) + assert.Equal(t, "auto", resp.Reasoning.Summary) +} + +func TestAnthropicToResponses_OutputConfigMax(t *testing.T) { + // output_config.effort="max" → mapped to OpenAI's highest supported level "xhigh". + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + OutputConfig: &AnthropicOutputConfig{Effort: "max"}, + } + resp, err := AnthropicToResponses(req) require.NoError(t, err) require.NotNil(t, resp.Reasoning) @@ -740,7 +756,7 @@ func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) { } func TestAnthropicToResponses_NoOutputConfig(t *testing.T) { - // No output_config → default xhigh regardless of thinking.type. + // No output_config → default high regardless of thinking.type. req := &AnthropicRequest{ Model: "gpt-5.2", MaxTokens: 1024, @@ -751,11 +767,11 @@ func TestAnthropicToResponses_NoOutputConfig(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) require.NotNil(t, resp.Reasoning) - assert.Equal(t, "xhigh", resp.Reasoning.Effort) + assert.Equal(t, "high", resp.Reasoning.Effort) } func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) { - // output_config present but effort empty (e.g. only format set) → default xhigh. + // output_config present but effort empty (e.g. only format set) → default high. req := &AnthropicRequest{ Model: "gpt-5.2", MaxTokens: 1024, @@ -766,7 +782,7 @@ func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) require.NotNil(t, resp.Reasoning) - assert.Equal(t, "xhigh", resp.Reasoning.Effort) + assert.Equal(t, "high", resp.Reasoning.Effort) } // --------------------------------------------------------------------------- @@ -1008,3 +1024,114 @@ func TestAnthropicToResponses_ImageEmptyMediaType(t *testing.T) { // Should default to image/png when media_type is empty. assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL) } + +// --------------------------------------------------------------------------- +// normalizeToolParameters tests +// --------------------------------------------------------------------------- + +func TestNormalizeToolParameters(t *testing.T) { + tests := []struct { + name string + input json.RawMessage + expected string + }{ + { + name: "nil input", + input: nil, + expected: `{"type":"object","properties":{}}`, + }, + { + name: "empty input", + input: json.RawMessage(``), + expected: `{"type":"object","properties":{}}`, + }, + { + name: "null input", + input: json.RawMessage(`null`), + expected: `{"type":"object","properties":{}}`, + }, + { + name: "object without properties", + input: json.RawMessage(`{"type":"object"}`), + expected: `{"type":"object","properties":{}}`, + }, + { + name: "object with properties", + input: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`), + expected: `{"type":"object","properties":{"city":{"type":"string"}}}`, + }, + { + name: "non-object type", + input: json.RawMessage(`{"type":"string"}`), + expected: `{"type":"string"}`, + }, + { + name: "object with additional fields preserved", + input: json.RawMessage(`{"type":"object","required":["name"]}`), + expected: `{"type":"object","required":["name"],"properties":{}}`, + }, + { + name: "invalid JSON passthrough", + input: json.RawMessage(`not json`), + expected: `not json`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := normalizeToolParameters(tt.input) + if tt.name == "invalid JSON passthrough" { + assert.Equal(t, tt.expected, string(result)) + } else { + assert.JSONEq(t, tt.expected, string(result)) + } + }) + } +} + +func TestAnthropicToResponses_ToolWithoutProperties(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"Hello"`)}, + }, + Tools: []AnthropicTool{ + {Name: "mcp__pencil__get_style_guide_tags", Description: "Get style tags", InputSchema: json.RawMessage(`{"type":"object"}`)}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + require.Len(t, resp.Tools, 1) + assert.Equal(t, "function", resp.Tools[0].Type) + assert.Equal(t, "mcp__pencil__get_style_guide_tags", resp.Tools[0].Name) + + // Parameters must have "properties" field after normalization. + var params map[string]json.RawMessage + require.NoError(t, json.Unmarshal(resp.Tools[0].Parameters, ¶ms)) + assert.Contains(t, params, "properties") +} + +func TestAnthropicToResponses_ToolWithNilSchema(t *testing.T) { + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"Hello"`)}, + }, + Tools: []AnthropicTool{ + {Name: "simple_tool", Description: "A tool"}, + }, + } + + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + + require.Len(t, resp.Tools, 1) + var params map[string]json.RawMessage + require.NoError(t, json.Unmarshal(resp.Tools[0].Parameters, ¶ms)) + assert.JSONEq(t, `"object"`, string(params["type"])) + assert.JSONEq(t, `{}`, string(params["properties"])) +} diff --git a/backend/internal/pkg/apicompat/anthropic_to_responses.go b/backend/internal/pkg/apicompat/anthropic_to_responses.go index 0a747869..485262e8 100644 --- a/backend/internal/pkg/apicompat/anthropic_to_responses.go +++ b/backend/internal/pkg/apicompat/anthropic_to_responses.go @@ -46,9 +46,10 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) { } // Determine reasoning effort: only output_config.effort controls the - // level; thinking.type is ignored. Default is xhigh when unset. - // Anthropic levels map to OpenAI: low→low, medium→high, high→xhigh. - effort := "high" // default → maps to xhigh + // level; thinking.type is ignored. Default is high when unset (both + // Anthropic and OpenAI default to high). + // Anthropic levels map 1:1 to OpenAI: low→low, medium→medium, high→high, max→xhigh. + effort := "high" // default → both sides' default if req.OutputConfig != nil && req.OutputConfig.Effort != "" { effort = req.OutputConfig.Effort } @@ -380,18 +381,19 @@ func extractAnthropicTextFromBlocks(blocks []AnthropicContentBlock) string { // mapAnthropicEffortToResponses converts Anthropic reasoning effort levels to // OpenAI Responses API effort levels. // +// Both APIs default to "high". The mapping is 1:1 for shared levels; +// only Anthropic's "max" (Opus 4.6 exclusive) maps to OpenAI's "xhigh" +// (GPT-5.2+ exclusive) as both represent the highest reasoning tier. +// // low → low -// medium → high -// high → xhigh +// medium → medium +// high → high +// max → xhigh func mapAnthropicEffortToResponses(effort string) string { - switch effort { - case "medium": - return "high" - case "high": + if effort == "max" { return "xhigh" - default: - return effort // "low" and any unknown values pass through unchanged } + return effort // low→low, medium→medium, high→high, unknown→passthrough } // convertAnthropicToolsToResponses maps Anthropic tool definitions to @@ -409,8 +411,41 @@ func convertAnthropicToolsToResponses(tools []AnthropicTool) []ResponsesTool { Type: "function", Name: t.Name, Description: t.Description, - Parameters: t.InputSchema, + Parameters: normalizeToolParameters(t.InputSchema), }) } return out } + +// normalizeToolParameters ensures the tool parameter schema is valid for +// OpenAI's Responses API, which requires "properties" on object schemas. +// +// - nil/empty → {"type":"object","properties":{}} +// - type=object without properties → adds "properties": {} +// - otherwise → returned unchanged +func normalizeToolParameters(schema json.RawMessage) json.RawMessage { + if len(schema) == 0 || string(schema) == "null" { + return json.RawMessage(`{"type":"object","properties":{}}`) + } + + var m map[string]json.RawMessage + if err := json.Unmarshal(schema, &m); err != nil { + return schema + } + + typ := m["type"] + if string(typ) != `"object"` { + return schema + } + + if _, ok := m["properties"]; ok { + return schema + } + + m["properties"] = json.RawMessage(`{}`) + out, err := json.Marshal(m) + if err != nil { + return schema + } + return out +} diff --git a/backend/internal/pkg/apicompat/anthropic_to_responses_response.go b/backend/internal/pkg/apicompat/anthropic_to_responses_response.go new file mode 100644 index 00000000..9290e399 --- /dev/null +++ b/backend/internal/pkg/apicompat/anthropic_to_responses_response.go @@ -0,0 +1,521 @@ +package apicompat + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "time" +) + +// --------------------------------------------------------------------------- +// Non-streaming: AnthropicResponse → ResponsesResponse +// --------------------------------------------------------------------------- + +// AnthropicToResponsesResponse converts an Anthropic Messages response into a +// Responses API response. This is the reverse of ResponsesToAnthropic and +// enables Anthropic upstream responses to be returned in OpenAI Responses format. +func AnthropicToResponsesResponse(resp *AnthropicResponse) *ResponsesResponse { + id := resp.ID + if id == "" { + id = generateResponsesID() + } + + out := &ResponsesResponse{ + ID: id, + Object: "response", + Model: resp.Model, + } + + var outputs []ResponsesOutput + var msgParts []ResponsesContentPart + + for _, block := range resp.Content { + switch block.Type { + case "thinking": + if block.Thinking != "" { + outputs = append(outputs, ResponsesOutput{ + Type: "reasoning", + ID: generateItemID(), + Summary: []ResponsesSummary{{ + Type: "summary_text", + Text: block.Thinking, + }}, + }) + } + case "text": + if block.Text != "" { + msgParts = append(msgParts, ResponsesContentPart{ + Type: "output_text", + Text: block.Text, + }) + } + case "tool_use": + args := "{}" + if len(block.Input) > 0 { + args = string(block.Input) + } + outputs = append(outputs, ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: toResponsesCallID(block.ID), + Name: block.Name, + Arguments: args, + Status: "completed", + }) + } + } + + // Assemble message output item from text parts + if len(msgParts) > 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "message", + ID: generateItemID(), + Role: "assistant", + Content: msgParts, + Status: "completed", + }) + } + + if len(outputs) == 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "message", + ID: generateItemID(), + Role: "assistant", + Content: []ResponsesContentPart{{Type: "output_text", Text: ""}}, + Status: "completed", + }) + } + out.Output = outputs + + // Map stop_reason → status + out.Status = anthropicStopReasonToResponsesStatus(resp.StopReason, resp.Content) + if out.Status == "incomplete" { + out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"} + } + + // Usage + out.Usage = &ResponsesUsage{ + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens, + } + if resp.Usage.CacheReadInputTokens > 0 { + out.Usage.InputTokensDetails = &ResponsesInputTokensDetails{ + CachedTokens: resp.Usage.CacheReadInputTokens, + } + } + + return out +} + +// anthropicStopReasonToResponsesStatus maps Anthropic stop_reason to Responses status. +func anthropicStopReasonToResponsesStatus(stopReason string, blocks []AnthropicContentBlock) string { + switch stopReason { + case "max_tokens": + return "incomplete" + case "end_turn", "tool_use", "stop_sequence": + return "completed" + default: + return "completed" + } +} + +// --------------------------------------------------------------------------- +// Streaming: AnthropicStreamEvent → []ResponsesStreamEvent (stateful converter) +// --------------------------------------------------------------------------- + +// AnthropicEventToResponsesState tracks state for converting a sequence of +// Anthropic SSE events into Responses SSE events. +type AnthropicEventToResponsesState struct { + ResponseID string + Model string + Created int64 + SequenceNumber int + + // CreatedSent tracks whether response.created has been emitted. + CreatedSent bool + // CompletedSent tracks whether the terminal event has been emitted. + CompletedSent bool + + // Current output tracking + OutputIndex int + CurrentItemID string + CurrentItemType string // "message" | "function_call" | "reasoning" + + // For message output: accumulate text parts + ContentIndex int + + // For function_call: track per-output info + CurrentCallID string + CurrentName string + + // Usage from message_delta + InputTokens int + OutputTokens int + CacheReadInputTokens int +} + +// NewAnthropicEventToResponsesState returns an initialised stream state. +func NewAnthropicEventToResponsesState() *AnthropicEventToResponsesState { + return &AnthropicEventToResponsesState{ + Created: time.Now().Unix(), + } +} + +// AnthropicEventToResponsesEvents converts a single Anthropic SSE event into +// zero or more Responses SSE events, updating state as it goes. +func AnthropicEventToResponsesEvents( + evt *AnthropicStreamEvent, + state *AnthropicEventToResponsesState, +) []ResponsesStreamEvent { + switch evt.Type { + case "message_start": + return anthToResHandleMessageStart(evt, state) + case "content_block_start": + return anthToResHandleContentBlockStart(evt, state) + case "content_block_delta": + return anthToResHandleContentBlockDelta(evt, state) + case "content_block_stop": + return anthToResHandleContentBlockStop(evt, state) + case "message_delta": + return anthToResHandleMessageDelta(evt, state) + case "message_stop": + return anthToResHandleMessageStop(state) + default: + return nil + } +} + +// FinalizeAnthropicResponsesStream emits synthetic termination events if the +// stream ended without a proper message_stop. +func FinalizeAnthropicResponsesStream(state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if !state.CreatedSent || state.CompletedSent { + return nil + } + + var events []ResponsesStreamEvent + + // Close any open item + events = append(events, closeCurrentResponsesItem(state)...) + + // Emit response.completed + events = append(events, makeResponsesCompletedEvent(state, "completed", nil)) + state.CompletedSent = true + return events +} + +// ResponsesEventToSSE formats a ResponsesStreamEvent as an SSE data line. +func ResponsesEventToSSE(evt ResponsesStreamEvent) (string, error) { + data, err := json.Marshal(evt) + if err != nil { + return "", err + } + return fmt.Sprintf("event: %s\ndata: %s\n\n", evt.Type, data), nil +} + +// --- internal handlers --- + +func anthToResHandleMessageStart(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if evt.Message != nil { + state.ResponseID = evt.Message.ID + if state.Model == "" { + state.Model = evt.Message.Model + } + if evt.Message.Usage.InputTokens > 0 { + state.InputTokens = evt.Message.Usage.InputTokens + } + } + + if state.CreatedSent { + return nil + } + state.CreatedSent = true + + // Emit response.created + return []ResponsesStreamEvent{makeResponsesCreatedEvent(state)} +} + +func anthToResHandleContentBlockStart(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if evt.ContentBlock == nil { + return nil + } + + var events []ResponsesStreamEvent + + switch evt.ContentBlock.Type { + case "thinking": + state.CurrentItemID = generateItemID() + state.CurrentItemType = "reasoning" + state.ContentIndex = 0 + + events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + Item: &ResponsesOutput{ + Type: "reasoning", + ID: state.CurrentItemID, + }, + })) + + case "text": + // If we don't have an open message item, open one + if state.CurrentItemType != "message" { + state.CurrentItemID = generateItemID() + state.CurrentItemType = "message" + state.ContentIndex = 0 + + events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + Item: &ResponsesOutput{ + Type: "message", + ID: state.CurrentItemID, + Role: "assistant", + Status: "in_progress", + }, + })) + } + + case "tool_use": + // Close previous item if any + events = append(events, closeCurrentResponsesItem(state)...) + + state.CurrentItemID = generateItemID() + state.CurrentItemType = "function_call" + state.CurrentCallID = toResponsesCallID(evt.ContentBlock.ID) + state.CurrentName = evt.ContentBlock.Name + + events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + Item: &ResponsesOutput{ + Type: "function_call", + ID: state.CurrentItemID, + CallID: state.CurrentCallID, + Name: state.CurrentName, + Status: "in_progress", + }, + })) + } + + return events +} + +func anthToResHandleContentBlockDelta(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if evt.Delta == nil { + return nil + } + + switch evt.Delta.Type { + case "text_delta": + if evt.Delta.Text == "" { + return nil + } + return []ResponsesStreamEvent{makeResponsesEvent(state, "response.output_text.delta", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + ContentIndex: state.ContentIndex, + Delta: evt.Delta.Text, + ItemID: state.CurrentItemID, + })} + + case "thinking_delta": + if evt.Delta.Thinking == "" { + return nil + } + return []ResponsesStreamEvent{makeResponsesEvent(state, "response.reasoning_summary_text.delta", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + SummaryIndex: 0, + Delta: evt.Delta.Thinking, + ItemID: state.CurrentItemID, + })} + + case "input_json_delta": + if evt.Delta.PartialJSON == "" { + return nil + } + return []ResponsesStreamEvent{makeResponsesEvent(state, "response.function_call_arguments.delta", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + Delta: evt.Delta.PartialJSON, + ItemID: state.CurrentItemID, + CallID: state.CurrentCallID, + Name: state.CurrentName, + })} + + case "signature_delta": + // Anthropic signature deltas have no Responses equivalent; skip + return nil + } + + return nil +} + +func anthToResHandleContentBlockStop(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + switch state.CurrentItemType { + case "reasoning": + // Emit reasoning summary done + output item done + events := []ResponsesStreamEvent{ + makeResponsesEvent(state, "response.reasoning_summary_text.done", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + SummaryIndex: 0, + ItemID: state.CurrentItemID, + }), + } + events = append(events, closeCurrentResponsesItem(state)...) + return events + + case "function_call": + // Emit function_call_arguments.done + output item done + events := []ResponsesStreamEvent{ + makeResponsesEvent(state, "response.function_call_arguments.done", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + ItemID: state.CurrentItemID, + CallID: state.CurrentCallID, + Name: state.CurrentName, + }), + } + events = append(events, closeCurrentResponsesItem(state)...) + return events + + case "message": + // Emit output_text.done (text block is done, but message item stays open for potential more blocks) + return []ResponsesStreamEvent{ + makeResponsesEvent(state, "response.output_text.done", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + ContentIndex: state.ContentIndex, + ItemID: state.CurrentItemID, + }), + } + } + + return nil +} + +func anthToResHandleMessageDelta(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + // Update usage + if evt.Usage != nil { + state.OutputTokens = evt.Usage.OutputTokens + if evt.Usage.CacheReadInputTokens > 0 { + state.CacheReadInputTokens = evt.Usage.CacheReadInputTokens + } + } + + return nil +} + +func anthToResHandleMessageStop(state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if state.CompletedSent { + return nil + } + + var events []ResponsesStreamEvent + + // Close any open item + events = append(events, closeCurrentResponsesItem(state)...) + + // Determine status + status := "completed" + var incompleteDetails *ResponsesIncompleteDetails + + // Emit response.completed + events = append(events, makeResponsesCompletedEvent(state, status, incompleteDetails)) + state.CompletedSent = true + return events +} + +// --- helper functions --- + +func closeCurrentResponsesItem(state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if state.CurrentItemType == "" { + return nil + } + + itemType := state.CurrentItemType + itemID := state.CurrentItemID + + // Reset + state.CurrentItemType = "" + state.CurrentItemID = "" + state.CurrentCallID = "" + state.CurrentName = "" + state.OutputIndex++ + state.ContentIndex = 0 + + return []ResponsesStreamEvent{makeResponsesEvent(state, "response.output_item.done", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex - 1, // Use the index before increment + Item: &ResponsesOutput{ + Type: itemType, + ID: itemID, + Status: "completed", + }, + })} +} + +func makeResponsesCreatedEvent(state *AnthropicEventToResponsesState) ResponsesStreamEvent { + seq := state.SequenceNumber + state.SequenceNumber++ + return ResponsesStreamEvent{ + Type: "response.created", + SequenceNumber: seq, + Response: &ResponsesResponse{ + ID: state.ResponseID, + Object: "response", + Model: state.Model, + Status: "in_progress", + Output: []ResponsesOutput{}, + }, + } +} + +func makeResponsesCompletedEvent( + state *AnthropicEventToResponsesState, + status string, + incompleteDetails *ResponsesIncompleteDetails, +) ResponsesStreamEvent { + seq := state.SequenceNumber + state.SequenceNumber++ + + usage := &ResponsesUsage{ + InputTokens: state.InputTokens, + OutputTokens: state.OutputTokens, + TotalTokens: state.InputTokens + state.OutputTokens, + } + if state.CacheReadInputTokens > 0 { + usage.InputTokensDetails = &ResponsesInputTokensDetails{ + CachedTokens: state.CacheReadInputTokens, + } + } + + return ResponsesStreamEvent{ + Type: "response.completed", + SequenceNumber: seq, + Response: &ResponsesResponse{ + ID: state.ResponseID, + Object: "response", + Model: state.Model, + Status: status, + Output: []ResponsesOutput{}, // Simplified; full output tracking would add complexity + Usage: usage, + IncompleteDetails: incompleteDetails, + }, + } +} + +func makeResponsesEvent(state *AnthropicEventToResponsesState, eventType string, template *ResponsesStreamEvent) ResponsesStreamEvent { + seq := state.SequenceNumber + state.SequenceNumber++ + + evt := *template + evt.Type = eventType + evt.SequenceNumber = seq + return evt +} + +func generateResponsesID() string { + b := make([]byte, 12) + _, _ = rand.Read(b) + return "resp_" + hex.EncodeToString(b) +} + +func generateItemID() string { + b := make([]byte, 12) + _, _ = rand.Read(b) + return "item_" + hex.EncodeToString(b) +} diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go index 8b819033..f54a4a02 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -181,6 +181,35 @@ func TestChatCompletionsToResponses_ImageURL(t *testing.T) { assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL) } +func TestChatCompletionsToResponses_SystemArrayContent(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "system", Content: json.RawMessage(`[{"type":"text","text":"You are a careful visual assistant."}]`)}, + {Role: "user", Content: json.RawMessage(`[{"type":"text","text":"Describe this image"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]`)}, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 2) + + var systemParts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &systemParts)) + require.Len(t, systemParts, 1) + assert.Equal(t, "input_text", systemParts[0].Type) + assert.Equal(t, "You are a careful visual assistant.", systemParts[0].Text) + + var userParts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[1].Content, &userParts)) + require.Len(t, userParts, 2) + assert.Equal(t, "input_image", userParts[1].Type) + assert.Equal(t, "data:image/png;base64,abc123", userParts[1].ImageURL) +} + func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) { req := &ChatCompletionsRequest{ Model: "gpt-4o", @@ -398,6 +427,45 @@ func TestResponsesToChatCompletions_Reasoning(t *testing.T) { assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent) } +func TestChatCompletionsToResponses_ToolArrayContent(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Use the tool"`)}, + { + Role: "assistant", + ToolCalls: []ChatToolCall{ + { + ID: "call_1", + Type: "function", + Function: ChatFunctionCall{ + Name: "inspect_image", + Arguments: `{}`, + }, + }, + }, + }, + { + Role: "tool", + ToolCallID: "call_1", + Content: json.RawMessage( + `[{"type":"text","text":"image width: 100"},{"type":"image_url","image_url":{"url":"data:image/png;base64,ignored"}},{"type":"text","text":"; image height: 200"}]`, + ), + }, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 3) + assert.Equal(t, "function_call_output", items[2].Type) + assert.Equal(t, "call_1", items[2].CallID) + assert.Equal(t, "image width: 100; image height: 200", items[2].Output) +} + func TestResponsesToChatCompletions_Incomplete(t *testing.T) { resp := &ResponsesResponse{ ID: "resp_inc", diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go index c4a9e773..6cdd012a 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go +++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go @@ -6,6 +6,11 @@ import ( "strings" ) +type chatMessageContent struct { + Text *string + Parts []ChatContentPart +} + // ChatCompletionsToResponses converts a Chat Completions request into a // Responses API request. The upstream always streams, so Stream is forced to // true. store is always false and reasoning.encrypted_content is always @@ -113,11 +118,11 @@ func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) { // chatSystemToResponses converts a system message. func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) { - text, err := parseChatContent(m.Content) + parsed, err := parseChatMessageContent(m.Content) if err != nil { return nil, err } - content, err := json.Marshal(text) + content, err := marshalChatInputContent(parsed) if err != nil { return nil, err } @@ -127,39 +132,11 @@ func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) { // chatUserToResponses converts a user message, handling both plain strings and // multi-modal content arrays. func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) { - // Try plain string first. - var s string - if err := json.Unmarshal(m.Content, &s); err == nil { - content, _ := json.Marshal(s) - return []ResponsesInputItem{{Role: "user", Content: content}}, nil - } - - var parts []ChatContentPart - if err := json.Unmarshal(m.Content, &parts); err != nil { + parsed, err := parseChatMessageContent(m.Content) + if err != nil { return nil, fmt.Errorf("parse user content: %w", err) } - - var responseParts []ResponsesContentPart - for _, p := range parts { - switch p.Type { - case "text": - if p.Text != "" { - responseParts = append(responseParts, ResponsesContentPart{ - Type: "input_text", - Text: p.Text, - }) - } - case "image_url": - if p.ImageURL != nil && p.ImageURL.URL != "" { - responseParts = append(responseParts, ResponsesContentPart{ - Type: "input_image", - ImageURL: p.ImageURL.URL, - }) - } - } - } - - content, err := json.Marshal(responseParts) + content, err := marshalChatInputContent(parsed) if err != nil { return nil, err } @@ -312,16 +289,79 @@ func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) { } // parseChatContent returns the string value of a ChatMessage Content field. -// Content must be a JSON string. Returns "" if content is null or empty. +// Content can be a JSON string or an array of typed parts. Array content is +// flattened to text by concatenating text parts and ignoring non-text parts. func parseChatContent(raw json.RawMessage) (string, error) { + parsed, err := parseChatMessageContent(raw) + if err != nil { + return "", err + } + if parsed.Text != nil { + return *parsed.Text, nil + } + return flattenChatContentParts(parsed.Parts), nil +} + +func parseChatMessageContent(raw json.RawMessage) (chatMessageContent, error) { if len(raw) == 0 { - return "", nil + return chatMessageContent{Text: stringPtr("")}, nil } + var s string - if err := json.Unmarshal(raw, &s); err != nil { - return "", fmt.Errorf("parse content as string: %w", err) + if err := json.Unmarshal(raw, &s); err == nil { + return chatMessageContent{Text: &s}, nil } - return s, nil + + var parts []ChatContentPart + if err := json.Unmarshal(raw, &parts); err == nil { + return chatMessageContent{Parts: parts}, nil + } + + return chatMessageContent{}, fmt.Errorf("parse content as string or parts array") +} + +func marshalChatInputContent(content chatMessageContent) (json.RawMessage, error) { + if content.Text != nil { + return json.Marshal(*content.Text) + } + return json.Marshal(convertChatContentPartsToResponses(content.Parts)) +} + +func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesContentPart { + var responseParts []ResponsesContentPart + for _, p := range parts { + switch p.Type { + case "text": + if p.Text != "" { + responseParts = append(responseParts, ResponsesContentPart{ + Type: "input_text", + Text: p.Text, + }) + } + case "image_url": + if p.ImageURL != nil && p.ImageURL.URL != "" { + responseParts = append(responseParts, ResponsesContentPart{ + Type: "input_image", + ImageURL: p.ImageURL.URL, + }) + } + } + } + return responseParts +} + +func flattenChatContentParts(parts []ChatContentPart) string { + var textParts []string + for _, p := range parts { + if p.Type == "text" && p.Text != "" { + textParts = append(textParts, p.Text) + } + } + return strings.Join(textParts, "") +} + +func stringPtr(s string) *string { + return &s } // convertChatToolsToResponses maps Chat Completions tool definitions and legacy diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic_request.go b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go new file mode 100644 index 00000000..f0a5b07e --- /dev/null +++ b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go @@ -0,0 +1,464 @@ +package apicompat + +import ( + "encoding/json" + "fmt" + "strings" +) + +// ResponsesToAnthropicRequest converts a Responses API request into an +// Anthropic Messages request. This is the reverse of AnthropicToResponses and +// enables Anthropic platform groups to accept OpenAI Responses API requests +// by converting them to the native /v1/messages format before forwarding upstream. +func ResponsesToAnthropicRequest(req *ResponsesRequest) (*AnthropicRequest, error) { + system, messages, err := convertResponsesInputToAnthropic(req.Input) + if err != nil { + return nil, err + } + + out := &AnthropicRequest{ + Model: req.Model, + Messages: messages, + Temperature: req.Temperature, + TopP: req.TopP, + Stream: req.Stream, + } + + if len(system) > 0 { + out.System = system + } + + // max_output_tokens → max_tokens + if req.MaxOutputTokens != nil && *req.MaxOutputTokens > 0 { + out.MaxTokens = *req.MaxOutputTokens + } + if out.MaxTokens == 0 { + // Anthropic requires max_tokens; default to a sensible value. + out.MaxTokens = 8192 + } + + // Convert tools + if len(req.Tools) > 0 { + out.Tools = convertResponsesToAnthropicTools(req.Tools) + } + + // Convert tool_choice (reverse of convertAnthropicToolChoiceToResponses) + if len(req.ToolChoice) > 0 { + tc, err := convertResponsesToAnthropicToolChoice(req.ToolChoice) + if err != nil { + return nil, fmt.Errorf("convert tool_choice: %w", err) + } + out.ToolChoice = tc + } + + // reasoning.effort → output_config.effort + thinking + if req.Reasoning != nil && req.Reasoning.Effort != "" { + effort := mapResponsesEffortToAnthropic(req.Reasoning.Effort) + out.OutputConfig = &AnthropicOutputConfig{Effort: effort} + // Enable thinking for non-low efforts + if effort != "low" { + out.Thinking = &AnthropicThinking{ + Type: "enabled", + BudgetTokens: defaultThinkingBudget(effort), + } + } + } + + return out, nil +} + +// defaultThinkingBudget returns a sensible thinking budget based on effort level. +func defaultThinkingBudget(effort string) int { + switch effort { + case "low": + return 1024 + case "medium": + return 4096 + case "high": + return 10240 + case "max": + return 32768 + default: + return 10240 + } +} + +// mapResponsesEffortToAnthropic converts OpenAI Responses reasoning effort to +// Anthropic effort levels. Reverse of mapAnthropicEffortToResponses. +// +// low → low +// medium → medium +// high → high +// xhigh → max +func mapResponsesEffortToAnthropic(effort string) string { + if effort == "xhigh" { + return "max" + } + return effort // low→low, medium→medium, high→high, unknown→passthrough +} + +// convertResponsesInputToAnthropic extracts system prompt and messages from +// a Responses API input array. Returns the system as raw JSON (for Anthropic's +// polymorphic system field) and a list of Anthropic messages. +func convertResponsesInputToAnthropic(inputRaw json.RawMessage) (json.RawMessage, []AnthropicMessage, error) { + // Try as plain string input. + var inputStr string + if err := json.Unmarshal(inputRaw, &inputStr); err == nil { + content, _ := json.Marshal(inputStr) + return nil, []AnthropicMessage{{Role: "user", Content: content}}, nil + } + + var items []ResponsesInputItem + if err := json.Unmarshal(inputRaw, &items); err != nil { + return nil, nil, fmt.Errorf("parse responses input: %w", err) + } + + var system json.RawMessage + var messages []AnthropicMessage + + for _, item := range items { + switch { + case item.Role == "system": + // System prompt → Anthropic system field + text := extractTextFromContent(item.Content) + if text != "" { + system, _ = json.Marshal(text) + } + + case item.Type == "function_call": + // function_call → assistant message with tool_use block + input := json.RawMessage("{}") + if item.Arguments != "" { + input = json.RawMessage(item.Arguments) + } + block := AnthropicContentBlock{ + Type: "tool_use", + ID: fromResponsesCallIDToAnthropic(item.CallID), + Name: item.Name, + Input: input, + } + blockJSON, _ := json.Marshal([]AnthropicContentBlock{block}) + messages = append(messages, AnthropicMessage{ + Role: "assistant", + Content: blockJSON, + }) + + case item.Type == "function_call_output": + // function_call_output → user message with tool_result block + outputContent := item.Output + if outputContent == "" { + outputContent = "(empty)" + } + contentJSON, _ := json.Marshal(outputContent) + block := AnthropicContentBlock{ + Type: "tool_result", + ToolUseID: fromResponsesCallIDToAnthropic(item.CallID), + Content: contentJSON, + } + blockJSON, _ := json.Marshal([]AnthropicContentBlock{block}) + messages = append(messages, AnthropicMessage{ + Role: "user", + Content: blockJSON, + }) + + case item.Role == "user": + content, err := convertResponsesUserToAnthropicContent(item.Content) + if err != nil { + return nil, nil, err + } + messages = append(messages, AnthropicMessage{ + Role: "user", + Content: content, + }) + + case item.Role == "assistant": + content, err := convertResponsesAssistantToAnthropicContent(item.Content) + if err != nil { + return nil, nil, err + } + messages = append(messages, AnthropicMessage{ + Role: "assistant", + Content: content, + }) + + default: + // Unknown role/type — attempt as user message + if item.Content != nil { + messages = append(messages, AnthropicMessage{ + Role: "user", + Content: item.Content, + }) + } + } + } + + // Merge consecutive same-role messages (Anthropic requires alternating roles) + messages = mergeConsecutiveMessages(messages) + + return system, messages, nil +} + +// extractTextFromContent extracts text from a content field that may be a +// plain string or an array of content parts. +func extractTextFromContent(raw json.RawMessage) string { + if len(raw) == 0 { + return "" + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s + } + var parts []ResponsesContentPart + if err := json.Unmarshal(raw, &parts); err == nil { + var texts []string + for _, p := range parts { + if (p.Type == "input_text" || p.Type == "output_text" || p.Type == "text") && p.Text != "" { + texts = append(texts, p.Text) + } + } + return strings.Join(texts, "\n\n") + } + return "" +} + +// convertResponsesUserToAnthropicContent converts a Responses user message +// content field into Anthropic content blocks JSON. +func convertResponsesUserToAnthropicContent(raw json.RawMessage) (json.RawMessage, error) { + if len(raw) == 0 { + return json.Marshal("") // empty string content + } + + // Try plain string. + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return json.Marshal(s) + } + + // Array of content parts → Anthropic content blocks. + var parts []ResponsesContentPart + if err := json.Unmarshal(raw, &parts); err != nil { + // Pass through as-is if we can't parse + return raw, nil + } + + var blocks []AnthropicContentBlock + for _, p := range parts { + switch p.Type { + case "input_text", "text": + if p.Text != "" { + blocks = append(blocks, AnthropicContentBlock{ + Type: "text", + Text: p.Text, + }) + } + case "input_image": + src := dataURIToAnthropicImageSource(p.ImageURL) + if src != nil { + blocks = append(blocks, AnthropicContentBlock{ + Type: "image", + Source: src, + }) + } + } + } + + if len(blocks) == 0 { + return json.Marshal("") + } + return json.Marshal(blocks) +} + +// convertResponsesAssistantToAnthropicContent converts a Responses assistant +// message content field into Anthropic content blocks JSON. +func convertResponsesAssistantToAnthropicContent(raw json.RawMessage) (json.RawMessage, error) { + if len(raw) == 0 { + return json.Marshal([]AnthropicContentBlock{{Type: "text", Text: ""}}) + } + + // Try plain string. + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return json.Marshal([]AnthropicContentBlock{{Type: "text", Text: s}}) + } + + // Array of content parts → Anthropic content blocks. + var parts []ResponsesContentPart + if err := json.Unmarshal(raw, &parts); err != nil { + return raw, nil + } + + var blocks []AnthropicContentBlock + for _, p := range parts { + switch p.Type { + case "output_text", "text": + if p.Text != "" { + blocks = append(blocks, AnthropicContentBlock{ + Type: "text", + Text: p.Text, + }) + } + } + } + + if len(blocks) == 0 { + blocks = append(blocks, AnthropicContentBlock{Type: "text", Text: ""}) + } + return json.Marshal(blocks) +} + +// fromResponsesCallIDToAnthropic converts an OpenAI function call ID back to +// Anthropic format. Reverses toResponsesCallID. +func fromResponsesCallIDToAnthropic(id string) string { + // If it has our "fc_" prefix wrapping a known Anthropic prefix, strip it + if after, ok := strings.CutPrefix(id, "fc_"); ok { + if strings.HasPrefix(after, "toolu_") || strings.HasPrefix(after, "call_") { + return after + } + } + // Generate a synthetic Anthropic tool ID + if !strings.HasPrefix(id, "toolu_") && !strings.HasPrefix(id, "call_") { + return "toolu_" + id + } + return id +} + +// dataURIToAnthropicImageSource parses a data URI into an AnthropicImageSource. +func dataURIToAnthropicImageSource(dataURI string) *AnthropicImageSource { + if !strings.HasPrefix(dataURI, "data:") { + return nil + } + // Format: data:;base64, + rest := strings.TrimPrefix(dataURI, "data:") + semicolonIdx := strings.Index(rest, ";") + if semicolonIdx < 0 { + return nil + } + mediaType := rest[:semicolonIdx] + rest = rest[semicolonIdx+1:] + if !strings.HasPrefix(rest, "base64,") { + return nil + } + data := strings.TrimPrefix(rest, "base64,") + return &AnthropicImageSource{ + Type: "base64", + MediaType: mediaType, + Data: data, + } +} + +// mergeConsecutiveMessages merges consecutive messages with the same role +// because Anthropic requires alternating user/assistant turns. +func mergeConsecutiveMessages(messages []AnthropicMessage) []AnthropicMessage { + if len(messages) <= 1 { + return messages + } + + var merged []AnthropicMessage + for _, msg := range messages { + if len(merged) == 0 || merged[len(merged)-1].Role != msg.Role { + merged = append(merged, msg) + continue + } + + // Same role — merge content arrays + last := &merged[len(merged)-1] + lastBlocks := parseContentBlocks(last.Content) + newBlocks := parseContentBlocks(msg.Content) + combined := append(lastBlocks, newBlocks...) + last.Content, _ = json.Marshal(combined) + } + return merged +} + +// parseContentBlocks attempts to parse content as []AnthropicContentBlock. +// If it's a string, wraps it in a text block. +func parseContentBlocks(raw json.RawMessage) []AnthropicContentBlock { + var blocks []AnthropicContentBlock + if err := json.Unmarshal(raw, &blocks); err == nil { + return blocks + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return []AnthropicContentBlock{{Type: "text", Text: s}} + } + return nil +} + +// convertResponsesToAnthropicTools maps Responses API tools to Anthropic format. +// Reverse of convertAnthropicToolsToResponses. +func convertResponsesToAnthropicTools(tools []ResponsesTool) []AnthropicTool { + var out []AnthropicTool + for _, t := range tools { + switch t.Type { + case "web_search": + out = append(out, AnthropicTool{ + Type: "web_search_20250305", + Name: "web_search", + }) + case "function": + out = append(out, AnthropicTool{ + Name: t.Name, + Description: t.Description, + InputSchema: normalizeAnthropicInputSchema(t.Parameters), + }) + default: + // Pass through unknown tool types + out = append(out, AnthropicTool{ + Type: t.Type, + Name: t.Name, + Description: t.Description, + InputSchema: t.Parameters, + }) + } + } + return out +} + +// normalizeAnthropicInputSchema ensures the input_schema has a "type" field. +func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage { + if len(schema) == 0 || string(schema) == "null" { + return json.RawMessage(`{"type":"object","properties":{}}`) + } + return schema +} + +// convertResponsesToAnthropicToolChoice maps Responses tool_choice to Anthropic format. +// Reverse of convertAnthropicToolChoiceToResponses. +// +// "auto" → {"type":"auto"} +// "required" → {"type":"any"} +// "none" → {"type":"none"} +// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} +func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) { + // Try as string first + var s string + if err := json.Unmarshal(raw, &s); err == nil { + switch s { + case "auto": + return json.Marshal(map[string]string{"type": "auto"}) + case "required": + return json.Marshal(map[string]string{"type": "any"}) + case "none": + return json.Marshal(map[string]string{"type": "none"}) + default: + return raw, nil + } + } + + // Try as object with type=function + var tc struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + } `json:"function"` + } + if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" { + return json.Marshal(map[string]string{ + "type": "tool", + "name": tc.Function.Name, + }) + } + + // Pass through unknown + return raw, nil +} diff --git a/backend/internal/pkg/gemini/models.go b/backend/internal/pkg/gemini/models.go index 882d2ebd..fac79d18 100644 --- a/backend/internal/pkg/gemini/models.go +++ b/backend/internal/pkg/gemini/models.go @@ -2,6 +2,8 @@ // It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes). package gemini +import "strings" + type Model struct { Name string `json:"name"` DisplayName string `json:"displayName,omitempty"` @@ -23,10 +25,27 @@ func DefaultModels() []Model { {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods}, + {Name: "models/gemini-3.1-pro-preview-customtools", SupportedGenerationMethods: methods}, {Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods}, } } +func HasFallbackModel(model string) bool { + trimmed := strings.TrimSpace(model) + if trimmed == "" { + return false + } + if !strings.HasPrefix(trimmed, "models/") { + trimmed = "models/" + trimmed + } + for _, model := range DefaultModels() { + if model.Name == trimmed { + return true + } + } + return false +} + func FallbackModelsList() ModelsListResponse { return ModelsListResponse{Models: DefaultModels()} } diff --git a/backend/internal/pkg/gemini/models_test.go b/backend/internal/pkg/gemini/models_test.go index b80047fb..1d20c0e6 100644 --- a/backend/internal/pkg/gemini/models_test.go +++ b/backend/internal/pkg/gemini/models_test.go @@ -2,7 +2,7 @@ package gemini import "testing" -func TestDefaultModels_ContainsImageModels(t *testing.T) { +func TestDefaultModels_ContainsFallbackCatalogModels(t *testing.T) { t.Parallel() models := DefaultModels() @@ -13,6 +13,7 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) { required := []string{ "models/gemini-2.5-flash-image", + "models/gemini-3.1-pro-preview-customtools", "models/gemini-3.1-flash-image", } @@ -26,3 +27,17 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) { } } } + +func TestHasFallbackModel_RecognizesCustomtoolsModel(t *testing.T) { + t.Parallel() + + if !HasFallbackModel("gemini-3.1-pro-preview-customtools") { + t.Fatalf("expected customtools model to exist in fallback catalog") + } + if !HasFallbackModel("models/gemini-3.1-pro-preview-customtools") { + t.Fatalf("expected prefixed customtools model to exist in fallback catalog") + } + if HasFallbackModel("gemini-unknown") { + t.Fatalf("did not expect unknown model to exist in fallback catalog") + } +} diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go index 32e4bc5b..12804cc6 100644 --- a/backend/internal/pkg/httpclient/pool.go +++ b/backend/internal/pkg/httpclient/pool.go @@ -17,6 +17,7 @@ package httpclient import ( "fmt" + "net" "net/http" "strings" "sync" @@ -32,6 +33,8 @@ const ( defaultMaxIdleConns = 100 // 最大空闲连接数 defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数 defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时) + defaultDialTimeout = 5 * time.Second // TCP 连接超时(含代理握手),代理不通时快速失败 + defaultTLSHandshakeTimeout = 5 * time.Second // TLS 握手超时 validatedHostTTL = 30 * time.Second // DNS Rebinding 校验缓存 TTL ) @@ -107,6 +110,10 @@ func buildTransport(opts Options) (*http.Transport, error) { } transport := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: defaultDialTimeout, + }).DialContext, + TLSHandshakeTimeout: defaultTLSHandshakeTimeout, MaxIdleConns: maxIdleConns, MaxIdleConnsPerHost: maxIdleConnsPerHost, MaxConnsPerHost: opts.MaxConnsPerHost, // 0 表示无限制 diff --git a/backend/internal/pkg/oauth/oauth.go b/backend/internal/pkg/oauth/oauth.go index cfc91bee..c5ef3c6e 100644 --- a/backend/internal/pkg/oauth/oauth.go +++ b/backend/internal/pkg/oauth/oauth.go @@ -24,20 +24,18 @@ const ( RedirectURI = "https://platform.claude.com/oauth/code/callback" // Scopes - Browser URL (includes org:create_api_key for user authorization) - ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers" + ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload" // Scopes - Internal API call (org:create_api_key not supported in API) - ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers" + ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload" // Scopes - Setup token (inference only) ScopeInference = "user:inference" - // Code Verifier character set (RFC 7636 compliant) - codeVerifierCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" - // Session TTL SessionTTL = 30 * time.Minute ) // OAuthSession stores OAuth flow state + type OAuthSession struct { State string `json:"state"` CodeVerifier string `json:"code_verifier"` @@ -147,30 +145,14 @@ func GenerateSessionID() (string, error) { return hex.EncodeToString(bytes), nil } -// GenerateCodeVerifier generates a PKCE code verifier using character set method +// GenerateCodeVerifier generates a PKCE code verifier (RFC 7636). +// Uses 32 random bytes → base64url-no-pad, producing a 43-char verifier. func GenerateCodeVerifier() (string, error) { - const targetLen = 32 - charsetLen := len(codeVerifierCharset) - limit := 256 - (256 % charsetLen) - - result := make([]byte, 0, targetLen) - randBuf := make([]byte, targetLen*2) - - for len(result) < targetLen { - if _, err := rand.Read(randBuf); err != nil { - return "", err - } - for _, b := range randBuf { - if int(b) < limit { - result = append(result, codeVerifierCharset[int(b)%charsetLen]) - if len(result) >= targetLen { - break - } - } - } + bytes, err := GenerateRandomBytes(32) + if err != nil { + return "", err } - - return base64URLEncode(result), nil + return base64URLEncode(bytes), nil } // GenerateCodeChallenge generates a PKCE code challenge using S256 method diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go index b0a31a5f..49e38bf8 100644 --- a/backend/internal/pkg/openai/constants.go +++ b/backend/internal/pkg/openai/constants.go @@ -16,6 +16,8 @@ type Model struct { // DefaultModels OpenAI models list var DefaultModels = []Model{ {ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"}, + {ID: "gpt-5.4-mini", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Mini"}, + {ID: "gpt-5.4-nano", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Nano"}, {ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"}, {ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"}, {ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"}, diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go index a35a5ea6..6b8521bd 100644 --- a/backend/internal/pkg/openai/oauth.go +++ b/backend/internal/pkg/openai/oauth.go @@ -270,6 +270,7 @@ type OpenAIAuthClaims struct { ChatGPTUserID string `json:"chatgpt_user_id"` ChatGPTPlanType string `json:"chatgpt_plan_type"` UserID string `json:"user_id"` + POID string `json:"poid"` // organization ID in access_token JWT Organizations []OrganizationClaim `json:"organizations"` } diff --git a/backend/internal/pkg/response/response.go b/backend/internal/pkg/response/response.go index 0519c2cc..b1d6c2d0 100644 --- a/backend/internal/pkg/response/response.go +++ b/backend/internal/pkg/response/response.go @@ -47,6 +47,15 @@ func Created(c *gin.Context, data any) { }) } +// Accepted 返回异步接受响应 (HTTP 202) +func Accepted(c *gin.Context, data any) { + c.JSON(http.StatusAccepted, Response{ + Code: 0, + Message: "accepted", + Data: data, + }) +} + // Error 返回错误响应 func Error(c *gin.Context, statusCode int, message string) { c.JSON(statusCode, Response{ diff --git a/backend/internal/pkg/tlsfingerprint/dialer.go b/backend/internal/pkg/tlsfingerprint/dialer.go index 4f25a34a..c8d8369f 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer.go +++ b/backend/internal/pkg/tlsfingerprint/dialer.go @@ -17,12 +17,19 @@ import ( ) // Profile contains TLS fingerprint configuration. +// All slice fields use built-in defaults when empty. type Profile struct { - Name string // Profile name for identification - CipherSuites []uint16 - Curves []uint16 - PointFormats []uint8 - EnableGREASE bool + Name string // Profile name for identification + CipherSuites []uint16 + Curves []uint16 + PointFormats []uint16 + EnableGREASE bool + SignatureAlgorithms []uint16 // Empty uses defaultSignatureAlgorithms + ALPNProtocols []string // Empty uses ["http/1.1"] + SupportedVersions []uint16 // Empty uses [TLS1.3, TLS1.2] + KeyShareGroups []uint16 // Empty uses [X25519] + PSKModes []uint16 // Empty uses [psk_dhe_ke] + Extensions []uint16 // Extension type IDs in order; empty uses default Node.js 24.x order } // Dialer creates TLS connections with custom fingerprints. @@ -45,154 +52,67 @@ type SOCKS5ProxyDialer struct { proxyURL *url.URL } -// Default TLS fingerprint values captured from Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x) -// Captured using: tshark -i lo -f "tcp port 8443" -Y "tls.handshake.type == 1" -V -// JA3 Hash: 1a28e69016765d92e3b381168d68922c -// -// Note: JA3/JA4 may have slight variations due to: -// - Session ticket presence/absence -// - Extension negotiation state +// Default TLS fingerprint values captured from Claude Code (Node.js 24.x) +// Captured via tls-fingerprint-web capture server +// JA3 Hash: 44f88fca027f27bab4bb08d4af15f23e +// JA4: t13d1714h1_5b57614c22b0_7baf387fc6ff var ( - // defaultCipherSuites contains all 59 cipher suites from Claude CLI + // defaultCipherSuites contains the 17 cipher suites from Node.js 24.x // Order is critical for JA3 fingerprint matching defaultCipherSuites = []uint16{ - // TLS 1.3 cipher suites (MUST be first) + // TLS 1.3 cipher suites + 0x1301, // TLS_AES_128_GCM_SHA256 0x1302, // TLS_AES_256_GCM_SHA384 0x1303, // TLS_CHACHA20_POLY1305_SHA256 - 0x1301, // TLS_AES_128_GCM_SHA256 // ECDHE + AES-GCM - 0xc02f, // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 0xc02b, // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 - 0xc030, // TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 + 0xc02f, // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 0xc02c, // TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 + 0xc030, // TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 - // DHE + AES-GCM - 0x009e, // TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 - - // ECDHE/DHE + AES-CBC-SHA256/384 - 0xc027, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 - 0x0067, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 - 0xc028, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 - 0x006b, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 - - // DHE-DSS/RSA + AES-GCM - 0x00a3, // TLS_DHE_DSS_WITH_AES_256_GCM_SHA384 - 0x009f, // TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 - - // ChaCha20-Poly1305 + // ECDHE + ChaCha20-Poly1305 0xcca9, // TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 0xcca8, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 - 0xccaa, // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 - // AES-CCM (256-bit) - 0xc0af, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8 - 0xc0ad, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM - 0xc0a3, // TLS_DHE_RSA_WITH_AES_256_CCM_8 - 0xc09f, // TLS_DHE_RSA_WITH_AES_256_CCM - - // ARIA (256-bit) - 0xc05d, // TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384 - 0xc061, // TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384 - 0xc057, // TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384 - 0xc053, // TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384 - - // DHE-DSS + AES-GCM (128-bit) - 0x00a2, // TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 - - // AES-CCM (128-bit) - 0xc0ae, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 - 0xc0ac, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM - 0xc0a2, // TLS_DHE_RSA_WITH_AES_128_CCM_8 - 0xc09e, // TLS_DHE_RSA_WITH_AES_128_CCM - - // ARIA (128-bit) - 0xc05c, // TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256 - 0xc060, // TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256 - 0xc056, // TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256 - 0xc052, // TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256 - - // ECDHE/DHE + AES-CBC-SHA384/256 (more) - 0xc024, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 - 0x006a, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 - 0xc023, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 - 0x0040, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 - - // ECDHE/DHE + AES-CBC-SHA (legacy) - 0xc00a, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA - 0xc014, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA - 0x0039, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA - 0x0038, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA + // ECDHE + AES-CBC-SHA (legacy fallback) 0xc009, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA 0xc013, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA - 0x0033, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA - 0x0032, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA + 0xc00a, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA + 0xc014, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA - // RSA + AES-GCM/CCM/ARIA (non-PFS, 256-bit) - 0x009d, // TLS_RSA_WITH_AES_256_GCM_SHA384 - 0xc0a1, // TLS_RSA_WITH_AES_256_CCM_8 - 0xc09d, // TLS_RSA_WITH_AES_256_CCM - 0xc051, // TLS_RSA_WITH_ARIA_256_GCM_SHA384 - - // RSA + AES-GCM/CCM/ARIA (non-PFS, 128-bit) + // RSA + AES-GCM (non-PFS) 0x009c, // TLS_RSA_WITH_AES_128_GCM_SHA256 - 0xc0a0, // TLS_RSA_WITH_AES_128_CCM_8 - 0xc09c, // TLS_RSA_WITH_AES_128_CCM - 0xc050, // TLS_RSA_WITH_ARIA_128_GCM_SHA256 + 0x009d, // TLS_RSA_WITH_AES_256_GCM_SHA384 - // RSA + AES-CBC (non-PFS, legacy) - 0x003d, // TLS_RSA_WITH_AES_256_CBC_SHA256 - 0x003c, // TLS_RSA_WITH_AES_128_CBC_SHA256 - 0x0035, // TLS_RSA_WITH_AES_256_CBC_SHA + // RSA + AES-CBC-SHA (non-PFS, legacy) 0x002f, // TLS_RSA_WITH_AES_128_CBC_SHA - - // Renegotiation indication - 0x00ff, // TLS_EMPTY_RENEGOTIATION_INFO_SCSV + 0x0035, // TLS_RSA_WITH_AES_256_CBC_SHA } - // defaultCurves contains the 10 supported groups from Claude CLI (including FFDHE) + // defaultCurves contains the 3 supported groups from Node.js 24.x defaultCurves = []utls.CurveID{ - utls.X25519, // 0x001d - utls.CurveP256, // 0x0017 (secp256r1) - utls.CurveID(0x001e), // x448 - utls.CurveP521, // 0x0019 (secp521r1) - utls.CurveP384, // 0x0018 (secp384r1) - utls.CurveID(0x0100), // ffdhe2048 - utls.CurveID(0x0101), // ffdhe3072 - utls.CurveID(0x0102), // ffdhe4096 - utls.CurveID(0x0103), // ffdhe6144 - utls.CurveID(0x0104), // ffdhe8192 + utls.X25519, // 0x001d + utls.CurveP256, // 0x0017 (secp256r1) + utls.CurveP384, // 0x0018 (secp384r1) } - // defaultPointFormats contains all 3 point formats from Claude CLI - defaultPointFormats = []uint8{ + // defaultPointFormats contains point formats from Node.js 24.x + defaultPointFormats = []uint16{ 0, // uncompressed - 1, // ansiX962_compressed_prime - 2, // ansiX962_compressed_char2 } - // defaultSignatureAlgorithms contains the 20 signature algorithms from Claude CLI + // defaultSignatureAlgorithms contains the 9 signature algorithms from Node.js 24.x defaultSignatureAlgorithms = []utls.SignatureScheme{ 0x0403, // ecdsa_secp256r1_sha256 - 0x0503, // ecdsa_secp384r1_sha384 - 0x0603, // ecdsa_secp521r1_sha512 - 0x0807, // ed25519 - 0x0808, // ed448 - 0x0809, // rsa_pss_pss_sha256 - 0x080a, // rsa_pss_pss_sha384 - 0x080b, // rsa_pss_pss_sha512 0x0804, // rsa_pss_rsae_sha256 - 0x0805, // rsa_pss_rsae_sha384 - 0x0806, // rsa_pss_rsae_sha512 0x0401, // rsa_pkcs1_sha256 + 0x0503, // ecdsa_secp384r1_sha384 + 0x0805, // rsa_pss_rsae_sha384 0x0501, // rsa_pkcs1_sha384 + 0x0806, // rsa_pss_rsae_sha512 0x0601, // rsa_pkcs1_sha512 - 0x0303, // ecdsa_sha224 - 0x0301, // rsa_pkcs1_sha224 - 0x0302, // dsa_sha224 - 0x0402, // dsa_sha256 - 0x0502, // dsa_sha384 - 0x0602, // dsa_sha512 + 0x0201, // rsa_pkcs1_sha1 } ) @@ -256,49 +176,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st slog.Debug("tls_fingerprint_socks5_tunnel_established") // Step 3: Perform TLS handshake on the tunnel with utls fingerprint - host, _, err := net.SplitHostPort(addr) - if err != nil { - host = addr - } - slog.Debug("tls_fingerprint_socks5_starting_handshake", "host", host) - - // Build ClientHello specification from profile (Node.js/Claude CLI fingerprint) - spec := buildClientHelloSpecFromProfile(d.profile) - slog.Debug("tls_fingerprint_socks5_clienthello_spec", - "cipher_suites", len(spec.CipherSuites), - "extensions", len(spec.Extensions), - "compression_methods", spec.CompressionMethods, - "tls_vers_max", spec.TLSVersMax, - "tls_vers_min", spec.TLSVersMin) - - if d.profile != nil { - slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE) - } - - // Create uTLS connection on the tunnel - tlsConn := utls.UClient(conn, &utls.Config{ - ServerName: host, - }, utls.HelloCustom) - - if err := tlsConn.ApplyPreset(spec); err != nil { - slog.Debug("tls_fingerprint_socks5_apply_preset_failed", "error", err) - _ = conn.Close() - return nil, fmt.Errorf("apply TLS preset: %w", err) - } - - if err := tlsConn.HandshakeContext(ctx); err != nil { - slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err) - _ = conn.Close() - return nil, fmt.Errorf("TLS handshake failed: %w", err) - } - - state := tlsConn.ConnectionState() - slog.Debug("tls_fingerprint_socks5_handshake_success", - "version", state.Version, - "cipher_suite", state.CipherSuite, - "alpn", state.NegotiatedProtocol) - - return tlsConn, nil + return performTLSHandshake(ctx, conn, d.profile, addr) } // DialTLSContext establishes a TLS connection through HTTP proxy with the configured fingerprint. @@ -358,7 +236,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri slog.Debug("tls_fingerprint_http_proxy_read_response_failed", "error", err) return nil, fmt.Errorf("read CONNECT response: %w", err) } - defer func() { _ = resp.Body.Close() }() + // CONNECT response has no body; do not defer resp.Body.Close() as it wraps the + // same conn that will be used for the TLS handshake. if resp.StatusCode != http.StatusOK { _ = conn.Close() @@ -368,47 +247,7 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri slog.Debug("tls_fingerprint_http_proxy_tunnel_established") // Step 4: Perform TLS handshake on the tunnel with utls fingerprint - host, _, err := net.SplitHostPort(addr) - if err != nil { - host = addr - } - slog.Debug("tls_fingerprint_http_proxy_starting_handshake", "host", host) - - // Build ClientHello specification (reuse the shared method) - spec := buildClientHelloSpecFromProfile(d.profile) - slog.Debug("tls_fingerprint_http_proxy_clienthello_spec", - "cipher_suites", len(spec.CipherSuites), - "extensions", len(spec.Extensions)) - - if d.profile != nil { - slog.Debug("tls_fingerprint_http_proxy_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE) - } - - // Create uTLS connection on the tunnel - // Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions - tlsConn := utls.UClient(conn, &utls.Config{ - ServerName: host, - }, utls.HelloCustom) - - if err := tlsConn.ApplyPreset(spec); err != nil { - slog.Debug("tls_fingerprint_http_proxy_apply_preset_failed", "error", err) - _ = conn.Close() - return nil, fmt.Errorf("apply TLS preset: %w", err) - } - - if err := tlsConn.HandshakeContext(ctx); err != nil { - slog.Debug("tls_fingerprint_http_proxy_handshake_failed", "error", err) - _ = conn.Close() - return nil, fmt.Errorf("TLS handshake failed: %w", err) - } - - state := tlsConn.ConnectionState() - slog.Debug("tls_fingerprint_http_proxy_handshake_success", - "version", state.Version, - "cipher_suite", state.CipherSuite, - "alpn", state.NegotiatedProtocol) - - return tlsConn, nil + return performTLSHandshake(ctx, conn, d.profile, addr) } // DialTLSContext establishes a TLS connection with the configured fingerprint. @@ -423,53 +262,35 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net. } slog.Debug("tls_fingerprint_tcp_connected", "addr", addr) - // Extract hostname for SNI + // Perform TLS handshake with utls fingerprint + return performTLSHandshake(ctx, conn, d.profile, addr) +} + +// performTLSHandshake performs the uTLS handshake on an established connection. +// It builds a ClientHello spec from the profile, applies it, and completes the handshake. +// On failure, conn is closed and an error is returned. +func performTLSHandshake(ctx context.Context, conn net.Conn, profile *Profile, addr string) (net.Conn, error) { host, _, err := net.SplitHostPort(addr) if err != nil { host = addr } - slog.Debug("tls_fingerprint_sni_hostname", "host", host) - // Build ClientHello specification - spec := d.buildClientHelloSpec() - slog.Debug("tls_fingerprint_clienthello_spec", - "cipher_suites", len(spec.CipherSuites), - "extensions", len(spec.Extensions)) + spec := buildClientHelloSpecFromProfile(profile) + tlsConn := utls.UClient(conn, &utls.Config{ServerName: host}, utls.HelloCustom) - // Log profile info - if d.profile != nil { - slog.Debug("tls_fingerprint_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE) - } else { - slog.Debug("tls_fingerprint_using_default_profile") - } - - // Create uTLS connection - // Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions - tlsConn := utls.UClient(conn, &utls.Config{ - ServerName: host, - }, utls.HelloCustom) - - // Apply fingerprint if err := tlsConn.ApplyPreset(spec); err != nil { - slog.Debug("tls_fingerprint_apply_preset_failed", "error", err) _ = conn.Close() - return nil, err + return nil, fmt.Errorf("apply TLS preset: %w", err) } - slog.Debug("tls_fingerprint_preset_applied") - // Perform TLS handshake if err := tlsConn.HandshakeContext(ctx); err != nil { - slog.Debug("tls_fingerprint_handshake_failed", - "error", err, - "local_addr", conn.LocalAddr(), - "remote_addr", conn.RemoteAddr()) _ = conn.Close() return nil, fmt.Errorf("TLS handshake failed: %w", err) } - // Log successful handshake details state := tlsConn.ConnectionState() slog.Debug("tls_fingerprint_handshake_success", + "host", host, "version", state.Version, "cipher_suite", state.CipherSuite, "alpn", state.NegotiatedProtocol) @@ -477,11 +298,6 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net. return tlsConn, nil } -// buildClientHelloSpec constructs the ClientHello specification based on the profile. -func (d *Dialer) buildClientHelloSpec() *utls.ClientHelloSpec { - return buildClientHelloSpecFromProfile(d.profile) -} - // toUTLSCurves converts uint16 slice to utls.CurveID slice. func toUTLSCurves(curves []uint16) []utls.CurveID { result := make([]utls.CurveID, len(curves)) @@ -491,70 +307,143 @@ func toUTLSCurves(curves []uint16) []utls.CurveID { return result } +// defaultExtensionOrder is the Node.js 24.x extension order. +// Used when Profile.Extensions is empty. +var defaultExtensionOrder = []uint16{ + 0, // server_name + 65037, // encrypted_client_hello + 23, // extended_master_secret + 65281, // renegotiation_info + 10, // supported_groups + 11, // ec_point_formats + 35, // session_ticket + 16, // alpn + 5, // status_request + 13, // signature_algorithms + 18, // signed_certificate_timestamp + 51, // key_share + 45, // psk_key_exchange_modes + 43, // supported_versions +} + +// isGREASEValue checks if a uint16 value matches the TLS GREASE pattern (0x?a?a). +func isGREASEValue(v uint16) bool { + return v&0x0f0f == 0x0a0a && v>>8 == v&0xff +} + // buildClientHelloSpecFromProfile constructs ClientHelloSpec from a Profile. // This is a standalone function that can be used by both Dialer and HTTPProxyDialer. func buildClientHelloSpecFromProfile(profile *Profile) *utls.ClientHelloSpec { - // Get cipher suites - var cipherSuites []uint16 + // Resolve effective values (profile overrides or built-in defaults) + cipherSuites := defaultCipherSuites if profile != nil && len(profile.CipherSuites) > 0 { cipherSuites = profile.CipherSuites - } else { - cipherSuites = defaultCipherSuites } - // Get curves - var curves []utls.CurveID + curves := defaultCurves if profile != nil && len(profile.Curves) > 0 { curves = toUTLSCurves(profile.Curves) - } else { - curves = defaultCurves } - // Get point formats - var pointFormats []uint8 + pointFormats := defaultPointFormats if profile != nil && len(profile.PointFormats) > 0 { pointFormats = profile.PointFormats - } else { - pointFormats = defaultPointFormats } - // Check if GREASE is enabled + signatureAlgorithms := defaultSignatureAlgorithms + if profile != nil && len(profile.SignatureAlgorithms) > 0 { + signatureAlgorithms = make([]utls.SignatureScheme, len(profile.SignatureAlgorithms)) + for i, s := range profile.SignatureAlgorithms { + signatureAlgorithms[i] = utls.SignatureScheme(s) + } + } + + alpnProtocols := []string{"http/1.1"} + if profile != nil && len(profile.ALPNProtocols) > 0 { + alpnProtocols = profile.ALPNProtocols + } + + supportedVersions := []uint16{utls.VersionTLS13, utls.VersionTLS12} + if profile != nil && len(profile.SupportedVersions) > 0 { + supportedVersions = profile.SupportedVersions + } + + keyShareGroups := []utls.CurveID{utls.X25519} + if profile != nil && len(profile.KeyShareGroups) > 0 { + keyShareGroups = toUTLSCurves(profile.KeyShareGroups) + } + + pskModes := []uint16{uint16(utls.PskModeDHE)} + if profile != nil && len(profile.PSKModes) > 0 { + pskModes = profile.PSKModes + } + enableGREASE := profile != nil && profile.EnableGREASE - extensions := make([]utls.TLSExtension, 0, 16) - - if enableGREASE { - extensions = append(extensions, &utls.UtlsGREASEExtension{}) + // Build key shares + keyShares := make([]utls.KeyShare, len(keyShareGroups)) + for i, g := range keyShareGroups { + keyShares[i] = utls.KeyShare{Group: g} } - // SNI extension - MUST be explicitly added for HelloCustom mode - // utls will populate the server name from Config.ServerName - extensions = append(extensions, &utls.SNIExtension{}) + // Determine extension order + extOrder := defaultExtensionOrder + if profile != nil && len(profile.Extensions) > 0 { + extOrder = profile.Extensions + } - // Claude CLI extension order (captured from tshark): - // server_name(0), ec_point_formats(11), supported_groups(10), session_ticket(35), - // alpn(16), encrypt_then_mac(22), extended_master_secret(23), - // signature_algorithms(13), supported_versions(43), - // psk_key_exchange_modes(45), key_share(51) - extensions = append(extensions, - &utls.SupportedPointsExtension{SupportedPoints: pointFormats}, - &utls.SupportedCurvesExtension{Curves: curves}, - &utls.SessionTicketExtension{}, - &utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}}, - &utls.GenericExtension{Id: 22}, - &utls.ExtendedMasterSecretExtension{}, - &utls.SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: defaultSignatureAlgorithms}, - &utls.SupportedVersionsExtension{Versions: []uint16{ - utls.VersionTLS13, - utls.VersionTLS12, - }}, - &utls.PSKKeyExchangeModesExtension{Modes: []uint8{utls.PskModeDHE}}, - &utls.KeyShareExtension{KeyShares: []utls.KeyShare{ - {Group: utls.X25519}, - }}, - ) + // Build extensions list from the ordered IDs. + // Parametric extensions (curves, sigalgs, etc.) are populated with resolved profile values. + // Unknown IDs use GenericExtension (sends type ID with empty data). + extensions := make([]utls.TLSExtension, 0, len(extOrder)+2) + for _, id := range extOrder { + if isGREASEValue(id) { + extensions = append(extensions, &utls.UtlsGREASEExtension{}) + continue + } + switch id { + case 0: // server_name + extensions = append(extensions, &utls.SNIExtension{}) + case 5: // status_request (OCSP) + extensions = append(extensions, &utls.StatusRequestExtension{}) + case 10: // supported_groups + extensions = append(extensions, &utls.SupportedCurvesExtension{Curves: curves}) + case 11: // ec_point_formats + extensions = append(extensions, &utls.SupportedPointsExtension{SupportedPoints: toUint8s(pointFormats)}) + case 13: // signature_algorithms + extensions = append(extensions, &utls.SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: signatureAlgorithms}) + case 16: // alpn + extensions = append(extensions, &utls.ALPNExtension{AlpnProtocols: alpnProtocols}) + case 18: // signed_certificate_timestamp + extensions = append(extensions, &utls.SCTExtension{}) + case 23: // extended_master_secret + extensions = append(extensions, &utls.ExtendedMasterSecretExtension{}) + case 35: // session_ticket + extensions = append(extensions, &utls.SessionTicketExtension{}) + case 43: // supported_versions + extensions = append(extensions, &utls.SupportedVersionsExtension{Versions: supportedVersions}) + case 45: // psk_key_exchange_modes + extensions = append(extensions, &utls.PSKKeyExchangeModesExtension{Modes: toUint8s(pskModes)}) + case 50: // signature_algorithms_cert + extensions = append(extensions, &utls.SignatureAlgorithmsCertExtension{SupportedSignatureAlgorithms: signatureAlgorithms}) + case 51: // key_share + extensions = append(extensions, &utls.KeyShareExtension{KeyShares: keyShares}) + case 0xfe0d: // encrypted_client_hello (ECH, 65037) + // Send GREASE ECH with random payload — mimics Node.js behavior when no real ECHConfig is available. + // An empty GenericExtension causes "error decoding message" from servers that validate ECH format. + extensions = append(extensions, &utls.GREASEEncryptedClientHelloExtension{}) + case 0xff01: // renegotiation_info + extensions = append(extensions, &utls.RenegotiationInfoExtension{}) + default: + // Unknown extension — send as GenericExtension (type ID + empty data). + // This covers encrypt_then_mac(22) and any future extensions. + extensions = append(extensions, &utls.GenericExtension{Id: id}) + } + } - if enableGREASE { + // For default extension order with EnableGREASE, wrap with GREASE bookends + if enableGREASE && (profile == nil || len(profile.Extensions) == 0) { + extensions = append([]utls.TLSExtension{&utls.UtlsGREASEExtension{}}, extensions...) extensions = append(extensions, &utls.UtlsGREASEExtension{}) } @@ -566,3 +455,12 @@ func buildClientHelloSpecFromProfile(profile *Profile) *utls.ClientHelloSpec { TLSVersMin: utls.VersionTLS10, } } + +// toUint8s converts []uint16 to []uint8 (for utls fields that require []uint8). +func toUint8s(vals []uint16) []uint8 { + out := make([]uint8, len(vals)) + for i, v := range vals { + out[i] = uint8(v) + } + return out +} diff --git a/backend/internal/pkg/tlsfingerprint/dialer_capture_test.go b/backend/internal/pkg/tlsfingerprint/dialer_capture_test.go new file mode 100644 index 00000000..de9d79a0 --- /dev/null +++ b/backend/internal/pkg/tlsfingerprint/dialer_capture_test.go @@ -0,0 +1,368 @@ +//go:build integration + +package tlsfingerprint + +import ( + "context" + "encoding/json" + "io" + "net/http" + "os" + "strings" + "testing" + "time" + + utls "github.com/refraction-networking/utls" +) + +// CapturedFingerprint mirrors the Fingerprint struct from tls-fingerprint-web. +// Used to deserialize the JSON response from the capture server. +type CapturedFingerprint struct { + JA3Raw string `json:"ja3_raw"` + JA3Hash string `json:"ja3_hash"` + JA4 string `json:"ja4"` + HTTP2 string `json:"http2"` + CipherSuites []int `json:"cipher_suites"` + Curves []int `json:"curves"` + PointFormats []int `json:"point_formats"` + Extensions []int `json:"extensions"` + SignatureAlgorithms []int `json:"signature_algorithms"` + ALPNProtocols []string `json:"alpn_protocols"` + SupportedVersions []int `json:"supported_versions"` + KeyShareGroups []int `json:"key_share_groups"` + PSKModes []int `json:"psk_modes"` + CompressCertAlgos []int `json:"compress_cert_algos"` + EnableGREASE bool `json:"enable_grease"` +} + +// TestDialerAgainstCaptureServer connects to the tls-fingerprint-web capture server +// and verifies that the dialer's TLS fingerprint matches the configured Profile. +// +// Default capture server: https://tls.sub2api.org:8090 +// Override with env: TLSFINGERPRINT_CAPTURE_URL=https://localhost:8443 +// +// Run: go test -v -run TestDialerAgainstCaptureServer ./internal/pkg/tlsfingerprint/... +func TestDialerAgainstCaptureServer(t *testing.T) { + captureURL := os.Getenv("TLSFINGERPRINT_CAPTURE_URL") + if captureURL == "" { + captureURL = "https://tls.sub2api.org:8090" + } + + tests := []struct { + name string + profile *Profile + }{ + { + name: "default_profile", + profile: &Profile{ + Name: "default", + EnableGREASE: false, + // All empty → uses built-in defaults + }, + }, + { + name: "linux_x64_node_v22171", + profile: &Profile{ + Name: "linux_x64_node_v22171", + EnableGREASE: false, + CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255}, + Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260}, + PointFormats: []uint16{0, 1, 2}, + SignatureAlgorithms: []uint16{0x0403, 0x0503, 0x0603, 0x0807, 0x0808, 0x0809, 0x080a, 0x080b, 0x0804, 0x0805, 0x0806, 0x0401, 0x0501, 0x0601, 0x0303, 0x0301, 0x0302, 0x0402, 0x0502, 0x0602}, + ALPNProtocols: []string{"http/1.1"}, + SupportedVersions: []uint16{0x0304, 0x0303}, + KeyShareGroups: []uint16{29}, + PSKModes: []uint16{1}, + Extensions: []uint16{0, 11, 10, 35, 16, 22, 23, 13, 43, 45, 51}, + }, + }, + { + name: "macos_arm64_node_v2430", + profile: &Profile{ + Name: "MacOS_arm64_node_v2430", + EnableGREASE: false, + CipherSuites: []uint16{4865, 4866, 4867, 49195, 49199, 49196, 49200, 52393, 52392, 49161, 49171, 49162, 49172, 156, 157, 47, 53}, + Curves: []uint16{29, 23, 24}, + PointFormats: []uint16{0}, + SignatureAlgorithms: []uint16{0x0403, 0x0804, 0x0401, 0x0503, 0x0805, 0x0501, 0x0806, 0x0601, 0x0201}, + ALPNProtocols: []string{"http/1.1"}, + SupportedVersions: []uint16{0x0304, 0x0303}, + KeyShareGroups: []uint16{29}, + PSKModes: []uint16{1}, + Extensions: []uint16{0, 65037, 23, 65281, 10, 11, 35, 16, 5, 13, 18, 51, 45, 43}, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + captured := fetchCapturedFingerprint(t, captureURL, tc.profile) + if captured == nil { + return + } + + t.Logf("JA3 Hash: %s", captured.JA3Hash) + t.Logf("JA4: %s", captured.JA4) + + // Resolve effective profile values (what the dialer actually uses) + effectiveCipherSuites := tc.profile.CipherSuites + if len(effectiveCipherSuites) == 0 { + effectiveCipherSuites = defaultCipherSuites + } + effectiveCurves := tc.profile.Curves + if len(effectiveCurves) == 0 { + effectiveCurves = make([]uint16, len(defaultCurves)) + for i, c := range defaultCurves { + effectiveCurves[i] = uint16(c) + } + } + effectivePointFormats := tc.profile.PointFormats + if len(effectivePointFormats) == 0 { + effectivePointFormats = defaultPointFormats + } + effectiveSigAlgs := tc.profile.SignatureAlgorithms + if len(effectiveSigAlgs) == 0 { + effectiveSigAlgs = make([]uint16, len(defaultSignatureAlgorithms)) + for i, s := range defaultSignatureAlgorithms { + effectiveSigAlgs[i] = uint16(s) + } + } + effectiveALPN := tc.profile.ALPNProtocols + if len(effectiveALPN) == 0 { + effectiveALPN = []string{"http/1.1"} + } + effectiveVersions := tc.profile.SupportedVersions + if len(effectiveVersions) == 0 { + effectiveVersions = []uint16{0x0304, 0x0303} + } + effectiveKeyShare := tc.profile.KeyShareGroups + if len(effectiveKeyShare) == 0 { + effectiveKeyShare = []uint16{29} // X25519 + } + effectivePSKModes := tc.profile.PSKModes + if len(effectivePSKModes) == 0 { + effectivePSKModes = []uint16{1} // psk_dhe_ke + } + + // Verify each field + assertIntSliceEqual(t, "cipher_suites", uint16sToInts(effectiveCipherSuites), captured.CipherSuites) + assertIntSliceEqual(t, "curves", uint16sToInts(effectiveCurves), captured.Curves) + assertIntSliceEqual(t, "point_formats", uint16sToInts(effectivePointFormats), captured.PointFormats) + assertIntSliceEqual(t, "signature_algorithms", uint16sToInts(effectiveSigAlgs), captured.SignatureAlgorithms) + assertStringSliceEqual(t, "alpn_protocols", effectiveALPN, captured.ALPNProtocols) + assertIntSliceEqual(t, "supported_versions", uint16sToInts(effectiveVersions), captured.SupportedVersions) + assertIntSliceEqual(t, "key_share_groups", uint16sToInts(effectiveKeyShare), captured.KeyShareGroups) + assertIntSliceEqual(t, "psk_modes", uint16sToInts(effectivePSKModes), captured.PSKModes) + + if captured.EnableGREASE != tc.profile.EnableGREASE { + t.Errorf("enable_grease: got %v, want %v", captured.EnableGREASE, tc.profile.EnableGREASE) + } else { + t.Logf(" enable_grease: %v OK", captured.EnableGREASE) + } + + // Verify extension order + // Use profile.Extensions if set, otherwise the default order (Node.js 24.x) + expectedExtOrder := uint16sToInts(defaultExtensionOrder) + if len(tc.profile.Extensions) > 0 { + expectedExtOrder = uint16sToInts(tc.profile.Extensions) + } + // Strip GREASE values from both expected and captured for comparison + var filteredExpected, filteredActual []int + for _, e := range expectedExtOrder { + if !isGREASEValue(uint16(e)) { + filteredExpected = append(filteredExpected, e) + } + } + for _, e := range captured.Extensions { + if !isGREASEValue(uint16(e)) { + filteredActual = append(filteredActual, e) + } + } + assertIntSliceEqual(t, "extensions (order, non-GREASE)", filteredExpected, filteredActual) + + // Print full captured data as JSON for debugging + capturedJSON, _ := json.MarshalIndent(captured, " ", " ") + t.Logf("Full captured fingerprint:\n %s", string(capturedJSON)) + }) + } +} + +func fetchCapturedFingerprint(t *testing.T, captureURL string, profile *Profile) *CapturedFingerprint { + t.Helper() + + dialer := NewDialer(profile, nil) + client := &http.Client{ + Transport: &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + }, + Timeout: 10 * time.Second, + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", captureURL, strings.NewReader(`{"model":"test"}`)) + if err != nil { + t.Fatalf("create request: %v", err) + return nil + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-token") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + return nil + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + return nil + } + + var fp CapturedFingerprint + if err := json.Unmarshal(body, &fp); err != nil { + t.Logf("Response body: %s", string(body)) + t.Fatalf("parse response: %v", err) + return nil + } + + return &fp +} + +func uint16sToInts(vals []uint16) []int { + result := make([]int, len(vals)) + for i, v := range vals { + result[i] = int(v) + } + return result +} + +func assertIntSliceEqual(t *testing.T, name string, expected, actual []int) { + t.Helper() + if len(expected) != len(actual) { + t.Errorf("%s: length mismatch: got %d, want %d", name, len(actual), len(expected)) + if len(actual) < 20 && len(expected) < 20 { + t.Errorf(" got: %v", actual) + t.Errorf(" want: %v", expected) + } + return + } + mismatches := 0 + for i := range expected { + if expected[i] != actual[i] { + if mismatches < 5 { + t.Errorf("%s[%d]: got %d (0x%04x), want %d (0x%04x)", name, i, actual[i], actual[i], expected[i], expected[i]) + } + mismatches++ + } + } + if mismatches == 0 { + t.Logf(" %s: %d items OK", name, len(expected)) + } else if mismatches > 5 { + t.Errorf(" %s: %d/%d mismatches (showing first 5)", name, mismatches, len(expected)) + } +} + +func assertStringSliceEqual(t *testing.T, name string, expected, actual []string) { + t.Helper() + if len(expected) != len(actual) { + t.Errorf("%s: length mismatch: got %d (%v), want %d (%v)", name, len(actual), actual, len(expected), expected) + return + } + for i := range expected { + if expected[i] != actual[i] { + t.Errorf("%s[%d]: got %q, want %q", name, i, actual[i], expected[i]) + return + } + } + t.Logf(" %s: %v OK", name, expected) +} + +// TestBuildClientHelloSpecNewFields tests that new Profile fields are correctly applied. +func TestBuildClientHelloSpecNewFields(t *testing.T) { + // Test custom ALPN, versions, key shares, PSK modes + profile := &Profile{ + Name: "custom_full", + EnableGREASE: false, + CipherSuites: []uint16{0x1301, 0x1302}, + Curves: []uint16{29, 23}, + PointFormats: []uint16{0}, + SignatureAlgorithms: []uint16{0x0403, 0x0804}, + ALPNProtocols: []string{"h2", "http/1.1"}, + SupportedVersions: []uint16{0x0304}, + KeyShareGroups: []uint16{29, 23}, + PSKModes: []uint16{1}, + } + + spec := buildClientHelloSpecFromProfile(profile) + + // Verify cipher suites + if len(spec.CipherSuites) != 2 || spec.CipherSuites[0] != 0x1301 { + t.Errorf("cipher suites: got %v", spec.CipherSuites) + } + + // Check extensions for expected values + var foundALPN, foundVersions, foundKeyShare, foundPSK, foundSigAlgs bool + for _, ext := range spec.Extensions { + switch e := ext.(type) { + case *utls.ALPNExtension: + foundALPN = true + if len(e.AlpnProtocols) != 2 || e.AlpnProtocols[0] != "h2" { + t.Errorf("ALPN: got %v, want [h2, http/1.1]", e.AlpnProtocols) + } + case *utls.SupportedVersionsExtension: + foundVersions = true + if len(e.Versions) != 1 || e.Versions[0] != 0x0304 { + t.Errorf("versions: got %v, want [0x0304]", e.Versions) + } + case *utls.KeyShareExtension: + foundKeyShare = true + if len(e.KeyShares) != 2 { + t.Errorf("key shares: got %d entries, want 2", len(e.KeyShares)) + } + case *utls.PSKKeyExchangeModesExtension: + foundPSK = true + if len(e.Modes) != 1 || e.Modes[0] != 1 { + t.Errorf("PSK modes: got %v, want [1]", e.Modes) + } + case *utls.SignatureAlgorithmsExtension: + foundSigAlgs = true + if len(e.SupportedSignatureAlgorithms) != 2 { + t.Errorf("sig algs: got %d, want 2", len(e.SupportedSignatureAlgorithms)) + } + } + } + + for name, found := range map[string]bool{ + "ALPN": foundALPN, "Versions": foundVersions, "KeyShare": foundKeyShare, + "PSK": foundPSK, "SigAlgs": foundSigAlgs, + } { + if !found { + t.Errorf("extension %s not found in spec", name) + } + } + + // Test nil profile uses all defaults + specDefault := buildClientHelloSpecFromProfile(nil) + for _, ext := range specDefault.Extensions { + switch e := ext.(type) { + case *utls.ALPNExtension: + if len(e.AlpnProtocols) != 1 || e.AlpnProtocols[0] != "http/1.1" { + t.Errorf("default ALPN: got %v, want [http/1.1]", e.AlpnProtocols) + } + case *utls.SupportedVersionsExtension: + if len(e.Versions) != 2 { + t.Errorf("default versions: got %v, want 2 entries", e.Versions) + } + case *utls.KeyShareExtension: + if len(e.KeyShares) != 1 { + t.Errorf("default key shares: got %d, want 1", len(e.KeyShares)) + } + } + } + + t.Log("TestBuildClientHelloSpecNewFields passed") +} diff --git a/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go b/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go index 3f668fbe..38cddd0d 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go +++ b/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go @@ -40,16 +40,15 @@ func skipIfExternalServiceUnavailable(t *testing.T, err error) { // TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value. // This test uses tls.peet.ws to verify the fingerprint. -// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x) -// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP) +// Expected JA3 hash: 44f88fca027f27bab4bb08d4af15f23e (Node.js 24.x) +// Expected JA4: t13d1714h1_5b57614c22b0_7baf387fc6ff func TestJA3Fingerprint(t *testing.T) { - // Skip if network is unavailable or if running in short mode if testing.Short() { t.Skip("skipping integration test in short mode") } profile := &Profile{ - Name: "Claude CLI Test", + Name: "Default Profile Test", EnableGREASE: false, } dialer := NewDialer(profile, nil) @@ -61,7 +60,6 @@ func TestJA3Fingerprint(t *testing.T) { Timeout: 30 * time.Second, } - // Use tls.peet.ws fingerprint detection API ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -69,7 +67,7 @@ func TestJA3Fingerprint(t *testing.T) { if err != nil { t.Fatalf("failed to create request: %v", err) } - req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0") + req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/24.3.0") resp, err := client.Do(req) skipIfExternalServiceUnavailable(t, err) @@ -86,71 +84,23 @@ func TestJA3Fingerprint(t *testing.T) { t.Fatalf("failed to parse fingerprint response: %v", err) } - // Log all fingerprint information t.Logf("JA3: %s", fpResp.TLS.JA3) t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash) t.Logf("JA4: %s", fpResp.TLS.JA4) - t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint) - t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash) - // Verify JA3 hash matches expected value - expectedJA3Hash := "1a28e69016765d92e3b381168d68922c" + expectedJA3Hash := "44f88fca027f27bab4bb08d4af15f23e" if fpResp.TLS.JA3Hash == expectedJA3Hash { - t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash) + t.Logf("✓ JA3 hash matches: %s", expectedJA3Hash) } else { t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash) } - // Verify JA4 fingerprint - // JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash] - // Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP) - // The suffix _a33745022dd6_1f22a2ca17c4 should match - expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4" - if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) { - t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix) + expectedJA4CipherHash := "_5b57614c22b0_" + if strings.Contains(fpResp.TLS.JA4, expectedJA4CipherHash) { + t.Logf("✓ JA4 cipher hash matches: %s", expectedJA4CipherHash) } else { - t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix) + t.Errorf("✗ JA4 cipher hash mismatch: got %s, expected containing %s", fpResp.TLS.JA4, expectedJA4CipherHash) } - - // Verify JA4 prefix (t13d5911h1 or t13i5911h1) - // d = domain (SNI present), i = IP (no SNI) - // Since we connect to tls.peet.ws (domain), we expect 'd' - expectedJA4Prefix := "t13d5911h1" - if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) { - t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix) - } else { - // Also accept 'i' variant for IP connections - altPrefix := "t13i5911h1" - if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) { - t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix) - } else { - t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix) - } - } - - // Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning) - if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") { - t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites") - } else { - t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites") - } - - // Verify extension list (should be 11 extensions including SNI) - // Expected: 0-11-10-35-16-22-23-13-43-45-51 - expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51" - if strings.Contains(fpResp.TLS.JA3, expectedExtensions) { - t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions) - } else { - t.Logf("Warning: JA3 extension list may differ") - } -} - -// TestProfileExpectation defines expected fingerprint values for a profile. -type TestProfileExpectation struct { - Profile *Profile - ExpectedJA3 string // Expected JA3 hash (empty = don't check) - ExpectedJA4 string // Expected full JA4 (empty = don't check) - JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check) } // TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws. @@ -164,30 +114,24 @@ func TestAllProfiles(t *testing.T) { // These profiles are from config.yaml gateway.tls_fingerprint.profiles profiles := []TestProfileExpectation{ { - // Linux x64 Node.js v22.17.1 - // Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c - // Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 + // Default profile (Node.js 24.x) + Profile: &Profile{ + Name: "default_node_v24", + EnableGREASE: false, + }, + JA4CipherHash: "5b57614c22b0", + }, + { + // Linux x64 Node.js v22.17.1 (explicit profile with v22 extensions) Profile: &Profile{ Name: "linux_x64_node_v22171", EnableGREASE: false, CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255}, Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260}, - PointFormats: []uint8{0, 1, 2}, + PointFormats: []uint16{0, 1, 2}, + Extensions: []uint16{0, 11, 10, 35, 16, 22, 23, 13, 43, 45, 51}, }, - JA4CipherHash: "a33745022dd6", // stable part - }, - { - // MacOS arm64 Node.js v22.18.0 - // Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea - // Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406 - Profile: &Profile{ - Name: "macos_arm64_node_v22180", - EnableGREASE: false, - CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255}, - Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260}, - PointFormats: []uint8{0, 1, 2}, - }, - JA4CipherHash: "a33745022dd6", // stable part (same cipher suites) + JA4CipherHash: "a33745022dd6", }, } diff --git a/backend/internal/pkg/tlsfingerprint/dialer_test.go b/backend/internal/pkg/tlsfingerprint/dialer_test.go index 6d3db174..048418c9 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer_test.go +++ b/backend/internal/pkg/tlsfingerprint/dialer_test.go @@ -55,13 +55,13 @@ func TestDialerBasicConnection(t *testing.T) { // TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value. // This test uses tls.peet.ws to verify the fingerprint. -// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x) -// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP) +// Expected JA3 hash: 44f88fca027f27bab4bb08d4af15f23e (Node.js 24.x) +// Expected JA4: t13d1714h1_5b57614c22b0_7baf387fc6ff func TestJA3Fingerprint(t *testing.T) { skipNetworkTest(t) profile := &Profile{ - Name: "Claude CLI Test", + Name: "Default Profile Test", EnableGREASE: false, } dialer := NewDialer(profile, nil) @@ -81,7 +81,7 @@ func TestJA3Fingerprint(t *testing.T) { if err != nil { t.Fatalf("failed to create request: %v", err) } - req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0") + req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/24.3.0") resp, err := client.Do(req) if err != nil { @@ -107,34 +107,28 @@ func TestJA3Fingerprint(t *testing.T) { t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint) t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash) - // Verify JA3 hash matches expected value - expectedJA3Hash := "1a28e69016765d92e3b381168d68922c" + // Verify JA3 hash matches expected value (Node.js 24.x default) + expectedJA3Hash := "44f88fca027f27bab4bb08d4af15f23e" if fpResp.TLS.JA3Hash == expectedJA3Hash { t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash) } else { t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash) } - // Verify JA4 fingerprint - // JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash] - // Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP) - // The suffix _a33745022dd6_1f22a2ca17c4 should match - expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4" - if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) { - t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix) + // Verify JA4 cipher hash (stable middle part) + expectedJA4CipherHash := "_5b57614c22b0_" + if strings.Contains(fpResp.TLS.JA4, expectedJA4CipherHash) { + t.Logf("✓ JA4 cipher hash matches: %s", expectedJA4CipherHash) } else { - t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix) + t.Errorf("✗ JA4 cipher hash mismatch: got %s, expected containing %s", fpResp.TLS.JA4, expectedJA4CipherHash) } - // Verify JA4 prefix (t13d5911h1 or t13i5911h1) - // d = domain (SNI present), i = IP (no SNI) - // Since we connect to tls.peet.ws (domain), we expect 'd' - expectedJA4Prefix := "t13d5911h1" + // Verify JA4 prefix (t13d1714h1 or t13i1714h1) + expectedJA4Prefix := "t13d1714h1" if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) { - t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix) + t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 17=ciphers, 14=extensions, h1=HTTP/1.1)", expectedJA4Prefix) } else { - // Also accept 'i' variant for IP connections - altPrefix := "t13i5911h1" + altPrefix := "t13i1714h1" if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) { t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix) } else { @@ -142,16 +136,15 @@ func TestJA3Fingerprint(t *testing.T) { } } - // Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning) - if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") { + // Verify JA3 contains expected TLS 1.3 cipher suites + if strings.Contains(fpResp.TLS.JA3, "4865-4866-4867") { t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites") } else { t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites") } - // Verify extension list (should be 11 extensions including SNI) - // Expected: 0-11-10-35-16-22-23-13-43-45-51 - expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51" + // Verify extension list (14 extensions, Node.js 24.x order) + expectedExtensions := "0-65037-23-65281-10-11-35-16-5-13-18-51-45-43" if strings.Contains(fpResp.TLS.JA3, expectedExtensions) { t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions) } else { @@ -186,8 +179,8 @@ func TestDialerWithProfile(t *testing.T) { // Build specs and compare // Note: We can't directly compare JA3 without making network requests // but we can verify the specs are different - spec1 := dialer1.buildClientHelloSpec() - spec2 := dialer2.buildClientHelloSpec() + spec1 := buildClientHelloSpecFromProfile(dialer1.profile) + spec2 := buildClientHelloSpecFromProfile(dialer2.profile) // Profile with GREASE should have more extensions if len(spec2.Extensions) <= len(spec1.Extensions) { @@ -296,47 +289,33 @@ func mustParseURL(rawURL string) *url.URL { return u } -// TestProfileExpectation defines expected fingerprint values for a profile. -type TestProfileExpectation struct { - Profile *Profile - ExpectedJA3 string // Expected JA3 hash (empty = don't check) - ExpectedJA4 string // Expected full JA4 (empty = don't check) - JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check) -} - // TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws. // Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/... func TestAllProfiles(t *testing.T) { skipNetworkTest(t) - // Define all profiles to test with their expected fingerprints - // These profiles are from config.yaml gateway.tls_fingerprint.profiles profiles := []TestProfileExpectation{ { - // Linux x64 Node.js v22.17.1 - // Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c - // Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 + // Default profile (Node.js 24.x) + // JA3 Hash: 44f88fca027f27bab4bb08d4af15f23e + // JA4: t13d1714h1_5b57614c22b0_7baf387fc6ff + Profile: &Profile{ + Name: "default_node_v24", + EnableGREASE: false, + }, + JA4CipherHash: "5b57614c22b0", + }, + { + // Linux x64 Node.js v22.17.1 (explicit profile) Profile: &Profile{ Name: "linux_x64_node_v22171", EnableGREASE: false, CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255}, Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260}, - PointFormats: []uint8{0, 1, 2}, + PointFormats: []uint16{0, 1, 2}, + Extensions: []uint16{0, 11, 10, 35, 16, 22, 23, 13, 43, 45, 51}, }, - JA4CipherHash: "a33745022dd6", // stable part - }, - { - // MacOS arm64 Node.js v22.18.0 - // Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea - // Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406 - Profile: &Profile{ - Name: "macos_arm64_node_v22180", - EnableGREASE: false, - CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255}, - Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260}, - PointFormats: []uint8{0, 1, 2}, - }, - JA4CipherHash: "a33745022dd6", // stable part (same cipher suites) + JA4CipherHash: "a33745022dd6", }, } diff --git a/backend/internal/pkg/tlsfingerprint/registry.go b/backend/internal/pkg/tlsfingerprint/registry.go deleted file mode 100644 index 6e9dc539..00000000 --- a/backend/internal/pkg/tlsfingerprint/registry.go +++ /dev/null @@ -1,171 +0,0 @@ -// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients. -package tlsfingerprint - -import ( - "log/slog" - "sort" - "sync" - - "github.com/Wei-Shaw/sub2api/internal/config" -) - -// DefaultProfileName is the name of the built-in Claude CLI profile. -const DefaultProfileName = "claude_cli_v2" - -// Registry manages TLS fingerprint profiles. -// It holds a collection of profiles that can be used for TLS fingerprint simulation. -// Profiles are selected based on account ID using modulo operation. -type Registry struct { - mu sync.RWMutex - profiles map[string]*Profile - profileNames []string // Sorted list of profile names for deterministic selection -} - -// NewRegistry creates a new TLS fingerprint profile registry. -// It initializes with the built-in default profile. -func NewRegistry() *Registry { - r := &Registry{ - profiles: make(map[string]*Profile), - profileNames: make([]string, 0), - } - - // Register the built-in default profile - r.registerBuiltinProfile() - - return r -} - -// NewRegistryFromConfig creates a new registry and loads profiles from config. -// If the config has custom profiles defined, they will be merged with the built-in default. -func NewRegistryFromConfig(cfg *config.TLSFingerprintConfig) *Registry { - r := NewRegistry() - - if cfg == nil || !cfg.Enabled { - slog.Debug("tls_registry_disabled", "reason", "disabled or no config") - return r - } - - // Load custom profiles from config - for name, profileCfg := range cfg.Profiles { - profile := &Profile{ - Name: profileCfg.Name, - EnableGREASE: profileCfg.EnableGREASE, - CipherSuites: profileCfg.CipherSuites, - Curves: profileCfg.Curves, - PointFormats: profileCfg.PointFormats, - } - - // If the profile has empty values, they will use defaults in dialer - r.RegisterProfile(name, profile) - slog.Debug("tls_registry_loaded_profile", "key", name, "name", profileCfg.Name) - } - - slog.Debug("tls_registry_initialized", "profile_count", len(r.profileNames), "profiles", r.profileNames) - return r -} - -// registerBuiltinProfile adds the default Claude CLI profile to the registry. -func (r *Registry) registerBuiltinProfile() { - defaultProfile := &Profile{ - Name: "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)", - EnableGREASE: false, // Node.js does not use GREASE - // Empty slices will cause dialer to use built-in defaults - CipherSuites: nil, - Curves: nil, - PointFormats: nil, - } - r.RegisterProfile(DefaultProfileName, defaultProfile) -} - -// RegisterProfile adds or updates a profile in the registry. -func (r *Registry) RegisterProfile(name string, profile *Profile) { - r.mu.Lock() - defer r.mu.Unlock() - - // Check if this is a new profile - _, exists := r.profiles[name] - r.profiles[name] = profile - - if !exists { - r.profileNames = append(r.profileNames, name) - // Keep names sorted for deterministic selection - sort.Strings(r.profileNames) - } -} - -// GetProfile returns a profile by name. -// Returns nil if the profile does not exist. -func (r *Registry) GetProfile(name string) *Profile { - r.mu.RLock() - defer r.mu.RUnlock() - return r.profiles[name] -} - -// GetDefaultProfile returns the built-in default profile. -func (r *Registry) GetDefaultProfile() *Profile { - return r.GetProfile(DefaultProfileName) -} - -// GetProfileByAccountID returns a profile for the given account ID. -// The profile is selected using: profileNames[accountID % len(profiles)] -// This ensures deterministic profile assignment for each account. -func (r *Registry) GetProfileByAccountID(accountID int64) *Profile { - r.mu.RLock() - defer r.mu.RUnlock() - - if len(r.profileNames) == 0 { - return nil - } - - // Use modulo to select profile index - // Use absolute value to handle negative IDs (though unlikely) - idx := accountID - if idx < 0 { - idx = -idx - } - selectedIndex := int(idx % int64(len(r.profileNames))) - selectedName := r.profileNames[selectedIndex] - - return r.profiles[selectedName] -} - -// ProfileCount returns the number of registered profiles. -func (r *Registry) ProfileCount() int { - r.mu.RLock() - defer r.mu.RUnlock() - return len(r.profiles) -} - -// ProfileNames returns a sorted list of all registered profile names. -func (r *Registry) ProfileNames() []string { - r.mu.RLock() - defer r.mu.RUnlock() - - // Return a copy to prevent modification - names := make([]string, len(r.profileNames)) - copy(names, r.profileNames) - return names -} - -// Global registry instance for convenience -var globalRegistry *Registry -var globalRegistryOnce sync.Once - -// GlobalRegistry returns the global TLS fingerprint registry. -// The registry is lazily initialized with the default profile. -func GlobalRegistry() *Registry { - globalRegistryOnce.Do(func() { - globalRegistry = NewRegistry() - }) - return globalRegistry -} - -// InitGlobalRegistry initializes the global registry with configuration. -// This should be called during application startup. -// It is safe to call multiple times; subsequent calls will update the registry. -func InitGlobalRegistry(cfg *config.TLSFingerprintConfig) *Registry { - globalRegistryOnce.Do(func() { - globalRegistry = NewRegistryFromConfig(cfg) - }) - return globalRegistry -} diff --git a/backend/internal/pkg/tlsfingerprint/registry_test.go b/backend/internal/pkg/tlsfingerprint/registry_test.go deleted file mode 100644 index 752ba0cc..00000000 --- a/backend/internal/pkg/tlsfingerprint/registry_test.go +++ /dev/null @@ -1,243 +0,0 @@ -package tlsfingerprint - -import ( - "testing" - - "github.com/Wei-Shaw/sub2api/internal/config" -) - -func TestNewRegistry(t *testing.T) { - r := NewRegistry() - - // Should have exactly one profile (the default) - if r.ProfileCount() != 1 { - t.Errorf("expected 1 profile, got %d", r.ProfileCount()) - } - - // Should have the default profile - profile := r.GetDefaultProfile() - if profile == nil { - t.Error("expected default profile to exist") - } - - // Default profile name should be in the list - names := r.ProfileNames() - if len(names) != 1 || names[0] != DefaultProfileName { - t.Errorf("expected profile names to be [%s], got %v", DefaultProfileName, names) - } -} - -func TestRegisterProfile(t *testing.T) { - r := NewRegistry() - - // Register a new profile - customProfile := &Profile{ - Name: "Custom Profile", - EnableGREASE: true, - } - r.RegisterProfile("custom", customProfile) - - // Should now have 2 profiles - if r.ProfileCount() != 2 { - t.Errorf("expected 2 profiles, got %d", r.ProfileCount()) - } - - // Should be able to retrieve the custom profile - retrieved := r.GetProfile("custom") - if retrieved == nil { - t.Fatal("expected custom profile to exist") - } - if retrieved.Name != "Custom Profile" { - t.Errorf("expected profile name 'Custom Profile', got '%s'", retrieved.Name) - } - if !retrieved.EnableGREASE { - t.Error("expected EnableGREASE to be true") - } -} - -func TestGetProfile(t *testing.T) { - r := NewRegistry() - - // Get existing profile - profile := r.GetProfile(DefaultProfileName) - if profile == nil { - t.Error("expected default profile to exist") - } - - // Get non-existing profile - nonExistent := r.GetProfile("nonexistent") - if nonExistent != nil { - t.Error("expected nil for non-existent profile") - } -} - -func TestGetProfileByAccountID(t *testing.T) { - r := NewRegistry() - - // With only default profile, all account IDs should return the same profile - for i := int64(0); i < 10; i++ { - profile := r.GetProfileByAccountID(i) - if profile == nil { - t.Errorf("expected profile for account %d, got nil", i) - } - } - - // Add more profiles - r.RegisterProfile("profile_a", &Profile{Name: "Profile A"}) - r.RegisterProfile("profile_b", &Profile{Name: "Profile B"}) - - // Now we have 3 profiles: claude_cli_v2, profile_a, profile_b - // Names are sorted, so order is: claude_cli_v2, profile_a, profile_b - expectedOrder := []string{DefaultProfileName, "profile_a", "profile_b"} - names := r.ProfileNames() - for i, name := range expectedOrder { - if names[i] != name { - t.Errorf("expected name at index %d to be %s, got %s", i, name, names[i]) - } - } - - // Test modulo selection - // Account ID 0 % 3 = 0 -> claude_cli_v2 - // Account ID 1 % 3 = 1 -> profile_a - // Account ID 2 % 3 = 2 -> profile_b - // Account ID 3 % 3 = 0 -> claude_cli_v2 - testCases := []struct { - accountID int64 - expectedName string - }{ - {0, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"}, - {1, "Profile A"}, - {2, "Profile B"}, - {3, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"}, - {4, "Profile A"}, - {5, "Profile B"}, - {100, "Profile A"}, // 100 % 3 = 1 - {-1, "Profile A"}, // |-1| % 3 = 1 - {-3, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"}, // |-3| % 3 = 0 - } - - for _, tc := range testCases { - profile := r.GetProfileByAccountID(tc.accountID) - if profile == nil { - t.Errorf("expected profile for account %d, got nil", tc.accountID) - continue - } - if profile.Name != tc.expectedName { - t.Errorf("account %d: expected profile name '%s', got '%s'", tc.accountID, tc.expectedName, profile.Name) - } - } -} - -func TestNewRegistryFromConfig(t *testing.T) { - // Test with nil config - r := NewRegistryFromConfig(nil) - if r.ProfileCount() != 1 { - t.Errorf("expected 1 profile with nil config, got %d", r.ProfileCount()) - } - - // Test with disabled config - disabledCfg := &config.TLSFingerprintConfig{ - Enabled: false, - } - r = NewRegistryFromConfig(disabledCfg) - if r.ProfileCount() != 1 { - t.Errorf("expected 1 profile with disabled config, got %d", r.ProfileCount()) - } - - // Test with enabled config and custom profiles - enabledCfg := &config.TLSFingerprintConfig{ - Enabled: true, - Profiles: map[string]config.TLSProfileConfig{ - "custom1": { - Name: "Custom Profile 1", - EnableGREASE: true, - }, - "custom2": { - Name: "Custom Profile 2", - EnableGREASE: false, - }, - }, - } - r = NewRegistryFromConfig(enabledCfg) - - // Should have 3 profiles: default + 2 custom - if r.ProfileCount() != 3 { - t.Errorf("expected 3 profiles, got %d", r.ProfileCount()) - } - - // Check custom profiles exist - custom1 := r.GetProfile("custom1") - if custom1 == nil || custom1.Name != "Custom Profile 1" { - t.Error("expected custom1 profile to exist with correct name") - } - custom2 := r.GetProfile("custom2") - if custom2 == nil || custom2.Name != "Custom Profile 2" { - t.Error("expected custom2 profile to exist with correct name") - } -} - -func TestProfileNames(t *testing.T) { - r := NewRegistry() - - // Add profiles in non-alphabetical order - r.RegisterProfile("zebra", &Profile{Name: "Zebra"}) - r.RegisterProfile("alpha", &Profile{Name: "Alpha"}) - r.RegisterProfile("beta", &Profile{Name: "Beta"}) - - names := r.ProfileNames() - - // Should be sorted alphabetically - expected := []string{"alpha", "beta", DefaultProfileName, "zebra"} - if len(names) != len(expected) { - t.Errorf("expected %d names, got %d", len(expected), len(names)) - } - for i, name := range expected { - if names[i] != name { - t.Errorf("expected name at index %d to be %s, got %s", i, name, names[i]) - } - } - - // Test that returned slice is a copy (modifying it shouldn't affect registry) - names[0] = "modified" - originalNames := r.ProfileNames() - if originalNames[0] == "modified" { - t.Error("modifying returned slice should not affect registry") - } -} - -func TestConcurrentAccess(t *testing.T) { - r := NewRegistry() - - // Run concurrent reads and writes - done := make(chan bool) - - // Writers - for i := 0; i < 10; i++ { - go func(id int) { - for j := 0; j < 100; j++ { - r.RegisterProfile("concurrent"+string(rune('0'+id)), &Profile{Name: "Concurrent"}) - } - done <- true - }(i) - } - - // Readers - for i := 0; i < 10; i++ { - go func(id int) { - for j := 0; j < 100; j++ { - _ = r.ProfileCount() - _ = r.ProfileNames() - _ = r.GetProfileByAccountID(int64(id * j)) - _ = r.GetProfile(DefaultProfileName) - } - done <- true - }(i) - } - - // Wait for all goroutines - for i := 0; i < 20; i++ { - <-done - } - - // Test should pass without data races (run with -race flag) -} diff --git a/backend/internal/pkg/tlsfingerprint/test_types_test.go b/backend/internal/pkg/tlsfingerprint/test_types_test.go index 2bbf2d22..1711100d 100644 --- a/backend/internal/pkg/tlsfingerprint/test_types_test.go +++ b/backend/internal/pkg/tlsfingerprint/test_types_test.go @@ -8,6 +8,14 @@ type FingerprintResponse struct { HTTP2 any `json:"http2"` } +// TestProfileExpectation defines expected fingerprint values for a profile. +type TestProfileExpectation struct { + Profile *Profile + ExpectedJA3 string // Expected JA3 hash (empty = don't check) + ExpectedJA4 string // Expected full JA4 (empty = don't check) + JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check) +} + // TLSInfo contains TLS fingerprint details. type TLSInfo struct { JA3 string `json:"ja3"` diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 99c9cda7..44cddb6a 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -3,6 +3,28 @@ package usagestats import "time" +const ( + ModelSourceRequested = "requested" + ModelSourceUpstream = "upstream" + ModelSourceMapping = "mapping" +) + +func IsValidModelSource(source string) bool { + switch source { + case ModelSourceRequested, ModelSourceUpstream, ModelSourceMapping: + return true + default: + return false + } +} + +func NormalizeModelSource(source string) string { + if IsValidModelSource(source) { + return source + } + return ModelSourceRequested +} + // DashboardStats 仪表盘统计 type DashboardStats struct { // 用户统计 @@ -90,6 +112,13 @@ type EndpointStat struct { ActualCost float64 `json:"actual_cost"` // 实际扣除 } +// GroupUsageSummary represents today's and cumulative cost for a single group. +type GroupUsageSummary struct { + GroupID int64 `json:"group_id"` + TodayCost float64 `json:"today_cost"` + TotalCost float64 `json:"total_cost"` +} + // GroupStat represents usage statistics for a single group type GroupStat struct { GroupID int64 `json:"group_id"` @@ -129,6 +158,25 @@ type UserSpendingRankingResponse struct { TotalTokens int64 `json:"total_tokens"` } +// UserBreakdownItem represents per-user usage breakdown within a dimension (group, model, endpoint). +type UserBreakdownItem struct { + UserID int64 `json:"user_id"` + Email string `json:"email"` + Requests int64 `json:"requests"` + TotalTokens int64 `json:"total_tokens"` + Cost float64 `json:"cost"` // 标准计费 + ActualCost float64 `json:"actual_cost"` // 实际扣除 +} + +// UserBreakdownDimension specifies the dimension to filter for user breakdown. +type UserBreakdownDimension struct { + GroupID int64 // filter by group_id (>0 to enable) + Model string // filter by model name (non-empty to enable) + ModelType string // "requested", "upstream", or "mapping" + Endpoint string // filter by endpoint value (non-empty to enable) + EndpointType string // "inbound", "upstream", or "path" +} + // APIKeyUsageTrendPoint represents API key usage trend data point type APIKeyUsageTrendPoint struct { Date string `json:"date"` diff --git a/backend/internal/pkg/usagestats/usage_log_types_test.go b/backend/internal/pkg/usagestats/usage_log_types_test.go new file mode 100644 index 00000000..95cf6069 --- /dev/null +++ b/backend/internal/pkg/usagestats/usage_log_types_test.go @@ -0,0 +1,47 @@ +package usagestats + +import "testing" + +func TestIsValidModelSource(t *testing.T) { + tests := []struct { + name string + source string + want bool + }{ + {name: "requested", source: ModelSourceRequested, want: true}, + {name: "upstream", source: ModelSourceUpstream, want: true}, + {name: "mapping", source: ModelSourceMapping, want: true}, + {name: "invalid", source: "foobar", want: false}, + {name: "empty", source: "", want: false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := IsValidModelSource(tc.source); got != tc.want { + t.Fatalf("IsValidModelSource(%q)=%v want %v", tc.source, got, tc.want) + } + }) + } +} + +func TestNormalizeModelSource(t *testing.T) { + tests := []struct { + name string + source string + want string + }{ + {name: "requested", source: ModelSourceRequested, want: ModelSourceRequested}, + {name: "upstream", source: ModelSourceUpstream, want: ModelSourceUpstream}, + {name: "mapping", source: ModelSourceMapping, want: ModelSourceMapping}, + {name: "invalid falls back", source: "foobar", want: ModelSourceRequested}, + {name: "empty falls back", source: "", want: ModelSourceRequested}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := NormalizeModelSource(tc.source); got != tc.want { + t.Fatalf("NormalizeModelSource(%q)=%q want %q", tc.source, got, tc.want) + } + }) + } +} diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 20ff7373..d45e8a12 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -56,6 +56,7 @@ var schedulerNeutralExtraKeyPrefixes = []string{ "codex_secondary_", "codex_5h_", "codex_7d_", + "passive_usage_", } var schedulerNeutralExtraKeys = map[string]struct{}{ @@ -403,6 +404,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account return nil } +func (r *accountRepository) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error { + _, err := r.client.Account.UpdateOneID(id). + SetCredentials(normalizeJSONMap(credentials)). + Save(ctx) + if err != nil { + return translatePersistenceError(err, service.ErrAccountNotFound, nil) + } + r.syncSchedulerAccountSnapshot(ctx, id) + return nil +} + func (r *accountRepository) Delete(ctx context.Context, id int64) error { groupIDs, err := r.loadAccountGroupIDs(ctx, id) if err != nil { @@ -442,10 +454,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { } func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { - return r.ListWithFilters(ctx, params, "", "", "", "", 0) + return r.ListWithFilters(ctx, params, "", "", "", "", 0, "") } -func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { +func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) { q := r.client.Account.Query() if platform != "" { @@ -473,9 +485,25 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati if search != "" { q = q.Where(dbaccount.NameContainsFold(search)) } - if groupID > 0 { + if groupID == service.AccountListGroupUngrouped { + q = q.Where(dbaccount.Not(dbaccount.HasAccountGroups())) + } else if groupID > 0 { q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID))) } + if privacyMode != "" { + q = q.Where(dbpredicate.Account(func(s *entsql.Selector) { + path := sqljson.Path("privacy_mode") + switch privacyMode { + case service.AccountPrivacyModeUnsetFilter: + s.Where(entsql.Or( + entsql.Not(sqljson.HasKey(dbaccount.FieldExtra, path)), + sqljson.ValueEQ(dbaccount.FieldExtra, "", path), + )) + default: + s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, privacyMode, path)) + } + })) + } total, err := q.Count(ctx) if err != nil { diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index e697802e..8da30c92 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -208,14 +208,16 @@ func (s *AccountRepoSuite) TestList() { func (s *AccountRepoSuite) TestListWithFilters() { tests := []struct { - name string - setup func(client *dbent.Client) - platform string - accType string - status string - search string - wantCount int - validate func(accounts []service.Account) + name string + setup func(client *dbent.Client) + platform string + accType string + status string + search string + groupID int64 + privacyMode string + wantCount int + validate func(accounts []service.Account) }{ { name: "filter_by_platform", @@ -265,6 +267,47 @@ func (s *AccountRepoSuite) TestListWithFilters() { s.Require().Contains(accounts[0].Name, "alpha") }, }, + { + name: "filter_by_ungrouped", + setup: func(client *dbent.Client) { + group := mustCreateGroup(s.T(), client, &service.Group{Name: "g-ungrouped"}) + grouped := mustCreateAccount(s.T(), client, &service.Account{Name: "grouped-account"}) + mustCreateAccount(s.T(), client, &service.Account{Name: "ungrouped-account"}) + mustBindAccountToGroup(s.T(), client, grouped.ID, group.ID, 1) + }, + groupID: service.AccountListGroupUngrouped, + wantCount: 1, + validate: func(accounts []service.Account) { + s.Require().Equal("ungrouped-account", accounts[0].Name) + s.Require().Empty(accounts[0].GroupIDs) + }, + }, + { + name: "filter_by_privacy_mode", + setup: func(client *dbent.Client) { + mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-ok", Extra: map[string]any{"privacy_mode": service.PrivacyModeTrainingOff}}) + mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-fail", Extra: map[string]any{"privacy_mode": service.PrivacyModeFailed}}) + }, + privacyMode: service.PrivacyModeTrainingOff, + wantCount: 1, + validate: func(accounts []service.Account) { + s.Require().Equal("privacy-ok", accounts[0].Name) + }, + }, + { + name: "filter_by_privacy_mode_unset", + setup: func(client *dbent.Client) { + mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-unset", Extra: nil}) + mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-empty", Extra: map[string]any{"privacy_mode": ""}}) + mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-set", Extra: map[string]any{"privacy_mode": service.PrivacyModeTrainingOff}}) + }, + privacyMode: service.AccountPrivacyModeUnsetFilter, + wantCount: 2, + validate: func(accounts []service.Account) { + names := []string{accounts[0].Name, accounts[1].Name} + s.ElementsMatch([]string{"privacy-unset", "privacy-empty"}, names) + }, + }, } for _, tt := range tests { @@ -277,7 +320,7 @@ func (s *AccountRepoSuite) TestListWithFilters() { tt.setup(client) - accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0) + accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID, tt.privacyMode) s.Require().NoError(err) s.Require().Len(accounts, tt.wantCount) if tt.validate != nil { @@ -344,7 +387,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { s.Require().Len(got.Groups, 1, "expected Groups to be populated") s.Require().Equal(group.ID, got.Groups[0].ID) - accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0) + accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0, "") s.Require().NoError(err, "ListWithFilters") s.Require().Equal(int64(1), page.Total) s.Require().Len(accounts, 1) diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 4c7f38a8..ade0d464 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -3,6 +3,7 @@ package repository import ( "context" "database/sql" + "fmt" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -257,9 +258,12 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro } func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error { + // 存在唯一键约束 生成tombstone key 用来释放原key,长度远小于 128,满足 schema 限制 + tombstoneKey := fmt.Sprintf("__deleted__%d__%d", id, time.Now().UnixNano()) // 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。 affected, err := r.client.APIKey.Update(). Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). + SetKey(tombstoneKey). SetDeletedAt(time.Now()). Save(ctx) if err != nil { @@ -409,6 +413,16 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in return int64(n), err } +// UpdateGroupIDByUserAndGroup 将用户下绑定 oldGroupID 的所有 Key 迁移到 newGroupID +func (r *apiKeyRepository) UpdateGroupIDByUserAndGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (int64, error) { + client := clientFromContext(ctx, r.client) + n, err := client.APIKey.Update(). + Where(apikey.UserIDEQ(userID), apikey.GroupIDEQ(oldGroupID), apikey.DeletedAtIsNil()). + SetGroupID(newGroupID). + Save(ctx) + return int64(n), err +} + // CountByGroupID 获取分组的 API Key 数量 func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { count, err := r.activeQuery().Where(apikey.GroupIDEQ(groupID)).Count(ctx) @@ -648,6 +662,8 @@ func groupEntityToService(g *dbent.Group) *service.Group { SupportedModelScopes: g.SupportedModelScopes, SortOrder: g.SortOrder, AllowMessagesDispatch: g.AllowMessagesDispatch, + RequireOAuthOnly: g.RequireOauthOnly, + RequirePrivacySet: g.RequirePrivacySet, DefaultMappedModel: g.DefaultMappedModel, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go index a8989ff2..7d5c1826 100644 --- a/backend/internal/repository/api_key_repo_integration_test.go +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -151,6 +151,31 @@ func (s *APIKeyRepoSuite) TestDelete() { s.Require().Error(err, "expected error after delete") } +func (s *APIKeyRepoSuite) TestCreate_AfterSoftDelete_AllowsSameKey() { + user := s.mustCreateUser("recreate-after-soft-delete@test.com") + const reusedKey = "sk-reuse-after-soft-delete" + + first := &service.APIKey{ + UserID: user.ID, + Key: reusedKey, + Name: "First Key", + Status: service.StatusActive, + } + s.Require().NoError(s.repo.Create(s.ctx, first), "create first key") + + s.Require().NoError(s.repo.Delete(s.ctx, first.ID), "soft delete first key") + + second := &service.APIKey{ + UserID: user.ID, + Key: reusedKey, + Name: "Second Key", + Status: service.StatusActive, + } + s.Require().NoError(s.repo.Create(s.ctx, second), "create second key with same key") + s.Require().NotZero(second.ID) + s.Require().NotEqual(first.ID, second.ID, "recreated key should be a new row") +} + // --- ListByUserID / CountByUserID --- func (s *APIKeyRepoSuite) TestListByUserID() { diff --git a/backend/internal/repository/backup_s3_store.go b/backend/internal/repository/backup_s3_store.go index ba5434f5..5d419f57 100644 --- a/backend/internal/repository/backup_s3_store.go +++ b/backend/internal/repository/backup_s3_store.go @@ -57,6 +57,7 @@ func NewS3BackupStoreFactory() service.BackupObjectStoreFactory { func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) { // 读取全部内容以获取大小(S3 PutObject 需要知道内容长度) + // 注意:阿里云 OSS 不兼容 s3manager 分片上传的签名方式,因此使用 PutObject data, err := io.ReadAll(body) if err != nil { return 0, fmt.Errorf("read body: %w", err) diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go index 4fbdae14..6922b4c8 100644 --- a/backend/internal/repository/billing_cache.go +++ b/backend/internal/repository/billing_cache.go @@ -20,6 +20,11 @@ const ( billingCacheTTL = 5 * time.Minute billingCacheJitter = 30 * time.Second rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window + + // Rate limit window durations — must match service.RateLimitWindow* constants. + rateLimitWindow5h = 5 * time.Hour + rateLimitWindow1d = 24 * time.Hour + rateLimitWindow7d = 7 * 24 * time.Hour ) // jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩 @@ -90,17 +95,40 @@ var ( return 1 `) - // updateRateLimitUsageScript atomically increments all three rate limit usage counters. - // Returns 0 if the key doesn't exist (cache miss), 1 on success. + // updateRateLimitUsageScript atomically increments all three rate limit usage counters + // with window expiration checking. If a window has expired, its usage is reset to cost + // (instead of accumulated) and the window timestamp is updated, matching the DB-side + // IncrementRateLimitUsage semantics. + // + // ARGV: [1]=cost, [2]=ttl_seconds, [3]=now_unix, [4]=window_5h_seconds, [5]=window_1d_seconds, [6]=window_7d_seconds updateRateLimitUsageScript = redis.NewScript(` local exists = redis.call('EXISTS', KEYS[1]) if exists == 0 then return 0 end local cost = tonumber(ARGV[1]) - redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost) - redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost) - redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost) + local now = tonumber(ARGV[3]) + local win5h = tonumber(ARGV[4]) + local win1d = tonumber(ARGV[5]) + local win7d = tonumber(ARGV[6]) + + -- Helper: check if window is expired and update usage + window accordingly + -- Returns nothing, modifies the hash in-place. + local function update_window(usage_field, window_field, window_duration) + local w = tonumber(redis.call('HGET', KEYS[1], window_field) or 0) + if w == 0 or (now - w) >= window_duration then + -- Window expired or never started: reset usage to cost, start new window + redis.call('HSET', KEYS[1], usage_field, tostring(cost)) + redis.call('HSET', KEYS[1], window_field, tostring(now)) + else + -- Window still valid: accumulate + redis.call('HINCRBYFLOAT', KEYS[1], usage_field, cost) + end + end + + update_window('usage_5h', 'window_5h', win5h) + update_window('usage_1d', 'window_1d', win1d) + update_window('usage_7d', 'window_7d', win7d) redis.call('EXPIRE', KEYS[1], ARGV[2]) return 1 `) @@ -280,7 +308,15 @@ func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error { key := billingRateLimitKey(keyID) - _, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result() + now := time.Now().Unix() + _, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, + cost, + int(rateLimitCacheTTL.Seconds()), + now, + int(rateLimitWindow5h.Seconds()), + int(rateLimitWindow1d.Seconds()), + int(rateLimitWindow7d.Seconds()), + ).Result() if err != nil && !errors.Is(err, redis.Nil) { log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err) return err diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index b754bd55..fee5c645 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -212,7 +212,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod SetContext(ctx). SetHeader("Accept", "application/json, text/plain, */*"). SetHeader("Content-Type", "application/json"). - SetHeader("User-Agent", "axios/1.8.4"). + SetHeader("User-Agent", "axios/1.13.6"). SetBody(reqBody). SetSuccessResult(&tokenResp). Post(s.tokenURL) @@ -250,7 +250,7 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro SetContext(ctx). SetHeader("Accept", "application/json, text/plain, */*"). SetHeader("Content-Type", "application/json"). - SetHeader("User-Agent", "axios/1.8.4"). + SetHeader("User-Agent", "axios/1.13.6"). SetBody(reqBody). SetSuccessResult(&tokenResp). Post(s.tokenURL) diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index 1264f6bb..b44adde2 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -68,10 +68,9 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se var resp *http.Response - // 如果启用 TLS 指纹且有 HTTPUpstream,使用 DoWithTLS - if opts.EnableTLSFingerprint && s.httpUpstream != nil { - // accountConcurrency 传 0 使用默认连接池配置,usage 请求不需要特殊的并发设置 - resp, err = s.httpUpstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, 0, true) + // 如果有 TLS Profile 且有 HTTPUpstream,使用 DoWithTLS + if opts.TLSProfile != nil && s.httpUpstream != nil { + resp, err = s.httpUpstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, 0, opts.TLSProfile) if err != nil { return nil, fmt.Errorf("request with TLS fingerprint failed: %w", err) } diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index c195f1f1..3cfd649b 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -61,6 +61,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetMcpXMLInject(groupIn.MCPXMLInject). SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). + SetRequireOauthOnly(groupIn.RequireOAuthOnly). + SetRequirePrivacySet(groupIn.RequirePrivacySet). SetDefaultMappedModel(groupIn.DefaultMappedModel) // 设置模型路由配置 @@ -88,8 +90,9 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group if err != nil { return nil, err } - count, _ := r.GetAccountCount(ctx, out.ID) - out.AccountCount = count + total, active, _ := r.GetAccountCount(ctx, out.ID) + out.AccountCount = total + out.ActiveAccountCount = active return out, nil } @@ -129,6 +132,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetMcpXMLInject(groupIn.MCPXMLInject). SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). + SetRequireOauthOnly(groupIn.RequireOAuthOnly). + SetRequirePrivacySet(groupIn.RequirePrivacySet). SetDefaultMappedModel(groupIn.DefaultMappedModel) // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 @@ -256,7 +261,10 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination counts, err := r.loadAccountCounts(ctx, groupIDs) if err == nil { for i := range outGroups { - outGroups[i].AccountCount = counts[outGroups[i].ID] + c := counts[outGroups[i].ID] + outGroups[i].AccountCount = c.Total + outGroups[i].ActiveAccountCount = c.Active + outGroups[i].RateLimitedAccountCount = c.RateLimited } } @@ -283,7 +291,10 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro counts, err := r.loadAccountCounts(ctx, groupIDs) if err == nil { for i := range outGroups { - outGroups[i].AccountCount = counts[outGroups[i].ID] + c := counts[outGroups[i].ID] + outGroups[i].AccountCount = c.Total + outGroups[i].ActiveAccountCount = c.Active + outGroups[i].RateLimitedAccountCount = c.RateLimited } } @@ -310,7 +321,10 @@ func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform str counts, err := r.loadAccountCounts(ctx, groupIDs) if err == nil { for i := range outGroups { - outGroups[i].AccountCount = counts[outGroups[i].ID] + c := counts[outGroups[i].ID] + outGroups[i].AccountCount = c.Total + outGroups[i].ActiveAccountCount = c.Active + outGroups[i].RateLimitedAccountCount = c.RateLimited } } @@ -369,12 +383,20 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int return result, nil } -func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { - var count int64 - if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil { - return 0, err - } - return count, nil +func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) { + var rateLimited int64 + err = scanSingleRow(ctx, r.sql, + `SELECT COUNT(*), + COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true), + COUNT(*) FILTER (WHERE a.status = 'active' AND ( + a.rate_limit_reset_at > NOW() OR + a.overload_until > NOW() OR + a.temp_unschedulable_until > NOW() + )) + FROM account_groups ag JOIN accounts a ON a.id = ag.account_id + WHERE ag.group_id = $1`, + []any{groupID}, &total, &active, &rateLimited) + return } func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { @@ -500,15 +522,32 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, return affectedUserIDs, nil } -func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) { - counts = make(map[int64]int64, len(groupIDs)) +type groupAccountCounts struct { + Total int64 + Active int64 + RateLimited int64 +} + +func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) { + counts = make(map[int64]groupAccountCounts, len(groupIDs)) if len(groupIDs) == 0 { return counts, nil } rows, err := r.sql.QueryContext( ctx, - "SELECT group_id, COUNT(*) FROM account_groups WHERE group_id = ANY($1) GROUP BY group_id", + `SELECT ag.group_id, + COUNT(*) AS total, + COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active, + COUNT(*) FILTER (WHERE a.status = 'active' AND ( + a.rate_limit_reset_at > NOW() OR + a.overload_until > NOW() OR + a.temp_unschedulable_until > NOW() + )) AS rate_limited + FROM account_groups ag + JOIN accounts a ON a.id = ag.account_id + WHERE ag.group_id = ANY($1) + GROUP BY ag.group_id`, pq.Array(groupIDs), ) if err != nil { @@ -523,11 +562,11 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6 for rows.Next() { var groupID int64 - var count int64 - if err = rows.Scan(&groupID, &count); err != nil { + var c groupAccountCounts + if err = rows.Scan(&groupID, &c.Total, &c.Active, &c.RateLimited); err != nil { return nil, err } - counts[groupID] = count + counts[groupID] = c } if err = rows.Err(); err != nil { return nil, err diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index 4a849a46..eccf5cea 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -603,7 +603,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() { _, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2) s.Require().NoError(err) - count, err := s.repo.GetAccountCount(s.ctx, group.ID) + count, _, err := s.repo.GetAccountCount(s.ctx, group.ID) s.Require().NoError(err, "GetAccountCount") s.Require().Equal(int64(2), count) } @@ -619,7 +619,7 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() { } s.Require().NoError(s.repo.Create(s.ctx, group)) - count, err := s.repo.GetAccountCount(s.ctx, group.ID) + count, _, err := s.repo.GetAccountCount(s.ctx, group.ID) s.Require().NoError(err) s.Require().Zero(count) } @@ -651,7 +651,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() { s.Require().NoError(err, "DeleteAccountGroupsByGroupID") s.Require().Equal(int64(1), affected, "expected 1 affected row") - count, err := s.repo.GetAccountCount(s.ctx, g.ID) + count, _, err := s.repo.GetAccountCount(s.ctx, g.ID) s.Require().NoError(err, "GetAccountCount") s.Require().Equal(int64(0), count, "expected 0 account groups") } @@ -692,7 +692,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() { s.Require().NoError(err) s.Require().Equal(int64(3), affected) - count, _ := s.repo.GetAccountCount(s.ctx, g.ID) + count, _, _ := s.repo.GetAccountCount(s.ctx, g.ID) s.Require().Zero(count) } diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index a4674c1a..4309e997 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -1,6 +1,8 @@ package repository import ( + "compress/flate" + "compress/gzip" "errors" "fmt" "io" @@ -13,6 +15,8 @@ import ( "sync/atomic" "time" + "github.com/andybalholm/brotli" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" @@ -143,6 +147,9 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i return nil, err } + // 如果上游返回了压缩内容,解压后再交给业务层 + decompressResponseBody(resp) + // 包装响应体,在关闭时自动减少计数并更新时间戳 // 这确保了流式响应(如 SSE)在完全读取前不会被淘汰 resp.Body = wrapTrackedBody(resp.Body, func() { @@ -154,26 +161,14 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i } // DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求 -// 根据 enableTLSFingerprint 参数决定是否使用 TLS 指纹 // -// 参数: -// - req: HTTP 请求对象 -// - proxyURL: 代理地址,空字符串表示直连 -// - accountID: 账户 ID,用于账户级隔离和 TLS 指纹模板选择 -// - accountConcurrency: 账户并发限制,用于动态调整连接池大小 -// - enableTLSFingerprint: 是否启用 TLS 指纹伪装 -// -// TLS 指纹说明: -// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹 -// - 指纹模板根据 accountID % len(profiles) 自动选择 -// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景 -func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { - // 如果未启用 TLS 指纹,直接使用标准请求路径 - if !enableTLSFingerprint { +// profile 为 nil 时不启用 TLS 指纹,行为与 Do 方法相同。 +// profile 非 nil 时使用指定的 Profile 进行 TLS 指纹伪装。 +func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { + if profile == nil { return s.Do(req, proxyURL, accountID, accountConcurrency) } - // TLS 指纹已启用,记录调试日志 targetHost := "" if req != nil && req.URL != nil { targetHost = req.URL.Host @@ -182,43 +177,28 @@ func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, acco if proxyURL != "" { proxyInfo = proxyURL } - slog.Debug("tls_fingerprint_enabled", "account_id", accountID, "target", targetHost, "proxy", proxyInfo) + slog.Debug("tls_fingerprint_enabled", "account_id", accountID, "target", targetHost, "proxy", proxyInfo, "profile", profile.Name) if err := s.validateRequestHost(req); err != nil { return nil, err } - // 获取 TLS 指纹 Profile - registry := tlsfingerprint.GlobalRegistry() - profile := registry.GetProfileByAccountID(accountID) - if profile == nil { - // 如果获取不到 profile,回退到普通请求 - slog.Debug("tls_fingerprint_no_profile", "account_id", accountID, "fallback", "standard_request") - return s.Do(req, proxyURL, accountID, accountConcurrency) - } - - slog.Debug("tls_fingerprint_using_profile", "account_id", accountID, "profile", profile.Name, "grease", profile.EnableGREASE) - - // 获取或创建带 TLS 指纹的客户端 entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile) if err != nil { slog.Debug("tls_fingerprint_acquire_client_failed", "account_id", accountID, "error", err) return nil, err } - // 执行请求 resp, err := entry.client.Do(req) if err != nil { - // 请求失败,立即减少计数 atomic.AddInt64(&entry.inFlight, -1) atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano()) slog.Debug("tls_fingerprint_request_failed", "account_id", accountID, "error", err) return nil, err } - slog.Debug("tls_fingerprint_request_success", "account_id", accountID, "status", resp.StatusCode) + decompressResponseBody(resp) - // 包装响应体,在关闭时自动减少计数并更新时间戳 resp.Body = wrapTrackedBody(resp.Body, func() { atomic.AddInt64(&entry.inFlight, -1) atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano()) @@ -884,3 +864,56 @@ func wrapTrackedBody(body io.ReadCloser, onClose func()) io.ReadCloser { } return &trackedBody{ReadCloser: body, onClose: onClose} } + +// decompressResponseBody 根据 Content-Encoding 解压响应体。 +// 当请求显式设置了 accept-encoding 时,Go 的 Transport 不会自动解压,需要手动处理。 +// 解压成功后会删除 Content-Encoding 和 Content-Length header(长度已不准确)。 +func decompressResponseBody(resp *http.Response) { + if resp == nil || resp.Body == nil { + return + } + ce := strings.ToLower(strings.TrimSpace(resp.Header.Get("Content-Encoding"))) + if ce == "" { + return + } + + var reader io.Reader + switch ce { + case "gzip": + gr, err := gzip.NewReader(resp.Body) + if err != nil { + return // 解压失败,保持原样 + } + reader = gr + case "br": + reader = brotli.NewReader(resp.Body) + case "deflate": + reader = flate.NewReader(resp.Body) + default: + return + } + + originalBody := resp.Body + resp.Body = &decompressedBody{reader: reader, closer: originalBody} + resp.Header.Del("Content-Encoding") + resp.Header.Del("Content-Length") // 解压后长度不确定 + resp.ContentLength = -1 +} + +// decompressedBody 组合解压 reader 和原始 body 的 close。 +type decompressedBody struct { + reader io.Reader + closer io.Closer +} + +func (d *decompressedBody) Read(p []byte) (int, error) { + return d.reader.Read(p) +} + +func (d *decompressedBody) Close() error { + // 如果 reader 本身也是 Closer(如 gzip.Reader),先关闭它 + if rc, ok := d.reader.(io.Closer); ok { + _ = rc.Close() + } + return d.closer.Close() +} diff --git a/backend/internal/repository/internal500_counter_cache.go b/backend/internal/repository/internal500_counter_cache.go new file mode 100644 index 00000000..13b0faa8 --- /dev/null +++ b/backend/internal/repository/internal500_counter_cache.go @@ -0,0 +1,55 @@ +package repository + +import ( + "context" + "fmt" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const ( + internal500CounterPrefix = "internal500_count:account:" + internal500CounterTTLSeconds = 86400 // 24 小时兜底 +) + +// internal500CounterIncrScript 使用 Lua 脚本原子性地增加计数并返回当前值 +// 如果 key 不存在,则创建并设置过期时间 +var internal500CounterIncrScript = redis.NewScript(` + local key = KEYS[1] + local ttl = tonumber(ARGV[1]) + + local count = redis.call('INCR', key) + if count == 1 then + redis.call('EXPIRE', key, ttl) + end + + return count +`) + +type internal500CounterCache struct { + rdb *redis.Client +} + +// NewInternal500CounterCache 创建 INTERNAL 500 连续失败计数器缓存实例 +func NewInternal500CounterCache(rdb *redis.Client) service.Internal500CounterCache { + return &internal500CounterCache{rdb: rdb} +} + +// IncrementInternal500Count 原子递增计数并返回当前值 +func (c *internal500CounterCache) IncrementInternal500Count(ctx context.Context, accountID int64) (int64, error) { + key := fmt.Sprintf("%s%d", internal500CounterPrefix, accountID) + + result, err := internal500CounterIncrScript.Run(ctx, c.rdb, []string{key}, internal500CounterTTLSeconds).Int64() + if err != nil { + return 0, fmt.Errorf("increment internal500 count: %w", err) + } + + return result, nil +} + +// ResetInternal500Count 清零计数器(成功响应时调用) +func (c *internal500CounterCache) ResetInternal500Count(ctx context.Context, accountID int64) error { + key := fmt.Sprintf("%s%d", internal500CounterPrefix, accountID) + return c.rdb.Del(ctx, key).Err() +} diff --git a/backend/internal/repository/ops_repo.go b/backend/internal/repository/ops_repo.go index 02ca1a3b..5154b269 100644 --- a/backend/internal/repository/ops_repo.go +++ b/backend/internal/repository/ops_repo.go @@ -29,6 +29,11 @@ INSERT INTO ops_error_logs ( model, request_path, stream, + inbound_endpoint, + upstream_endpoint, + requested_model, + upstream_model, + request_type, user_agent, error_phase, error_type, @@ -57,7 +62,7 @@ INSERT INTO ops_error_logs ( retry_count, created_at ) VALUES ( - $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38 + $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43 )` func NewOpsRepository(db *sql.DB) service.OpsRepository { @@ -140,6 +145,11 @@ func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any { opsNullString(input.Model), opsNullString(input.RequestPath), input.Stream, + opsNullString(input.InboundEndpoint), + opsNullString(input.UpstreamEndpoint), + opsNullString(input.RequestedModel), + opsNullString(input.UpstreamModel), + opsNullInt16(input.RequestType), opsNullString(input.UserAgent), input.ErrorPhase, input.ErrorType, @@ -231,7 +241,12 @@ SELECT COALESCE(g.name, ''), CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END, COALESCE(e.request_path, ''), - e.stream + e.stream, + COALESCE(e.inbound_endpoint, ''), + COALESCE(e.upstream_endpoint, ''), + COALESCE(e.requested_model, ''), + COALESCE(e.upstream_model, ''), + e.request_type FROM ops_error_logs e LEFT JOIN accounts a ON e.account_id = a.id LEFT JOIN groups g ON e.group_id = g.id @@ -263,6 +278,7 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) var resolvedBy sql.NullInt64 var resolvedByName string var resolvedRetryID sql.NullInt64 + var requestType sql.NullInt64 if err := rows.Scan( &item.ID, &item.CreatedAt, @@ -294,6 +310,11 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) &clientIP, &item.RequestPath, &item.Stream, + &item.InboundEndpoint, + &item.UpstreamEndpoint, + &item.RequestedModel, + &item.UpstreamModel, + &requestType, ); err != nil { return nil, err } @@ -334,6 +355,10 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) item.GroupID = &v } item.GroupName = groupName + if requestType.Valid { + v := int16(requestType.Int64) + item.RequestType = &v + } out = append(out, &item) } if err := rows.Err(); err != nil { @@ -393,6 +418,11 @@ SELECT CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END, COALESCE(e.request_path, ''), e.stream, + COALESCE(e.inbound_endpoint, ''), + COALESCE(e.upstream_endpoint, ''), + COALESCE(e.requested_model, ''), + COALESCE(e.upstream_model, ''), + e.request_type, COALESCE(e.user_agent, ''), e.auth_latency_ms, e.routing_latency_ms, @@ -427,6 +457,7 @@ LIMIT 1` var responseLatency sql.NullInt64 var ttft sql.NullInt64 var requestBodyBytes sql.NullInt64 + var requestType sql.NullInt64 err := r.db.QueryRowContext(ctx, q, id).Scan( &out.ID, @@ -464,6 +495,11 @@ LIMIT 1` &clientIP, &out.RequestPath, &out.Stream, + &out.InboundEndpoint, + &out.UpstreamEndpoint, + &out.RequestedModel, + &out.UpstreamModel, + &requestType, &out.UserAgent, &authLatency, &routingLatency, @@ -540,6 +576,10 @@ LIMIT 1` v := int(requestBodyBytes.Int64) out.RequestBodyBytes = &v } + if requestType.Valid { + v := int16(requestType.Int64) + out.RequestType = &v + } // Normalize request_body to empty string when stored as JSON null. out.RequestBody = strings.TrimSpace(out.RequestBody) @@ -1479,3 +1519,10 @@ func opsNullInt(v any) any { return sql.NullInt64{} } } + +func opsNullInt16(v *int16) any { + if v == nil { + return sql.NullInt64{} + } + return sql.NullInt64{Int64: int64(*v), Valid: true} +} diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index b4aeab71..d877abde 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -40,7 +40,7 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { } const ( - defaultProxyProbeTimeout = 30 * time.Second + defaultProxyProbeTimeout = 10 * time.Second defaultProxyProbeResponseMaxBytes = int64(1024 * 1024) ) diff --git a/backend/internal/repository/tls_fingerprint_profile_cache.go b/backend/internal/repository/tls_fingerprint_profile_cache.go new file mode 100644 index 00000000..81ee0434 --- /dev/null +++ b/backend/internal/repository/tls_fingerprint_profile_cache.go @@ -0,0 +1,122 @@ +package repository + +import ( + "context" + "encoding/json" + "log/slog" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const ( + tlsFPProfileCacheKey = "tls_fingerprint_profiles" + tlsFPProfilePubSubKey = "tls_fingerprint_profiles_updated" + tlsFPProfileCacheTTL = 24 * time.Hour +) + +type tlsFingerprintProfileCache struct { + rdb *redis.Client + localCache []*model.TLSFingerprintProfile + localMu sync.RWMutex +} + +// NewTLSFingerprintProfileCache 创建 TLS 指纹模板缓存 +func NewTLSFingerprintProfileCache(rdb *redis.Client) service.TLSFingerprintProfileCache { + return &tlsFingerprintProfileCache{ + rdb: rdb, + } +} + +// Get 从缓存获取模板列表 +func (c *tlsFingerprintProfileCache) Get(ctx context.Context) ([]*model.TLSFingerprintProfile, bool) { + c.localMu.RLock() + if c.localCache != nil { + profiles := c.localCache + c.localMu.RUnlock() + return profiles, true + } + c.localMu.RUnlock() + + data, err := c.rdb.Get(ctx, tlsFPProfileCacheKey).Bytes() + if err != nil { + if err != redis.Nil { + slog.Warn("tls_fp_profile_cache_get_failed", "error", err) + } + return nil, false + } + + var profiles []*model.TLSFingerprintProfile + if err := json.Unmarshal(data, &profiles); err != nil { + slog.Warn("tls_fp_profile_cache_unmarshal_failed", "error", err) + return nil, false + } + + c.localMu.Lock() + c.localCache = profiles + c.localMu.Unlock() + + return profiles, true +} + +// Set 设置缓存 +func (c *tlsFingerprintProfileCache) Set(ctx context.Context, profiles []*model.TLSFingerprintProfile) error { + data, err := json.Marshal(profiles) + if err != nil { + return err + } + + if err := c.rdb.Set(ctx, tlsFPProfileCacheKey, data, tlsFPProfileCacheTTL).Err(); err != nil { + return err + } + + c.localMu.Lock() + c.localCache = profiles + c.localMu.Unlock() + + return nil +} + +// Invalidate 使缓存失效 +func (c *tlsFingerprintProfileCache) Invalidate(ctx context.Context) error { + c.localMu.Lock() + c.localCache = nil + c.localMu.Unlock() + + return c.rdb.Del(ctx, tlsFPProfileCacheKey).Err() +} + +// NotifyUpdate 通知其他实例刷新缓存 +func (c *tlsFingerprintProfileCache) NotifyUpdate(ctx context.Context) error { + return c.rdb.Publish(ctx, tlsFPProfilePubSubKey, "refresh").Err() +} + +// SubscribeUpdates 订阅缓存更新通知 +func (c *tlsFingerprintProfileCache) SubscribeUpdates(ctx context.Context, handler func()) { + go func() { + sub := c.rdb.Subscribe(ctx, tlsFPProfilePubSubKey) + defer func() { _ = sub.Close() }() + + ch := sub.Channel() + for { + select { + case <-ctx.Done(): + slog.Debug("tls_fp_profile_cache_subscriber_stopped", "reason", "context_done") + return + case msg := <-ch: + if msg == nil { + slog.Warn("tls_fp_profile_cache_subscriber_stopped", "reason", "channel_closed") + return + } + c.localMu.Lock() + c.localCache = nil + c.localMu.Unlock() + + handler() + } + } + }() +} diff --git a/backend/internal/repository/tls_fingerprint_profile_repo.go b/backend/internal/repository/tls_fingerprint_profile_repo.go new file mode 100644 index 00000000..40bebdc3 --- /dev/null +++ b/backend/internal/repository/tls_fingerprint_profile_repo.go @@ -0,0 +1,213 @@ +package repository + +import ( + "context" + + "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile" + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type tlsFingerprintProfileRepository struct { + client *ent.Client +} + +// NewTLSFingerprintProfileRepository 创建 TLS 指纹模板仓库 +func NewTLSFingerprintProfileRepository(client *ent.Client) service.TLSFingerprintProfileRepository { + return &tlsFingerprintProfileRepository{client: client} +} + +// List 获取所有模板 +func (r *tlsFingerprintProfileRepository) List(ctx context.Context) ([]*model.TLSFingerprintProfile, error) { + profiles, err := r.client.TLSFingerprintProfile.Query(). + Order(ent.Asc(tlsfingerprintprofile.FieldName)). + All(ctx) + if err != nil { + return nil, err + } + + result := make([]*model.TLSFingerprintProfile, len(profiles)) + for i, p := range profiles { + result[i] = r.toModel(p) + } + return result, nil +} + +// GetByID 根据 ID 获取模板 +func (r *tlsFingerprintProfileRepository) GetByID(ctx context.Context, id int64) (*model.TLSFingerprintProfile, error) { + p, err := r.client.TLSFingerprintProfile.Get(ctx, id) + if err != nil { + if ent.IsNotFound(err) { + return nil, nil + } + return nil, err + } + return r.toModel(p), nil +} + +// Create 创建模板 +func (r *tlsFingerprintProfileRepository) Create(ctx context.Context, p *model.TLSFingerprintProfile) (*model.TLSFingerprintProfile, error) { + builder := r.client.TLSFingerprintProfile.Create(). + SetName(p.Name). + SetEnableGrease(p.EnableGREASE) + + if p.Description != nil { + builder.SetDescription(*p.Description) + } + if len(p.CipherSuites) > 0 { + builder.SetCipherSuites(p.CipherSuites) + } + if len(p.Curves) > 0 { + builder.SetCurves(p.Curves) + } + if len(p.PointFormats) > 0 { + builder.SetPointFormats(p.PointFormats) + } + if len(p.SignatureAlgorithms) > 0 { + builder.SetSignatureAlgorithms(p.SignatureAlgorithms) + } + if len(p.ALPNProtocols) > 0 { + builder.SetAlpnProtocols(p.ALPNProtocols) + } + if len(p.SupportedVersions) > 0 { + builder.SetSupportedVersions(p.SupportedVersions) + } + if len(p.KeyShareGroups) > 0 { + builder.SetKeyShareGroups(p.KeyShareGroups) + } + if len(p.PSKModes) > 0 { + builder.SetPskModes(p.PSKModes) + } + if len(p.Extensions) > 0 { + builder.SetExtensions(p.Extensions) + } + + created, err := builder.Save(ctx) + if err != nil { + return nil, err + } + return r.toModel(created), nil +} + +// Update 更新模板 +func (r *tlsFingerprintProfileRepository) Update(ctx context.Context, p *model.TLSFingerprintProfile) (*model.TLSFingerprintProfile, error) { + builder := r.client.TLSFingerprintProfile.UpdateOneID(p.ID). + SetName(p.Name). + SetEnableGrease(p.EnableGREASE) + + if p.Description != nil { + builder.SetDescription(*p.Description) + } else { + builder.ClearDescription() + } + + if len(p.CipherSuites) > 0 { + builder.SetCipherSuites(p.CipherSuites) + } else { + builder.ClearCipherSuites() + } + if len(p.Curves) > 0 { + builder.SetCurves(p.Curves) + } else { + builder.ClearCurves() + } + if len(p.PointFormats) > 0 { + builder.SetPointFormats(p.PointFormats) + } else { + builder.ClearPointFormats() + } + if len(p.SignatureAlgorithms) > 0 { + builder.SetSignatureAlgorithms(p.SignatureAlgorithms) + } else { + builder.ClearSignatureAlgorithms() + } + if len(p.ALPNProtocols) > 0 { + builder.SetAlpnProtocols(p.ALPNProtocols) + } else { + builder.ClearAlpnProtocols() + } + if len(p.SupportedVersions) > 0 { + builder.SetSupportedVersions(p.SupportedVersions) + } else { + builder.ClearSupportedVersions() + } + if len(p.KeyShareGroups) > 0 { + builder.SetKeyShareGroups(p.KeyShareGroups) + } else { + builder.ClearKeyShareGroups() + } + if len(p.PSKModes) > 0 { + builder.SetPskModes(p.PSKModes) + } else { + builder.ClearPskModes() + } + if len(p.Extensions) > 0 { + builder.SetExtensions(p.Extensions) + } else { + builder.ClearExtensions() + } + + updated, err := builder.Save(ctx) + if err != nil { + return nil, err + } + return r.toModel(updated), nil +} + +// Delete 删除模板 +func (r *tlsFingerprintProfileRepository) Delete(ctx context.Context, id int64) error { + return r.client.TLSFingerprintProfile.DeleteOneID(id).Exec(ctx) +} + +// toModel 将 Ent 实体转换为服务模型 +func (r *tlsFingerprintProfileRepository) toModel(e *ent.TLSFingerprintProfile) *model.TLSFingerprintProfile { + p := &model.TLSFingerprintProfile{ + ID: e.ID, + Name: e.Name, + Description: e.Description, + EnableGREASE: e.EnableGrease, + CipherSuites: e.CipherSuites, + Curves: e.Curves, + PointFormats: e.PointFormats, + SignatureAlgorithms: e.SignatureAlgorithms, + ALPNProtocols: e.AlpnProtocols, + SupportedVersions: e.SupportedVersions, + KeyShareGroups: e.KeyShareGroups, + PSKModes: e.PskModes, + Extensions: e.Extensions, + CreatedAt: e.CreatedAt, + UpdatedAt: e.UpdatedAt, + } + + // 确保切片不为 nil + if p.CipherSuites == nil { + p.CipherSuites = []uint16{} + } + if p.Curves == nil { + p.Curves = []uint16{} + } + if p.PointFormats == nil { + p.PointFormats = []uint16{} + } + if p.SignatureAlgorithms == nil { + p.SignatureAlgorithms = []uint16{} + } + if p.ALPNProtocols == nil { + p.ALPNProtocols = []string{} + } + if p.SupportedVersions == nil { + p.SupportedVersions = []uint16{} + } + if p.KeyShareGroups == nil { + p.KeyShareGroups = []uint16{} + } + if p.PSKModes == nil { + p.PSKModes = []uint16{} + } + if p.Extensions == nil { + p.Extensions = []uint16{} + } + + return p +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index dc70812d..e4da825b 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -28,49 +28,64 @@ import ( gocache "github.com/patrickmn/go-cache" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" +// usageLogInsertArgTypes must stay in the same order as: +// 1. prepareUsageLogInsert().args +// 2. every INSERT/CTE VALUES column list in this file +// 3. execUsageLogInsertNoResult placeholder positions +// 4. scanUsageLog selected column order (via usageLogSelectColumns) +// +// When adding a usage_logs column, update all of those call sites together. var usageLogInsertArgTypes = [...]string{ - "bigint", - "bigint", - "bigint", - "text", - "text", - "bigint", - "bigint", - "integer", - "integer", - "integer", - "integer", - "integer", - "integer", - "numeric", - "numeric", - "numeric", - "numeric", - "numeric", - "numeric", - "numeric", - "numeric", - "smallint", - "smallint", - "boolean", - "boolean", - "integer", - "integer", - "text", - "text", - "integer", - "text", - "text", - "text", - "text", - "text", - "text", - "boolean", - "timestamptz", + "bigint", // user_id + "bigint", // api_key_id + "bigint", // account_id + "text", // request_id + "text", // model + "text", // requested_model + "text", // upstream_model + "bigint", // group_id + "bigint", // subscription_id + "integer", // input_tokens + "integer", // output_tokens + "integer", // cache_creation_tokens + "integer", // cache_read_tokens + "integer", // cache_creation_5m_tokens + "integer", // cache_creation_1h_tokens + "numeric", // input_cost + "numeric", // output_cost + "numeric", // cache_creation_cost + "numeric", // cache_read_cost + "numeric", // total_cost + "numeric", // actual_cost + "numeric", // rate_multiplier + "numeric", // account_rate_multiplier + "smallint", // billing_type + "smallint", // request_type + "boolean", // stream + "boolean", // openai_ws_mode + "integer", // duration_ms + "integer", // first_token_ms + "text", // user_agent + "text", // ip_address + "integer", // image_count + "text", // image_size + "text", // media_type + "text", // service_tier + "text", // reasoning_effort + "text", // inbound_endpoint + "text", // upstream_endpoint + "boolean", // cache_ttl_overridden + "timestamptz", // created_at } +const rawUsageLogModelColumn = "model" + +// rawUsageLogModelColumn preserves the exact stored usage_logs.model semantics for direct filters. +// Historical rows may contain upstream/billing model values, while newer rows store requested_model. +// Requested/upstream/mapping analytics must use resolveModelDimensionExpression instead. + // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL var dateFormatWhitelist = map[string]string{ "hour": "YYYY-MM-DD HH24:00", @@ -87,6 +102,30 @@ func safeDateFormat(granularity string) string { return "YYYY-MM-DD" } +// appendRawUsageLogModelWhereCondition keeps direct model filters on the raw model column for backward +// compatibility with historical rows. Requested/upstream analytics must use +// resolveModelDimensionExpression instead. +func appendRawUsageLogModelWhereCondition(conditions []string, args []any, model string) ([]string, []any) { + if strings.TrimSpace(model) == "" { + return conditions, args + } + conditions = append(conditions, fmt.Sprintf("%s = $%d", rawUsageLogModelColumn, len(args)+1)) + args = append(args, model) + return conditions, args +} + +// appendRawUsageLogModelQueryFilter keeps direct model filters on the raw model column for backward +// compatibility with historical rows. Requested/upstream analytics must use +// resolveModelDimensionExpression instead. +func appendRawUsageLogModelQueryFilter(query string, args []any, model string) (string, []any) { + if strings.TrimSpace(model) == "" { + return query, args + } + query += fmt.Sprintf(" AND %s = $%d", rawUsageLogModelColumn, len(args)+1) + args = append(args, model) + return query, args +} + type usageLogRepository struct { client *dbent.Client sql sqlExecutor @@ -277,6 +316,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, account_id, request_id, model, + requested_model, + upstream_model, group_id, subscription_id, input_tokens, @@ -311,12 +352,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, cache_ttl_overridden, created_at ) VALUES ( - $1, $2, $3, $4, $5, - $6, $7, - $8, $9, $10, $11, - $12, $13, - $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38 + $1, $2, $3, $4, $5, $6, $7, + $8, $9, + $10, $11, $12, $13, + $14, $15, + $16, $17, $18, $19, $20, $21, + $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -707,6 +748,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + requested_model, + upstream_model, group_id, subscription_id, input_tokens, @@ -742,7 +785,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage created_at ) AS (VALUES `) - args := make([]any, 0, len(keys)*38) + args := make([]any, 0, len(keys)*39) argPos := 1 for idx, key := range keys { if idx > 0 { @@ -776,6 +819,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + requested_model, + upstream_model, group_id, subscription_id, input_tokens, @@ -816,6 +861,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + requested_model, + upstream_model, group_id, subscription_id, input_tokens, @@ -896,6 +943,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + requested_model, + upstream_model, group_id, subscription_id, input_tokens, @@ -931,7 +980,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( created_at ) AS (VALUES `) - args := make([]any, 0, len(preparedList)*38) + args := make([]any, 0, len(preparedList)*40) argPos := 1 for idx, prepared := range preparedList { if idx > 0 { @@ -962,6 +1011,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + requested_model, + upstream_model, group_id, subscription_id, input_tokens, @@ -1002,6 +1053,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + requested_model, + upstream_model, group_id, subscription_id, input_tokens, @@ -1050,6 +1103,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared account_id, request_id, model, + requested_model, + upstream_model, group_id, subscription_id, input_tokens, @@ -1084,12 +1139,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared cache_ttl_overridden, created_at ) VALUES ( - $1, $2, $3, $4, $5, - $6, $7, - $8, $9, $10, $11, - $12, $13, - $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38 + $1, $2, $3, $4, $5, $6, $7, + $8, $9, + $10, $11, $12, $13, + $14, $15, + $16, $17, $18, $19, $20, $21, + $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 ) ON CONFLICT (request_id, api_key_id) DO NOTHING `, prepared.args...) @@ -1121,6 +1176,11 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { reasoningEffort := nullString(log.ReasoningEffort) inboundEndpoint := nullString(log.InboundEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint) + requestedModel := strings.TrimSpace(log.RequestedModel) + if requestedModel == "" { + requestedModel = strings.TrimSpace(log.Model) + } + upstreamModel := nullString(log.UpstreamModel) var requestIDArg any if requestID != "" { @@ -1138,6 +1198,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { log.AccountID, requestIDArg, log.Model, + nullString(&requestedModel), + upstreamModel, groupID, subscriptionID, log.InputTokens, @@ -1691,7 +1753,7 @@ func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, acco // GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据 // 性能优化:数据库层聚合计算,避免应用层循环统计 func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) { - query := ` + query := fmt.Sprintf(` SELECT COUNT(*) as total_requests, COALESCE(SUM(input_tokens), 0) as total_input_tokens, @@ -1701,8 +1763,8 @@ func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelN COALESCE(SUM(actual_cost), 0) as total_actual_cost, COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms FROM usage_logs - WHERE model = $1 AND created_at >= $2 AND created_at < $3 - ` + WHERE %s = $1 AND created_at >= $2 AND created_at < $3 + `, rawUsageLogModelColumn) var stats usagestats.UsageStats if err := scanSingleRow( @@ -1826,7 +1888,7 @@ func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco } func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" + query := fmt.Sprintf("SELECT %s FROM usage_logs WHERE %s = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000", usageLogSelectColumns, rawUsageLogModelColumn) logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime) return logs, nil, err } @@ -2521,10 +2583,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1)) args = append(args, filters.GroupID) } - if filters.Model != "" { - conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) - args = append(args, filters.Model) - } + conditions, args = appendRawUsageLogModelWhereCondition(conditions, args, filters.Model) conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) if filters.BillingType != nil { conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) @@ -2757,10 +2816,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) args = append(args, groupID) } - if model != "" { - query += fmt.Sprintf(" AND model = $%d", len(args)+1) - args = append(args, model) - } + query, args = appendRawUsageLogModelQueryFilter(query, args, model) query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) @@ -2864,15 +2920,26 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st // GetModelStatsWithFilters returns model statistics with optional filters func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) { + return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, usagestats.ModelSourceRequested) +} + +// GetModelStatsWithFiltersBySource returns model statistics with optional filters and model source dimension. +// source: requested | upstream | mapping. +func (r *usageLogRepository) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) { + return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, source) +} + +func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" // 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。 if accountID > 0 && userID == 0 && apiKeyID == 0 { actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" } + modelExpr := resolveModelDimensionExpression(source) query := fmt.Sprintf(` SELECT - model, + %s as model, COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, @@ -2883,7 +2950,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start %s FROM usage_logs WHERE created_at >= $1 AND created_at < $2 - `, actualCostExpr) + `, modelExpr, actualCostExpr) args := []any{startTime, endTime} if userID > 0 { @@ -2907,7 +2974,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) args = append(args, int16(*billingType)) } - query += " GROUP BY model ORDER BY total_tokens DESC" + query += fmt.Sprintf(" GROUP BY %s ORDER BY total_tokens DESC", modelExpr) rows, err := r.sql.QueryContext(ctx, query, args...) if err != nil { @@ -3000,6 +3067,133 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start return results, nil } +// GetUserBreakdownStats returns per-user usage breakdown within a specific dimension. +func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) (results []usagestats.UserBreakdownItem, err error) { + query := ` + SELECT + COALESCE(ul.user_id, 0) as user_id, + COALESCE(u.email, '') as email, + COUNT(*) as requests, + COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens, + COALESCE(SUM(ul.total_cost), 0) as cost, + COALESCE(SUM(ul.actual_cost), 0) as actual_cost + FROM usage_logs ul + LEFT JOIN users u ON u.id = ul.user_id + WHERE ul.created_at >= $1 AND ul.created_at < $2 + ` + args := []any{startTime, endTime} + + if dim.GroupID > 0 { + query += fmt.Sprintf(" AND ul.group_id = $%d", len(args)+1) + args = append(args, dim.GroupID) + } + if dim.Model != "" { + query += fmt.Sprintf(" AND %s = $%d", resolveModelDimensionExpression(dim.ModelType), len(args)+1) + args = append(args, dim.Model) + } + if dim.Endpoint != "" { + col := resolveEndpointColumn(dim.EndpointType) + query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1) + args = append(args, dim.Endpoint) + } + + query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC" + if limit > 0 { + query += fmt.Sprintf(" LIMIT %d", limit) + } + + rows, err := r.sql.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() + + results = make([]usagestats.UserBreakdownItem, 0) + for rows.Next() { + var row usagestats.UserBreakdownItem + if err := rows.Scan( + &row.UserID, + &row.Email, + &row.Requests, + &row.TotalTokens, + &row.Cost, + &row.ActualCost, + ); err != nil { + return nil, err + } + results = append(results, row) + } + if err := rows.Err(); err != nil { + return nil, err + } + return results, nil +} + +// GetAllGroupUsageSummary returns today's and cumulative actual_cost for every group. +// todayStart is the start-of-day in the caller's timezone (UTC-based). +// TODO(perf): This query scans ALL usage_logs rows for total_cost aggregation. +// When usage_logs exceeds ~1M rows, consider adding a short-lived cache (30s) +// or a materialized view / pre-aggregation table for cumulative costs. +func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) { + query := ` + SELECT + g.id AS group_id, + COALESCE(SUM(ul.actual_cost), 0) AS total_cost, + COALESCE(SUM(CASE WHEN ul.created_at >= $1 THEN ul.actual_cost ELSE 0 END), 0) AS today_cost + FROM groups g + LEFT JOIN usage_logs ul ON ul.group_id = g.id + GROUP BY g.id + ` + + rows, err := r.sql.QueryContext(ctx, query, todayStart) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var results []usagestats.GroupUsageSummary + for rows.Next() { + var row usagestats.GroupUsageSummary + if err := rows.Scan(&row.GroupID, &row.TotalCost, &row.TodayCost); err != nil { + return nil, err + } + results = append(results, row) + } + if err := rows.Err(); err != nil { + return nil, err + } + return results, nil +} + +// resolveModelDimensionExpression maps model source type to a safe SQL expression. +func resolveModelDimensionExpression(modelType string) string { + requestedExpr := "COALESCE(NULLIF(TRIM(requested_model), ''), model)" + switch usagestats.NormalizeModelSource(modelType) { + case usagestats.ModelSourceUpstream: + return fmt.Sprintf("COALESCE(NULLIF(TRIM(upstream_model), ''), %s)", requestedExpr) + case usagestats.ModelSourceMapping: + return fmt.Sprintf("(%s || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), %s))", requestedExpr, requestedExpr) + default: + return requestedExpr + } +} + +// resolveEndpointColumn maps endpoint type to the corresponding DB column name. +func resolveEndpointColumn(endpointType string) string { + switch endpointType { + case "upstream": + return "ul.upstream_endpoint" + case "path": + return "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint" + default: + return "ul.inbound_endpoint" + } +} + // GetGlobalStats gets usage statistics for all users within a time range func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) { query := ` @@ -3056,10 +3250,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1)) args = append(args, filters.GroupID) } - if filters.Model != "" { - conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) - args = append(args, filters.Model) - } + conditions, args = appendRawUsageLogModelWhereCondition(conditions, args, filters.Model) conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) if filters.BillingType != nil { conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) @@ -3188,10 +3379,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) args = append(args, groupID) } - if model != "" { - query += fmt.Sprintf(" AND model = $%d", len(args)+1) - args = append(args, model) - } + query, args = appendRawUsageLogModelQueryFilter(query, args, model) query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) @@ -3262,10 +3450,7 @@ func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) args = append(args, groupID) } - if model != "" { - query += fmt.Sprintf(" AND model = $%d", len(args)+1) - args = append(args, model) - } + query, args = appendRawUsageLogModelQueryFilter(query, args, model) query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) @@ -3740,6 +3925,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e accountID int64 requestID sql.NullString model string + requestedModel sql.NullString + upstreamModel sql.NullString groupID sql.NullInt64 subscriptionID sql.NullInt64 inputTokens int @@ -3782,6 +3969,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &accountID, &requestID, &model, + &requestedModel, + &upstreamModel, &groupID, &subscriptionID, &inputTokens, @@ -3825,6 +4014,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e APIKeyID: apiKeyID, AccountID: accountID, Model: model, + RequestedModel: coalesceTrimmedString(requestedModel, model), InputTokens: inputTokens, OutputTokens: outputTokens, CacheCreationTokens: cacheCreationTokens, @@ -3894,6 +4084,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if upstreamEndpoint.Valid { log.UpstreamEndpoint = &upstreamEndpoint.String } + if upstreamModel.Valid { + log.UpstreamModel = &upstreamModel.String + } return log, nil } @@ -4028,6 +4221,13 @@ func nullString(v *string) sql.NullString { return sql.NullString{String: *v, Valid: true} } +func coalesceTrimmedString(v sql.NullString, fallback string) string { + if v.Valid && strings.TrimSpace(v.String) != "" { + return v.String + } + return fallback +} + func setToSlice(set map[int64]struct{}) []int64 { out := make([]int64, 0, len(set)) for id := range set { diff --git a/backend/internal/repository/usage_log_repo_breakdown_test.go b/backend/internal/repository/usage_log_repo_breakdown_test.go new file mode 100644 index 00000000..da62e8dd --- /dev/null +++ b/backend/internal/repository/usage_log_repo_breakdown_test.go @@ -0,0 +1,50 @@ +//go:build unit + +package repository + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/stretchr/testify/require" +) + +func TestResolveEndpointColumn(t *testing.T) { + tests := []struct { + endpointType string + want string + }{ + {"inbound", "ul.inbound_endpoint"}, + {"upstream", "ul.upstream_endpoint"}, + {"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"}, + {"", "ul.inbound_endpoint"}, // default + {"unknown", "ul.inbound_endpoint"}, // fallback + } + + for _, tc := range tests { + t.Run(tc.endpointType, func(t *testing.T) { + got := resolveEndpointColumn(tc.endpointType) + require.Equal(t, tc.want, got) + }) + } +} + +func TestResolveModelDimensionExpression(t *testing.T) { + tests := []struct { + modelType string + want string + }{ + {usagestats.ModelSourceRequested, "COALESCE(NULLIF(TRIM(requested_model), ''), model)"}, + {usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), COALESCE(NULLIF(TRIM(requested_model), ''), model))"}, + {usagestats.ModelSourceMapping, "(COALESCE(NULLIF(TRIM(requested_model), ''), model) || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), COALESCE(NULLIF(TRIM(requested_model), ''), model)))"}, + {"", "COALESCE(NULLIF(TRIM(requested_model), ''), model)"}, + {"invalid", "COALESCE(NULLIF(TRIM(requested_model), ''), model)"}, + } + + for _, tc := range tests { + t.Run(tc.modelType, func(t *testing.T) { + got := resolveModelDimensionExpression(tc.modelType) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index 27ae4571..ebc8929a 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -3,6 +3,7 @@ package repository import ( "context" "database/sql" + "database/sql/driver" "fmt" "reflect" "testing" @@ -21,20 +22,21 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { createdAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) log := &service.UsageLog{ - UserID: 1, - APIKeyID: 2, - AccountID: 3, - RequestID: "req-1", - Model: "gpt-5", - InputTokens: 10, - OutputTokens: 20, - TotalCost: 1, - ActualCost: 1, - BillingType: service.BillingTypeBalance, - RequestType: service.RequestTypeWSV2, - Stream: false, - OpenAIWSMode: false, - CreatedAt: createdAt, + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-1", + Model: "gpt-5", + RequestedModel: "gpt-5", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 1, + ActualCost: 1, + BillingType: service.BillingTypeBalance, + RequestType: service.RequestTypeWSV2, + Stream: false, + OpenAIWSMode: false, + CreatedAt: createdAt, } mock.ExpectQuery("INSERT INTO usage_logs"). @@ -44,6 +46,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { log.AccountID, log.RequestID, log.Model, + log.RequestedModel, + sqlmock.AnyArg(), // upstream_model sqlmock.AnyArg(), // group_id sqlmock.AnyArg(), // subscription_id log.InputTokens, @@ -98,13 +102,14 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC) serviceTier := "priority" log := &service.UsageLog{ - UserID: 1, - APIKeyID: 2, - AccountID: 3, - RequestID: "req-service-tier", - Model: "gpt-5.4", - ServiceTier: &serviceTier, - CreatedAt: createdAt, + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-service-tier", + Model: "gpt-5.4", + RequestedModel: "gpt-5.4", + ServiceTier: &serviceTier, + CreatedAt: createdAt, } mock.ExpectQuery("INSERT INTO usage_logs"). @@ -114,6 +119,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { log.AccountID, log.RequestID, log.Model, + log.RequestedModel, + sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), log.InputTokens, @@ -156,6 +163,75 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { require.NoError(t, mock.ExpectationsWereMet()) } +func TestBuildUsageLogBestEffortInsertQuery_IncludesRequestedModelColumn(t *testing.T) { + prepared := prepareUsageLogInsert(&service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-best-effort-query", + Model: "gpt-5", + RequestedModel: "gpt-5", + CreatedAt: time.Date(2025, 1, 3, 12, 0, 0, 0, time.UTC), + }) + + query, args := buildUsageLogBestEffortInsertQuery([]usageLogInsertPrepared{prepared}) + + require.Contains(t, query, "INSERT INTO usage_logs (") + require.Contains(t, query, "\n\t\t\tmodel,\n\t\t\trequested_model,\n\t\t\tupstream_model,") + require.Contains(t, query, "\n\t\t\trequest_id,\n\t\t\tmodel,\n\t\t\trequested_model,\n\t\t\tupstream_model,") + require.Len(t, args, len(prepared.args)) + require.Equal(t, prepared.args[5], args[5]) +} + +func TestExecUsageLogInsertNoResult_PersistsRequestedModel(t *testing.T) { + db, mock := newSQLMock(t) + prepared := prepareUsageLogInsert(&service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-best-effort-exec", + Model: "gpt-5", + RequestedModel: "gpt-5", + CreatedAt: time.Date(2025, 1, 4, 12, 0, 0, 0, time.UTC), + }) + + mock.ExpectExec("INSERT INTO usage_logs"). + WithArgs(anySliceToDriverValues(prepared.args)...). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := execUsageLogInsertNoResult(context.Background(), db, prepared) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPrepareUsageLogInsert_ArgCountMatchesTypes(t *testing.T) { + prepared := prepareUsageLogInsert(&service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-arg-count", + Model: "gpt-5", + RequestedModel: "gpt-5", + CreatedAt: time.Date(2025, 1, 5, 12, 0, 0, 0, time.UTC), + }) + + require.Len(t, prepared.args, len(usageLogInsertArgTypes)) +} + +func TestCoalesceTrimmedString(t *testing.T) { + require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{}, "fallback")) + require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{Valid: true, String: " "}, "fallback")) + require.Equal(t, "value", coalesceTrimmedString(sql.NullString{Valid: true, String: "value"}, "fallback")) +} + +func anySliceToDriverValues(values []any) []driver.Value { + out := make([]driver.Value, 0, len(values)) + for _, value := range values { + out = append(out, value) + } + return out +} + func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) { db, mock := newSQLMock(t) repo := &usageLogRepository{sql: db} @@ -352,7 +428,9 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(20), // api_key_id int64(30), // account_id sql.NullString{Valid: true, String: "req-1"}, - "gpt-5", // model + "gpt-5", // model + sql.NullString{Valid: true, String: "gpt-5"}, // requested_model + sql.NullString{}, // upstream_model sql.NullInt64{}, // group_id sql.NullInt64{}, // subscription_id 1, // input_tokens @@ -404,6 +482,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(31), sql.NullString{Valid: true, String: "req-2"}, "gpt-5", + sql.NullString{Valid: true, String: "gpt-5"}, + sql.NullString{}, sql.NullInt64{}, sql.NullInt64{}, 1, 2, 3, 4, 5, 6, @@ -445,6 +525,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(32), sql.NullString{Valid: true, String: "req-3"}, "gpt-5.4", + sql.NullString{Valid: true, String: "gpt-5.4"}, + sql.NullString{}, sql.NullInt64{}, sql.NullInt64{}, 1, 2, 3, 4, 5, 6, diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index b56aaaf9..575754e0 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -11,6 +11,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" + dbgroup "github.com/Wei-Shaw/sub2api/ent/group" dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -200,6 +201,12 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. ) } + if filters.GroupName != "" { + q = q.Where(dbuser.HasAllowedGroupsWith( + dbgroup.NameContainsFold(filters.GroupName), + )) + } + // If attribute filters are specified, we need to filter by user IDs first var allowedUserIDs []int64 if len(filters.Attributes) > 0 { @@ -453,6 +460,15 @@ func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group return int64(affected), nil } +// RemoveGroupFromUserAllowedGroups 移除单个用户的指定分组权限 +func (r *userRepository) RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error { + client := clientFromContext(ctx, r.client) + _, err := client.UserAllowedGroup.Delete(). + Where(userallowedgroup.UserIDEQ(userID), userallowedgroup.GroupIDEQ(groupID)). + Exec(ctx) + return err +} + func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) { m, err := r.client.User.Query(). Where( diff --git a/backend/internal/repository/user_subscription_repo.go b/backend/internal/repository/user_subscription_repo.go index 5a649846..e3f64a5f 100644 --- a/backend/internal/repository/user_subscription_repo.go +++ b/backend/internal/repository/user_subscription_repo.go @@ -5,6 +5,7 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" @@ -190,7 +191,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil } -func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { +func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { client := clientFromContext(ctx, r.client) q := client.UserSubscription.Query() if userID != nil { @@ -199,6 +200,9 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination if groupID != nil { q = q.Where(usersubscription.GroupIDEQ(*groupID)) } + if platform != "" { + q = q.Where(usersubscription.HasGroupWith(group.PlatformEQ(platform))) + } // Status filtering with real-time expiration check now := time.Now() diff --git a/backend/internal/repository/user_subscription_repo_integration_test.go b/backend/internal/repository/user_subscription_repo_integration_test.go index 60a5a378..a74860e3 100644 --- a/backend/internal/repository/user_subscription_repo_integration_test.go +++ b/backend/internal/repository/user_subscription_repo_integration_test.go @@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() { group := s.mustCreateGroup("g-list") s.mustCreateSubscription(user.ID, group.ID, nil) - subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "") + subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "", "") s.Require().NoError(err, "List") s.Require().Len(subs, 1) s.Require().Equal(int64(1), page.Total) @@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() { s.mustCreateSubscription(user1.ID, group.ID, nil) s.mustCreateSubscription(user2.ID, group.ID, nil) - subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "") + subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "", "") s.Require().NoError(err) s.Require().Len(subs, 1) s.Require().Equal(user1.ID, subs[0].UserID) @@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() { s.mustCreateSubscription(user.ID, g1.ID, nil) s.mustCreateSubscription(user.ID, g2.ID, nil) - subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "") + subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "", "") s.Require().NoError(err) s.Require().Len(subs, 1) s.Require().Equal(g1.ID, subs[0].GroupID) @@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() { c.SetExpiresAt(time.Now().Add(-24 * time.Hour)) }) - subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "") + subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "", "") s.Require().NoError(err) s.Require().Len(subs, 1) s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status) diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 138bf59e..49d47bf6 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -73,6 +73,7 @@ var ProviderSet = wire.NewSet( NewUserAttributeValueRepository, NewUserGroupRateRepository, NewErrorPassthroughRepository, + NewTLSFingerprintProfileRepository, // Cache implementations NewGatewayCache, @@ -80,6 +81,7 @@ var ProviderSet = wire.NewSet( NewAPIKeyCache, NewTempUnschedCache, NewTimeoutCounterCache, + NewInternal500CounterCache, ProvideConcurrencyCache, ProvideSessionLimitCache, NewRPMCache, @@ -96,6 +98,7 @@ var ProviderSet = wire.NewSet( NewTotpCache, NewRefreshTokenCache, NewErrorPassthroughCache, + NewTLSFingerprintProfileCache, // Encryptors NewAESEncryptor, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 6056d51f..7059cb76 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -214,6 +214,8 @@ func TestAPIContracts(t *testing.T) { "fallback_group_id": null, "fallback_group_id_on_invalid_request": null, "allow_messages_dispatch": false, + "require_oauth_only": false, + "require_privacy_set": false, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -537,9 +539,13 @@ func TestAPIContracts(t *testing.T) { "purchase_subscription_enabled": false, "purchase_subscription_url": "", "min_claude_code_version": "", + "max_claude_code_version": "", "allow_ungrouped_key_scheduling": false, "backend_mode_enabled": false, - "custom_menu_items": [] + "enable_fingerprint_unification": true, + "enable_metadata_passthrough": false, + "custom_menu_items": [], + "custom_endpoints": [] } }`, }, @@ -807,6 +813,10 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID return 0, errors.New("not implemented") } +func (r *stubUserRepo) RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error { + return errors.New("not implemented") +} + func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { return errors.New("not implemented") } @@ -924,8 +934,8 @@ func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error return false, errors.New("not implemented") } -func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { - return 0, errors.New("not implemented") +func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { + return 0, 0, errors.New("not implemented") } func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { @@ -984,7 +994,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination return nil, nil, errors.New("not implemented") } -func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { +func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } @@ -1289,7 +1299,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } -func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { +func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { @@ -1509,6 +1519,22 @@ func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int6 return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) UpdateGroupIDByUserAndGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (int64, error) { + var updated int64 + for id, key := range r.byID { + if key.UserID != userID || key.GroupID == nil || *key.GroupID != oldGroupID { + continue + } + clone := *key + gid := newGroupID + clone.GroupID = &gid + r.byID[id] = &clone + r.byKey[clone.Key] = &clone + updated++ + } + return updated, nil +} + func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, errors.New("not implemented") } @@ -1637,6 +1663,10 @@ func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTi return nil, errors.New("not implemented") } +func (r *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) { + return nil, errors.New("not implemented") +} + func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { return nil, errors.New("not implemented") } @@ -1782,6 +1812,9 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) { return nil, errors.New("not implemented") } +func (r *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) { + return nil, errors.New("not implemented") +} type stubSettingRepo struct { all map[string]string diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index 138663c4..aafe4a58 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -181,6 +181,10 @@ func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID panic("unexpected RemoveGroupFromAllowedGroups call") } +func (s *stubUserRepo) RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error { + panic("unexpected RemoveGroupFromUserAllowedGroups call") +} + func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { panic("unexpected AddGroupToAllowedGroups call") } diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 49db5f19..f8e50fcd 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -104,6 +104,9 @@ func (f fakeAPIKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) err func (f fakeAPIKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) { return &service.APIKeyRateLimitData{}, nil } +func (f fakeAPIKeyRepo) UpdateGroupIDByUserAndGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (int64, error) { + return 0, errors.New("not implemented") +} func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { return errors.New("not implemented") @@ -135,7 +138,7 @@ func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, user func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } -func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { +func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 22befa2a..4a4ab0f9 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -565,6 +565,10 @@ func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int6 return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) UpdateGroupIDByUserAndGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (int64, error) { + return 0, errors.New("not implemented") +} + func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, errors.New("not implemented") } @@ -646,7 +650,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in return nil, nil, errors.New("not implemented") } -func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { +func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 85bfa6a6..e04dae85 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -79,6 +79,9 @@ func RegisterAdminRoutes( // 错误透传规则管理 registerErrorPassthroughRoutes(admin, h) + // TLS 指纹模板管理 + registerTLSFingerprintProfileRoutes(admin, h) + // API Key 管理 registerAdminAPIKeyRoutes(admin, h) @@ -198,6 +201,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) { dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking) dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage) dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage) + dashboard.GET("/user-breakdown", h.Admin.Dashboard.GetUserBreakdown) dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation) } } @@ -214,6 +218,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys) users.GET("/:id/usage", h.Admin.User.GetUserUsage) users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory) + users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup) // User attribute values users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes) @@ -226,6 +231,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { { groups.GET("", h.Admin.Group.List) groups.GET("/all", h.Admin.Group.GetAll) + groups.GET("/usage-summary", h.Admin.Group.GetUsageSummary) + groups.GET("/capacity-summary", h.Admin.Group.GetCapacitySummary) groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder) groups.GET("/:id", h.Admin.Group.GetByID) groups.POST("", h.Admin.Group.Create) @@ -253,6 +260,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.POST("/:id/test", h.Admin.Account.Test) accounts.POST("/:id/recover-state", h.Admin.Account.RecoverState) accounts.POST("/:id/refresh", h.Admin.Account.Refresh) + accounts.POST("/:id/set-privacy", h.Admin.Account.SetPrivacy) accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier) accounts.GET("/:id/stats", h.Admin.Account.GetStats) accounts.POST("/:id/clear-error", h.Admin.Account.ClearError) @@ -399,6 +407,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey) adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey) adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey) + // 529过载冷却配置 + adminSettings.GET("/overload-cooldown", h.Admin.Setting.GetOverloadCooldownSettings) + adminSettings.PUT("/overload-cooldown", h.Admin.Setting.UpdateOverloadCooldownSettings) // 流超时处理配置 adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings) adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings) @@ -545,3 +556,14 @@ func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) rules.DELETE("/:id", h.Admin.ErrorPassthrough.Delete) } } + +func registerTLSFingerprintProfileRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + profiles := admin.Group("/tls-fingerprint-profiles") + { + profiles.GET("", h.Admin.TLSFingerprintProfile.List) + profiles.GET("/:id", h.Admin.TLSFingerprintProfile.GetByID) + profiles.POST("", h.Admin.TLSFingerprintProfile.Create) + profiles.PUT("/:id", h.Admin.TLSFingerprintProfile.Update) + profiles.DELETE("/:id", h.Admin.TLSFingerprintProfile.Delete) + } +} diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index fe820830..072cfdee 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -69,12 +69,30 @@ func RegisterGatewayRoutes( }) gateway.GET("/models", h.Gateway.Models) gateway.GET("/usage", h.Gateway.Usage) - // OpenAI Responses API - gateway.POST("/responses", h.OpenAIGateway.Responses) - gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses) + // OpenAI Responses API: auto-route based on group platform + gateway.POST("/responses", func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.Responses(c) + return + } + h.Gateway.Responses(c) + }) + gateway.POST("/responses/*subpath", func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.Responses(c) + return + } + h.Gateway.Responses(c) + }) gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket) - // OpenAI Chat Completions API - gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions) + // OpenAI Chat Completions API: auto-route based on group platform + gateway.POST("/chat/completions", func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.ChatCompletions(c) + return + } + h.Gateway.ChatCompletions(c) + }) } // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) @@ -92,12 +110,25 @@ func RegisterGatewayRoutes( gemini.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) } - // OpenAI Responses API(不带v1前缀的别名) - r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) - r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) + // OpenAI Responses API(不带v1前缀的别名)— auto-route based on group platform + responsesHandler := func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.Responses(c) + return + } + h.Gateway.Responses(c) + } + r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler) + r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler) r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket) - // OpenAI Chat Completions API(不带v1前缀的别名) - r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions) + // OpenAI Chat Completions API(不带v1前缀的别名)— auto-route based on group platform + r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.ChatCompletions(c) + return + } + h.Gateway.ChatCompletions(c) + }) // Antigravity 模型列表 r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index b6408f5f..512195e3 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -141,6 +141,21 @@ func (a *Account) IsOAuth() bool { return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken } +// IsPrivacySet 检查账号的 privacy 是否已成功设置。 +// OpenAI: privacy_mode == "training_off" +// Antigravity: privacy_mode == "privacy_set" +// 其他平台: 无 privacy 概念,始终返回 true +func (a *Account) IsPrivacySet() bool { + switch a.Platform { + case PlatformOpenAI: + return a.getExtraString("privacy_mode") == PrivacyModeTrainingOff + case PlatformAntigravity: + return a.getExtraString("privacy_mode") == AntigravityPrivacySet + default: + return true + } +} + func (a *Account) IsGemini() bool { return a.Platform == PlatformGemini } @@ -500,6 +515,45 @@ func ensureAntigravityDefaultPassthroughs(mapping map[string]string, models []st } } +func normalizeRequestedModelForLookup(platform, requestedModel string) string { + trimmed := strings.TrimSpace(requestedModel) + if trimmed == "" { + return "" + } + if platform != PlatformGemini && platform != PlatformAntigravity { + return trimmed + } + if trimmed == "gemini-3.1-pro-preview-customtools" { + return "gemini-3.1-pro-preview" + } + return trimmed +} + +func mappingSupportsRequestedModel(mapping map[string]string, requestedModel string) bool { + if requestedModel == "" { + return false + } + if _, exists := mapping[requestedModel]; exists { + return true + } + for pattern := range mapping { + if matchWildcard(pattern, requestedModel) { + return true + } + } + return false +} + +func resolveRequestedModelInMapping(mapping map[string]string, requestedModel string) (mappedModel string, matched bool) { + if requestedModel == "" { + return "", false + } + if mappedModel, exists := mapping[requestedModel]; exists { + return mappedModel, true + } + return matchWildcardMappingResult(mapping, requestedModel) +} + // IsModelSupported 检查模型是否在 model_mapping 中(支持通配符) // 如果未配置 mapping,返回 true(允许所有模型) func (a *Account) IsModelSupported(requestedModel string) bool { @@ -507,17 +561,11 @@ func (a *Account) IsModelSupported(requestedModel string) bool { if len(mapping) == 0 { return true // 无映射 = 允许所有 } - // 精确匹配 - if _, exists := mapping[requestedModel]; exists { + if mappingSupportsRequestedModel(mapping, requestedModel) { return true } - // 通配符匹配 - for pattern := range mapping { - if matchWildcard(pattern, requestedModel) { - return true - } - } - return false + normalized := normalizeRequestedModelForLookup(a.Platform, requestedModel) + return normalized != requestedModel && mappingSupportsRequestedModel(mapping, normalized) } // GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配) @@ -534,12 +582,16 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, if len(mapping) == 0 { return requestedModel, false } - // 精确匹配优先 - if mappedModel, exists := mapping[requestedModel]; exists { + if mappedModel, matched := resolveRequestedModelInMapping(mapping, requestedModel); matched { return mappedModel, true } - // 通配符匹配(最长优先) - return matchWildcardMappingResult(mapping, requestedModel) + normalized := normalizeRequestedModelForLookup(a.Platform, requestedModel) + if normalized != requestedModel { + if mappedModel, matched := resolveRequestedModelInMapping(mapping, normalized); matched { + return mappedModel, true + } + } + return requestedModel, false } func (a *Account) GetBaseURL() string { @@ -1165,6 +1217,31 @@ func (a *Account) IsTLSFingerprintEnabled() bool { return false } +// GetTLSFingerprintProfileID 获取账号绑定的 TLS 指纹模板 ID +// 返回 0 表示未绑定(使用内置默认 profile) +func (a *Account) GetTLSFingerprintProfileID() int64 { + if a.Extra == nil { + return 0 + } + v, ok := a.Extra["tls_fingerprint_profile_id"] + if !ok { + return 0 + } + switch id := v.(type) { + case float64: + return int64(id) + case int64: + return id + case int: + return int64(id) + case json.Number: + if i, err := id.Int64(); err == nil { + return i + } + } + return 0 +} + // GetUserMsgQueueMode 获取用户消息队列模式 // "serialize" = 串行队列, "throttle" = 软性限速, "" = 未设置(使用全局配置) func (a *Account) GetUserMsgQueueMode() string { @@ -1204,6 +1281,28 @@ func (a *Account) IsSessionIDMaskingEnabled() bool { return false } +// IsCustomBaseURLEnabled 检查是否启用自定义 base URL 中继转发 +// 仅适用于 Anthropic OAuth/SetupToken 类型账号 +func (a *Account) IsCustomBaseURLEnabled() bool { + if !a.IsAnthropicOAuthOrSetupToken() { + return false + } + if a.Extra == nil { + return false + } + if v, ok := a.Extra["custom_base_url_enabled"]; ok { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + +// GetCustomBaseURL 返回自定义中继服务的 base URL +func (a *Account) GetCustomBaseURL() string { + return a.GetExtraString("custom_base_url") +} + // IsCacheTTLOverrideEnabled 检查是否启用缓存 TTL 强制替换 // 仅适用于 Anthropic OAuth/SetupToken 类型账号 // 启用后将所有 cache creation tokens 归入指定的 TTL 类型(5m 或 1h) @@ -1543,6 +1642,24 @@ func isPeriodExpired(periodStart time.Time, dur time.Duration) bool { return time.Since(periodStart) >= dur } +// IsDailyQuotaPeriodExpired 检查日配额周期是否已过期(用于显示层判断是否需要将 used 归零) +func (a *Account) IsDailyQuotaPeriodExpired() bool { + start := a.getExtraTime("quota_daily_start") + if a.GetQuotaDailyResetMode() == "fixed" { + return a.isFixedDailyPeriodExpired(start) + } + return isPeriodExpired(start, 24*time.Hour) +} + +// IsWeeklyQuotaPeriodExpired 检查周配额周期是否已过期(用于显示层判断是否需要将 used 归零) +func (a *Account) IsWeeklyQuotaPeriodExpired() bool { + start := a.getExtraTime("quota_weekly_start") + if a.GetQuotaWeeklyResetMode() == "fixed" { + return a.isFixedWeeklyPeriodExpired(start) + } + return isPeriodExpired(start, 7*24*time.Hour) +} + // IsQuotaExceeded 检查 API Key 账号配额是否已超限(任一维度超限即返回 true) func (a *Account) IsQuotaExceeded() bool { // 总额度 @@ -1662,22 +1779,47 @@ func (a *Account) GetRPMStrategy() string { } // GetRPMStickyBuffer 获取 RPM 粘性缓冲数量 -// tiered 模式下的黄区大小,默认为 base_rpm 的 20%(至少 1) +// Cache-driven: buffer = concurrency + maxSessions(覆盖幽灵窗口 + 稳态会话需求) +// floor = baseRPM / 5(向后兼容 maxSessions=0 且 concurrency=0 场景) func (a *Account) GetRPMStickyBuffer() int { if a.Extra == nil { return 0 } + + // 手动 override 最高优先级 if v, ok := a.Extra["rpm_sticky_buffer"]; ok { val := parseExtraInt(v) if val > 0 { return val } } + base := a.GetBaseRPM() - buffer := base / 5 - if buffer < 1 && base > 0 { - buffer = 1 + if base <= 0 { + return 0 } + + // Cache-driven buffer = concurrency + maxSessions + conc := a.Concurrency + if conc < 0 { + conc = 0 + } + sess := a.GetMaxSessions() + if sess < 0 { + sess = 0 + } + + buffer := conc + sess + + // floor: 向后兼容 + floor := base / 5 + if floor < 1 { + floor = 1 + } + if buffer < floor { + buffer = floor + } + return buffer } diff --git a/backend/internal/service/account_credentials_persistence.go b/backend/internal/service/account_credentials_persistence.go new file mode 100644 index 00000000..916df536 --- /dev/null +++ b/backend/internal/service/account_credentials_persistence.go @@ -0,0 +1,30 @@ +package service + +import "context" + +type accountCredentialsUpdater interface { + UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error +} + +func persistAccountCredentials(ctx context.Context, repo AccountRepository, account *Account, credentials map[string]any) error { + if repo == nil || account == nil { + return nil + } + + account.Credentials = cloneCredentials(credentials) + if updater, ok := any(repo).(accountCredentialsUpdater); ok { + return updater.UpdateCredentials(ctx, account.ID, account.Credentials) + } + return repo.Update(ctx, account) +} + +func cloneCredentials(in map[string]any) map[string]any { + if in == nil { + return map[string]any{} + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} diff --git a/backend/internal/service/account_rpm_test.go b/backend/internal/service/account_rpm_test.go index 9d91f3e0..40298263 100644 --- a/backend/internal/service/account_rpm_test.go +++ b/backend/internal/service/account_rpm_test.go @@ -90,28 +90,47 @@ func TestCheckRPMSchedulability(t *testing.T) { func TestGetRPMStickyBuffer(t *testing.T) { tests := []struct { - name string - extra map[string]any - expected int + name string + concurrency int + extra map[string]any + expected int }{ - {"nil extra", nil, 0}, - {"no keys", map[string]any{}, 0}, - {"base_rpm=0", map[string]any{"base_rpm": 0}, 0}, - {"base_rpm=1 min buffer 1", map[string]any{"base_rpm": 1}, 1}, - {"base_rpm=4 min buffer 1", map[string]any{"base_rpm": 4}, 1}, - {"base_rpm=5 buffer 1", map[string]any{"base_rpm": 5}, 1}, - {"base_rpm=10 buffer 2", map[string]any{"base_rpm": 10}, 2}, - {"base_rpm=15 buffer 3", map[string]any{"base_rpm": 15}, 3}, - {"base_rpm=100 buffer 20", map[string]any{"base_rpm": 100}, 20}, - {"custom buffer=5", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 5}, - {"custom buffer=0 fallback to default", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0}, 2}, - {"custom buffer negative fallback", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1}, 2}, - {"custom buffer with float", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7}, - {"json.Number base_rpm", map[string]any{"base_rpm": json.Number("10")}, 2}, + // 基础退化 + {"nil extra", 0, nil, 0}, + {"no keys", 0, map[string]any{}, 0}, + {"base_rpm=0", 0, map[string]any{"base_rpm": 0}, 0}, + + // 新公式: concurrency + maxSessions, floor = base/5 + {"conc=3 sess=10 → 13", 3, map[string]any{"base_rpm": 15, "max_sessions": 10}, 13}, + {"conc=2 sess=5 → 7", 2, map[string]any{"base_rpm": 10, "max_sessions": 5}, 7}, + {"conc=3 sess=15 → 18", 3, map[string]any{"base_rpm": 30, "max_sessions": 15}, 18}, + + // floor 生效 (conc+sess < base/5) + {"conc=0 sess=0 base=15 → floor 3", 0, map[string]any{"base_rpm": 15}, 3}, + {"conc=0 sess=0 base=10 → floor 2", 0, map[string]any{"base_rpm": 10}, 2}, + {"conc=0 sess=0 base=1 → floor 1", 0, map[string]any{"base_rpm": 1}, 1}, + {"conc=0 sess=0 base=4 → floor 1", 0, map[string]any{"base_rpm": 4}, 1}, + {"conc=1 sess=0 base=15 → floor 3", 1, map[string]any{"base_rpm": 15}, 3}, + + // 手动 override + {"custom buffer=5", 3, map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5, "max_sessions": 10}, 5}, + {"custom buffer=0 fallback", 3, map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0, "max_sessions": 10}, 13}, + {"custom buffer negative fallback", 3, map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1, "max_sessions": 10}, 13}, + {"custom buffer with float", 3, map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7}, + + // 负值 clamp + {"negative concurrency clamped", -5, map[string]any{"base_rpm": 15, "max_sessions": 10}, 10}, + {"negative maxSessions clamped", 3, map[string]any{"base_rpm": 15, "max_sessions": -5}, 3}, + + // 高并发低会话 + {"conc=10 sess=5 → 15", 10, map[string]any{"base_rpm": 10, "max_sessions": 5}, 15}, + + // json.Number + {"json.Number base_rpm", 3, map[string]any{"base_rpm": json.Number("10"), "max_sessions": json.Number("5")}, 8}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - a := &Account{Extra: tt.extra} + a := &Account{Concurrency: tt.concurrency, Extra: tt.extra} if got := a.GetRPMStickyBuffer(); got != tt.expected { t.Errorf("GetRPMStickyBuffer() = %d, want %d", got, tt.expected) } diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index a06d8048..328790a8 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -14,6 +14,9 @@ var ( ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil") ) +const AccountListGroupUngrouped int64 = -1 +const AccountPrivacyModeUnsetFilter = "__unset__" + type AccountRepository interface { Create(ctx context.Context, account *Account) error GetByID(ctx context.Context, id int64) (*Account, error) @@ -35,7 +38,7 @@ type AccountRepository interface { Delete(ctx context.Context, id int64) error List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) - ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) ListActive(ctx context.Context) ([]Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error) @@ -171,6 +174,19 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( return nil, fmt.Errorf("create account: %w", err) } + // require_oauth_only 检查:apikey 类型账号不可加入限制分组 + if account.Type == AccountTypeAPIKey && len(req.GroupIDs) > 0 { + for _, gid := range req.GroupIDs { + g, err := s.groupRepo.GetByID(ctx, gid) + if err != nil { + return nil, err + } + if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) { + return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name) + } + } + } + // 绑定分组 if len(req.GroupIDs) > 0 { if err := s.accountRepo.BindGroups(ctx, account.ID, req.GroupIDs); err != nil { @@ -274,6 +290,19 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount return nil, fmt.Errorf("update account: %w", err) } + // require_oauth_only 检查 + if account.Type == AccountTypeAPIKey && req.GroupIDs != nil { + for _, gid := range *req.GroupIDs { + g, err := s.groupRepo.GetByID(ctx, gid) + if err != nil { + return nil, err + } + if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) { + return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name) + } + } + } + // 绑定分组 if req.GroupIDs != nil { if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil { diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index c96b436f..81169a02 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -79,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination panic("unexpected List call") } -func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { panic("unexpected ListWithFilters call") } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 482d22b1..fec98e12 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -23,6 +23,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/Wei-Shaw/sub2api/internal/util/soraerror" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" @@ -69,6 +70,7 @@ type AccountTestService struct { antigravityGatewayService *AntigravityGatewayService httpUpstream HTTPUpstream cfg *config.Config + tlsFPProfileService *TLSFingerprintProfileService soraTestGuardMu sync.Mutex soraTestLastRun map[int64]time.Time soraTestCooldown time.Duration @@ -83,6 +85,7 @@ func NewAccountTestService( antigravityGatewayService *AntigravityGatewayService, httpUpstream HTTPUpstream, cfg *config.Config, + tlsFPProfileService *TLSFingerprintProfileService, ) *AccountTestService { return &AccountTestService{ accountRepo: accountRepo, @@ -90,6 +93,7 @@ func NewAccountTestService( antigravityGatewayService: antigravityGatewayService, httpUpstream: httpUpstream, cfg: cfg, + tlsFPProfileService: tlsFPProfileService, soraTestLastRun: make(map[int64]time.Time), soraTestCooldown: defaultSoraTestCooldown, } @@ -113,15 +117,18 @@ func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) return normalized, nil } -// generateSessionString generates a Claude Code style session string +// generateSessionString generates a Claude Code style session string. +// The output format is determined by the UA version in claude.DefaultHeaders, +// ensuring consistency between the user_id format and the UA sent to upstream. func generateSessionString() (string, error) { - bytes := make([]byte, 32) - if _, err := rand.Read(bytes); err != nil { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { return "", err } - hex64 := hex.EncodeToString(bytes) + hex64 := hex.EncodeToString(b) sessionUUID := uuid.New().String() - return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil + uaVersion := ExtractCLIVersion(claude.DefaultHeaders["User-Agent"]) + return FormatMetadataUserID(hex64, "", sessionUUID, uaVersion), nil } // createTestPayload creates a Claude Code style test request payload @@ -297,7 +304,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account proxyURL = account.Proxy.URL() } - resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } @@ -305,7 +312,14 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) + errMsg := fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)) + + // 403 表示账号被上游封禁,标记为 error 状态 + if resp.StatusCode == http.StatusForbidden { + _ = s.accountRepo.SetError(ctx, account.ID, errMsg) + } + + return s.sendErrorAndEnd(c, errMsg) } // Process SSE stream @@ -380,7 +394,7 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co proxyURL = account.Proxy.URL() } - resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, false) + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, nil) if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } @@ -510,7 +524,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account proxyURL = account.Proxy.URL() } - resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } @@ -600,7 +614,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account proxyURL = account.Proxy.URL() } - resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } @@ -871,9 +885,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * if account.ProxyID != nil && account.Proxy != nil { proxyURL = account.Proxy.URL() } - enableSoraTLSFingerprint := s.shouldEnableSoraTLSFingerprint() + soraTLSProfile := s.resolveSoraTLSProfile() - resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint) + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, soraTLSProfile) if err != nil { recorder.addStep("me", "failed", 0, "network_error", err.Error()) s.emitSoraProbeSummary(c, recorder) @@ -938,7 +952,7 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * subReq.Header.Set("Origin", "https://sora.chatgpt.com") subReq.Header.Set("Referer", "https://sora.chatgpt.com/") - subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint) + subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, soraTLSProfile) if subErr != nil { recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error()) s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())}) @@ -967,7 +981,7 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * } // 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。 - s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint, recorder) + s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, soraTLSProfile, recorder) s.emitSoraProbeSummary(c, recorder) s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) @@ -980,7 +994,7 @@ func (s *AccountTestService) testSora2Capabilities( account *Account, authToken string, proxyURL string, - enableTLSFingerprint bool, + tlsProfile *tlsfingerprint.Profile, recorder *soraProbeRecorder, ) { inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint( @@ -989,7 +1003,7 @@ func (s *AccountTestService) testSora2Capabilities( authToken, soraInviteMineURL, proxyURL, - enableTLSFingerprint, + tlsProfile, ) if err != nil { if recorder != nil { @@ -1006,7 +1020,7 @@ func (s *AccountTestService) testSora2Capabilities( authToken, soraBootstrapURL, proxyURL, - enableTLSFingerprint, + tlsProfile, ) if bootstrapErr == nil && bootstrapStatus == http.StatusOK { if recorder != nil { @@ -1019,7 +1033,7 @@ func (s *AccountTestService) testSora2Capabilities( authToken, soraInviteMineURL, proxyURL, - enableTLSFingerprint, + tlsProfile, ) if err != nil { if recorder != nil { @@ -1071,7 +1085,7 @@ func (s *AccountTestService) testSora2Capabilities( authToken, soraRemainingURL, proxyURL, - enableTLSFingerprint, + tlsProfile, ) if remainingErr != nil { if recorder != nil { @@ -1112,7 +1126,7 @@ func (s *AccountTestService) fetchSoraTestEndpoint( authToken string, url string, proxyURL string, - enableTLSFingerprint bool, + tlsProfile *tlsfingerprint.Profile, ) (int, http.Header, []byte, error) { req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { @@ -1125,7 +1139,7 @@ func (s *AccountTestService) fetchSoraTestEndpoint( req.Header.Set("Origin", "https://sora.chatgpt.com") req.Header.Set("Referer", "https://sora.chatgpt.com/") - resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableTLSFingerprint) + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile) if err != nil { return 0, nil, nil, err } @@ -1214,11 +1228,12 @@ func parseSoraRemainingSummary(body []byte) string { return strings.Join(parts, " | ") } -func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool { - if s == nil || s.cfg == nil { - return true +func (s *AccountTestService) resolveSoraTLSProfile() *tlsfingerprint.Profile { + if s == nil || s.cfg == nil || !s.cfg.Sora.Client.DisableTLSFingerprint { + // Sora TLS fingerprint enabled — use built-in default profile + return &tlsfingerprint.Profile{Name: "Built-in Default (Sora)"} } - return !s.cfg.Sora.Client.DisableTLSFingerprint + return nil // disabled } func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go index 3dfac786..52f832a9 100644 --- a/backend/internal/service/account_test_service_sora_test.go +++ b/backend/internal/service/account_test_service_sora_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -24,9 +25,9 @@ func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*htt return nil, fmt.Errorf("unexpected Do call") } -func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) { +func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) { u.requests = append(u.requests, req) - u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint) + u.tlsFlags = append(u.tlsFlags, profile != nil) if len(u.responses) == 0 { return nil, fmt.Errorf("no mocked response") } diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index f117abfd..0e5741d8 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -17,6 +17,7 @@ import ( openaipkg "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "golang.org/x/sync/errgroup" "golang.org/x/sync/singleflight" @@ -48,6 +49,8 @@ type UsageLogRepository interface { GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) + GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) + GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) @@ -175,6 +178,7 @@ type AICredit struct { // UsageInfo 账号使用量信息 type UsageInfo struct { + Source string `json:"source,omitempty"` // "passive" or "active" UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间 FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口 SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口 @@ -238,11 +242,11 @@ type ClaudeUsageResponse struct { // ClaudeUsageFetchOptions 包含获取 Claude 用量数据所需的所有选项 type ClaudeUsageFetchOptions struct { - AccessToken string // OAuth access token - ProxyURL string // 代理 URL(可选) - AccountID int64 // 账号 ID(用于 TLS 指纹选择) - EnableTLSFingerprint bool // 是否启用 TLS 指纹伪装 - Fingerprint *Fingerprint // 缓存的指纹信息(User-Agent 等) + AccessToken string // OAuth access token + ProxyURL string // 代理 URL(可选) + AccountID int64 // 账号 ID(用于连接池隔离) + TLSProfile *tlsfingerprint.Profile // TLS 指纹 Profile(nil 表示不启用) + Fingerprint *Fingerprint // 缓存的指纹信息(User-Agent 等) } // ClaudeUsageFetcher fetches usage data from Anthropic OAuth API @@ -261,6 +265,7 @@ type AccountUsageService struct { antigravityQuotaFetcher *AntigravityQuotaFetcher cache *UsageCache identityCache IdentityCache + tlsFPProfileService *TLSFingerprintProfileService } // NewAccountUsageService 创建AccountUsageService实例 @@ -272,6 +277,7 @@ func NewAccountUsageService( antigravityQuotaFetcher *AntigravityQuotaFetcher, cache *UsageCache, identityCache IdentityCache, + tlsFPProfileService *TLSFingerprintProfileService, ) *AccountUsageService { return &AccountUsageService{ accountRepo: accountRepo, @@ -281,6 +287,7 @@ func NewAccountUsageService( antigravityQuotaFetcher: antigravityQuotaFetcher, cache: cache, identityCache: identityCache, + tlsFPProfileService: tlsFPProfileService, } } @@ -391,6 +398,9 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U // 4. 添加窗口统计(有独立缓存,1 分钟) s.addWindowStats(ctx, account, usage) + // 5. 将主动查询结果同步到被动缓存,下次 passive 加载即为最新值 + s.syncActiveToPassive(ctx, account.ID, usage) + s.tryClearRecoverableAccountError(ctx, account) return usage, nil } @@ -407,6 +417,81 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U return nil, fmt.Errorf("account type %s does not support usage query", account.Type) } +// GetPassiveUsage 从 Account.Extra 中的被动采样数据构建 UsageInfo,不调用外部 API。 +// 仅适用于 Anthropic OAuth / SetupToken 账号。 +func (s *AccountUsageService) GetPassiveUsage(ctx context.Context, accountID int64) (*UsageInfo, error) { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err != nil { + return nil, fmt.Errorf("get account failed: %w", err) + } + + if !account.IsAnthropicOAuthOrSetupToken() { + return nil, fmt.Errorf("passive usage only supported for Anthropic OAuth/SetupToken accounts") + } + + // 复用 estimateSetupTokenUsage 构建 5h 窗口(OAuth 和 SetupToken 逻辑一致) + info := s.estimateSetupTokenUsage(account) + info.Source = "passive" + + // 设置采样时间 + if raw, ok := account.Extra["passive_usage_sampled_at"]; ok { + if str, ok := raw.(string); ok { + if t, err := time.Parse(time.RFC3339, str); err == nil { + info.UpdatedAt = &t + } + } + } + + // 构建 7d 窗口(从被动采样数据) + util7d := parseExtraFloat64(account.Extra["passive_usage_7d_utilization"]) + reset7dRaw := parseExtraFloat64(account.Extra["passive_usage_7d_reset"]) + if util7d > 0 || reset7dRaw > 0 { + var resetAt *time.Time + var remaining int + if reset7dRaw > 0 { + t := time.Unix(int64(reset7dRaw), 0) + resetAt = &t + remaining = int(time.Until(t).Seconds()) + if remaining < 0 { + remaining = 0 + } + } + info.SevenDay = &UsageProgress{ + Utilization: util7d * 100, + ResetsAt: resetAt, + RemainingSeconds: remaining, + } + } + + // 添加窗口统计 + s.addWindowStats(ctx, account, info) + + return info, nil +} + +// syncActiveToPassive 将主动查询的最新数据回写到 Extra 被动缓存, +// 这样下次被动加载时能看到最新值。 +func (s *AccountUsageService) syncActiveToPassive(ctx context.Context, accountID int64, usage *UsageInfo) { + extraUpdates := make(map[string]any, 4) + + if usage.FiveHour != nil { + extraUpdates["session_window_utilization"] = usage.FiveHour.Utilization / 100 + } + if usage.SevenDay != nil { + extraUpdates["passive_usage_7d_utilization"] = usage.SevenDay.Utilization / 100 + if usage.SevenDay.ResetsAt != nil { + extraUpdates["passive_usage_7d_reset"] = usage.SevenDay.ResetsAt.Unix() + } + } + + if len(extraUpdates) > 0 { + extraUpdates["passive_usage_sampled_at"] = time.Now().UTC().Format(time.RFC3339) + if err := s.accountRepo.UpdateExtra(ctx, accountID, extraUpdates); err != nil { + slog.Warn("sync_active_to_passive_failed", "account_id", accountID, "error", err) + } + } +} + func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Account) (*UsageInfo, error) { now := time.Now() usage := &UsageInfo{UpdatedAt: &now} @@ -446,23 +531,17 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou } if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-5*time.Hour)); err == nil { - windowStats := windowStatsFromAccountStats(stats) - if hasMeaningfulWindowStats(windowStats) { - if usage.FiveHour == nil { - usage.FiveHour = &UsageProgress{Utilization: 0} - } - usage.FiveHour.WindowStats = windowStats + if usage.FiveHour == nil { + usage.FiveHour = &UsageProgress{Utilization: 0} } + usage.FiveHour.WindowStats = windowStatsFromAccountStats(stats) } if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-7*24*time.Hour)); err == nil { - windowStats := windowStatsFromAccountStats(stats) - if hasMeaningfulWindowStats(windowStats) { - if usage.SevenDay == nil { - usage.SevenDay = &UsageProgress{Utilization: 0} - } - usage.SevenDay.WindowStats = windowStats + if usage.SevenDay == nil { + usage.SevenDay = &UsageProgress{Utilization: 0} } + usage.SevenDay.WindowStats = windowStatsFromAccountStats(stats) } return usage, nil @@ -992,13 +1071,6 @@ func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats { } } -func hasMeaningfulWindowStats(stats *WindowStats) bool { - if stats == nil { - return false - } - return stats.Requests > 0 || stats.Tokens > 0 || stats.Cost > 0 || stats.StandardCost > 0 || stats.UserCost > 0 -} - func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now time.Time) *UsageProgress { if len(extra) == 0 { return nil @@ -1055,6 +1127,11 @@ func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now t } } + // 窗口已过期(resetAt 在 now 之前)→ 额度已重置,归零 + if progress.ResetsAt != nil && !now.Before(*progress.ResetsAt) { + progress.Utilization = 0 + } + return progress } @@ -1082,10 +1159,10 @@ func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *A // 构建完整的选项 opts := &ClaudeUsageFetchOptions{ - AccessToken: accessToken, - ProxyURL: proxyURL, - AccountID: account.ID, - EnableTLSFingerprint: account.IsTLSFingerprintEnabled(), + AccessToken: accessToken, + ProxyURL: proxyURL, + AccountID: account.ID, + TLSProfile: s.tlsFPProfileService.ResolveTLSProfile(account), } // 尝试获取缓存的 Fingerprint(包含 User-Agent 等信息) diff --git a/backend/internal/service/account_usage_service_test.go b/backend/internal/service/account_usage_service_test.go index a063fe26..fe255225 100644 --- a/backend/internal/service/account_usage_service_test.go +++ b/backend/internal/service/account_usage_service_test.go @@ -148,3 +148,54 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes t.Fatal("waiting for codex probe rate limit persistence timed out") } } + +func TestBuildCodexUsageProgressFromExtra_ZerosExpiredWindow(t *testing.T) { + t.Parallel() + now := time.Date(2026, 3, 16, 12, 0, 0, 0, time.UTC) + + t.Run("expired 5h window zeroes utilization", func(t *testing.T) { + extra := map[string]any{ + "codex_5h_used_percent": 42.0, + "codex_5h_reset_at": "2026-03-16T10:00:00Z", // 2h ago + } + progress := buildCodexUsageProgressFromExtra(extra, "5h", now) + if progress == nil { + t.Fatal("expected non-nil progress") + } + if progress.Utilization != 0 { + t.Fatalf("expected Utilization=0 for expired window, got %v", progress.Utilization) + } + if progress.RemainingSeconds != 0 { + t.Fatalf("expected RemainingSeconds=0, got %v", progress.RemainingSeconds) + } + }) + + t.Run("active 5h window keeps utilization", func(t *testing.T) { + resetAt := now.Add(2 * time.Hour).Format(time.RFC3339) + extra := map[string]any{ + "codex_5h_used_percent": 42.0, + "codex_5h_reset_at": resetAt, + } + progress := buildCodexUsageProgressFromExtra(extra, "5h", now) + if progress == nil { + t.Fatal("expected non-nil progress") + } + if progress.Utilization != 42.0 { + t.Fatalf("expected Utilization=42, got %v", progress.Utilization) + } + }) + + t.Run("expired 7d window zeroes utilization", func(t *testing.T) { + extra := map[string]any{ + "codex_7d_used_percent": 88.0, + "codex_7d_reset_at": "2026-03-15T00:00:00Z", // yesterday + } + progress := buildCodexUsageProgressFromExtra(extra, "7d", now) + if progress == nil { + t.Fatal("expected non-nil progress") + } + if progress.Utilization != 0 { + t.Fatalf("expected Utilization=0 for expired 7d window, got %v", progress.Utilization) + } + }) +} diff --git a/backend/internal/service/account_wildcard_test.go b/backend/internal/service/account_wildcard_test.go index 0d7ffffa..d903b940 100644 --- a/backend/internal/service/account_wildcard_test.go +++ b/backend/internal/service/account_wildcard_test.go @@ -133,6 +133,7 @@ func TestMatchWildcardMappingResult(t *testing.T) { func TestAccountIsModelSupported(t *testing.T) { tests := []struct { name string + platform string credentials map[string]any requestedModel string expected bool @@ -184,6 +185,17 @@ func TestAccountIsModelSupported(t *testing.T) { requestedModel: "claude-opus-4-5-thinking", expected: true, }, + { + name: "gemini customtools alias matches normalized mapping", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: true, + }, { name: "wildcard match not supported", credentials: map[string]any{ @@ -199,6 +211,7 @@ func TestAccountIsModelSupported(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { account := &Account{ + Platform: tt.platform, Credentials: tt.credentials, } result := account.IsModelSupported(tt.requestedModel) @@ -212,6 +225,7 @@ func TestAccountIsModelSupported(t *testing.T) { func TestAccountGetMappedModel(t *testing.T) { tests := []struct { name string + platform string credentials map[string]any requestedModel string expected string @@ -223,6 +237,13 @@ func TestAccountGetMappedModel(t *testing.T) { requestedModel: "claude-sonnet-4-5", expected: "claude-sonnet-4-5", }, + { + name: "no mapping preserves gemini customtools model", + platform: PlatformGemini, + credentials: nil, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: "gemini-3.1-pro-preview-customtools", + }, // 精确匹配 { @@ -250,6 +271,29 @@ func TestAccountGetMappedModel(t *testing.T) { }, // 无匹配返回原始模型 + { + name: "gemini customtools alias resolves through normalized mapping", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: "gemini-3.1-pro-preview", + }, + { + name: "gemini customtools exact mapping wins over normalized fallback", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + "gemini-3.1-pro-preview-customtools": "gemini-3.1-pro-preview-customtools", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: "gemini-3.1-pro-preview-customtools", + }, { name: "no match returns original", credentials: map[string]any{ @@ -265,6 +309,7 @@ func TestAccountGetMappedModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { account := &Account{ + Platform: tt.platform, Credentials: tt.credentials, } result := account.GetMappedModel(tt.requestedModel) @@ -278,6 +323,7 @@ func TestAccountGetMappedModel(t *testing.T) { func TestAccountResolveMappedModel(t *testing.T) { tests := []struct { name string + platform string credentials map[string]any requestedModel string expectedModel string @@ -312,6 +358,31 @@ func TestAccountResolveMappedModel(t *testing.T) { expectedModel: "gpt-5.4", expectedMatch: true, }, + { + name: "gemini customtools alias reports normalized match", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expectedModel: "gemini-3.1-pro-preview", + expectedMatch: true, + }, + { + name: "gemini customtools exact mapping reports exact match", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + "gemini-3.1-pro-preview-customtools": "gemini-3.1-pro-preview-customtools", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expectedModel: "gemini-3.1-pro-preview-customtools", + expectedMatch: true, + }, { name: "missing mapping reports unmatched", credentials: map[string]any{ @@ -328,6 +399,7 @@ func TestAccountResolveMappedModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { account := &Account{ + Platform: tt.platform, Credentials: tt.credentials, } mappedModel, matched := account.ResolveMappedModel(tt.requestedModel) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index ea76e171..0620d7ca 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/http" "strings" "time" @@ -50,8 +51,11 @@ type AdminService interface { // API Key management (admin) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) + // ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限 + ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) + // Account management - ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) + ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) GetAccount(ctx context.Context, id int64) (*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) @@ -62,6 +66,12 @@ type AdminService interface { SetAccountError(ctx context.Context, id int64, errorMsg string) error // EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号 privacy_mode,未设置则尝试关闭训练数据共享并持久化。 EnsureOpenAIPrivacy(ctx context.Context, account *Account) string + // EnsureAntigravityPrivacy 检查 Antigravity OAuth 账号 privacy_mode,未设置则调用 setUserSettings 并持久化。 + EnsureAntigravityPrivacy(ctx context.Context, account *Account) string + // ForceOpenAIPrivacy 强制重新设置 OpenAI OAuth 账号隐私,无论当前状态。 + ForceOpenAIPrivacy(ctx context.Context, account *Account) string + // ForceAntigravityPrivacy 强制重新设置 Antigravity OAuth 账号隐私,无论当前状态。 + ForceAntigravityPrivacy(ctx context.Context, account *Account) string SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error @@ -153,6 +163,8 @@ type CreateGroupInput struct { // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch bool DefaultMappedModel string + RequireOAuthOnly bool + RequirePrivacySet bool // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -192,6 +204,8 @@ type UpdateGroupInput struct { // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch *bool DefaultMappedModel *string + RequireOAuthOnly *bool + RequirePrivacySet *bool // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -270,6 +284,11 @@ type AdminUpdateAPIKeyGroupIDResult struct { GrantedGroupName string // the group name that was auto-granted } +// ReplaceUserGroupResult 分组替换操作的结果 +type ReplaceUserGroupResult struct { + MigratedKeys int64 // 迁移的 Key 数量 +} + // BulkUpdateAccountsResult is the aggregated response for bulk updates. type BulkUpdateAccountsResult struct { Success int `json:"success"` @@ -927,12 +946,35 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn SupportedModelScopes: input.SupportedModelScopes, SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, AllowMessagesDispatch: input.AllowMessagesDispatch, + RequireOAuthOnly: input.RequireOAuthOnly, + RequirePrivacySet: input.RequirePrivacySet, DefaultMappedModel: input.DefaultMappedModel, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err } + // require_oauth_only: 过滤掉 apikey 类型账号 + if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 { + accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy) + if err != nil { + return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err) + } + oauthIDs := make(map[int64]struct{}, len(accounts)) + for _, acc := range accounts { + if acc.Type != AccountTypeAPIKey { + oauthIDs[acc.ID] = struct{}{} + } + } + var filtered []int64 + for _, aid := range accountIDsToCopy { + if _, ok := oauthIDs[aid]; ok { + filtered = append(filtered, aid) + } + } + accountIDsToCopy = filtered + } + // 如果有需要复制的账号,绑定到新分组 if len(accountIDsToCopy) > 0 { if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil { @@ -1140,6 +1182,12 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.AllowMessagesDispatch != nil { group.AllowMessagesDispatch = *input.AllowMessagesDispatch } + if input.RequireOAuthOnly != nil { + group.RequireOAuthOnly = *input.RequireOAuthOnly + } + if input.RequirePrivacySet != nil { + group.RequirePrivacySet = *input.RequirePrivacySet + } if input.DefaultMappedModel != nil { group.DefaultMappedModel = *input.DefaultMappedModel } @@ -1187,6 +1235,27 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd return nil, fmt.Errorf("failed to clear existing account bindings: %w", err) } + // require_oauth_only: 过滤掉 apikey 类型账号 + if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 { + accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy) + if err != nil { + return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err) + } + oauthIDs := make(map[int64]struct{}, len(accounts)) + for _, acc := range accounts { + if acc.Type != AccountTypeAPIKey { + oauthIDs[acc.ID] = struct{}{} + } + } + var filtered []int64 + for _, aid := range accountIDsToCopy { + if _, ok := oauthIDs[aid]; ok { + filtered = append(filtered, aid) + } + } + accountIDsToCopy = filtered + } + // 再绑定源分组的账号 if len(accountIDsToCopy) > 0 { if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil { @@ -1377,10 +1446,75 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i return result, nil } +// ReplaceUserGroup 替换用户的专属分组 +func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) { + if oldGroupID == newGroupID { + return nil, infraerrors.BadRequest("SAME_GROUP", "old and new group must be different") + } + + // 验证新分组存在且为活跃的专属标准分组 + newGroup, err := s.groupRepo.GetByID(ctx, newGroupID) + if err != nil { + return nil, err + } + if newGroup.Status != StatusActive { + return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active") + } + if !newGroup.IsExclusive { + return nil, infraerrors.BadRequest("GROUP_NOT_EXCLUSIVE", "target group is not exclusive") + } + if newGroup.IsSubscriptionType() { + return nil, infraerrors.BadRequest("GROUP_IS_SUBSCRIPTION", "subscription groups are not supported for replacement") + } + + // 事务保证原子性 + if s.entClient == nil { + return nil, fmt.Errorf("entClient is nil, cannot perform group replacement") + } + tx, err := s.entClient.Tx(ctx) + if err != nil { + return nil, fmt.Errorf("begin transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + opCtx := dbent.NewTxContext(ctx, tx) + + // 1. 授予新分组权限 + if err := s.userRepo.AddGroupToAllowedGroups(opCtx, userID, newGroupID); err != nil { + return nil, fmt.Errorf("add new group to allowed groups: %w", err) + } + + // 2. 迁移绑定旧分组的 Key 到新分组 + migrated, err := s.apiKeyRepo.UpdateGroupIDByUserAndGroup(opCtx, userID, oldGroupID, newGroupID) + if err != nil { + return nil, fmt.Errorf("migrate api keys: %w", err) + } + + // 3. 移除旧分组权限 + if err := s.userRepo.RemoveGroupFromUserAllowedGroups(opCtx, userID, oldGroupID); err != nil { + return nil, fmt.Errorf("remove old group from allowed groups: %w", err) + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit transaction: %w", err) + } + + // 失效该用户所有 Key 的认证缓存 + if s.authCacheInvalidator != nil { + keys, keyErr := s.apiKeyRepo.ListKeysByUserID(ctx, userID) + if keyErr == nil { + for _, k := range keys { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, k) + } + } + } + + return &ReplaceUserGroupResult{MigratedKeys: migrated}, nil +} + // Account management implementations -func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) { +func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} - accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID) + accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode) if err != nil { return nil, 0, err } @@ -1508,6 +1642,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } } + // Antigravity OAuth 账号:创建后异步设置隐私 + if account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth { + go func() { + defer func() { + if r := recover(); r != nil { + slog.Error("create_account_antigravity_privacy_panic", "account_id", account.ID, "recover", r) + } + }() + s.EnsureAntigravityPrivacy(context.Background(), account) + }() + } + return account, nil } @@ -1530,7 +1676,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U if len(input.Credentials) > 0 { account.Credentials = input.Credentials } - if len(input.Extra) > 0 { + // Extra 使用 map:需要区分“未提供(nil)”与“显式清空({})”。 + // 关闭配额限制时前端会删除 quota_* 键并提交 extra:{},此时也必须落库。 + if input.Extra != nil { // 保留配额用量字段,防止编辑账号时意外重置 for _, key := range []string{"quota_used", "quota_daily_used", "quota_daily_start", "quota_weekly_used", "quota_weekly_start"} { if v, ok := account.Extra[key]; ok { @@ -1785,6 +1933,18 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac if err := s.accountRepo.ClearError(ctx, id); err != nil { return nil, err } + if err := s.accountRepo.ClearRateLimit(ctx, id); err != nil { + return nil, err + } + if err := s.accountRepo.ClearAntigravityQuotaScopes(ctx, id); err != nil { + return nil, err + } + if err := s.accountRepo.ClearModelRateLimits(ctx, id); err != nil { + return nil, err + } + if err := s.accountRepo.ClearTempUnschedulable(ctx, id); err != nil { + return nil, err + } return s.accountRepo.GetByID(ctx, id) } @@ -2560,10 +2720,8 @@ func (s *adminServiceImpl) EnsureOpenAIPrivacy(ctx context.Context, account *Acc if s.privacyClientFactory == nil { return "" } - if account.Extra != nil { - if _, ok := account.Extra["privacy_mode"]; ok { - return "" - } + if shouldSkipOpenAIPrivacyEnsure(account.Extra) { + return "" } token, _ := account.Credentials["access_token"].(string) @@ -2586,3 +2744,115 @@ func (s *adminServiceImpl) EnsureOpenAIPrivacy(ctx context.Context, account *Acc _ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode}) return mode } + +// ForceOpenAIPrivacy 强制重新设置 OpenAI OAuth 账号隐私,无论当前状态。 +func (s *adminServiceImpl) ForceOpenAIPrivacy(ctx context.Context, account *Account) string { + if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth { + return "" + } + if s.privacyClientFactory == nil { + return "" + } + + token, _ := account.Credentials["access_token"].(string) + if token == "" { + return "" + } + + var proxyURL string + if account.ProxyID != nil { + if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil { + proxyURL = p.URL() + } + } + + mode := disableOpenAITraining(ctx, s.privacyClientFactory, token, proxyURL) + if mode == "" { + return "" + } + + if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode}); err != nil { + logger.LegacyPrintf("service.admin", "force_update_openai_privacy_mode_failed: account_id=%d err=%v", account.ID, err) + return mode + } + if account.Extra == nil { + account.Extra = make(map[string]any) + } + account.Extra["privacy_mode"] = mode + return mode +} + +// EnsureAntigravityPrivacy 检查 Antigravity OAuth 账号隐私状态。 +// 如果 Extra["privacy_mode"] 已存在(无论成功或失败),直接跳过。 +// 仅对从未设置过隐私的账号执行 setUserSettings + fetchUserInfo 流程。 +// 用户可通过前端 ForceAntigravityPrivacy(SetPrivacy 按钮)强制重新设置。 +func (s *adminServiceImpl) EnsureAntigravityPrivacy(ctx context.Context, account *Account) string { + if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth { + return "" + } + // 已设置过则跳过(无论成功或失败),用户可通过 Force 手动重试 + if account.Extra != nil { + if existing, ok := account.Extra["privacy_mode"].(string); ok && existing != "" { + return existing + } + } + + token, _ := account.Credentials["access_token"].(string) + if token == "" { + return "" + } + + projectID, _ := account.Credentials["project_id"].(string) + + var proxyURL string + if account.ProxyID != nil { + if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil { + proxyURL = p.URL() + } + } + + mode := setAntigravityPrivacy(ctx, token, projectID, proxyURL) + if mode == "" { + return "" + } + + if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode}); err != nil { + logger.LegacyPrintf("service.admin", "update_antigravity_privacy_mode_failed: account_id=%d err=%v", account.ID, err) + return mode + } + applyAntigravityPrivacyMode(account, mode) + return mode +} + +// ForceAntigravityPrivacy 强制重新设置 Antigravity OAuth 账号隐私,无论当前状态。 +func (s *adminServiceImpl) ForceAntigravityPrivacy(ctx context.Context, account *Account) string { + if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth { + return "" + } + + token, _ := account.Credentials["access_token"].(string) + if token == "" { + return "" + } + + projectID, _ := account.Credentials["project_id"].(string) + + var proxyURL string + if account.ProxyID != nil { + if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil { + proxyURL = p.URL() + } + } + + mode := setAntigravityPrivacy(ctx, token, projectID, proxyURL) + if mode == "" { + return "" + } + + if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode}); err != nil { + logger.LegacyPrintf("service.admin", "force_update_antigravity_privacy_mode_failed: account_id=%d err=%v", account.ID, err) + return mode + } + applyAntigravityPrivacyMode(account, mode) + return mode +} diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index 88d2f492..f9fd6742 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -65,6 +65,9 @@ func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (boo func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + panic("unexpected") +} func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") } @@ -128,6 +131,9 @@ func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, str func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { panic("unexpected") } +func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) { + panic("unexpected") +} func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) { panic("unexpected") } @@ -194,7 +200,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) { panic("unexpected") } -func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) { +func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) { panic("unexpected") } func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { diff --git a/backend/internal/service/admin_service_clear_error_test.go b/backend/internal/service/admin_service_clear_error_test.go new file mode 100644 index 00000000..f039612c --- /dev/null +++ b/backend/internal/service/admin_service_clear_error_test.go @@ -0,0 +1,86 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type accountRepoStubForClearAccountError struct { + mockAccountRepoForGemini + account *Account + clearErrorCalls int + clearRateLimitCalls int + clearAntigravityCalls int + clearModelRateLimitCalls int + clearTempUnschedCalls int +} + +func (r *accountRepoStubForClearAccountError) GetByID(ctx context.Context, id int64) (*Account, error) { + return r.account, nil +} + +func (r *accountRepoStubForClearAccountError) ClearError(ctx context.Context, id int64) error { + r.clearErrorCalls++ + r.account.Status = StatusActive + r.account.ErrorMessage = "" + return nil +} + +func (r *accountRepoStubForClearAccountError) ClearRateLimit(ctx context.Context, id int64) error { + r.clearRateLimitCalls++ + r.account.RateLimitedAt = nil + r.account.RateLimitResetAt = nil + return nil +} + +func (r *accountRepoStubForClearAccountError) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { + r.clearAntigravityCalls++ + return nil +} + +func (r *accountRepoStubForClearAccountError) ClearModelRateLimits(ctx context.Context, id int64) error { + r.clearModelRateLimitCalls++ + return nil +} + +func (r *accountRepoStubForClearAccountError) ClearTempUnschedulable(ctx context.Context, id int64) error { + r.clearTempUnschedCalls++ + r.account.TempUnschedulableUntil = nil + r.account.TempUnschedulableReason = "" + return nil +} + +func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *testing.T) { + until := time.Now().Add(10 * time.Minute) + resetAt := time.Now().Add(5 * time.Minute) + repo := &accountRepoStubForClearAccountError{ + account: &Account{ + ID: 31, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusError, + ErrorMessage: "refresh failed", + RateLimitResetAt: &resetAt, + TempUnschedulableUntil: &until, + TempUnschedulableReason: "missing refresh token", + }, + } + svc := &adminServiceImpl{accountRepo: repo} + + updated, err := svc.ClearAccountError(context.Background(), 31) + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, 1, repo.clearErrorCalls) + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 1, repo.clearModelRateLimitCalls) + require.Equal(t, 1, repo.clearTempUnschedCalls) + require.Nil(t, updated.RateLimitResetAt) + require.Nil(t, updated.TempUnschedulableUntil) + require.Empty(t, updated.TempUnschedulableReason) +} diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index 2e0f7d90..fbc856cf 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -93,6 +93,10 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID panic("unexpected RemoveGroupFromAllowedGroups call") } +func (s *userRepoStub) RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error { + panic("unexpected RemoveGroupFromUserAllowedGroups call") +} + func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { panic("unexpected AddGroupToAllowedGroups call") } @@ -160,7 +164,7 @@ func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, er panic("unexpected ExistsByName call") } -func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { +func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { panic("unexpected GetAccountCount call") } diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index ef77a980..536be0b5 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -100,7 +100,7 @@ func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool, panic("unexpected ExistsByName call") } -func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) { +func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) { panic("unexpected GetAccountCount call") } @@ -383,7 +383,7 @@ func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string panic("unexpected ExistsByName call") } -func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) { +func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) { panic("unexpected GetAccountCount call") } @@ -458,7 +458,7 @@ func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, panic("unexpected ExistsByName call") } -func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) { +func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) { panic("unexpected GetAccountCount call") } diff --git a/backend/internal/service/admin_service_overages_test.go b/backend/internal/service/admin_service_overages_test.go index 779b08b9..d6380f4d 100644 --- a/backend/internal/service/admin_service_overages_test.go +++ b/backend/internal/service/admin_service_overages_test.go @@ -121,3 +121,35 @@ func TestUpdateAccount_EnableOveragesClearsModelRateLimitsBeforePersist(t *testi _, exists := repo.account.Extra[modelRateLimitsKey] require.False(t, exists, "开启 overages 时应在持久化前清掉旧模型限流") } + +func TestUpdateAccount_EmptyExtraPayloadCanClearQuotaLimits(t *testing.T) { + accountID := int64(103) + repo := &updateAccountOveragesRepoStub{ + account: &Account{ + ID: accountID, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Status: StatusActive, + Extra: map[string]any{ + "quota_limit": 100.0, + "quota_daily_limit": 10.0, + "quota_weekly_limit": 40.0, + }, + }, + } + + svc := &adminServiceImpl{accountRepo: repo} + updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{ + // 显式空对象:语义是“清空 extra 中的可配置键”(例如关闭配额限制) + Extra: map[string]any{}, + }) + + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, 1, repo.updateCalls) + require.NotNil(t, repo.account.Extra) + require.NotContains(t, repo.account.Extra, "quota_limit") + require.NotContains(t, repo.account.Extra, "quota_daily_limit") + require.NotContains(t, repo.account.Extra, "quota_weekly_limit") + require.Len(t, repo.account.Extra, 0) +} diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go index ff58fd01..eb213e6a 100644 --- a/backend/internal/service/admin_service_search_test.go +++ b/backend/internal/service/admin_service_search_test.go @@ -19,18 +19,20 @@ type accountRepoStubForAdminList struct { listWithFiltersType string listWithFiltersStatus string listWithFiltersSearch string + listWithFiltersPrivacy string listWithFiltersAccounts []Account listWithFiltersResult *pagination.PaginationResult listWithFiltersErr error } -func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { s.listWithFiltersCalls++ s.listWithFiltersParams = params s.listWithFiltersPlatform = platform s.listWithFiltersType = accountType s.listWithFiltersStatus = status s.listWithFiltersSearch = search + s.listWithFiltersPrivacy = privacyMode if s.listWithFiltersErr != nil { return nil, nil, s.listWithFiltersErr @@ -168,7 +170,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0) + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "") require.NoError(t, err) require.Equal(t, int64(10), total) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) @@ -182,6 +184,22 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { }) } +func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) { + t.Run("privacy_mode 参数正常传递到 repository 层", func(t *testing.T) { + repo := &accountRepoStubForAdminList{ + listWithFiltersAccounts: []Account{{ID: 2, Name: "acc2"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 1}, + } + svc := &adminServiceImpl{accountRepo: repo} + + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked) + require.NoError(t, err) + require.Equal(t, int64(1), total) + require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts) + require.Equal(t, PrivacyModeCFBlocked, repo.listWithFiltersPrivacy) + }) +} + func TestAdminService_ListProxies_WithSearch(t *testing.T) { t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { repo := &proxyRepoStubForAdminList{ diff --git a/backend/internal/service/antigravity_credits_overages.go b/backend/internal/service/antigravity_credits_overages.go index 1521dfcd..ec365085 100644 --- a/backend/internal/service/antigravity_credits_overages.go +++ b/backend/internal/service/antigravity_credits_overages.go @@ -45,6 +45,7 @@ var ( "minimumcreditamountforusage", "minimum credit amount for usage", "minimum credit", + "resource has been exhausted", } ) @@ -147,9 +148,9 @@ func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr err if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout { return false } - if isURLLevelRateLimit(respBody) { - return false - } + // 注意:不再检查 isURLLevelRateLimit。此函数仅在积分重试失败后调用, + // 如果注入 enabledCreditTypes 后仍返回 "Resource has been exhausted", + // 说明积分也已耗尽,应该标记。clearCreditsExhausted 会在后续成功时自动清除。 if info := parseAntigravitySmartRetryInfo(respBody); info != nil { return false } diff --git a/backend/internal/service/antigravity_credits_overages_test.go b/backend/internal/service/antigravity_credits_overages_test.go index bc679494..7a5224da 100644 --- a/backend/internal/service/antigravity_credits_overages_test.go +++ b/backend/internal/service/antigravity_credits_overages_test.go @@ -406,10 +406,16 @@ func TestShouldMarkCreditsExhausted(t *testing.T) { require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil)) }) - t.Run("URL 级限流不标记", func(t *testing.T) { + t.Run("Resource has been exhausted 应标记为积分耗尽", func(t *testing.T) { resp := &http.Response{StatusCode: http.StatusTooManyRequests} body := []byte(`{"error":{"message":"Resource has been exhausted"}}`) - require.False(t, shouldMarkCreditsExhausted(resp, body, nil)) + require.True(t, shouldMarkCreditsExhausted(resp, body, nil)) + }) + + t.Run("Resource has been exhausted (check quota) 完整格式应标记", func(t *testing.T) { + resp := &http.Response{StatusCode: http.StatusTooManyRequests} + body := []byte(`{"error":{"code":429,"message":"Resource has been exhausted (e.g. check quota).","status":"RESOURCE_EXHAUSTED"}}`) + require.True(t, shouldMarkCreditsExhausted(resp, body, nil)) }) t.Run("结构化限流不标记", func(t *testing.T) { diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index cafc2a79..a76e59fb 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -614,6 +614,7 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP urlFallbackLoop: for urlIdx, baseURL := range availableURLs { usedBaseURL = baseURL + allAttemptsInternal500 := true // 追踪本轮所有 attempt 是否全部命中 INTERNAL 500 for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { select { case <-p.ctx.Done(): @@ -643,6 +644,7 @@ urlFallbackLoop: AccountID: p.account.ID, AccountName: p.account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "request_error", Message: safeErr, }) @@ -720,6 +722,7 @@ urlFallbackLoop: AccountName: p.account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "retry", Message: upstreamMsg, Detail: getUpstreamDetail(respBody), @@ -754,6 +757,7 @@ urlFallbackLoop: AccountName: p.account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "retry", Message: upstreamMsg, Detail: getUpstreamDetail(respBody), @@ -763,10 +767,19 @@ urlFallbackLoop: logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix) return nil, p.ctx.Err() } + // 追踪 INTERNAL 500:非匹配的 attempt 清除标记 + if !isAntigravityInternalServerError(resp.StatusCode, respBody) { + allAttemptsInternal500 = false + } continue } } + // INTERNAL 500 渐进惩罚:3 次重试全部命中特定 500 时递增计数器并惩罚 + if allAttemptsInternal500 && isAntigravityInternalServerError(resp.StatusCode, respBody) { + s.handleInternal500RetryExhausted(p.ctx, p.prefix, p.account) + } + // 其他 4xx 错误或重试用尽,直接返回 resp = &http.Response{ StatusCode: resp.StatusCode, @@ -785,6 +798,11 @@ urlFallbackLoop: antigravity.DefaultURLAvailability.MarkSuccess(usedBaseURL) } + // 成功响应时清零 INTERNAL 500 连续失败计数器(覆盖所有成功路径,含 smart retry) + if resp != nil && resp.StatusCode < 400 { + s.resetInternal500Counter(p.ctx, p.prefix, p.account.ID) + } + return &antigravityRetryLoopResult{resp: resp}, nil } @@ -859,6 +877,7 @@ type AntigravityGatewayService struct { settingService *SettingService cache GatewayCache // 用于模型级限流时清除粘性会话绑定 schedulerSnapshot *SchedulerSnapshotService + internal500Cache Internal500CounterCache // INTERNAL 500 渐进惩罚计数器 } func NewAntigravityGatewayService( @@ -869,6 +888,7 @@ func NewAntigravityGatewayService( rateLimitService *RateLimitService, httpUpstream HTTPUpstream, settingService *SettingService, + internal500Cache Internal500CounterCache, ) *AntigravityGatewayService { return &AntigravityGatewayService{ accountRepo: accountRepo, @@ -878,6 +898,7 @@ func NewAntigravityGatewayService( settingService: settingService, cache: cache, schedulerSnapshot: schedulerSnapshot, + internal500Cache: internal500Cache, } } @@ -930,7 +951,7 @@ func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParam case ErrorPolicyTempUnscheduled: slog.Info("temp_unschedulable_matched", "prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID) - return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession} + return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, RateLimitedModel: p.requestedModel, IsStickySession: p.isStickySession} } return false, statusCode, nil } @@ -1001,8 +1022,9 @@ type TestConnectionResult struct { MappedModel string // 实际使用的模型 } -// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费) -// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择 +// TestConnection 测试 Antigravity 账号连接。 +// 复用 antigravityRetryLoop 的完整重试 / credits overages / 智能重试逻辑, +// 与真实调度行为一致。差异:不做账号切换(测试指定账号)、不记录 ops 错误。 func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { // 获取 token @@ -1026,10 +1048,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account // 构建请求体 var requestBody []byte if strings.HasPrefix(modelID, "gemini-") { - // Gemini 模型:直接使用 Gemini 格式 requestBody, err = s.buildGeminiTestRequest(projectID, mappedModel) } else { - // Claude 模型:使用协议转换 requestBody, err = s.buildClaudeTestRequest(projectID, mappedModel) } if err != nil { @@ -1042,64 +1062,63 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account proxyURL = account.Proxy.URL() } - baseURL := resolveAntigravityForwardBaseURL() - if baseURL == "" { - return nil, errors.New("no antigravity forward base url configured") - } - availableURLs := []string{baseURL} - - var lastErr error - for urlIdx, baseURL := range availableURLs { - // 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致) - req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody) - if err != nil { - lastErr = err - continue - } - - // 调试日志:Test 请求信息 - logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String()) - - // 发送请求 - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - lastErr = fmt.Errorf("请求失败: %w", err) - if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) - continue - } - return nil, lastErr - } - - // 读取响应 - respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏 - if err != nil { - return nil, fmt.Errorf("读取响应失败: %w", err) - } - - // 检查是否需要 URL 降级 - if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { - logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) - continue - } - - if resp.StatusCode >= 400 { - return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) - } - - // 解析流式响应,提取文本 - text := extractTextFromSSEResponse(respBody) - - // 标记成功的 URL,下次优先使用 - antigravity.DefaultURLAvailability.MarkSuccess(baseURL) - return &TestConnectionResult{ - Text: text, - MappedModel: mappedModel, - }, nil + // 复用 antigravityRetryLoop:完整的重试 / credits overages / 智能重试 + prefix := fmt.Sprintf("[antigravity-Test] account=%d(%s)", account.ID, account.Name) + p := antigravityRetryLoopParams{ + ctx: ctx, + prefix: prefix, + account: account, + proxyURL: proxyURL, + accessToken: accessToken, + action: "streamGenerateContent", + body: requestBody, + c: nil, // 无 gin.Context → 跳过 ops 追踪 + httpUpstream: s.httpUpstream, + settingService: s.settingService, + accountRepo: s.accountRepo, + requestedModel: modelID, + handleError: testConnectionHandleError, } - return nil, lastErr + result, err := s.antigravityRetryLoop(p) + if err != nil { + // AccountSwitchError → 测试时不切换账号,返回友好提示 + var switchErr *AntigravityAccountSwitchError + if errors.As(err, &switchErr) { + return nil, fmt.Errorf("该账号模型 %s 当前限流中,请稍后重试", switchErr.RateLimitedModel) + } + return nil, err + } + + if result == nil || result.resp == nil { + return nil, errors.New("upstream returned empty response") + } + defer func() { _ = result.resp.Body.Close() }() + + respBody, err := io.ReadAll(io.LimitReader(result.resp.Body, 2<<20)) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + if result.resp.StatusCode >= 400 { + return nil, fmt.Errorf("API 返回 %d: %s", result.resp.StatusCode, string(respBody)) + } + + text := extractTextFromSSEResponse(respBody) + return &TestConnectionResult{Text: text, MappedModel: mappedModel}, nil +} + +// testConnectionHandleError 是 TestConnection 使用的轻量 handleError 回调。 +// 仅记录日志,不做 ops 错误追踪或粘性会话清除。 +func testConnectionHandleError( + _ context.Context, prefix string, account *Account, + statusCode int, _ http.Header, body []byte, + requestedModel string, _ int64, _ string, _ bool, +) *handleModelRateLimitResult { + logger.LegacyPrintf("service.antigravity_gateway", + "%s test_handle_error status=%d model=%s account=%d body=%s", + prefix, statusCode, requestedModel, account.ID, truncateForLog(body, 200)) + return nil } // buildGeminiTestRequest 构建 Gemini 格式测试请求 @@ -1361,7 +1380,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) if err != nil { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "authentication_error", "Failed to get upstream access token") + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + ResponseBody: []byte(`{"error":{"type":"authentication_error","message":"Failed to get upstream access token"},"type":"error"}`), + } } // 获取 project_id(部分账户类型可能没有) @@ -1741,7 +1763,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: billingModel, // 使用映射模型用于计费和日志 + Model: originalModel, + UpstreamModel: billingModel, Stream: claudeReq.Stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -2103,7 +2126,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co } accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) if err != nil { - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Failed to get upstream access token") + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + ResponseBody: []byte(`{"error":{"message":"Failed to get upstream access token","status":"UNAVAILABLE"}}`), + } } // 获取 project_id(部分账户类型可能没有) @@ -2431,7 +2457,8 @@ handleSuccess: return &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: billingModel, + Model: originalModel, + UpstreamModel: billingModel, Stream: stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -3079,6 +3106,22 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context intervalCh = intervalTicker.C } + // 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开 + keepaliveInterval := time.Duration(0) + if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } + lastDataAt := time.Now() + cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini") // 仅发送一次错误事件,避免多次写入导致协议混乱 @@ -3111,6 +3154,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context return nil, ev.err } + lastDataAt = time.Now() + line := ev.line trimmed := strings.TrimRight(line, "\r\n") if strings.HasPrefix(trimmed, "data:") { @@ -3170,6 +3215,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)") sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if cw.Disconnected() { + continue + } + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + // SSE ping/keepalive:保持连接活跃防止 Cloudflare Tunnel 等代理断开 + if !cw.Fprintf(":\n\n") { + logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity gemini), continuing to drain upstream for billing") + continue + } } } } @@ -3895,6 +3953,22 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context intervalCh = intervalTicker.C } + // 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开 + keepaliveInterval := time.Duration(0) + if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } + lastDataAt := time.Now() + cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude") // 仅发送一次错误事件,避免多次写入导致协议混乱 @@ -3947,6 +4021,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context return nil, fmt.Errorf("stream read error: %w", ev.err) } + lastDataAt = time.Now() + // 处理 SSE 行,转换为 Claude 格式 claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n")) if len(claudeEvents) > 0 { @@ -3969,6 +4045,20 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)") sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if cw.Disconnected() { + continue + } + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + // SSE ping 事件:Anthropic 原生格式,客户端会正确处理, + // 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开 + if !cw.Fprintf("event: ping\ndata: {\"type\": \"ping\"}\n\n") { + logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity claude), continuing to drain upstream for billing") + continue + } } } } @@ -4299,6 +4389,22 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp intervalCh = intervalTicker.C } + // 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开 + keepaliveInterval := time.Duration(0) + if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } + lastDataAt := time.Now() + flusher, _ := c.Writer.(http.Flusher) cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream") @@ -4316,6 +4422,8 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs} } + lastDataAt = time.Now() + line := ev.line // 记录首 token 时间 @@ -4341,6 +4449,20 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp } logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity upstream)") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs} + + case <-keepaliveCh: + if cw.Disconnected() { + continue + } + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + // SSE ping 事件:Anthropic 原生格式,客户端会正确处理, + // 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开 + if !cw.Fprintf("event: ping\ndata: {\"type\": \"ping\"}\n\n") { + logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity upstream), continuing to drain upstream for billing") + continue + } } } } diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 6e0a7305..1eb1451e 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -15,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -130,7 +131,7 @@ func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http. return s.resp, s.err } -func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) { +func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ *tlsfingerprint.Profile) (*http.Response, error) { return s.resp, s.err } @@ -171,7 +172,7 @@ func (s *queuedHTTPUpstreamStub) Do(req *http.Request, _ string, _ int64, _ int) return resp, err } -func (s *queuedHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, concurrency int, _ bool) (*http.Response, error) { +func (s *queuedHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, concurrency int, _ *tlsfingerprint.Profile) (*http.Response, error) { return s.Do(req, proxyURL, accountID, concurrency) } @@ -542,7 +543,8 @@ func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) { result, err := svc.Forward(context.Background(), c, account, body, false) require.NoError(t, err) require.NotNil(t, result) - require.Equal(t, mappedModel, result.Model) + require.Equal(t, "claude-sonnet-4-5", result.Model) + require.Equal(t, mappedModel, result.UpstreamModel) } // TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel @@ -594,7 +596,8 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false) require.NoError(t, err) require.NotNil(t, result) - require.Equal(t, mappedModel, result.Model) + require.Equal(t, "gemini-2.5-flash", result.Model) + require.Equal(t, mappedModel, result.UpstreamModel) } func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) { @@ -664,7 +667,8 @@ func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignatur result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false) require.NoError(t, err) require.NotNil(t, result) - require.Equal(t, mappedModel, result.Model) + require.Equal(t, originalModel, result.Model) + require.Equal(t, mappedModel, result.UpstreamModel) require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry") firstReq := string(upstream.requestBodies[0]) diff --git a/backend/internal/service/antigravity_internal500_penalty.go b/backend/internal/service/antigravity_internal500_penalty.go new file mode 100644 index 00000000..747a4d4e --- /dev/null +++ b/backend/internal/service/antigravity_internal500_penalty.go @@ -0,0 +1,97 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "time" + + "github.com/tidwall/gjson" +) + +// INTERNAL 500 渐进惩罚:连续多轮全部返回特定 500 错误时的惩罚时长 +const ( + internal500PenaltyTier1Duration = 30 * time.Minute // 第 1 轮:临时不可调度 30 分钟 + internal500PenaltyTier2Duration = 2 * time.Hour // 第 2 轮:临时不可调度 2 小时 + internal500PenaltyTier3Threshold = 3 // 第 3+ 轮:永久禁用 +) + +// isAntigravityInternalServerError 检测特定的 INTERNAL 500 错误 +// 必须同时匹配 error.code==500, error.message=="Internal error encountered.", error.status=="INTERNAL" +func isAntigravityInternalServerError(statusCode int, body []byte) bool { + if statusCode != http.StatusInternalServerError { + return false + } + return gjson.GetBytes(body, "error.code").Int() == 500 && + gjson.GetBytes(body, "error.message").String() == "Internal error encountered." && + gjson.GetBytes(body, "error.status").String() == "INTERNAL" +} + +// applyInternal500Penalty 根据连续 INTERNAL 500 轮次数应用渐进惩罚 +// count=1: temp_unschedulable 10 分钟 +// count=2: temp_unschedulable 10 小时 +// count>=3: SetError 永久禁用 +func (s *AntigravityGatewayService) applyInternal500Penalty( + ctx context.Context, prefix string, account *Account, count int64, +) { + switch { + case count >= int64(internal500PenaltyTier3Threshold): + reason := fmt.Sprintf("INTERNAL 500 consecutive failures: %d rounds", count) + if err := s.accountRepo.SetError(ctx, account.ID, reason); err != nil { + slog.Error("internal500_set_error_failed", "account_id", account.ID, "error", err) + return + } + slog.Warn("internal500_account_disabled", + "account_id", account.ID, "account_name", account.Name, "consecutive_count", count) + case count == 2: + until := time.Now().Add(internal500PenaltyTier2Duration) + reason := fmt.Sprintf("INTERNAL 500 x%d (temp unsched %v)", count, internal500PenaltyTier2Duration) + if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { + slog.Error("internal500_temp_unsched_failed", "account_id", account.ID, "error", err) + return + } + slog.Warn("internal500_temp_unschedulable", + "account_id", account.ID, "account_name", account.Name, + "duration", internal500PenaltyTier2Duration, "consecutive_count", count) + case count == 1: + until := time.Now().Add(internal500PenaltyTier1Duration) + reason := fmt.Sprintf("INTERNAL 500 x%d (temp unsched %v)", count, internal500PenaltyTier1Duration) + if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { + slog.Error("internal500_temp_unsched_failed", "account_id", account.ID, "error", err) + return + } + slog.Info("internal500_temp_unschedulable", + "account_id", account.ID, "account_name", account.Name, + "duration", internal500PenaltyTier1Duration, "consecutive_count", count) + } +} + +// handleInternal500RetryExhausted 处理 INTERNAL 500 重试耗尽:递增计数器并应用惩罚 +func (s *AntigravityGatewayService) handleInternal500RetryExhausted( + ctx context.Context, prefix string, account *Account, +) { + if s.internal500Cache == nil { + return + } + count, err := s.internal500Cache.IncrementInternal500Count(ctx, account.ID) + if err != nil { + slog.Error("internal500_counter_increment_failed", + "prefix", prefix, "account_id", account.ID, "error", err) + return + } + s.applyInternal500Penalty(ctx, prefix, account, count) +} + +// resetInternal500Counter 成功响应时清零 INTERNAL 500 计数器 +func (s *AntigravityGatewayService) resetInternal500Counter( + ctx context.Context, prefix string, accountID int64, +) { + if s.internal500Cache == nil { + return + } + if err := s.internal500Cache.ResetInternal500Count(ctx, accountID); err != nil { + slog.Error("internal500_counter_reset_failed", + "prefix", prefix, "account_id", accountID, "error", err) + } +} diff --git a/backend/internal/service/antigravity_internal500_penalty_test.go b/backend/internal/service/antigravity_internal500_penalty_test.go new file mode 100644 index 00000000..03831839 --- /dev/null +++ b/backend/internal/service/antigravity_internal500_penalty_test.go @@ -0,0 +1,321 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// --- mock: Internal500CounterCache --- + +type mockInternal500Cache struct { + incrementCount int64 + incrementErr error + resetErr error + + incrementCalls []int64 // 记录 IncrementInternal500Count 被调用时的 accountID + resetCalls []int64 // 记录 ResetInternal500Count 被调用时的 accountID +} + +func (m *mockInternal500Cache) IncrementInternal500Count(_ context.Context, accountID int64) (int64, error) { + m.incrementCalls = append(m.incrementCalls, accountID) + return m.incrementCount, m.incrementErr +} + +func (m *mockInternal500Cache) ResetInternal500Count(_ context.Context, accountID int64) error { + m.resetCalls = append(m.resetCalls, accountID) + return m.resetErr +} + +// --- mock: 专用于 internal500 惩罚测试的 AccountRepository --- + +type internal500AccountRepoStub struct { + AccountRepository // 嵌入接口,未实现的方法会 panic(不应被调用) + + tempUnschedCalls []tempUnschedCall + setErrorCalls []setErrorCall +} + +type tempUnschedCall struct { + accountID int64 + until time.Time + reason string +} + +type setErrorCall struct { + accountID int64 + reason string +} + +func (r *internal500AccountRepoStub) SetTempUnschedulable(_ context.Context, id int64, until time.Time, reason string) error { + r.tempUnschedCalls = append(r.tempUnschedCalls, tempUnschedCall{accountID: id, until: until, reason: reason}) + return nil +} + +func (r *internal500AccountRepoStub) SetError(_ context.Context, id int64, errorMsg string) error { + r.setErrorCalls = append(r.setErrorCalls, setErrorCall{accountID: id, reason: errorMsg}) + return nil +} + +// ============================================================================= +// TestIsAntigravityInternalServerError +// ============================================================================= + +func TestIsAntigravityInternalServerError(t *testing.T) { + t.Run("匹配完整的 INTERNAL 500 body", func(t *testing.T) { + body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`) + require.True(t, isAntigravityInternalServerError(500, body)) + }) + + t.Run("statusCode 不是 500", func(t *testing.T) { + body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`) + require.False(t, isAntigravityInternalServerError(429, body)) + require.False(t, isAntigravityInternalServerError(503, body)) + require.False(t, isAntigravityInternalServerError(200, body)) + }) + + t.Run("body 中 message 不匹配", func(t *testing.T) { + body := []byte(`{"error":{"code":500,"message":"Some other error","status":"INTERNAL"}}`) + require.False(t, isAntigravityInternalServerError(500, body)) + }) + + t.Run("body 中 status 不匹配", func(t *testing.T) { + body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"UNAVAILABLE"}}`) + require.False(t, isAntigravityInternalServerError(500, body)) + }) + + t.Run("body 中 code 不匹配", func(t *testing.T) { + body := []byte(`{"error":{"code":503,"message":"Internal error encountered.","status":"INTERNAL"}}`) + require.False(t, isAntigravityInternalServerError(500, body)) + }) + + t.Run("空 body", func(t *testing.T) { + require.False(t, isAntigravityInternalServerError(500, []byte{})) + require.False(t, isAntigravityInternalServerError(500, nil)) + }) + + t.Run("其他 500 错误格式(纯文本)", func(t *testing.T) { + body := []byte(`Internal Server Error`) + require.False(t, isAntigravityInternalServerError(500, body)) + }) + + t.Run("其他 500 错误格式(不同 JSON 结构)", func(t *testing.T) { + body := []byte(`{"message":"Internal Server Error","statusCode":500}`) + require.False(t, isAntigravityInternalServerError(500, body)) + }) +} + +// ============================================================================= +// TestApplyInternal500Penalty +// ============================================================================= + +func TestApplyInternal500Penalty(t *testing.T) { + t.Run("count=1 → SetTempUnschedulable 10 分钟", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 1, Name: "acc-1"} + + before := time.Now() + svc.applyInternal500Penalty(context.Background(), "[test]", account, 1) + after := time.Now() + + require.Len(t, repo.tempUnschedCalls, 1) + require.Empty(t, repo.setErrorCalls) + + call := repo.tempUnschedCalls[0] + require.Equal(t, int64(1), call.accountID) + require.Contains(t, call.reason, "INTERNAL 500") + // until 应在 [before+10m, after+10m] 范围内 + require.True(t, call.until.After(before.Add(internal500PenaltyTier1Duration).Add(-time.Second))) + require.True(t, call.until.Before(after.Add(internal500PenaltyTier1Duration).Add(time.Second))) + }) + + t.Run("count=2 → SetTempUnschedulable 10 小时", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 2, Name: "acc-2"} + + before := time.Now() + svc.applyInternal500Penalty(context.Background(), "[test]", account, 2) + after := time.Now() + + require.Len(t, repo.tempUnschedCalls, 1) + require.Empty(t, repo.setErrorCalls) + + call := repo.tempUnschedCalls[0] + require.Equal(t, int64(2), call.accountID) + require.Contains(t, call.reason, "INTERNAL 500") + require.True(t, call.until.After(before.Add(internal500PenaltyTier2Duration).Add(-time.Second))) + require.True(t, call.until.Before(after.Add(internal500PenaltyTier2Duration).Add(time.Second))) + }) + + t.Run("count=3 → SetError 永久禁用", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 3, Name: "acc-3"} + + svc.applyInternal500Penalty(context.Background(), "[test]", account, 3) + + require.Empty(t, repo.tempUnschedCalls) + require.Len(t, repo.setErrorCalls, 1) + + call := repo.setErrorCalls[0] + require.Equal(t, int64(3), call.accountID) + require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 3") + }) + + t.Run("count=5 → SetError 永久禁用(>=3 都走永久禁用)", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 5, Name: "acc-5"} + + svc.applyInternal500Penalty(context.Background(), "[test]", account, 5) + + require.Empty(t, repo.tempUnschedCalls) + require.Len(t, repo.setErrorCalls, 1) + + call := repo.setErrorCalls[0] + require.Equal(t, int64(5), call.accountID) + require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 5") + }) + + t.Run("count=0 → 不调用任何方法", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 10, Name: "acc-10"} + + svc.applyInternal500Penalty(context.Background(), "[test]", account, 0) + + require.Empty(t, repo.tempUnschedCalls) + require.Empty(t, repo.setErrorCalls) + }) +} + +// ============================================================================= +// TestHandleInternal500RetryExhausted +// ============================================================================= + +func TestHandleInternal500RetryExhausted(t *testing.T) { + t.Run("internal500Cache 为 nil → 不 panic,不调用任何方法", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{ + accountRepo: repo, + internal500Cache: nil, + } + account := &Account{ID: 1, Name: "acc-1"} + + // 不应 panic + require.NotPanics(t, func() { + svc.handleInternal500RetryExhausted(context.Background(), "[test]", account) + }) + require.Empty(t, repo.tempUnschedCalls) + require.Empty(t, repo.setErrorCalls) + }) + + t.Run("IncrementInternal500Count 返回 error → 不调用惩罚方法", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + cache := &mockInternal500Cache{ + incrementErr: errors.New("redis connection error"), + } + svc := &AntigravityGatewayService{ + accountRepo: repo, + internal500Cache: cache, + } + account := &Account{ID: 2, Name: "acc-2"} + + svc.handleInternal500RetryExhausted(context.Background(), "[test]", account) + + require.Len(t, cache.incrementCalls, 1) + require.Equal(t, int64(2), cache.incrementCalls[0]) + require.Empty(t, repo.tempUnschedCalls) + require.Empty(t, repo.setErrorCalls) + }) + + t.Run("IncrementInternal500Count 返回 count=1 → 触发 tier1 惩罚", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + cache := &mockInternal500Cache{ + incrementCount: 1, + } + svc := &AntigravityGatewayService{ + accountRepo: repo, + internal500Cache: cache, + } + account := &Account{ID: 3, Name: "acc-3"} + + svc.handleInternal500RetryExhausted(context.Background(), "[test]", account) + + require.Len(t, cache.incrementCalls, 1) + require.Equal(t, int64(3), cache.incrementCalls[0]) + // tier1: SetTempUnschedulable + require.Len(t, repo.tempUnschedCalls, 1) + require.Equal(t, int64(3), repo.tempUnschedCalls[0].accountID) + require.Empty(t, repo.setErrorCalls) + }) + + t.Run("IncrementInternal500Count 返回 count=3 → 触发 tier3 永久禁用", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + cache := &mockInternal500Cache{ + incrementCount: 3, + } + svc := &AntigravityGatewayService{ + accountRepo: repo, + internal500Cache: cache, + } + account := &Account{ID: 4, Name: "acc-4"} + + svc.handleInternal500RetryExhausted(context.Background(), "[test]", account) + + require.Len(t, cache.incrementCalls, 1) + require.Empty(t, repo.tempUnschedCalls) + require.Len(t, repo.setErrorCalls, 1) + require.Equal(t, int64(4), repo.setErrorCalls[0].accountID) + }) +} + +// ============================================================================= +// TestResetInternal500Counter +// ============================================================================= + +func TestResetInternal500Counter(t *testing.T) { + t.Run("internal500Cache 为 nil → 不 panic", func(t *testing.T) { + svc := &AntigravityGatewayService{ + internal500Cache: nil, + } + + require.NotPanics(t, func() { + svc.resetInternal500Counter(context.Background(), "[test]", 1) + }) + }) + + t.Run("ResetInternal500Count 返回 error → 不 panic(仅日志)", func(t *testing.T) { + cache := &mockInternal500Cache{ + resetErr: errors.New("redis timeout"), + } + svc := &AntigravityGatewayService{ + internal500Cache: cache, + } + + require.NotPanics(t, func() { + svc.resetInternal500Counter(context.Background(), "[test]", 42) + }) + require.Len(t, cache.resetCalls, 1) + require.Equal(t, int64(42), cache.resetCalls[0]) + }) + + t.Run("正常调用 → 调用 ResetInternal500Count", func(t *testing.T) { + cache := &mockInternal500Cache{} + svc := &AntigravityGatewayService{ + internal500Cache: cache, + } + + svc.resetInternal500Counter(context.Background(), "[test]", 99) + + require.Len(t, cache.resetCalls, 1) + require.Equal(t, int64(99), cache.resetCalls[0]) + }) +} diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go index 71939d26..a29000e7 100644 --- a/backend/internal/service/antigravity_model_mapping_test.go +++ b/backend/internal/service/antigravity_model_mapping_test.go @@ -57,16 +57,16 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { expected: "claude-opus-4-6-thinking", }, { - name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-5", + name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-6", requestedModel: "claude-haiku-4-5", accountMapping: nil, - expected: "claude-sonnet-4-5", + expected: "claude-sonnet-4-6", }, { - name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5", + name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-6", requestedModel: "claude-haiku-4-5-20251001", accountMapping: nil, - expected: "claude-sonnet-4-5", + expected: "claude-sonnet-4-6", }, { name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5", @@ -268,6 +268,12 @@ func TestMapAntigravityModel_WildcardTargetEqualsRequest(t *testing.T) { requestedModel: "gemini-2.5-flash", expected: "gemini-2.5-flash", }, + { + name: "customtools alias falls back to normalized preview mapping", + modelMapping: map[string]any{"gemini-3.1-pro-preview": "gemini-3.1-pro-high"}, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: "gemini-3.1-pro-high", + }, } for _, tt := range tests { diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go index 5f6691be..a300d898 100644 --- a/backend/internal/service/antigravity_oauth_service.go +++ b/backend/internal/service/antigravity_oauth_service.go @@ -89,7 +89,8 @@ type AntigravityTokenInfo struct { TokenType string `json:"token_type"` Email string `json:"email,omitempty"` ProjectID string `json:"project_id,omitempty"` - ProjectIDMissing bool `json:"-"` // LoadCodeAssist 未返回 project_id + ProjectIDMissing bool `json:"-"` + PlanType string `json:"-"` } // ExchangeCode 用 authorization code 交换 token @@ -145,13 +146,17 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig result.Email = userInfo.Email } - // 获取 project_id(部分账户类型可能没有),失败时重试 - projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3) + // 获取 project_id + plan_type(部分账户类型可能没有),失败时重试 + loadResult, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3) if loadErr != nil { fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr) result.ProjectIDMissing = true - } else { - result.ProjectID = projectID + } + if loadResult != nil { + result.ProjectID = loadResult.ProjectID + if loadResult.Subscription != nil { + result.PlanType = loadResult.Subscription.PlanType + } } return result, nil @@ -192,6 +197,10 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken if isNonRetryableAntigravityOAuthError(err) { return nil, err } + // 代理连接错误(TCP 超时、连接拒绝、DNS 失败)不重试,立即返回 + if antigravity.IsConnectionError(err) { + return nil, fmt.Errorf("proxy unavailable: %w", err) + } lastErr = err } @@ -226,13 +235,17 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr tokenInfo.Email = userInfo.Email } - // 获取 project_id(容错,失败不阻塞) - projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3) + // 获取 project_id + plan_type(容错,失败不阻塞) + loadResult, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3) if loadErr != nil { fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr) tokenInfo.ProjectIDMissing = true - } else { - tokenInfo.ProjectID = projectID + } + if loadResult != nil { + tokenInfo.ProjectID = loadResult.ProjectID + if loadResult.Subscription != nil { + tokenInfo.PlanType = loadResult.Subscription.PlanType + } } return tokenInfo, nil @@ -284,33 +297,42 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou tokenInfo.Email = existingEmail } - // 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试 + // 每次刷新都调用 LoadCodeAssist 获取 project_id + plan_type,失败时重试 existingProjectID := strings.TrimSpace(account.GetCredential("project_id")) - projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3) + loadResult, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3) if loadErr != nil { - // LoadCodeAssist 失败,保留原有 project_id tokenInfo.ProjectID = existingProjectID - // 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失 - // 如果之前有 project_id,本次只是临时故障,不应标记为错误 if existingProjectID == "" { tokenInfo.ProjectIDMissing = true } - } else { - tokenInfo.ProjectID = projectID + } + if loadResult != nil { + if loadResult.ProjectID != "" { + tokenInfo.ProjectID = loadResult.ProjectID + } + if loadResult.Subscription != nil { + tokenInfo.PlanType = loadResult.Subscription.PlanType + } } return tokenInfo, nil } -// loadProjectIDWithRetry 带重试机制获取 project_id -// 返回 project_id 和错误,失败时会重试指定次数 -func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) { +// loadCodeAssistResult 封装 loadProjectIDWithRetry 的返回结果, +// 同时携带从 LoadCodeAssist 响应中提取的 plan_type 信息。 +type loadCodeAssistResult struct { + ProjectID string + Subscription *AntigravitySubscriptionResult +} + +// loadProjectIDWithRetry 带重试机制获取 project_id,同时从响应中提取 plan_type。 +func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (*loadCodeAssistResult, error) { var lastErr error + var lastSubscription *AntigravitySubscriptionResult for attempt := 0; attempt <= maxRetries; attempt++ { if attempt > 0 { - // 指数退避:1s, 2s, 4s backoff := time.Duration(1< 8*time.Second { backoff = 8 * time.Second @@ -320,24 +342,34 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac client, err := antigravity.NewClient(proxyURL) if err != nil { - return "", fmt.Errorf("create antigravity client failed: %w", err) + return nil, fmt.Errorf("create antigravity client failed: %w", err) } loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken) + if loadResp != nil { + sub := NormalizeAntigravitySubscription(loadResp) + lastSubscription = &sub + } + if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" { - return loadResp.CloudAICompanionProject, nil + return &loadCodeAssistResult{ + ProjectID: loadResp.CloudAICompanionProject, + Subscription: lastSubscription, + }, nil } if err == nil { if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" { - return projectID, nil + return &loadCodeAssistResult{ + ProjectID: projectID, + Subscription: lastSubscription, + }, nil } else if onboardErr != nil { lastErr = onboardErr continue } } - // 记录错误 if err != nil { lastErr = err } else if loadResp == nil { @@ -347,7 +379,10 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac } } - return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr) + if lastSubscription != nil { + return &loadCodeAssistResult{Subscription: lastSubscription}, fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr) + } + return nil, fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr) } func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) { @@ -406,7 +441,11 @@ func (s *AntigravityOAuthService) FillProjectID(ctx context.Context, account *Ac proxyURL = proxy.URL() } } - return s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3) + result, err := s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3) + if result != nil { + return result.ProjectID, err + } + return "", err } // BuildAccountCredentials 构建账户凭证 @@ -427,6 +466,9 @@ func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *Antigravity if tokenInfo.ProjectID != "" { creds["project_id"] = tokenInfo.ProjectID } + if tokenInfo.PlanType != "" { + creds["plan_type"] = tokenInfo.PlanType + } return creds } diff --git a/backend/internal/service/antigravity_privacy_service.go b/backend/internal/service/antigravity_privacy_service.go new file mode 100644 index 00000000..50fe07f6 --- /dev/null +++ b/backend/internal/service/antigravity_privacy_service.go @@ -0,0 +1,81 @@ +package service + +import ( + "context" + "log/slog" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" +) + +const ( + AntigravityPrivacySet = "privacy_set" + AntigravityPrivacyFailed = "privacy_set_failed" +) + +// setAntigravityPrivacy 调用 Antigravity API 设置隐私并验证结果。 +// 流程: +// 1. setUserSettings 清空设置 → 检查返回值 {"userSettings":{}} +// 2. fetchUserInfo 二次验证隐私是否已生效(需要 project_id) +// +// 返回 privacy_mode 值:"privacy_set" 成功,"privacy_set_failed" 失败,空串表示无法执行。 +func setAntigravityPrivacy(ctx context.Context, accessToken, projectID, proxyURL string) string { + if accessToken == "" { + return "" + } + + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + client, err := antigravity.NewClient(proxyURL) + if err != nil { + slog.Warn("antigravity_privacy_client_error", "error", err.Error()) + return AntigravityPrivacyFailed + } + + // 第 1 步:调用 setUserSettings,检查返回值 + setResp, err := client.SetUserSettings(ctx, accessToken) + if err != nil { + slog.Warn("antigravity_privacy_set_failed", "error", err.Error()) + return AntigravityPrivacyFailed + } + if !setResp.IsSuccess() { + slog.Warn("antigravity_privacy_set_response_not_empty", + "user_settings", setResp.UserSettings, + ) + return AntigravityPrivacyFailed + } + + // 第 2 步:调用 fetchUserInfo 二次验证隐私是否已生效 + if strings.TrimSpace(projectID) == "" { + slog.Warn("antigravity_privacy_missing_project_id") + return AntigravityPrivacyFailed + } + userInfo, err := client.FetchUserInfo(ctx, accessToken, projectID) + if err != nil { + slog.Warn("antigravity_privacy_verify_failed", "error", err.Error()) + return AntigravityPrivacyFailed + } + if !userInfo.IsPrivate() { + slog.Warn("antigravity_privacy_verify_not_private", + "user_settings", userInfo.UserSettings, + ) + return AntigravityPrivacyFailed + } + + slog.Info("antigravity_privacy_set_success") + return AntigravityPrivacySet +} + +func applyAntigravityPrivacyMode(account *Account, mode string) { + if account == nil || strings.TrimSpace(mode) == "" { + return + } + extra := make(map[string]any, len(account.Extra)+1) + for k, v := range account.Extra { + extra[k] = v + } + extra["privacy_mode"] = mode + account.Extra = extra +} diff --git a/backend/internal/service/antigravity_privacy_service_test.go b/backend/internal/service/antigravity_privacy_service_test.go new file mode 100644 index 00000000..893500a6 --- /dev/null +++ b/backend/internal/service/antigravity_privacy_service_test.go @@ -0,0 +1,67 @@ +//go:build unit + +package service + +import ( + "testing" +) + +func applyAntigravitySubscriptionResult(account *Account, result AntigravitySubscriptionResult) (map[string]any, map[string]any) { + credentials := make(map[string]any) + for k, v := range account.Credentials { + credentials[k] = v + } + credentials["plan_type"] = result.PlanType + + extra := make(map[string]any) + for k, v := range account.Extra { + extra[k] = v + } + if result.SubscriptionStatus != "" { + extra["subscription_status"] = result.SubscriptionStatus + } else { + delete(extra, "subscription_status") + } + if result.SubscriptionError != "" { + extra["subscription_error"] = result.SubscriptionError + } else { + delete(extra, "subscription_error") + } + return credentials, extra +} + +func TestApplyAntigravityPrivacyMode_SetsInMemoryExtra(t *testing.T) { + account := &Account{} + + applyAntigravityPrivacyMode(account, AntigravityPrivacySet) + + if account.Extra == nil { + t.Fatal("expected account.Extra to be initialized") + } + if got := account.Extra["privacy_mode"]; got != AntigravityPrivacySet { + t.Fatalf("expected privacy_mode %q, got %v", AntigravityPrivacySet, got) + } +} + +func TestApplyAntigravityPrivacyMode_PreservedBySubscriptionResult(t *testing.T) { + account := &Account{ + Credentials: map[string]any{ + "access_token": "token", + }, + Extra: map[string]any{ + "existing": "value", + }, + } + applyAntigravityPrivacyMode(account, AntigravityPrivacySet) + + _, extra := applyAntigravitySubscriptionResult(account, AntigravitySubscriptionResult{ + PlanType: "Pro", + }) + + if got := extra["privacy_mode"]; got != AntigravityPrivacySet { + t.Fatalf("expected subscription writeback to keep privacy_mode %q, got %v", AntigravityPrivacySet, got) + } + if got := extra["existing"]; got != "value" { + t.Fatalf("expected existing extra fields to be preserved, got %v", got) + } +} diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go index df1ce9b9..35e130dc 100644 --- a/backend/internal/service/antigravity_rate_limit_test.go +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/stretchr/testify/require" ) @@ -40,7 +41,7 @@ func (r *recordingOKUpstream) Do(req *http.Request, proxyURL string, accountID i }, nil } -func (r *recordingOKUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { +func (r *recordingOKUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { return r.Do(req, proxyURL, accountID, accountConcurrency) } @@ -61,7 +62,7 @@ func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, account }, nil } -func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { +func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { return s.Do(req, proxyURL, accountID, accountConcurrency) } diff --git a/backend/internal/service/antigravity_single_account_retry_test.go b/backend/internal/service/antigravity_single_account_retry_test.go index 8b01cc31..675e9c0c 100644 --- a/backend/internal/service/antigravity_single_account_retry_test.go +++ b/backend/internal/service/antigravity_single_account_retry_test.go @@ -260,14 +260,15 @@ func TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches(t *test // TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit // 503 + retryDelay < 7s + SingleAccountRetry → 智能重试耗尽后直接返回 503,不设限流 +// 使用 RATE_LIMIT_EXCEEDED(走 1 次智能重试),避免 MODEL_CAPACITY_EXHAUSTED 的 60 次重试导致测试超时 func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testing.T) { // 智能重试也返回 503 failRespBody := `{ "error": { "code": 503, - "status": "UNAVAILABLE", + "status": "RESOURCE_EXHAUSTED", "details": [ - {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} ] } @@ -278,8 +279,9 @@ func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testi Body: io.NopCloser(strings.NewReader(failRespBody)), } upstream := &mockSmartRetryUpstream{ - responses: []*http.Response{failResp}, - errors: []error{nil}, + responses: []*http.Response{failResp}, + errors: []error{nil}, + repeatLast: true, } repo := &stubAntigravityAccountRepo{} @@ -294,9 +296,9 @@ func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testi respBody := []byte(`{ "error": { "code": 503, - "status": "UNAVAILABLE", + "status": "RESOURCE_EXHAUSTED", "details": [ - {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} ] } @@ -569,8 +571,9 @@ func TestHandleSingleAccountRetryInPlace_WaitDurationClamped(t *testing.T) { svc := &AntigravityGatewayService{} - // 等待时间过大应被 clamp 到 antigravitySingleAccountSmartRetryMaxWait - result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 999*time.Second, "gemini-3-pro") + // waitDuration=0 会被 clamp 到 antigravitySmartRetryMinWait=1s。 + // 首次重试即成功(200),总耗时 ~1s。 + result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 0, "gemini-3-pro") require.NotNil(t, result) require.Equal(t, smartRetryActionBreakWithResp, result.action) require.NotNil(t, result.resp) diff --git a/backend/internal/service/antigravity_smart_retry_test.go b/backend/internal/service/antigravity_smart_retry_test.go index f569219f..ecaffcbc 100644 --- a/backend/internal/service/antigravity_smart_retry_test.go +++ b/backend/internal/service/antigravity_smart_retry_test.go @@ -10,6 +10,7 @@ import ( "strings" "testing" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/stretchr/testify/require" ) @@ -32,11 +33,13 @@ func (c *stubSmartRetryCache) DeleteSessionAccountID(_ context.Context, groupID // mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream type mockSmartRetryUpstream struct { - responses []*http.Response - errors []error - callIdx int - calls []string - requestBodies [][]byte + responses []*http.Response + responseBodies [][]byte // 缓存的 response body 字节(用于 repeatLast 重建) + errors []error + callIdx int + calls []string + requestBodies [][]byte + repeatLast bool // 超出范围时重复最后一个响应 } func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { @@ -50,13 +53,48 @@ func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountI m.requestBodies = append(m.requestBodies, nil) } m.callIdx++ - if idx < len(m.responses) { - return m.responses[idx], m.errors[idx] + + // 确定使用哪个索引 + respIdx := idx + if respIdx >= len(m.responses) { + if !m.repeatLast || len(m.responses) == 0 { + return nil, nil + } + respIdx = len(m.responses) - 1 } - return nil, nil + + resp := m.responses[respIdx] + respErr := m.errors[respIdx] + if resp == nil { + return nil, respErr + } + + // 首次调用时缓存 body 字节 + if respIdx >= len(m.responseBodies) { + for len(m.responseBodies) <= respIdx { + m.responseBodies = append(m.responseBodies, nil) + } + } + if m.responseBodies[respIdx] == nil && resp.Body != nil { + bodyBytes, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + m.responseBodies[respIdx] = bodyBytes + } + + // 用缓存的 body 字节重建新的 reader + var body io.ReadCloser + if m.responseBodies[respIdx] != nil { + body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx])) + } + + return &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: body, + }, respErr } -func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { +func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { return m.Do(req, proxyURL, accountID, accountConcurrency) } diff --git a/backend/internal/service/antigravity_subscription_service.go b/backend/internal/service/antigravity_subscription_service.go new file mode 100644 index 00000000..04559be8 --- /dev/null +++ b/backend/internal/service/antigravity_subscription_service.go @@ -0,0 +1,38 @@ +package service + +import ( + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" +) + +const antigravitySubscriptionAbnormal = "abnormal" + +// AntigravitySubscriptionResult 表示订阅检测后的规范化结果。 +type AntigravitySubscriptionResult struct { + PlanType string + SubscriptionStatus string + SubscriptionError string +} + +// NormalizeAntigravitySubscription 从 LoadCodeAssistResponse 提取 plan_type + 异常状态。 +// 使用 GetTier()(返回 tier ID)+ TierIDToPlanType 映射。 +func NormalizeAntigravitySubscription(resp *antigravity.LoadCodeAssistResponse) AntigravitySubscriptionResult { + if resp == nil { + return AntigravitySubscriptionResult{PlanType: "Free"} + } + if len(resp.IneligibleTiers) > 0 { + result := AntigravitySubscriptionResult{ + PlanType: "Abnormal", + SubscriptionStatus: antigravitySubscriptionAbnormal, + } + if resp.IneligibleTiers[0] != nil { + result.SubscriptionError = strings.TrimSpace(resp.IneligibleTiers[0].ReasonMessage) + } + return result + } + tierID := resp.GetTier() + return AntigravitySubscriptionResult{ + PlanType: antigravity.TierIDToPlanType(tierID), + } +} diff --git a/backend/internal/service/antigravity_token_provider.go b/backend/internal/service/antigravity_token_provider.go index 9cdc49aa..1b360d93 100644 --- a/backend/internal/service/antigravity_token_provider.go +++ b/backend/internal/service/antigravity_token_provider.go @@ -14,6 +14,10 @@ const ( antigravityTokenRefreshSkew = 3 * time.Minute antigravityTokenCacheSkew = 5 * time.Minute antigravityBackfillCooldown = 5 * time.Minute + // antigravityRequestRefreshTimeout 请求路径上 token 刷新的最大等待时间。 + // 超过此时间直接放弃刷新、标记账号临时不可调度并触发 failover, + // 让后台 TokenRefreshService 在下个周期继续重试。 + antigravityRequestRefreshTimeout = 8 * time.Second ) // AntigravityTokenCache token cache interface. @@ -28,6 +32,7 @@ type AntigravityTokenProvider struct { refreshAPI *OAuthRefreshAPI executor OAuthRefreshExecutor refreshPolicy ProviderRefreshPolicy + tempUnschedCache TempUnschedCache // 用于同步更新 Redis 临时不可调度缓存 } func NewAntigravityTokenProvider( @@ -54,6 +59,11 @@ func (p *AntigravityTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy p.refreshPolicy = policy } +// SetTempUnschedCache injects temp unschedulable cache for immediate scheduler sync. +func (p *AntigravityTokenProvider) SetTempUnschedCache(cache TempUnschedCache) { + p.tempUnschedCache = cache +} + // GetAccessToken returns a valid access_token. func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { if account == nil { @@ -88,8 +98,13 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * expiresAt := account.GetCredentialAsTime("expires_at") needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew if needsRefresh && p.refreshAPI != nil && p.executor != nil { - result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, antigravityTokenRefreshSkew) + // 请求路径使用短超时,避免代理不通时阻塞过久(后台刷新服务会继续重试) + refreshCtx, cancel := context.WithTimeout(ctx, antigravityRequestRefreshTimeout) + defer cancel() + result, err := p.refreshAPI.RefreshIfNeeded(refreshCtx, account, p.executor, antigravityTokenRefreshSkew) if err != nil { + // 标记账号临时不可调度,避免后续请求继续命中 + p.markTempUnschedulable(account, err) if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn { return "", err } @@ -123,7 +138,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * p.markBackfillAttempted(account.ID) if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" { account.Credentials["project_id"] = projectID - if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { + if updateErr := persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials); updateErr != nil { slog.Warn("antigravity_project_id_backfill_persist_failed", "account_id", account.ID, "error", updateErr, @@ -172,6 +187,45 @@ func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool { return true } +// markTempUnschedulable 在请求路径上 token 刷新失败时标记账号临时不可调度。 +// 同时写 DB 和 Redis 缓存,确保调度器立即跳过该账号。 +// 使用 background context 因为请求 context 可能已超时。 +func (p *AntigravityTokenProvider) markTempUnschedulable(account *Account, refreshErr error) { + if p.accountRepo == nil || account == nil { + return + } + now := time.Now() + until := now.Add(tokenRefreshTempUnschedDuration) + reason := "token refresh failed on request path: " + refreshErr.Error() + bgCtx := context.Background() + if err := p.accountRepo.SetTempUnschedulable(bgCtx, account.ID, until, reason); err != nil { + slog.Warn("antigravity_token_provider.set_temp_unschedulable_failed", + "account_id", account.ID, + "error", err, + ) + return + } + slog.Warn("antigravity_token_provider.temp_unschedulable_set", + "account_id", account.ID, + "until", until.Format(time.RFC3339), + "reason", reason, + ) + // 同步写 Redis 缓存,调度器立即生效 + if p.tempUnschedCache != nil { + state := &TempUnschedState{ + UntilUnix: until.Unix(), + TriggeredAtUnix: now.Unix(), + ErrorMessage: reason, + } + if err := p.tempUnschedCache.SetTempUnsched(bgCtx, account.ID, state); err != nil { + slog.Warn("antigravity_token_provider.temp_unsched_cache_set_failed", + "account_id", account.ID, + "error", err, + ) + } + } +} + func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) { p.backfillCooldown.Store(accountID, time.Now()) } diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index 18e9ff7a..48e0ab2f 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -63,6 +63,8 @@ type APIKeyRepository interface { ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) + // UpdateGroupIDByUserAndGroup 将用户下绑定 oldGroupID 的所有 Key 迁移到 newGroupID + UpdateGroupIDByUserAndGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (int64, error) CountByGroupID(ctx context.Context, groupID int64) (int64, error) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 97b8e229..357f8def 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -80,6 +80,9 @@ func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { panic("unexpected ClearGroupIDByGroupID call") } +func (s *authRepoStub) UpdateGroupIDByUserAndGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (int64, error) { + panic("unexpected UpdateGroupIDByUserAndGroup call") +} func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { panic("unexpected CountByGroupID call") diff --git a/backend/internal/service/api_key_service_delete_test.go b/backend/internal/service/api_key_service_delete_test.go index dfd481e8..392d52b9 100644 --- a/backend/internal/service/api_key_service_delete_test.go +++ b/backend/internal/service/api_key_service_delete_test.go @@ -108,6 +108,9 @@ func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keywor func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { panic("unexpected ClearGroupIDByGroupID call") } +func (s *apiKeyRepoStub) UpdateGroupIDByUserAndGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (int64, error) { + panic("unexpected UpdateGroupIDByUserAndGroup call") +} func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { panic("unexpected CountByGroupID call") diff --git a/backend/internal/service/api_key_service_quota_test.go b/backend/internal/service/api_key_service_quota_test.go index 2e2f6f78..cf05e16c 100644 --- a/backend/internal/service/api_key_service_quota_test.go +++ b/backend/internal/service/api_key_service_quota_test.go @@ -122,6 +122,9 @@ func (s *quotaBaseAPIKeyRepoStub) SearchAPIKeys(context.Context, int64, string, func (s *quotaBaseAPIKeyRepoStub) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { panic("unexpected ClearGroupIDByGroupID call") } +func (s *quotaBaseAPIKeyRepoStub) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) { + panic("unexpected UpdateGroupIDByUserAndGroup call") +} func (s *quotaBaseAPIKeyRepoStub) CountByGroupID(context.Context, int64) (int64, error) { panic("unexpected CountByGroupID call") } diff --git a/backend/internal/service/backup_service.go b/backend/internal/service/backup_service.go index 25f1e9a1..2fcf2da8 100644 --- a/backend/internal/service/backup_service.go +++ b/backend/internal/service/backup_service.go @@ -4,11 +4,13 @@ import ( "compress/gzip" "context" "encoding/json" + "errors" "fmt" "io" "sort" "strings" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -84,17 +86,21 @@ type BackupScheduleConfig struct { // BackupRecord 备份记录 type BackupRecord struct { - ID string `json:"id"` - Status string `json:"status"` // pending, running, completed, failed - BackupType string `json:"backup_type"` // postgres - FileName string `json:"file_name"` - S3Key string `json:"s3_key"` - SizeBytes int64 `json:"size_bytes"` - TriggeredBy string `json:"triggered_by"` // manual, scheduled - ErrorMsg string `json:"error_message,omitempty"` - StartedAt string `json:"started_at"` - FinishedAt string `json:"finished_at,omitempty"` - ExpiresAt string `json:"expires_at,omitempty"` // 过期时间 + ID string `json:"id"` + Status string `json:"status"` // pending, running, completed, failed + BackupType string `json:"backup_type"` // postgres + FileName string `json:"file_name"` + S3Key string `json:"s3_key"` + SizeBytes int64 `json:"size_bytes"` + TriggeredBy string `json:"triggered_by"` // manual, scheduled + ErrorMsg string `json:"error_message,omitempty"` + StartedAt string `json:"started_at"` + FinishedAt string `json:"finished_at,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` // 过期时间 + Progress string `json:"progress,omitempty"` // "dumping", "uploading", "" + RestoreStatus string `json:"restore_status,omitempty"` // "", "running", "completed", "failed" + RestoreError string `json:"restore_error,omitempty"` + RestoredAt string `json:"restored_at,omitempty"` } // BackupService 数据库备份恢复服务 @@ -105,17 +111,24 @@ type BackupService struct { storeFactory BackupObjectStoreFactory dumper DBDumper - mu sync.Mutex - store BackupObjectStore - s3Cfg *BackupS3Config + opMu sync.Mutex // 保护 backingUp/restoring 标志 backingUp bool restoring bool + storeMu sync.Mutex // 保护 store/s3Cfg 缓存 + store BackupObjectStore + s3Cfg *BackupS3Config + recordsMu sync.Mutex // 保护 records 的 load/save 操作 cronMu sync.Mutex cronSched *cron.Cron cronEntryID cron.EntryID + + wg sync.WaitGroup // 追踪活跃的备份/恢复 goroutine + shuttingDown atomic.Bool // 阻止新备份启动 + bgCtx context.Context // 所有后台操作的 parent context + bgCancel context.CancelFunc // 取消所有活跃后台操作 } func NewBackupService( @@ -125,20 +138,26 @@ func NewBackupService( storeFactory BackupObjectStoreFactory, dumper DBDumper, ) *BackupService { + bgCtx, bgCancel := context.WithCancel(context.Background()) return &BackupService{ settingRepo: settingRepo, dbCfg: &cfg.Database, encryptor: encryptor, storeFactory: storeFactory, dumper: dumper, + bgCtx: bgCtx, + bgCancel: bgCancel, } } -// Start 启动定时备份调度器 +// Start 启动定时备份调度器并清理孤立记录 func (s *BackupService) Start() { s.cronSched = cron.New() s.cronSched.Start() + // 清理重启后孤立的 running 记录 + s.recoverStaleRecords() + // 加载已有的定时配置 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -154,13 +173,65 @@ func (s *BackupService) Start() { } } -// Stop 停止定时备份 +// recoverStaleRecords 启动时将孤立的 running 记录标记为 failed +func (s *BackupService) recoverStaleRecords() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + records, err := s.loadRecords(ctx) + if err != nil { + return + } + for i := range records { + if records[i].Status == "running" { + records[i].Status = "failed" + records[i].ErrorMsg = "interrupted by server restart" + records[i].Progress = "" + records[i].FinishedAt = time.Now().Format(time.RFC3339) + _ = s.saveRecord(ctx, &records[i]) + logger.LegacyPrintf("service.backup", "[Backup] recovered stale running record: %s", records[i].ID) + } + if records[i].RestoreStatus == "running" { + records[i].RestoreStatus = "failed" + records[i].RestoreError = "interrupted by server restart" + _ = s.saveRecord(ctx, &records[i]) + logger.LegacyPrintf("service.backup", "[Backup] recovered stale restoring record: %s", records[i].ID) + } + } +} + +// Stop 停止定时备份并等待活跃操作完成 func (s *BackupService) Stop() { + s.shuttingDown.Store(true) + s.cronMu.Lock() - defer s.cronMu.Unlock() if s.cronSched != nil { s.cronSched.Stop() } + s.cronMu.Unlock() + + // 等待活跃备份/恢复完成(最多 5 分钟) + done := make(chan struct{}) + go func() { + s.wg.Wait() + close(done) + }() + select { + case <-done: + logger.LegacyPrintf("service.backup", "[Backup] all active operations finished") + case <-time.After(5 * time.Minute): + logger.LegacyPrintf("service.backup", "[Backup] shutdown timeout after 5min, cancelling active operations") + if s.bgCancel != nil { + s.bgCancel() // 取消所有后台操作 + } + // 给 goroutine 时间响应取消并完成清理 + select { + case <-done: + logger.LegacyPrintf("service.backup", "[Backup] active operations cancelled and cleaned up") + case <-time.After(10 * time.Second): + logger.LegacyPrintf("service.backup", "[Backup] goroutine cleanup timed out") + } + } } // ─── S3 配置管理 ─── @@ -203,10 +274,10 @@ func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) } // 清除缓存的 S3 客户端 - s.mu.Lock() + s.storeMu.Lock() s.store = nil s.s3Cfg = nil - s.mu.Unlock() + s.storeMu.Unlock() cfg.SecretAccessKey = "" return &cfg, nil @@ -314,7 +385,10 @@ func (s *BackupService) removeCronSchedule() { } func (s *BackupService) runScheduledBackup() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + s.wg.Add(1) + defer s.wg.Done() + + ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute) defer cancel() // 读取定时备份配置中的过期天数 @@ -327,7 +401,11 @@ func (s *BackupService) runScheduledBackup() { logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays) record, err := s.CreateBackup(ctx, "scheduled", expireDays) if err != nil { - logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err) + if errors.Is(err, ErrBackupInProgress) { + logger.LegacyPrintf("service.backup", "[Backup] 定时备份跳过: 已有备份正在进行中") + } else { + logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err) + } return } logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes) @@ -346,17 +424,21 @@ func (s *BackupService) runScheduledBackup() { // CreateBackup 创建全量数据库备份并上传到 S3(流式处理) // expireDays: 备份过期天数,0=永不过期,默认14天 func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) { - s.mu.Lock() + if s.shuttingDown.Load() { + return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down") + } + + s.opMu.Lock() if s.backingUp { - s.mu.Unlock() + s.opMu.Unlock() return nil, ErrBackupInProgress } s.backingUp = true - s.mu.Unlock() + s.opMu.Unlock() defer func() { - s.mu.Lock() + s.opMu.Lock() s.backingUp = false - s.mu.Unlock() + s.opMu.Unlock() }() s3Cfg, err := s.loadS3Config(ctx) @@ -405,36 +487,47 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex // 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传 pr, pw := io.Pipe() - var gzipErr error + gzipDone := make(chan error, 1) go func() { + defer func() { + if r := recover(); r != nil { + pw.CloseWithError(fmt.Errorf("gzip goroutine panic: %v", r)) //nolint:errcheck + gzipDone <- fmt.Errorf("gzip goroutine panic: %v", r) + } + }() gzWriter := gzip.NewWriter(pw) - _, gzipErr = io.Copy(gzWriter, dumpReader) - if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil { - gzipErr = closeErr + var gzErr error + _, gzErr = io.Copy(gzWriter, dumpReader) + if closeErr := gzWriter.Close(); closeErr != nil && gzErr == nil { + gzErr = closeErr } - if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil { - gzipErr = closeErr + if closeErr := dumpReader.Close(); closeErr != nil && gzErr == nil { + gzErr = closeErr } - if gzipErr != nil { - _ = pw.CloseWithError(gzipErr) + if gzErr != nil { + _ = pw.CloseWithError(gzErr) } else { _ = pw.Close() } + gzipDone <- gzErr }() contentType := "application/gzip" sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType) if err != nil { + _ = pr.CloseWithError(err) // 确保 gzip goroutine 不会悬挂 + gzErr := <-gzipDone // 安全等待 gzip goroutine 完成 record.Status = "failed" errMsg := fmt.Sprintf("S3 upload failed: %v", err) - if gzipErr != nil { - errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr) + if gzErr != nil { + errMsg = fmt.Sprintf("gzip/dump failed: %v", gzErr) } record.ErrorMsg = errMsg record.FinishedAt = time.Now().Format(time.RFC3339) _ = s.saveRecord(ctx, record) return record, fmt.Errorf("backup upload: %w", err) } + <-gzipDone // 确保 gzip goroutine 已退出 record.SizeBytes = sizeBytes record.Status = "completed" @@ -446,19 +539,187 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex return record, nil } +// StartBackup 异步创建备份,立即返回 running 状态的记录 +func (s *BackupService) StartBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) { + if s.shuttingDown.Load() { + return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down") + } + + s.opMu.Lock() + if s.backingUp { + s.opMu.Unlock() + return nil, ErrBackupInProgress + } + s.backingUp = true + s.opMu.Unlock() + + // 初始化阶段出错时自动重置标志 + launched := false + defer func() { + if !launched { + s.opMu.Lock() + s.backingUp = false + s.opMu.Unlock() + } + }() + + // 在返回前加载 S3 配置和创建 store,避免 goroutine 中配置被修改 + s3Cfg, err := s.loadS3Config(ctx) + if err != nil { + return nil, err + } + if s3Cfg == nil || !s3Cfg.IsConfigured() { + return nil, ErrBackupS3NotConfigured + } + + objectStore, err := s.getOrCreateStore(ctx, s3Cfg) + if err != nil { + return nil, fmt.Errorf("init object store: %w", err) + } + + now := time.Now() + backupID := uuid.New().String()[:8] + fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405")) + s3Key := s.buildS3Key(s3Cfg, fileName) + + var expiresAt string + if expireDays > 0 { + expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339) + } + + record := &BackupRecord{ + ID: backupID, + Status: "running", + BackupType: "postgres", + FileName: fileName, + S3Key: s3Key, + TriggeredBy: triggeredBy, + StartedAt: now.Format(time.RFC3339), + ExpiresAt: expiresAt, + Progress: "pending", + } + + if err := s.saveRecord(ctx, record); err != nil { + return nil, fmt.Errorf("save initial record: %w", err) + } + + launched = true + // 在启动 goroutine 前完成拷贝,避免数据竞争 + result := *record + + s.wg.Add(1) + go func() { + defer s.wg.Done() + defer func() { + s.opMu.Lock() + s.backingUp = false + s.opMu.Unlock() + }() + defer func() { + if r := recover(); r != nil { + logger.LegacyPrintf("service.backup", "[Backup] panic recovered: %v", r) + record.Status = "failed" + record.ErrorMsg = fmt.Sprintf("internal panic: %v", r) + record.Progress = "" + record.FinishedAt = time.Now().Format(time.RFC3339) + _ = s.saveRecord(context.Background(), record) + } + }() + s.executeBackup(record, objectStore) + }() + + return &result, nil +} + +// executeBackup 后台执行备份(独立于 HTTP context) +func (s *BackupService) executeBackup(record *BackupRecord, objectStore BackupObjectStore) { + ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute) + defer cancel() + + // 阶段1: pg_dump + record.Progress = "dumping" + _ = s.saveRecord(ctx, record) + + dumpReader, err := s.dumper.Dump(ctx) + if err != nil { + record.Status = "failed" + record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err) + record.Progress = "" + record.FinishedAt = time.Now().Format(time.RFC3339) + _ = s.saveRecord(context.Background(), record) + return + } + + // 阶段2: gzip + upload + record.Progress = "uploading" + _ = s.saveRecord(ctx, record) + + pr, pw := io.Pipe() + gzipDone := make(chan error, 1) + go func() { + defer func() { + if r := recover(); r != nil { + pw.CloseWithError(fmt.Errorf("gzip goroutine panic: %v", r)) //nolint:errcheck + gzipDone <- fmt.Errorf("gzip goroutine panic: %v", r) + } + }() + gzWriter := gzip.NewWriter(pw) + var gzErr error + _, gzErr = io.Copy(gzWriter, dumpReader) + if closeErr := gzWriter.Close(); closeErr != nil && gzErr == nil { + gzErr = closeErr + } + if closeErr := dumpReader.Close(); closeErr != nil && gzErr == nil { + gzErr = closeErr + } + if gzErr != nil { + _ = pw.CloseWithError(gzErr) + } else { + _ = pw.Close() + } + gzipDone <- gzErr + }() + + contentType := "application/gzip" + sizeBytes, err := objectStore.Upload(ctx, record.S3Key, pr, contentType) + if err != nil { + _ = pr.CloseWithError(err) // 确保 gzip goroutine 不会悬挂 + gzErr := <-gzipDone // 安全等待 gzip goroutine 完成 + record.Status = "failed" + errMsg := fmt.Sprintf("S3 upload failed: %v", err) + if gzErr != nil { + errMsg = fmt.Sprintf("gzip/dump failed: %v", gzErr) + } + record.ErrorMsg = errMsg + record.Progress = "" + record.FinishedAt = time.Now().Format(time.RFC3339) + _ = s.saveRecord(context.Background(), record) + return + } + <-gzipDone // 确保 gzip goroutine 已退出 + + record.SizeBytes = sizeBytes + record.Status = "completed" + record.Progress = "" + record.FinishedAt = time.Now().Format(time.RFC3339) + if err := s.saveRecord(context.Background(), record); err != nil { + logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err) + } +} + // RestoreBackup 从 S3 下载备份并流式恢复到数据库 func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error { - s.mu.Lock() + s.opMu.Lock() if s.restoring { - s.mu.Unlock() + s.opMu.Unlock() return ErrRestoreInProgress } s.restoring = true - s.mu.Unlock() + s.opMu.Unlock() defer func() { - s.mu.Lock() + s.opMu.Lock() s.restoring = false - s.mu.Unlock() + s.opMu.Unlock() }() record, err := s.GetBackupRecord(ctx, backupID) @@ -500,6 +761,112 @@ func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) erro return nil } +// StartRestore 异步恢复备份,立即返回 +func (s *BackupService) StartRestore(ctx context.Context, backupID string) (*BackupRecord, error) { + if s.shuttingDown.Load() { + return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down") + } + + s.opMu.Lock() + if s.restoring { + s.opMu.Unlock() + return nil, ErrRestoreInProgress + } + s.restoring = true + s.opMu.Unlock() + + // 初始化阶段出错时自动重置标志 + launched := false + defer func() { + if !launched { + s.opMu.Lock() + s.restoring = false + s.opMu.Unlock() + } + }() + + record, err := s.GetBackupRecord(ctx, backupID) + if err != nil { + return nil, err + } + if record.Status != "completed" { + return nil, infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup") + } + + s3Cfg, err := s.loadS3Config(ctx) + if err != nil { + return nil, err + } + objectStore, err := s.getOrCreateStore(ctx, s3Cfg) + if err != nil { + return nil, fmt.Errorf("init object store: %w", err) + } + + record.RestoreStatus = "running" + _ = s.saveRecord(ctx, record) + + launched = true + result := *record + + s.wg.Add(1) + go func() { + defer s.wg.Done() + defer func() { + s.opMu.Lock() + s.restoring = false + s.opMu.Unlock() + }() + defer func() { + if r := recover(); r != nil { + logger.LegacyPrintf("service.backup", "[Backup] restore panic recovered: %v", r) + record.RestoreStatus = "failed" + record.RestoreError = fmt.Sprintf("internal panic: %v", r) + _ = s.saveRecord(context.Background(), record) + } + }() + s.executeRestore(record, objectStore) + }() + + return &result, nil +} + +// executeRestore 后台执行恢复 +func (s *BackupService) executeRestore(record *BackupRecord, objectStore BackupObjectStore) { + ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute) + defer cancel() + + body, err := objectStore.Download(ctx, record.S3Key) + if err != nil { + record.RestoreStatus = "failed" + record.RestoreError = fmt.Sprintf("S3 download failed: %v", err) + _ = s.saveRecord(context.Background(), record) + return + } + defer func() { _ = body.Close() }() + + gzReader, err := gzip.NewReader(body) + if err != nil { + record.RestoreStatus = "failed" + record.RestoreError = fmt.Sprintf("gzip reader: %v", err) + _ = s.saveRecord(context.Background(), record) + return + } + defer func() { _ = gzReader.Close() }() + + if err := s.dumper.Restore(ctx, gzReader); err != nil { + record.RestoreStatus = "failed" + record.RestoreError = fmt.Sprintf("pg restore: %v", err) + _ = s.saveRecord(context.Background(), record) + return + } + + record.RestoreStatus = "completed" + record.RestoredAt = time.Now().Format(time.RFC3339) + if err := s.saveRecord(context.Background(), record); err != nil { + logger.LegacyPrintf("service.backup", "[Backup] 保存恢复记录失败: %v", err) + } +} + // ─── 备份记录管理 ─── func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) { @@ -614,8 +981,8 @@ func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, erro } func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) { - s.mu.Lock() - defer s.mu.Unlock() + s.storeMu.Lock() + defer s.storeMu.Unlock() if s.store != nil && s.s3Cfg != nil { return s.store, nil diff --git a/backend/internal/service/backup_service_test.go b/backend/internal/service/backup_service_test.go index e752997c..b308e6d0 100644 --- a/backend/internal/service/backup_service_test.go +++ b/backend/internal/service/backup_service_test.go @@ -134,6 +134,30 @@ func (m *mockDumper) Restore(_ context.Context, data io.Reader) error { return nil } +// blockingDumper 可控延迟的 dumper,用于测试异步行为 +type blockingDumper struct { + blockCh chan struct{} + data []byte + restErr error +} + +func (d *blockingDumper) Dump(ctx context.Context) (io.ReadCloser, error) { + select { + case <-d.blockCh: + case <-ctx.Done(): + return nil, ctx.Err() + } + return io.NopCloser(bytes.NewReader(d.data)), nil +} + +func (d *blockingDumper) Restore(_ context.Context, data io.Reader) error { + if d.restErr != nil { + return d.restErr + } + _, _ = io.ReadAll(data) + return nil +} + type mockObjectStore struct { objects map[string][]byte mu sync.Mutex @@ -179,7 +203,7 @@ func (m *mockObjectStore) HeadBucket(_ context.Context) error { return nil } -func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService { +func newTestBackupService(repo *mockSettingRepo, dumper DBDumper, store *mockObjectStore) *BackupService { cfg := &config.Config{ Database: config.DatabaseConfig{ Host: "localhost", @@ -361,9 +385,9 @@ func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) { svc := newTestBackupService(repo, dumper, store) // 手动设置 backingUp 标志 - svc.mu.Lock() + svc.opMu.Lock() svc.backingUp = true - svc.mu.Unlock() + svc.opMu.Unlock() _, err := svc.CreateBackup(context.Background(), "manual", 14) require.ErrorIs(t, err, ErrBackupInProgress) @@ -526,3 +550,154 @@ func TestBackupService_LoadS3Config_Corrupted(t *testing.T) { require.Error(t, err) require.Nil(t, cfg) } + +// ─── Async Backup Tests ─── + +func TestStartBackup_ReturnsImmediately(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + + dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")} + store := newMockObjectStore() + svc := newTestBackupService(repo, dumper, store) + + record, err := svc.StartBackup(context.Background(), "manual", 14) + require.NoError(t, err) + require.Equal(t, "running", record.Status) + require.NotEmpty(t, record.ID) + + // 释放 dumper 让后台完成 + close(dumper.blockCh) + svc.wg.Wait() + + // 验证最终状态 + final, err := svc.GetBackupRecord(context.Background(), record.ID) + require.NoError(t, err) + require.Equal(t, "completed", final.Status) + require.Greater(t, final.SizeBytes, int64(0)) +} + +func TestStartBackup_ConcurrentBlocked(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + + dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")} + store := newMockObjectStore() + svc := newTestBackupService(repo, dumper, store) + + // 第一次启动 + _, err := svc.StartBackup(context.Background(), "manual", 14) + require.NoError(t, err) + + // 第二次应被阻塞 + _, err = svc.StartBackup(context.Background(), "manual", 14) + require.ErrorIs(t, err, ErrBackupInProgress) + + close(dumper.blockCh) + svc.wg.Wait() +} + +func TestStartBackup_ShuttingDown(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + svc := newTestBackupService(repo, &mockDumper{dumpData: []byte("data")}, newMockObjectStore()) + + svc.shuttingDown.Store(true) + + _, err := svc.StartBackup(context.Background(), "manual", 14) + require.Error(t, err) + require.Contains(t, err.Error(), "shutting down") +} + +func TestRecoverStaleRecords(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + // 模拟一条孤立的 running 记录 + _ = svc.saveRecord(context.Background(), &BackupRecord{ + ID: "stale-1", + Status: "running", + StartedAt: time.Now().Add(-1 * time.Hour).Format(time.RFC3339), + }) + // 模拟一条孤立的恢复中记录 + _ = svc.saveRecord(context.Background(), &BackupRecord{ + ID: "stale-2", + Status: "completed", + RestoreStatus: "running", + StartedAt: time.Now().Add(-1 * time.Hour).Format(time.RFC3339), + }) + + svc.recoverStaleRecords() + + r1, _ := svc.GetBackupRecord(context.Background(), "stale-1") + require.Equal(t, "failed", r1.Status) + require.Contains(t, r1.ErrorMsg, "server restart") + + r2, _ := svc.GetBackupRecord(context.Background(), "stale-2") + require.Equal(t, "failed", r2.RestoreStatus) + require.Contains(t, r2.RestoreError, "server restart") +} + +func TestGracefulShutdown(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + + dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")} + store := newMockObjectStore() + svc := newTestBackupService(repo, dumper, store) + + _, err := svc.StartBackup(context.Background(), "manual", 14) + require.NoError(t, err) + + // Stop 应该等待备份完成 + done := make(chan struct{}) + go func() { + svc.Stop() + close(done) + }() + + // 短暂等待确认 Stop 还在等待 + select { + case <-done: + t.Fatal("Stop returned before backup finished") + case <-time.After(100 * time.Millisecond): + // 预期:Stop 还在等待 + } + + // 释放备份 + close(dumper.blockCh) + + // 现在 Stop 应该完成 + select { + case <-done: + // 预期 + case <-time.After(5 * time.Second): + t.Fatal("Stop did not return after backup finished") + } +} + +func TestStartRestore_Async(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + + dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n" + dumper := &mockDumper{dumpData: []byte(dumpContent)} + store := newMockObjectStore() + svc := newTestBackupService(repo, dumper, store) + + // 先创建一个备份(同步方式) + record, err := svc.CreateBackup(context.Background(), "manual", 14) + require.NoError(t, err) + + // 异步恢复 + restored, err := svc.StartRestore(context.Background(), record.ID) + require.NoError(t, err) + require.Equal(t, "running", restored.RestoreStatus) + + svc.wg.Wait() + + // 验证最终状态 + final, err := svc.GetBackupRecord(context.Background(), record.ID) + require.NoError(t, err) + require.Equal(t, "completed", final.RestoreStatus) +} diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 68d7a8f9..004511f5 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -221,6 +221,18 @@ func (s *BillingService) initFallbackPricing() { LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier, LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier, } + s.fallbackPrices["gpt-5.4-mini"] = &ModelPricing{ + InputPricePerToken: 7.5e-7, + OutputPricePerToken: 4.5e-6, + CacheReadPricePerToken: 7.5e-8, + SupportsCacheBreakdown: false, + } + s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{ + InputPricePerToken: 2e-7, + OutputPricePerToken: 1.25e-6, + CacheReadPricePerToken: 2e-8, + SupportsCacheBreakdown: false, + } // OpenAI GPT-5.2(本地兜底) s.fallbackPrices["gpt-5.2"] = &ModelPricing{ InputPricePerToken: 1.75e-6, @@ -294,6 +306,10 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") { normalized := normalizeCodexModel(modelLower) switch normalized { + case "gpt-5.4-mini": + return s.fallbackPrices["gpt-5.4-mini"] + case "gpt-5.4-nano": + return s.fallbackPrices["gpt-5.4-nano"] case "gpt-5.4": return s.fallbackPrices["gpt-5.4"] case "gpt-5.2": diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 45bbdcee..10943422 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -174,6 +174,30 @@ func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) { require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12) } +func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) { + svc := newTestBillingService() + + pricing, err := svc.GetModelPricing("gpt-5.4-mini") + require.NoError(t, err) + require.NotNil(t, pricing) + require.InDelta(t, 7.5e-7, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 4.5e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 7.5e-8, pricing.CacheReadPricePerToken, 1e-12) + require.Zero(t, pricing.LongContextInputThreshold) +} + +func TestGetModelPricing_OpenAIGPT54NanoFallback(t *testing.T) { + svc := newTestBillingService() + + pricing, err := svc.GetModelPricing("gpt-5.4-nano") + require.NoError(t, err) + require.NotNil(t, pricing) + require.InDelta(t, 2e-7, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 1.25e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 2e-8, pricing.CacheReadPricePerToken, 1e-12) + require.Zero(t, pricing.LongContextInputThreshold) +} + func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) { svc := newTestBillingService() @@ -210,6 +234,8 @@ func TestGetFallbackPricing_FamilyMatching(t *testing.T) { {name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true}, {name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6}, {name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6}, + {name: "openai gpt5.4 mini", model: "gpt-5.4-mini", expectedInput: 7.5e-7}, + {name: "openai gpt5.4 nano", model: "gpt-5.4-nano", expectedInput: 2e-7}, {name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6}, {name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6}, {name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6}, @@ -564,6 +590,40 @@ func TestCalculateCostWithServiceTier_FlexAppliesHalfMultiplier(t *testing.T) { require.InDelta(t, baseCost.TotalCost*0.5, flexCost.TotalCost, 1e-10) } +func TestCalculateCostWithServiceTier_Gpt54MiniPriorityFallsBackToTierMultiplier(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 120, OutputTokens: 30, CacheCreationTokens: 12, CacheReadTokens: 8} + + baseCost, err := svc.CalculateCost("gpt-5.4-mini", tokens, 1.0) + require.NoError(t, err) + + priorityCost, err := svc.CalculateCostWithServiceTier("gpt-5.4-mini", tokens, 1.0, "priority") + require.NoError(t, err) + + require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10) + require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10) + require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10) + require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10) + require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10) +} + +func TestCalculateCostWithServiceTier_Gpt54NanoFlexAppliesHalfMultiplier(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20} + + baseCost, err := svc.CalculateCost("gpt-5.4-nano", tokens, 1.0) + require.NoError(t, err) + + flexCost, err := svc.CalculateCostWithServiceTier("gpt-5.4-nano", tokens, 1.0, "flex") + require.NoError(t, err) + + require.InDelta(t, baseCost.InputCost*0.5, flexCost.InputCost, 1e-10) + require.InDelta(t, baseCost.OutputCost*0.5, flexCost.OutputCost, 1e-10) + require.InDelta(t, baseCost.CacheCreationCost*0.5, flexCost.CacheCreationCost, 1e-10) + require.InDelta(t, baseCost.CacheReadCost*0.5, flexCost.CacheReadCost, 1e-10) + require.InDelta(t, baseCost.TotalCost*0.5, flexCost.TotalCost, 1e-10) +} + func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWithoutExplicitPriorityPrice(t *testing.T) { svc := newTestBillingService() tokens := UsageTokens{InputTokens: 120, OutputTokens: 30, CacheCreationTokens: 12, CacheReadTokens: 8} diff --git a/backend/internal/service/claude_code_validator.go b/backend/internal/service/claude_code_validator.go index f71098b1..4e8ced67 100644 --- a/backend/internal/service/claude_code_validator.go +++ b/backend/internal/service/claude_code_validator.go @@ -21,9 +21,6 @@ var ( // 带捕获组的版本提取正则 claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`) - // metadata.user_id 格式: user_{64位hex}_account__session_{uuid} - userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`) - // System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致) systemPromptThreshold = 0.5 ) @@ -124,7 +121,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo return false } - if !userIDPattern.MatchString(userID) { + if ParseMetadataUserID(userID) == nil { return false } @@ -278,11 +275,7 @@ func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context // ExtractVersion 从 User-Agent 中提取 Claude Code 版本号 // 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串 func (v *ClaudeCodeValidator) ExtractVersion(ua string) string { - matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua) - if len(matches) >= 2 { - return matches[1] - } - return "" + return ExtractCLIVersion(ua) } // SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中 diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index 6a916740..b69b0639 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -367,8 +367,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput // 🔄 Refresh OAuth token after creation if targetType == AccountTypeOAuth { if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { - account.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, account) + _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds) } } item.Action = "created" @@ -402,8 +401,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput // 🔄 Refresh OAuth token after update if targetType == AccountTypeOAuth { if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { - existing.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, existing) + _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds) } } @@ -620,8 +618,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } // 🔄 Refresh OAuth token after creation if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { - account.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, account) + _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds) } item.Action = "created" result.Created++ @@ -652,8 +649,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput // 🔄 Refresh OAuth token after update if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { - existing.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, existing) + _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds) } item.Action = "updated" @@ -862,8 +858,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput continue } if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { - account.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, account) + _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds) } item.Action = "created" result.Created++ @@ -893,8 +888,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { - existing.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, existing) + _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds) } item.Action = "updated" diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index 63cad243..3e059e30 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -140,6 +140,27 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi return stats, nil } +func (s *DashboardService) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, modelSource string) ([]usagestats.ModelStat, error) { + normalizedSource := usagestats.NormalizeModelSource(modelSource) + if normalizedSource == usagestats.ModelSourceRequested { + return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + } + + type modelStatsBySourceRepo interface { + GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) ([]usagestats.ModelStat, error) + } + + if sourceRepo, ok := s.usageRepo.(modelStatsBySourceRepo); ok { + stats, err := sourceRepo.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, normalizedSource) + if err != nil { + return nil, fmt.Errorf("get model stats with filters by source: %w", err) + } + return stats, nil + } + + return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) +} + func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { @@ -148,6 +169,15 @@ func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTi return stats, nil } +// GetGroupUsageSummary returns today's and cumulative cost for all groups. +func (s *DashboardService) GetGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) { + results, err := s.usageRepo.GetAllGroupUsageSummary(ctx, todayStart) + if err != nil { + return nil, fmt.Errorf("get group usage summary: %w", err) + } + return results, nil +} + func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) { data, err := s.cache.GetDashboardStats(ctx) if err != nil { @@ -335,6 +365,14 @@ func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime return ranking, nil } +func (s *DashboardService) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) { + stats, err := s.usageRepo.GetUserBreakdownStats(ctx, startTime, endTime, dim, limit) + if err != nil { + return nil, fmt.Errorf("get user breakdown stats: %w", err) + } + return stats, nil +} + func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime) if err != nil { diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 2d8681d4..ecac0db0 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -119,6 +119,7 @@ const ( SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口 SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL(作为 iframe src) SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项(JSON 数组) + SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组) // 默认配置 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 @@ -170,6 +171,13 @@ const ( // SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings. SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config" + // ========================= + // Overload Cooldown (529) + // ========================= + + // SettingKeyOverloadCooldownSettings stores JSON config for 529 overload cooldown handling. + SettingKeyOverloadCooldownSettings = "overload_cooldown_settings" + // ========================= // Stream Timeout Handling // ========================= @@ -219,11 +227,20 @@ const ( // SettingKeyMinClaudeCodeVersion 最低 Claude Code 版本号要求 (semver, 如 "2.1.0",空值=不检查) SettingKeyMinClaudeCodeVersion = "min_claude_code_version" + // SettingKeyMaxClaudeCodeVersion 最高 Claude Code 版本号限制 (semver, 如 "3.0.0",空值=不检查) + SettingKeyMaxClaudeCodeVersion = "max_claude_code_version" + // SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false:未分组 Key 返回 403) SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling" // SettingKeyBackendModeEnabled Backend 模式:禁用用户注册和自助服务,仅管理员可登录 SettingKeyBackendModeEnabled = "backend_mode_enabled" + + // Gateway Forwarding Behavior + // SettingKeyEnableFingerprintUnification 是否统一 OAuth 账号的 X-Stainless-* 指纹头(默认 true) + SettingKeyEnableFingerprintUnification = "enable_fingerprint_unification" + // SettingKeyEnableMetadataPassthrough 是否透传客户端原始 metadata.user_id(默认 false) + SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough" ) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 44edf7f7..00691233 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -12,6 +12,7 @@ import ( "net/smtp" "net/url" "strconv" + "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" @@ -111,7 +112,7 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) { return nil, fmt.Errorf("get smtp settings: %w", err) } - host := settings[SettingKeySMTPHost] + host := strings.TrimSpace(settings[SettingKeySMTPHost]) if host == "" { return nil, ErrEmailNotConfigured } @@ -128,10 +129,10 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) { return &SMTPConfig{ Host: host, Port: port, - Username: settings[SettingKeySMTPUsername], - Password: settings[SettingKeySMTPPassword], - From: settings[SettingKeySMTPFrom], - FromName: settings[SettingKeySMTPFromName], + Username: strings.TrimSpace(settings[SettingKeySMTPUsername]), + Password: strings.TrimSpace(settings[SettingKeySMTPPassword]), + From: strings.TrimSpace(settings[SettingKeySMTPFrom]), + FromName: strings.TrimSpace(settings[SettingKeySMTPFromName]), UseTLS: useTLS, }, nil } diff --git a/backend/internal/service/error_policy_integration_test.go b/backend/internal/service/error_policy_integration_test.go index a8b42a2c..aa3e6ec4 100644 --- a/backend/internal/service/error_policy_integration_test.go +++ b/backend/internal/service/error_policy_integration_test.go @@ -12,6 +12,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/stretchr/testify/require" ) @@ -35,7 +36,7 @@ func (u *epFixedUpstream) Do(req *http.Request, proxyURL string, accountID int64 }, nil } -func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { +func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { return u.Do(req, proxyURL, accountID, accountConcurrency) } diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index 789cbab8..6e19db32 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -15,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" @@ -60,7 +61,7 @@ func (u *anthropicHTTPUpstreamRecorder) Do(req *http.Request, proxyURL string, a return u.resp, nil } -func (u *anthropicHTTPUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { +func (u *anthropicHTTPUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { return u.Do(req, proxyURL, accountID, accountConcurrency) } @@ -175,13 +176,13 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(upstream.lastBody, "model").String(), "透传模式应应用账号级模型映射") - require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) - require.Empty(t, upstream.lastReq.Header.Get("authorization")) - require.Empty(t, upstream.lastReq.Header.Get("x-goog-api-key")) - require.Empty(t, upstream.lastReq.Header.Get("cookie")) - require.Equal(t, "2023-06-01", upstream.lastReq.Header.Get("anthropic-version")) - require.Equal(t, "interleaved-thinking-2025-05-14", upstream.lastReq.Header.Get("anthropic-beta")) - require.Empty(t, upstream.lastReq.Header.Get("x-stainless-lang"), "API Key 透传不应注入 OAuth 指纹头") + require.Equal(t, "upstream-anthropic-key", getHeaderRaw(upstream.lastReq.Header, "x-api-key")) + require.Empty(t, getHeaderRaw(upstream.lastReq.Header, "authorization")) + require.Empty(t, getHeaderRaw(upstream.lastReq.Header, "x-goog-api-key")) + require.Empty(t, getHeaderRaw(upstream.lastReq.Header, "cookie")) + require.Equal(t, "2023-06-01", getHeaderRaw(upstream.lastReq.Header, "anthropic-version")) + require.Equal(t, "interleaved-thinking-2025-05-14", getHeaderRaw(upstream.lastReq.Header, "anthropic-beta")) + require.Empty(t, getHeaderRaw(upstream.lastReq.Header, "x-stainless-lang"), "API Key 透传不应注入 OAuth 指纹头") require.Contains(t, rec.Body.String(), `"cached_tokens":7`) require.NotContains(t, rec.Body.String(), `"cache_read_input_tokens":7`, "透传输出不应被网关改写") @@ -257,9 +258,9 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo require.NoError(t, err) require.Equal(t, "claude-3-opus-20240229", gjson.GetBytes(upstream.lastBody, "model").String(), "count_tokens 透传模式应应用账号级模型映射") - require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) - require.Empty(t, upstream.lastReq.Header.Get("authorization")) - require.Empty(t, upstream.lastReq.Header.Get("cookie")) + require.Equal(t, "upstream-anthropic-key", getHeaderRaw(upstream.lastReq.Header, "x-api-key")) + require.Empty(t, getHeaderRaw(upstream.lastReq.Header, "authorization")) + require.Empty(t, getHeaderRaw(upstream.lastReq.Header, "cookie")) require.Equal(t, http.StatusOK, rec.Code) require.JSONEq(t, upstreamRespBody, rec.Body.String()) require.Empty(t, rec.Header().Get("Set-Cookie")) @@ -684,8 +685,85 @@ func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *t req, err := svc.buildUpstreamRequest(context.Background(), c, account, []byte(`{"model":"claude-3-7-sonnet-20250219"}`), "oauth-token", "oauth", "claude-3-7-sonnet-20250219", true, false) require.NoError(t, err) - require.Equal(t, "Bearer oauth-token", req.Header.Get("authorization")) - require.Contains(t, req.Header.Get("anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta") + require.Equal(t, "Bearer oauth-token", getHeaderRaw(req.Header, "authorization")) + require.Contains(t, getHeaderRaw(req.Header, "anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta") +} + +func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + body string + }{ + { + name: "system array", + body: `{"model":"claude-3-5-sonnet-latest","system":[{"type":"text","text":"x-anthropic-billing-header keep"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`, + }, + { + name: "system string", + body: `{"model":"claude-3-5-sonnet-latest","system":"x-anthropic-billing-header keep","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + parsed, err := ParseGatewayRequest([]byte(tt.body), PlatformAnthropic) + require.NoError(t, err) + + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-oauth-preserve"}, + }, + Body: io.NopCloser(strings.NewReader(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":12,"output_tokens":7}}`)), + }, + } + + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &GatewayService{ + cfg: cfg, + responseHeaderFilter: compileResponseHeaderFilter(cfg), + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + deferredService: &DeferredService{}, + } + + account := &Account{ + ID: 301, + Name: "anthropic-oauth-preserve", + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + }, + Status: StatusActive, + Schedulable: true, + } + + result, err := svc.Forward(context.Background(), c, account, parsed) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "Bearer oauth-token", getHeaderRaw(upstream.lastReq.Header, "authorization")) + require.Contains(t, getHeaderRaw(upstream.lastReq.Header, "anthropic-beta"), claude.BetaOAuth) + + system := gjson.GetBytes(upstream.lastBody, "system") + require.True(t, system.Exists()) + require.Contains(t, system.Raw, "x-anthropic-billing-header keep") + }) + } } func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAfterClientDisconnect(t *testing.T) { @@ -788,7 +866,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuc rateLimitService: &RateLimitService{}, } - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now()) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 12, result.Usage.InputTokens) @@ -815,7 +893,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenTyp } svc := &GatewayService{} - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now()) require.Nil(t, result) require.Error(t, err) require.Contains(t, err.Error(), "requires apikey token") @@ -840,7 +918,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequest } account := newAnthropicAPIKeyAccountForTest() - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", "x", false, time.Now()) require.Nil(t, result) require.Error(t, err) require.Contains(t, err.Error(), "upstream request failed") @@ -873,7 +951,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBo httpUpstream: upstream, } - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", "x", false, time.Now()) require.Nil(t, result) require.Error(t, err) require.Contains(t, err.Error(), "empty response") diff --git a/backend/internal/service/gateway_body_order_test.go b/backend/internal/service/gateway_body_order_test.go new file mode 100644 index 00000000..641522f0 --- /dev/null +++ b/backend/internal/service/gateway_body_order_test.go @@ -0,0 +1,72 @@ +package service + +import ( + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/stretchr/testify/require" +) + +func assertJSONTokenOrder(t *testing.T, body string, tokens ...string) { + t.Helper() + + last := -1 + for _, token := range tokens { + pos := strings.Index(body, token) + require.NotEqualf(t, -1, pos, "missing token %s in body %s", token, body) + require.Greaterf(t, pos, last, "token %s should appear after previous tokens in body %s", token, body) + last = pos + } +} + +func TestReplaceModelInBody_PreservesTopLevelFieldOrder(t *testing.T) { + svc := &GatewayService{} + body := []byte(`{"alpha":1,"model":"claude-3-5-sonnet-latest","messages":[],"omega":2}`) + + result := svc.replaceModelInBody(body, "claude-3-5-sonnet-20241022") + resultStr := string(result) + + assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"messages"`, `"omega"`) + require.Contains(t, resultStr, `"model":"claude-3-5-sonnet-20241022"`) +} + +func TestNormalizeClaudeOAuthRequestBody_PreservesTopLevelFieldOrder(t *testing.T) { + body := []byte(`{"alpha":1,"model":"claude-3-5-sonnet-latest","temperature":0.2,"system":"You are OpenCode, the best coding agent on the planet.","messages":[],"tool_choice":{"type":"auto"},"omega":2}`) + + result, modelID := normalizeClaudeOAuthRequestBody(body, "claude-3-5-sonnet-latest", claudeOAuthNormalizeOptions{ + injectMetadata: true, + metadataUserID: "user-1", + }) + resultStr := string(result) + + require.Equal(t, claude.NormalizeModelID("claude-3-5-sonnet-latest"), modelID) + assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`) + require.NotContains(t, resultStr, `"temperature"`) + require.NotContains(t, resultStr, `"tool_choice"`) + require.Contains(t, resultStr, `"system":"`+claudeCodeSystemPrompt+`"`) + require.Contains(t, resultStr, `"tools":[]`) + require.Contains(t, resultStr, `"metadata":{"user_id":"user-1"}`) +} + +func TestInjectClaudeCodePrompt_PreservesFieldOrder(t *testing.T) { + body := []byte(`{"alpha":1,"system":[{"id":"block-1","type":"text","text":"Custom"}],"messages":[],"omega":2}`) + + result := injectClaudeCodePrompt(body, []any{ + map[string]any{"id": "block-1", "type": "text", "text": "Custom"}, + }) + resultStr := string(result) + + assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`) + require.Contains(t, resultStr, `{"id":"block-1","type":"text","text":"`+claudeCodeSystemPrompt+`\n\nCustom"}`) +} + +func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) { + body := []byte(`{"alpha":1,"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"s2","cache_control":{"type":"ephemeral"}}],"messages":[{"role":"user","content":[{"type":"text","text":"m1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m2","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m3","cache_control":{"type":"ephemeral"}}]}],"omega":2}`) + + result := enforceCacheControlLimit(body) + resultStr := string(result) + + assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`) + require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`)) +} diff --git a/backend/internal/service/gateway_debug_env_test.go b/backend/internal/service/gateway_debug_env_test.go new file mode 100644 index 00000000..bd88a667 --- /dev/null +++ b/backend/internal/service/gateway_debug_env_test.go @@ -0,0 +1,31 @@ +package service + +import "testing" + +func TestParseDebugEnvBool(t *testing.T) { + t.Run("empty is false", func(t *testing.T) { + if parseDebugEnvBool("") { + t.Fatalf("expected false for empty string") + } + }) + + t.Run("true-like values", func(t *testing.T) { + for _, value := range []string{"1", "true", "TRUE", "yes", "on"} { + t.Run(value, func(t *testing.T) { + if !parseDebugEnvBool(value) { + t.Fatalf("expected true for %q", value) + } + }) + } + }) + + t.Run("false-like values", func(t *testing.T) { + for _, value := range []string{"0", "false", "off", "debug"} { + t.Run(value, func(t *testing.T) { + if parseDebugEnvBool(value) { + t.Fatalf("expected false for %q", value) + } + }) + } + }) +} diff --git a/backend/internal/service/gateway_forward_as_chat_completions.go b/backend/internal/service/gateway_forward_as_chat_completions.go new file mode 100644 index 00000000..37b38f76 --- /dev/null +++ b/backend/internal/service/gateway_forward_as_chat_completions.go @@ -0,0 +1,485 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// ForwardAsChatCompletions accepts an OpenAI Chat Completions API request body, +// converts it to Anthropic Messages format (chained via Responses format), +// forwards to the Anthropic upstream, and converts the response back to Chat +// Completions format. This enables Chat Completions clients to access Anthropic +// models through Anthropic platform groups. +func (s *GatewayService) ForwardAsChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + parsed *ParsedRequest, +) (*ForwardResult, error) { + startTime := time.Now() + + // 1. Parse Chat Completions request + var ccReq apicompat.ChatCompletionsRequest + if err := json.Unmarshal(body, &ccReq); err != nil { + return nil, fmt.Errorf("parse chat completions request: %w", err) + } + originalModel := ccReq.Model + clientStream := ccReq.Stream + includeUsage := ccReq.StreamOptions != nil && ccReq.StreamOptions.IncludeUsage + + // 2. Convert CC → Responses → Anthropic (chained conversion) + responsesReq, err := apicompat.ChatCompletionsToResponses(&ccReq) + if err != nil { + return nil, fmt.Errorf("convert chat completions to responses: %w", err) + } + + anthropicReq, err := apicompat.ResponsesToAnthropicRequest(responsesReq) + if err != nil { + return nil, fmt.Errorf("convert responses to anthropic: %w", err) + } + + // 3. Force upstream streaming + anthropicReq.Stream = true + reqStream := true + + // 4. Model mapping + mappedModel := originalModel + if account.Type == AccountTypeAPIKey { + mappedModel = account.GetMappedModel(originalModel) + } + if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + normalized := claude.NormalizeModelID(originalModel) + if normalized != originalModel { + mappedModel = normalized + } + } + anthropicReq.Model = mappedModel + + logger.L().Debug("gateway forward_as_chat_completions: model mapping applied", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("mapped_model", mappedModel), + zap.Bool("client_stream", clientStream), + ) + + // 5. Marshal Anthropic request body + anthropicBody, err := json.Marshal(anthropicReq) + if err != nil { + return nil, fmt.Errorf("marshal anthropic request: %w", err) + } + + // 6. Apply Claude Code mimicry for OAuth accounts + isClaudeCode := false // CC API is never Claude Code + shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode + + if shouldMimicClaudeCode { + if !strings.Contains(strings.ToLower(mappedModel), "haiku") && + !systemIncludesClaudeCodePrompt(anthropicReq.System) { + anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System) + } + } + + // 7. Enforce cache_control block limit + anthropicBody = enforceCacheControlLimit(anthropicBody) + + // 8. Get access token + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("get access token: %w", err) + } + + // 9. Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 10. Build upstream request + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode) + releaseUpstreamCtx() + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + + // 11. Send request + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + // 12. Handle error response with failover + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + if s.shouldFailoverUpstreamError(resp.StatusCode) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + } + } + + writeGatewayCCError(c, mapUpstreamStatusCode(resp.StatusCode), "server_error", upstreamMsg) + return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) + } + + // 13. Extract reasoning effort from CC request body + reasoningEffort := extractCCReasoningEffortFromBody(body) + + // 14. Handle normal response + // Read Anthropic SSE → convert to Responses events → convert to CC format + var result *ForwardResult + var handleErr error + if clientStream { + result, handleErr = s.handleCCStreamingFromAnthropic(resp, c, originalModel, mappedModel, reasoningEffort, startTime, includeUsage) + } else { + result, handleErr = s.handleCCBufferedFromAnthropic(resp, c, originalModel, mappedModel, reasoningEffort, startTime) + } + + return result, handleErr +} + +// extractCCReasoningEffortFromBody reads reasoning effort from a Chat Completions +// request body. It checks both nested (reasoning.effort) and flat (reasoning_effort) +// formats used by OpenAI-compatible clients. +func extractCCReasoningEffortFromBody(body []byte) *string { + raw := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) + if raw == "" { + raw = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String()) + } + if raw == "" { + return nil + } + normalized := normalizeOpenAIReasoningEffort(raw) + if normalized == "" { + return nil + } + return &normalized +} + +// handleCCBufferedFromAnthropic reads Anthropic SSE events, assembles the full +// response, then converts Anthropic → Responses → Chat Completions. +func (s *GatewayService) handleCCBufferedFromAnthropic( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + reasoningEffort *string, + startTime time.Time, +) (*ForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + var finalResp *apicompat.AnthropicResponse + var usage ClaudeUsage + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "event: ") { + continue + } + + if !scanner.Scan() { + break + } + dataLine := scanner.Text() + if !strings.HasPrefix(dataLine, "data: ") { + continue + } + payload := dataLine[6:] + + var event apicompat.AnthropicStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + continue + } + + // message_start carries the initial response structure and cache usage + if event.Type == "message_start" && event.Message != nil { + finalResp = event.Message + mergeAnthropicUsage(&usage, event.Message.Usage) + } + + // message_delta carries final usage and stop_reason + if event.Type == "message_delta" { + if event.Usage != nil { + mergeAnthropicUsage(&usage, *event.Usage) + } + if event.Delta != nil && event.Delta.StopReason != "" && finalResp != nil { + finalResp.StopReason = event.Delta.StopReason + } + } + if event.Type == "content_block_start" && event.ContentBlock != nil && finalResp != nil { + finalResp.Content = append(finalResp.Content, *event.ContentBlock) + } + if event.Type == "content_block_delta" && event.Delta != nil && finalResp != nil && event.Index != nil { + idx := *event.Index + if idx < len(finalResp.Content) { + switch event.Delta.Type { + case "text_delta": + finalResp.Content[idx].Text += event.Delta.Text + case "thinking_delta": + finalResp.Content[idx].Thinking += event.Delta.Thinking + case "input_json_delta": + finalResp.Content[idx].Input = appendRawJSON(finalResp.Content[idx].Input, event.Delta.PartialJSON) + } + } + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("forward_as_cc buffered: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + if finalResp == nil { + writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream stream ended without a response") + return nil, fmt.Errorf("upstream stream ended without response") + } + + // Update usage from accumulated delta + if usage.InputTokens > 0 || usage.OutputTokens > 0 { + finalResp.Usage = apicompat.AnthropicUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + CacheCreationInputTokens: usage.CacheCreationInputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, + } + } + + // Chain: Anthropic → Responses → Chat Completions + responsesResp := apicompat.AnthropicToResponsesResponse(finalResp) + ccResp := apicompat.ResponsesToChatCompletions(responsesResp, originalModel) + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.JSON(http.StatusOK, ccResp) + + return &ForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ReasoningEffort: reasoningEffort, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +// handleCCStreamingFromAnthropic reads Anthropic SSE events, converts each +// to Responses events, then to Chat Completions chunks, and writes them. +func (s *GatewayService) handleCCStreamingFromAnthropic( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + reasoningEffort *string, + startTime time.Time, + includeUsage bool, +) (*ForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + // Use Anthropic→Responses state machine, then convert Responses→CC + anthState := apicompat.NewAnthropicEventToResponsesState() + anthState.Model = originalModel + ccState := apicompat.NewResponsesEventToChatState() + ccState.Model = originalModel + ccState.IncludeUsage = includeUsage + + var usage ClaudeUsage + var firstTokenMs *int + firstChunk := true + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + resultWithUsage := func() *ForwardResult { + return &ForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ReasoningEffort: reasoningEffort, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + } + } + + writeChunk := func(chunk apicompat.ChatCompletionsChunk) bool { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + return false + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + return true // client disconnected + } + return false + } + + processAnthropicEvent := func(event *apicompat.AnthropicStreamEvent) bool { + if firstChunk { + firstChunk = false + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + // Extract usage from message_delta + if event.Type == "message_delta" && event.Usage != nil { + mergeAnthropicUsage(&usage, *event.Usage) + } + // Also capture usage from message_start (carries cache fields) + if event.Type == "message_start" && event.Message != nil { + mergeAnthropicUsage(&usage, event.Message.Usage) + } + + // Chain: Anthropic event → Responses events → CC chunks + responsesEvents := apicompat.AnthropicEventToResponsesEvents(event, anthState) + for _, resEvt := range responsesEvents { + ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState) + for _, chunk := range ccChunks { + if disconnected := writeChunk(chunk); disconnected { + return true + } + } + } + c.Writer.Flush() + return false + } + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "event: ") { + continue + } + + if !scanner.Scan() { + break + } + dataLine := scanner.Text() + if !strings.HasPrefix(dataLine, "data: ") { + continue + } + payload := dataLine[6:] + + var event apicompat.AnthropicStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + continue + } + + if processAnthropicEvent(&event) { + return resultWithUsage(), nil + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("forward_as_cc stream: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + // Finalize both state machines + finalResEvents := apicompat.FinalizeAnthropicResponsesStream(anthState) + for _, resEvt := range finalResEvents { + ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState) + for _, chunk := range ccChunks { + writeChunk(chunk) //nolint:errcheck + } + } + finalCCChunks := apicompat.FinalizeResponsesChatStream(ccState) + for _, chunk := range finalCCChunks { + writeChunk(chunk) //nolint:errcheck + } + + // Write [DONE] marker + fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck + c.Writer.Flush() + + return resultWithUsage(), nil +} + +// writeGatewayCCError writes an error in OpenAI Chat Completions format for +// the Anthropic-upstream CC forwarding path. +func writeGatewayCCError(c *gin.Context, statusCode int, errType, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} diff --git a/backend/internal/service/gateway_forward_as_chat_completions_test.go b/backend/internal/service/gateway_forward_as_chat_completions_test.go new file mode 100644 index 00000000..5003e5b3 --- /dev/null +++ b/backend/internal/service/gateway_forward_as_chat_completions_test.go @@ -0,0 +1,109 @@ +//go:build unit + +package service + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestExtractCCReasoningEffortFromBody(t *testing.T) { + t.Parallel() + + t.Run("nested reasoning.effort", func(t *testing.T) { + got := extractCCReasoningEffortFromBody([]byte(`{"reasoning":{"effort":"HIGH"}}`)) + require.NotNil(t, got) + require.Equal(t, "high", *got) + }) + + t.Run("flat reasoning_effort", func(t *testing.T) { + got := extractCCReasoningEffortFromBody([]byte(`{"reasoning_effort":"x-high"}`)) + require.NotNil(t, got) + require.Equal(t, "xhigh", *got) + }) + + t.Run("missing effort", func(t *testing.T) { + require.Nil(t, extractCCReasoningEffortFromBody([]byte(`{"model":"gpt-5"}`))) + }) +} + +func TestHandleCCBufferedFromAnthropic_PreservesMessageStartCacheUsageAndReasoning(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + reasoningEffort := "high" + resp := &http.Response{ + Header: http.Header{"x-request-id": []string{"rid_cc_buffered"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":12,"cache_read_input_tokens":9,"cache_creation_input_tokens":3}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`, + ``, + }, "\n"))), + } + + svc := &GatewayService{} + result, err := svc.handleCCBufferedFromAnthropic(resp, c, "gpt-5", "claude-sonnet-4.5", &reasoningEffort, time.Now()) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 9, result.Usage.CacheReadInputTokens) + require.Equal(t, 3, result.Usage.CacheCreationInputTokens) + require.NotNil(t, result.ReasoningEffort) + require.Equal(t, "high", *result.ReasoningEffort) +} + +func TestHandleCCStreamingFromAnthropic_PreservesMessageStartCacheUsageAndReasoning(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + reasoningEffort := "medium" + resp := &http.Response{ + Header: http.Header{"x-request-id": []string{"rid_cc_stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_2","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":20,"cache_read_input_tokens":11,"cache_creation_input_tokens":4}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":8}}`, + ``, + `event: message_stop`, + `data: {"type":"message_stop"}`, + ``, + }, "\n"))), + } + + svc := &GatewayService{} + result, err := svc.handleCCStreamingFromAnthropic(resp, c, "gpt-5", "claude-sonnet-4.5", &reasoningEffort, time.Now(), true) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 20, result.Usage.InputTokens) + require.Equal(t, 8, result.Usage.OutputTokens) + require.Equal(t, 11, result.Usage.CacheReadInputTokens) + require.Equal(t, 4, result.Usage.CacheCreationInputTokens) + require.NotNil(t, result.ReasoningEffort) + require.Equal(t, "medium", *result.ReasoningEffort) + require.Contains(t, rec.Body.String(), `[DONE]`) +} diff --git a/backend/internal/service/gateway_forward_as_responses.go b/backend/internal/service/gateway_forward_as_responses.go new file mode 100644 index 00000000..2c917112 --- /dev/null +++ b/backend/internal/service/gateway_forward_as_responses.go @@ -0,0 +1,518 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// ForwardAsResponses accepts an OpenAI Responses API request body, converts it +// to Anthropic Messages format, forwards to the Anthropic upstream, and converts +// the response back to Responses format. This enables OpenAI Responses API +// clients to access Anthropic models through Anthropic platform groups. +// +// The method follows the same pattern as OpenAIGatewayService.ForwardAsAnthropic +// but in reverse direction: Responses → Anthropic upstream → Responses. +func (s *GatewayService) ForwardAsResponses( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + parsed *ParsedRequest, +) (*ForwardResult, error) { + startTime := time.Now() + + // 1. Parse Responses request + var responsesReq apicompat.ResponsesRequest + if err := json.Unmarshal(body, &responsesReq); err != nil { + return nil, fmt.Errorf("parse responses request: %w", err) + } + originalModel := responsesReq.Model + clientStream := responsesReq.Stream + + // 2. Convert Responses → Anthropic + anthropicReq, err := apicompat.ResponsesToAnthropicRequest(&responsesReq) + if err != nil { + return nil, fmt.Errorf("convert responses to anthropic: %w", err) + } + + // 3. Force upstream streaming (Anthropic works best with streaming) + anthropicReq.Stream = true + reqStream := true + + // 4. Model mapping + mappedModel := originalModel + reasoningEffort := ExtractResponsesReasoningEffortFromBody(body) + if account.Type == AccountTypeAPIKey { + mappedModel = account.GetMappedModel(originalModel) + } + if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + normalized := claude.NormalizeModelID(originalModel) + if normalized != originalModel { + mappedModel = normalized + } + } + anthropicReq.Model = mappedModel + + logger.L().Debug("gateway forward_as_responses: model mapping applied", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("mapped_model", mappedModel), + zap.Bool("client_stream", clientStream), + ) + + // 5. Marshal Anthropic request body + anthropicBody, err := json.Marshal(anthropicReq) + if err != nil { + return nil, fmt.Errorf("marshal anthropic request: %w", err) + } + + // 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints) + isClaudeCode := false // Responses API is never Claude Code + shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode + + if shouldMimicClaudeCode { + if !strings.Contains(strings.ToLower(mappedModel), "haiku") && + !systemIncludesClaudeCodePrompt(anthropicReq.System) { + anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System) + } + } + + // 7. Enforce cache_control block limit + anthropicBody = enforceCacheControlLimit(anthropicBody) + + // 8. Get access token + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("get access token: %w", err) + } + + // 9. Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 10. Build upstream request + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode) + releaseUpstreamCtx() + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + + // 11. Send request + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + // 12. Handle error response with failover + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + if s.shouldFailoverUpstreamError(resp.StatusCode) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + } + } + + // Non-failover error: return Responses-formatted error to client + writeResponsesError(c, mapUpstreamStatusCode(resp.StatusCode), "server_error", upstreamMsg) + return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) + } + + // 13. Handle normal response (convert Anthropic → Responses) + var result *ForwardResult + var handleErr error + if clientStream { + result, handleErr = s.handleResponsesStreamingResponse(resp, c, originalModel, mappedModel, reasoningEffort, startTime) + } else { + result, handleErr = s.handleResponsesBufferedStreamingResponse(resp, c, originalModel, mappedModel, reasoningEffort, startTime) + } + + return result, handleErr +} + +// ExtractResponsesReasoningEffortFromBody reads Responses API reasoning.effort +// and normalizes it for usage logging. +func ExtractResponsesReasoningEffortFromBody(body []byte) *string { + raw := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) + if raw == "" { + return nil + } + normalized := normalizeOpenAIReasoningEffort(raw) + if normalized == "" { + return nil + } + return &normalized +} + +func mergeAnthropicUsage(dst *ClaudeUsage, src apicompat.AnthropicUsage) { + if dst == nil { + return + } + if src.InputTokens > 0 { + dst.InputTokens = src.InputTokens + } + if src.OutputTokens > 0 { + dst.OutputTokens = src.OutputTokens + } + if src.CacheReadInputTokens > 0 { + dst.CacheReadInputTokens = src.CacheReadInputTokens + } + if src.CacheCreationInputTokens > 0 { + dst.CacheCreationInputTokens = src.CacheCreationInputTokens + } +} + +// handleResponsesBufferedStreamingResponse reads all Anthropic SSE events from +// the upstream streaming response, assembles them into a complete Anthropic +// response, converts to Responses API JSON format, and writes it to the client. +func (s *GatewayService) handleResponsesBufferedStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + reasoningEffort *string, + startTime time.Time, +) (*ForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + // Accumulate the final Anthropic response from streaming events + var finalResp *apicompat.AnthropicResponse + var usage ClaudeUsage + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "event: ") { + continue + } + eventType := strings.TrimPrefix(line, "event: ") + + // Read the data line + if !scanner.Scan() { + break + } + dataLine := scanner.Text() + if !strings.HasPrefix(dataLine, "data: ") { + continue + } + payload := dataLine[6:] + + var event apicompat.AnthropicStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("forward_as_responses buffered: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + zap.String("event_type", eventType), + ) + continue + } + + // message_start carries the initial response structure + if event.Type == "message_start" && event.Message != nil { + finalResp = event.Message + mergeAnthropicUsage(&usage, event.Message.Usage) + } + + // message_delta carries final usage and stop_reason + if event.Type == "message_delta" { + if event.Usage != nil { + mergeAnthropicUsage(&usage, *event.Usage) + } + if event.Delta != nil && event.Delta.StopReason != "" && finalResp != nil { + finalResp.StopReason = event.Delta.StopReason + } + } + + // Accumulate content blocks + if event.Type == "content_block_start" && event.ContentBlock != nil && finalResp != nil { + finalResp.Content = append(finalResp.Content, *event.ContentBlock) + } + if event.Type == "content_block_delta" && event.Delta != nil && finalResp != nil && event.Index != nil { + idx := *event.Index + if idx < len(finalResp.Content) { + switch event.Delta.Type { + case "text_delta": + finalResp.Content[idx].Text += event.Delta.Text + case "thinking_delta": + finalResp.Content[idx].Thinking += event.Delta.Thinking + case "input_json_delta": + finalResp.Content[idx].Input = appendRawJSON(finalResp.Content[idx].Input, event.Delta.PartialJSON) + } + } + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("forward_as_responses buffered: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + if finalResp == nil { + writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream stream ended without a response") + return nil, fmt.Errorf("upstream stream ended without response") + } + + // Update usage from accumulated delta + if usage.InputTokens > 0 || usage.OutputTokens > 0 { + finalResp.Usage = apicompat.AnthropicUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + CacheCreationInputTokens: usage.CacheCreationInputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, + } + } + + // Convert to Responses format + responsesResp := apicompat.AnthropicToResponsesResponse(finalResp) + responsesResp.Model = originalModel // Use original model name + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.JSON(http.StatusOK, responsesResp) + + return &ForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ReasoningEffort: reasoningEffort, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +// handleResponsesStreamingResponse reads Anthropic SSE events from upstream, +// converts each to Responses SSE events, and writes them to the client. +func (s *GatewayService) handleResponsesStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + reasoningEffort *string, + startTime time.Time, +) (*ForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + state := apicompat.NewAnthropicEventToResponsesState() + state.Model = originalModel + var usage ClaudeUsage + var firstTokenMs *int + firstChunk := true + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + resultWithUsage := func() *ForwardResult { + return &ForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ReasoningEffort: reasoningEffort, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + } + } + + // processEvent handles a single parsed Anthropic SSE event. + processEvent := func(event *apicompat.AnthropicStreamEvent) bool { + if firstChunk { + firstChunk = false + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + // Extract usage from message_delta + if event.Type == "message_delta" && event.Usage != nil { + mergeAnthropicUsage(&usage, *event.Usage) + } + // Also capture usage from message_start + if event.Type == "message_start" && event.Message != nil { + mergeAnthropicUsage(&usage, event.Message.Usage) + } + + // Convert to Responses events + events := apicompat.AnthropicEventToResponsesEvents(event, state) + for _, evt := range events { + sse, err := apicompat.ResponsesEventToSSE(evt) + if err != nil { + logger.L().Warn("forward_as_responses stream: failed to marshal event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + logger.L().Info("forward_as_responses stream: client disconnected", + zap.String("request_id", requestID), + ) + return true // client disconnected + } + } + if len(events) > 0 { + c.Writer.Flush() + } + return false + } + + finalizeStream := func() (*ForwardResult, error) { + if finalEvents := apicompat.FinalizeAnthropicResponsesStream(state); len(finalEvents) > 0 { + for _, evt := range finalEvents { + sse, err := apicompat.ResponsesEventToSSE(evt) + if err != nil { + continue + } + fmt.Fprint(c.Writer, sse) //nolint:errcheck + } + c.Writer.Flush() + } + return resultWithUsage(), nil + } + + // Read Anthropic SSE events + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "event: ") { + continue + } + eventType := strings.TrimPrefix(line, "event: ") + + // Read data line + if !scanner.Scan() { + break + } + dataLine := scanner.Text() + if !strings.HasPrefix(dataLine, "data: ") { + continue + } + payload := dataLine[6:] + + var event apicompat.AnthropicStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("forward_as_responses stream: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + zap.String("event_type", eventType), + ) + continue + } + + if processEvent(&event) { + return resultWithUsage(), nil + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("forward_as_responses stream: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + return finalizeStream() +} + +// appendRawJSON appends a JSON fragment string to existing raw JSON. +func appendRawJSON(existing json.RawMessage, fragment string) json.RawMessage { + if len(existing) == 0 { + return json.RawMessage(fragment) + } + return json.RawMessage(string(existing) + fragment) +} + +// writeResponsesError writes an error response in OpenAI Responses API format. +func writeResponsesError(c *gin.Context, statusCode int, code, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "code": code, + "message": message, + }, + }) +} + +// mapUpstreamStatusCode maps upstream HTTP status codes to appropriate client-facing codes. +func mapUpstreamStatusCode(code int) int { + if code >= 500 { + return http.StatusBadGateway + } + return code +} diff --git a/backend/internal/service/gateway_forward_as_responses_test.go b/backend/internal/service/gateway_forward_as_responses_test.go new file mode 100644 index 00000000..e48d8b22 --- /dev/null +++ b/backend/internal/service/gateway_forward_as_responses_test.go @@ -0,0 +1,94 @@ +//go:build unit + +package service + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestExtractResponsesReasoningEffortFromBody(t *testing.T) { + t.Parallel() + + got := ExtractResponsesReasoningEffortFromBody([]byte(`{"model":"claude-sonnet-4.5","reasoning":{"effort":"HIGH"}}`)) + require.NotNil(t, got) + require.Equal(t, "high", *got) + + require.Nil(t, ExtractResponsesReasoningEffortFromBody([]byte(`{"model":"claude-sonnet-4.5"}`))) +} + +func TestHandleResponsesBufferedStreamingResponse_PreservesMessageStartCacheUsage(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + resp := &http.Response{ + Header: http.Header{"x-request-id": []string{"rid_buffered"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":12,"cache_read_input_tokens":9,"cache_creation_input_tokens":3}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`, + ``, + }, "\n"))), + } + + svc := &GatewayService{} + result, err := svc.handleResponsesBufferedStreamingResponse(resp, c, "claude-sonnet-4.5", "claude-sonnet-4.5", nil, time.Now()) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 9, result.Usage.CacheReadInputTokens) + require.Equal(t, 3, result.Usage.CacheCreationInputTokens) + require.Contains(t, rec.Body.String(), `"cached_tokens":9`) +} + +func TestHandleResponsesStreamingResponse_PreservesMessageStartCacheUsage(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + resp := &http.Response{ + Header: http.Header{"x-request-id": []string{"rid_stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_2","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":20,"cache_read_input_tokens":11,"cache_creation_input_tokens":4}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":8}}`, + ``, + `event: message_stop`, + `data: {"type":"message_stop"}`, + ``, + }, "\n"))), + } + + svc := &GatewayService{} + result, err := svc.handleResponsesStreamingResponse(resp, c, "claude-sonnet-4.5", "claude-sonnet-4.5", nil, time.Now()) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 20, result.Usage.InputTokens) + require.Equal(t, 8, result.Usage.OutputTokens) + require.Equal(t, 11, result.Usage.CacheReadInputTokens) + require.Equal(t, 4, result.Usage.CacheCreationInputTokens) + require.Contains(t, rec.Body.String(), `response.completed`) +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index ea8fa784..2d16ad94 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -92,7 +92,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { +func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { @@ -278,8 +278,8 @@ func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, plat func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) { return false, nil } -func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { - return 0, nil +func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { + return 0, 0, nil } func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, nil @@ -3139,7 +3139,7 @@ func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) { account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil) require.NoError(t, err) require.NotNil(t, account) - require.Equal(t, 0, groupRepo.getByIDCalls) + require.Equal(t, 1, groupRepo.getByIDCalls) // +1 for require_privacy_set check require.Equal(t, 0, groupRepo.getByIDLiteCalls) } @@ -3182,7 +3182,7 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil) require.NoError(t, err) require.NotNil(t, account) - require.Equal(t, 0, groupRepo.getByIDCalls) + require.Equal(t, 1, groupRepo.getByIDCalls) // +1 for require_privacy_set check require.Equal(t, 1, groupRepo.getByIDLiteCalls) } @@ -3252,7 +3252,7 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) { account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil) require.NoError(t, err) require.NotNil(t, account) - require.Equal(t, 0, groupRepo.getByIDCalls) + require.Equal(t, 1, groupRepo.getByIDCalls) // +1 for require_privacy_set check require.Equal(t, 1, groupRepo.getByIDLiteCalls) } diff --git a/backend/internal/service/gateway_prompt_test.go b/backend/internal/service/gateway_prompt_test.go index 52c75d1d..356536b0 100644 --- a/backend/internal/service/gateway_prompt_test.go +++ b/backend/internal/service/gateway_prompt_test.go @@ -124,6 +124,27 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) { }, want: false, }, + // json.RawMessage cases (conversion path: ForwardAsResponses / ForwardAsChatCompletions) + { + name: "json.RawMessage string with Claude Code prompt", + system: json.RawMessage(`"` + claudeCodeSystemPrompt + `"`), + want: true, + }, + { + name: "json.RawMessage string without Claude Code prompt", + system: json.RawMessage(`"You are a helpful assistant"`), + want: false, + }, + { + name: "json.RawMessage nil (empty)", + system: json.RawMessage(nil), + want: false, + }, + { + name: "json.RawMessage empty string", + system: json.RawMessage(`""`), + want: false, + }, } for _, tt := range tests { @@ -202,6 +223,29 @@ func TestInjectClaudeCodePrompt(t *testing.T) { wantSystemLen: 1, wantFirstText: claudeCodeSystemPrompt, }, + // json.RawMessage cases (conversion path: ForwardAsResponses / ForwardAsChatCompletions) + { + name: "json.RawMessage string system", + body: `{"model":"claude-3","system":"Custom prompt"}`, + system: json.RawMessage(`"Custom prompt"`), + wantSystemLen: 2, + wantFirstText: claudeCodeSystemPrompt, + wantSecondText: claudePrefix + "\n\nCustom prompt", + }, + { + name: "json.RawMessage nil system", + body: `{"model":"claude-3"}`, + system: json.RawMessage(nil), + wantSystemLen: 1, + wantFirstText: claudeCodeSystemPrompt, + }, + { + name: "json.RawMessage Claude Code prompt (should not duplicate)", + body: `{"model":"claude-3","system":"` + claudeCodeSystemPrompt + `"}`, + system: json.RawMessage(`"` + claudeCodeSystemPrompt + `"`), + wantSystemLen: 1, + wantFirstText: claudeCodeSystemPrompt, + }, } for _, tt := range tests { diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go index 4c1f0317..48488dc8 100644 --- a/backend/internal/service/gateway_record_usage_test.go +++ b/backend/internal/service/gateway_record_usage_test.go @@ -40,6 +40,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo nil, nil, nil, + nil, ) } @@ -162,6 +163,32 @@ func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash) } +func TestGatewayServiceRecordUsage_PreservesRequestedAndUpstreamModels(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + mappedModel := "claude-sonnet-4-20250514" + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_models_split", + Usage: ClaudeUsage{InputTokens: 10, OutputTokens: 6}, + Model: "claude-sonnet-4", + UpstreamModel: mappedModel, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 501, Quota: 100}, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.Model) + require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.RequestedModel) + require.NotNil(t, usageRepo.lastLog.UpstreamModel) + require.Equal(t, mappedModel, *usageRepo.lastLog.UpstreamModel) +} + func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)} userRepo := &openAIRecordUsageUserRepoStub{} diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 3816aea9..e2badfed 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -5,6 +5,8 @@ import ( "encoding/json" "fmt" "math" + "regexp" + "sort" "strings" "unsafe" @@ -28,6 +30,15 @@ var ( patternEmptyContentSpaced = []byte(`"content": []`) patternEmptyContentSp1 = []byte(`"content" : []`) patternEmptyContentSp2 = []byte(`"content" :[]`) + + // Fast-path patterns for empty text blocks: {"type":"text","text":""} + patternEmptyText = []byte(`"text":""`) + patternEmptyTextSpaced = []byte(`"text": ""`) + patternEmptyTextSp1 = []byte(`"text" : ""`) + patternEmptyTextSp2 = []byte(`"text" :""`) + + sessionUserAgentProductPattern = regexp.MustCompile(`([A-Za-z0-9._-]+)/[A-Za-z0-9._-]+`) + sessionUserAgentVersionPattern = regexp.MustCompile(`\bv?\d+(?:\.\d+){1,3}\b`) ) // SessionContext 粘性会话上下文,用于区分不同来源的请求。 @@ -69,6 +80,49 @@ type ParsedRequest struct { OnUpstreamAccepted func() } +// NormalizeSessionUserAgent reduces UA noise for sticky-session and digest hashing. +// It preserves the set of product names from Product/Version tokens while +// discarding version-only changes and incidental comments. +func NormalizeSessionUserAgent(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + + matches := sessionUserAgentProductPattern.FindAllStringSubmatch(raw, -1) + if len(matches) == 0 { + return normalizeSessionUserAgentFallback(raw) + } + + products := make([]string, 0, len(matches)) + seen := make(map[string]struct{}, len(matches)) + for _, match := range matches { + if len(match) < 2 { + continue + } + product := strings.ToLower(strings.TrimSpace(match[1])) + if product == "" { + continue + } + if _, exists := seen[product]; exists { + continue + } + seen[product] = struct{}{} + products = append(products, product) + } + if len(products) == 0 { + return normalizeSessionUserAgentFallback(raw) + } + sort.Strings(products) + return strings.Join(products, "+") +} + +func normalizeSessionUserAgentFallback(raw string) string { + normalized := strings.ToLower(strings.Join(strings.Fields(raw), " ")) + normalized = sessionUserAgentVersionPattern.ReplaceAllString(normalized, "") + return strings.Join(strings.Fields(normalized), " ") +} + // ParseGatewayRequest 解析网关请求体并返回结构化结果。 // protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini), // 不同协议使用不同的 system/messages 字段名。 @@ -199,6 +253,118 @@ func sliceRawFromBody(body []byte, r gjson.Result) []byte { return []byte(r.Raw) } +// stripEmptyTextBlocksFromSlice removes empty text blocks from a content slice (including nested tool_result content). +// Returns (cleaned slice, true) if any blocks were removed, or (original, false) if unchanged. +func stripEmptyTextBlocksFromSlice(blocks []any) ([]any, bool) { + var result []any + changed := false + for i, block := range blocks { + blockMap, ok := block.(map[string]any) + if !ok { + if result != nil { + result = append(result, block) + } + continue + } + blockType, _ := blockMap["type"].(string) + + // Strip empty text blocks + if blockType == "text" { + if txt, _ := blockMap["text"].(string); txt == "" { + if result == nil { + result = make([]any, 0, len(blocks)) + result = append(result, blocks[:i]...) + } + changed = true + continue + } + } + + // Recurse into tool_result nested content + if blockType == "tool_result" { + if nestedContent, ok := blockMap["content"].([]any); ok { + if cleaned, nestedChanged := stripEmptyTextBlocksFromSlice(nestedContent); nestedChanged { + if result == nil { + result = make([]any, 0, len(blocks)) + result = append(result, blocks[:i]...) + } + changed = true + blockCopy := make(map[string]any, len(blockMap)) + for k, v := range blockMap { + blockCopy[k] = v + } + blockCopy["content"] = cleaned + result = append(result, blockCopy) + continue + } + } + } + + if result != nil { + result = append(result, block) + } + } + if !changed { + return blocks, false + } + return result, true +} + +// StripEmptyTextBlocks removes empty text blocks from the request body (including nested tool_result content). +// This is a lightweight pre-filter for the initial request path to prevent upstream 400 errors. +// Returns the original body unchanged if no empty text blocks are found. +func StripEmptyTextBlocks(body []byte) []byte { + // Fast path: check if body contains empty text patterns + hasEmptyTextBlock := bytes.Contains(body, patternEmptyText) || + bytes.Contains(body, patternEmptyTextSpaced) || + bytes.Contains(body, patternEmptyTextSp1) || + bytes.Contains(body, patternEmptyTextSp2) + if !hasEmptyTextBlock { + return body + } + + jsonStr := *(*string)(unsafe.Pointer(&body)) + msgsRes := gjson.Get(jsonStr, "messages") + if !msgsRes.Exists() || !msgsRes.IsArray() { + return body + } + + var messages []any + if err := json.Unmarshal(sliceRawFromBody(body, msgsRes), &messages); err != nil { + return body + } + + modified := false + for _, msg := range messages { + msgMap, ok := msg.(map[string]any) + if !ok { + continue + } + content, ok := msgMap["content"].([]any) + if !ok { + continue + } + if cleaned, changed := stripEmptyTextBlocksFromSlice(content); changed { + modified = true + msgMap["content"] = cleaned + } + } + + if !modified { + return body + } + + msgsBytes, err := json.Marshal(messages) + if err != nil { + return body + } + out, err := sjson.SetRawBytes(body, "messages", msgsBytes) + if err != nil { + return body + } + return out +} + // FilterThinkingBlocks removes thinking blocks from request body // Returns filtered body or original body if filtering fails (fail-safe) // This prevents 400 errors from invalid thinking block signatures @@ -233,15 +399,22 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { bytes.Contains(body, patternThinkingField) || bytes.Contains(body, patternThinkingFieldSpaced) - // Also check for empty content arrays that need fixing. + // Also check for empty content arrays and empty text blocks that need fixing. // Note: This is a heuristic check; the actual empty content handling is done below. hasEmptyContent := bytes.Contains(body, patternEmptyContent) || bytes.Contains(body, patternEmptyContentSpaced) || bytes.Contains(body, patternEmptyContentSp1) || bytes.Contains(body, patternEmptyContentSp2) + // Check for empty text blocks: {"type":"text","text":""} + // These cause upstream 400: "text content blocks must be non-empty" + hasEmptyTextBlock := bytes.Contains(body, patternEmptyText) || + bytes.Contains(body, patternEmptyTextSpaced) || + bytes.Contains(body, patternEmptyTextSp1) || + bytes.Contains(body, patternEmptyTextSp2) + // Fast path: nothing to process - if !hasThinkingContent && !hasEmptyContent { + if !hasThinkingContent && !hasEmptyContent && !hasEmptyTextBlock { return body } @@ -260,7 +433,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { bytes.Contains(body, patternTypeRedactedThinking) || bytes.Contains(body, patternTypeRedactedSpaced) || bytes.Contains(body, patternThinkingFieldSpaced) - if !hasEmptyContent && !containsThinkingBlocks { + if !hasEmptyContent && !hasEmptyTextBlock && !containsThinkingBlocks { if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() { if out, err := sjson.DeleteBytes(body, "thinking"); err == nil { out = removeThinkingDependentContextStrategies(out) @@ -320,6 +493,16 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { blockType, _ := blockMap["type"].(string) + // Strip empty text blocks: {"type":"text","text":""} + // Upstream rejects these with 400: "text content blocks must be non-empty" + if blockType == "text" { + if txt, _ := blockMap["text"].(string); txt == "" { + modifiedThisMsg = true + ensureNewContent(bi) + continue + } + } + // Convert thinking blocks to text (preserve content) and drop redacted_thinking. switch blockType { case "thinking": @@ -355,6 +538,23 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { } } + // Recursively strip empty text blocks from tool_result nested content. + if blockType == "tool_result" { + if nestedContent, ok := blockMap["content"].([]any); ok { + if cleaned, changed := stripEmptyTextBlocksFromSlice(nestedContent); changed { + modifiedThisMsg = true + ensureNewContent(bi) + blockCopy := make(map[string]any, len(blockMap)) + for k, v := range blockMap { + blockCopy[k] = v + } + blockCopy["content"] = cleaned + newContent = append(newContent, blockCopy) + continue + } + } + } + if newContent != nil { newContent = append(newContent, block) } diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index f60ed9fb..d262456d 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -404,6 +404,167 @@ func TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder(t *testing.T) require.NotEmpty(t, content0["text"]) } +func TestFilterThinkingBlocksForRetry_StripsEmptyTextBlocks(t *testing.T) { + // Empty text blocks cause upstream 400: "text content blocks must be non-empty" + input := []byte(`{ + "messages":[ + {"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":""}]}, + {"role":"assistant","content":[{"type":"text","text":""}]} + ] + }`) + + out := FilterThinkingBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + msgs, ok := req["messages"].([]any) + require.True(t, ok) + + // First message: empty text block stripped, "hello" preserved + msg0 := msgs[0].(map[string]any) + content0 := msg0["content"].([]any) + require.Len(t, content0, 1) + require.Equal(t, "hello", content0[0].(map[string]any)["text"]) + + // Second message: only had empty text block → gets placeholder + msg1 := msgs[1].(map[string]any) + content1 := msg1["content"].([]any) + require.Len(t, content1, 1) + block1 := content1[0].(map[string]any) + require.Equal(t, "text", block1["type"]) + require.NotEmpty(t, block1["text"]) +} + +func TestFilterThinkingBlocksForRetry_StripsNestedEmptyTextInToolResult(t *testing.T) { + // Empty text blocks nested inside tool_result content should also be stripped + input := []byte(`{ + "messages":[ + {"role":"user","content":[ + {"type":"tool_result","tool_use_id":"t1","content":[ + {"type":"text","text":"valid result"}, + {"type":"text","text":""} + ]} + ]} + ] + }`) + + out := FilterThinkingBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + msgs := req["messages"].([]any) + msg0 := msgs[0].(map[string]any) + content0 := msg0["content"].([]any) + require.Len(t, content0, 1) + toolResult := content0[0].(map[string]any) + require.Equal(t, "tool_result", toolResult["type"]) + nestedContent := toolResult["content"].([]any) + require.Len(t, nestedContent, 1) + require.Equal(t, "valid result", nestedContent[0].(map[string]any)["text"]) +} + +func TestFilterThinkingBlocksForRetry_NestedAllEmptyGetsEmptySlice(t *testing.T) { + // If all nested content blocks in tool_result are empty text, content becomes empty slice + input := []byte(`{ + "messages":[ + {"role":"user","content":[ + {"type":"tool_result","tool_use_id":"t1","content":[ + {"type":"text","text":""} + ]}, + {"type":"text","text":"hello"} + ]} + ] + }`) + + out := FilterThinkingBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + msgs := req["messages"].([]any) + msg0 := msgs[0].(map[string]any) + content0 := msg0["content"].([]any) + require.Len(t, content0, 2) + toolResult := content0[0].(map[string]any) + nestedContent := toolResult["content"].([]any) + require.Len(t, nestedContent, 0) +} + +func TestStripEmptyTextBlocks(t *testing.T) { + t.Run("strips top-level empty text", func(t *testing.T) { + input := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":""}]}]}`) + out := StripEmptyTextBlocks(input) + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + msgs := req["messages"].([]any) + content := msgs[0].(map[string]any)["content"].([]any) + require.Len(t, content, 1) + require.Equal(t, "hello", content[0].(map[string]any)["text"]) + }) + + t.Run("strips nested empty text in tool_result", func(t *testing.T) { + input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"text","text":"ok"},{"type":"text","text":""}]}]}]}`) + out := StripEmptyTextBlocks(input) + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + msgs := req["messages"].([]any) + content := msgs[0].(map[string]any)["content"].([]any) + toolResult := content[0].(map[string]any) + nestedContent := toolResult["content"].([]any) + require.Len(t, nestedContent, 1) + require.Equal(t, "ok", nestedContent[0].(map[string]any)["text"]) + }) + + t.Run("no-op when no empty text", func(t *testing.T) { + input := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + out := StripEmptyTextBlocks(input) + require.Equal(t, input, out) + }) + + t.Run("preserves non-map blocks in content", func(t *testing.T) { + // tool_result content can be a string; non-map blocks should pass through unchanged + input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":"string content"},{"type":"text","text":""}]}]}`) + out := StripEmptyTextBlocks(input) + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + msgs := req["messages"].([]any) + content := msgs[0].(map[string]any)["content"].([]any) + require.Len(t, content, 1) + toolResult := content[0].(map[string]any) + require.Equal(t, "tool_result", toolResult["type"]) + require.Equal(t, "string content", toolResult["content"]) + }) + + t.Run("handles deeply nested tool_result", func(t *testing.T) { + // Recursive: tool_result containing another tool_result with empty text + input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_result","tool_use_id":"t2","content":[{"type":"text","text":""},{"type":"text","text":"deep"}]}]}]}]}`) + out := StripEmptyTextBlocks(input) + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + msgs := req["messages"].([]any) + content := msgs[0].(map[string]any)["content"].([]any) + outer := content[0].(map[string]any) + innerContent := outer["content"].([]any) + inner := innerContent[0].(map[string]any) + deepContent := inner["content"].([]any) + require.Len(t, deepContent, 1) + require.Equal(t, "deep", deepContent[0].(map[string]any)["text"]) + }) +} + +func TestFilterThinkingBlocksForRetry_PreservesNonEmptyTextBlocks(t *testing.T) { + // Non-empty text blocks should pass through unchanged + input := []byte(`{ + "messages":[ + {"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]} + ] + }`) + + out := FilterThinkingBlocksForRetry(input) + + // Fast path: no thinking content, no empty content, no empty text blocks → unchanged + require.Equal(t, input, out) +} + func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) { input := []byte(`{ "thinking":{"type":"enabled","budget_tokens":1024}, diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 0b50162a..94e04d28 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -12,7 +12,9 @@ import ( "log/slog" mathrand "math/rand" "net/http" + "net/url" "os" + "path/filepath" "regexp" "sort" "strconv" @@ -51,6 +53,7 @@ const ( defaultUserGroupRateCacheTTL = 30 * time.Second defaultModelsListCacheTTL = 15 * time.Second postUsageBillingTimeout = 15 * time.Second + debugGatewayBodyEnv = "SUB2API_DEBUG_GATEWAY_BODY" ) const ( @@ -326,7 +329,6 @@ func isClaudeCodeCredentialScopeError(msg string) bool { // Some upstream APIs return non-standard "data:" without space (should be "data: "). var ( sseDataRe = regexp.MustCompile(`^data:\s*`) - sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 @@ -340,12 +342,6 @@ var ( } ) -// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表 -// OAuth/SetupToken 账号转发时,匹配这些前缀的 system 元素会被移除 -var systemBlockFilterPrefixes = []string{ - "x-anthropic-billing-header", -} - // ErrNoAvailableAccounts 表示没有可用的账号 var ErrNoAvailableAccounts = errors.New("no available accounts") @@ -372,6 +368,9 @@ var allowedHeaders = map[string]bool{ "sec-fetch-mode": true, "user-agent": true, "content-type": true, + "accept-encoding": true, + "x-claude-code-session-id": true, + "x-client-request-id": true, } // GatewayCache 定义网关服务的缓存操作接口。 @@ -488,9 +487,12 @@ type ClaudeUsage struct { // ForwardResult 转发结果 type ForwardResult struct { - RequestID string - Usage ClaudeUsage - Model string + RequestID string + Usage ClaudeUsage + Model string + // UpstreamModel is the actual upstream model after mapping. + // Prefer empty when it is identical to Model; persistence normalizes equal values away as no-op mappings. + UpstreamModel string Stream bool Duration time.Duration FirstTokenMs *int // 首字时间(流式请求) @@ -566,6 +568,8 @@ type GatewayService struct { responseHeaderFilter *responseheaders.CompiledHeaderFilter debugModelRouting atomic.Bool debugClaudeMimic atomic.Bool + debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set + tlsFPProfileService *TLSFingerprintProfileService } // NewGatewayService creates a new GatewayService @@ -592,6 +596,7 @@ func NewGatewayService( rpmCache RPMCache, digestStore *DigestSessionStore, settingService *SettingService, + tlsFPProfileService *TLSFingerprintProfileService, ) *GatewayService { userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg) @@ -623,6 +628,7 @@ func NewGatewayService( modelsListCache: gocache.New(modelsListTTL, time.Minute), modelsListCacheTTL: modelsListTTL, responseHeaderFilter: compileResponseHeaderFilter(cfg), + tlsFPProfileService: tlsFPProfileService, } svc.userGroupRateResolver = newUserGroupRateResolver( userGroupRateRepo, @@ -633,6 +639,9 @@ func NewGatewayService( ) svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) + if path := strings.TrimSpace(os.Getenv(debugGatewayBodyEnv)); path != "" { + svc.initDebugGatewayBodyFile(path) + } return svc } @@ -644,8 +653,8 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { // 1. 最高优先级:从 metadata.user_id 提取 session_xxx if parsed.MetadataUserID != "" { - if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 { - return match[1] + if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" { + return uid.SessionID } } @@ -661,7 +670,7 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { if parsed.SessionContext != nil { _, _ = combined.WriteString(parsed.SessionContext.ClientIP) _, _ = combined.WriteString(":") - _, _ = combined.WriteString(parsed.SessionContext.UserAgent) + _, _ = combined.WriteString(NormalizeSessionUserAgent(parsed.SessionContext.UserAgent)) _, _ = combined.WriteString(":") _, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10)) _, _ = combined.WriteString("|") @@ -840,20 +849,30 @@ func (s *GatewayService) hashContent(content string) string { return strconv.FormatUint(h, 36) } +type anthropicCacheControlPayload struct { + Type string `json:"type"` +} + +type anthropicSystemTextBlockPayload struct { + Type string `json:"type"` + Text string `json:"text"` + CacheControl *anthropicCacheControlPayload `json:"cache_control,omitempty"` +} + +type anthropicMetadataPayload struct { + UserID string `json:"user_id"` +} + // replaceModelInBody 替换请求体中的model字段 -// 使用 json.RawMessage 保留其他字段的原始字节,避免 thinking 块等内容被修改 +// 优先使用定点修改,尽量保持客户端原始字段顺序。 func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { - var req map[string]json.RawMessage - if err := json.Unmarshal(body, &req); err != nil { + if len(body) == 0 { return body } - // 只序列化 model 字段 - modelBytes, err := json.Marshal(newModel) - if err != nil { + if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel { return body } - req["model"] = modelBytes - newBody, err := json.Marshal(req) + newBody, err := sjson.SetBytes(body, "model", newModel) if err != nil { return body } @@ -884,24 +903,146 @@ func sanitizeSystemText(text string) string { return text } -func stripCacheControlFromSystemBlocks(system any) bool { - blocks, ok := system.([]any) - if !ok { - return false +func marshalAnthropicSystemTextBlock(text string, includeCacheControl bool) ([]byte, error) { + block := anthropicSystemTextBlockPayload{ + Type: "text", + Text: text, } - changed := false - for _, item := range blocks { - block, ok := item.(map[string]any) - if !ok { - continue - } - if _, exists := block["cache_control"]; !exists { - continue - } - delete(block, "cache_control") - changed = true + if includeCacheControl { + block.CacheControl = &anthropicCacheControlPayload{Type: "ephemeral"} } - return changed + return json.Marshal(block) +} + +func marshalAnthropicMetadata(userID string) ([]byte, error) { + return json.Marshal(anthropicMetadataPayload{UserID: userID}) +} + +func buildJSONArrayRaw(items [][]byte) []byte { + if len(items) == 0 { + return []byte("[]") + } + + total := 2 + for _, item := range items { + total += len(item) + } + total += len(items) - 1 + + buf := make([]byte, 0, total) + buf = append(buf, '[') + for i, item := range items { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, item...) + } + buf = append(buf, ']') + return buf +} + +func setJSONValueBytes(body []byte, path string, value any) ([]byte, bool) { + next, err := sjson.SetBytes(body, path, value) + if err != nil { + return body, false + } + return next, true +} + +func setJSONRawBytes(body []byte, path string, raw []byte) ([]byte, bool) { + next, err := sjson.SetRawBytes(body, path, raw) + if err != nil { + return body, false + } + return next, true +} + +func deleteJSONPathBytes(body []byte, path string) ([]byte, bool) { + next, err := sjson.DeleteBytes(body, path) + if err != nil { + return body, false + } + return next, true +} + +func normalizeClaudeOAuthSystemBody(body []byte, opts claudeOAuthNormalizeOptions) ([]byte, bool) { + sys := gjson.GetBytes(body, "system") + if !sys.Exists() { + return body, false + } + + out := body + modified := false + + switch { + case sys.Type == gjson.String: + sanitized := sanitizeSystemText(sys.String()) + if sanitized != sys.String() { + if next, ok := setJSONValueBytes(out, "system", sanitized); ok { + out = next + modified = true + } + } + case sys.IsArray(): + index := 0 + sys.ForEach(func(_, item gjson.Result) bool { + if item.Get("type").String() == "text" { + textResult := item.Get("text") + if textResult.Exists() && textResult.Type == gjson.String { + text := textResult.String() + sanitized := sanitizeSystemText(text) + if sanitized != text { + if next, ok := setJSONValueBytes(out, fmt.Sprintf("system.%d.text", index), sanitized); ok { + out = next + modified = true + } + } + } + } + + if opts.stripSystemCacheControl && item.Get("cache_control").Exists() { + if next, ok := deleteJSONPathBytes(out, fmt.Sprintf("system.%d.cache_control", index)); ok { + out = next + modified = true + } + } + + index++ + return true + }) + } + + return out, modified +} + +func ensureClaudeOAuthMetadataUserID(body []byte, userID string) ([]byte, bool) { + if strings.TrimSpace(userID) == "" { + return body, false + } + + metadata := gjson.GetBytes(body, "metadata") + if !metadata.Exists() || metadata.Type == gjson.Null { + raw, err := marshalAnthropicMetadata(userID) + if err != nil { + return body, false + } + return setJSONRawBytes(body, "metadata", raw) + } + + trimmedRaw := strings.TrimSpace(metadata.Raw) + if strings.HasPrefix(trimmedRaw, "{") { + existing := metadata.Get("user_id") + if existing.Exists() && existing.Type == gjson.String && existing.String() != "" { + return body, false + } + return setJSONValueBytes(body, "metadata.user_id", userID) + } + + raw, err := marshalAnthropicMetadata(userID) + if err != nil { + return body, false + } + return setJSONRawBytes(body, "metadata", raw) } func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) { @@ -909,96 +1050,59 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu return body, modelID } - // 解析为 map[string]any 用于修改字段 - var req map[string]any - if err := json.Unmarshal(body, &req); err != nil { - return body, modelID - } - + out := body modified := false - if system, ok := req["system"]; ok { - switch v := system.(type) { - case string: - sanitized := sanitizeSystemText(v) - if sanitized != v { - req["system"] = sanitized - modified = true - } - case []any: - for _, item := range v { - block, ok := item.(map[string]any) - if !ok { - continue - } - if blockType, _ := block["type"].(string); blockType != "text" { - continue - } - text, ok := block["text"].(string) - if !ok || text == "" { - continue - } - sanitized := sanitizeSystemText(text) - if sanitized != text { - block["text"] = sanitized - modified = true - } - } - } + if next, changed := normalizeClaudeOAuthSystemBody(out, opts); changed { + out = next + modified = true } - if rawModel, ok := req["model"].(string); ok { - normalized := claude.NormalizeModelID(rawModel) - if normalized != rawModel { - req["model"] = normalized + rawModel := gjson.GetBytes(out, "model") + if rawModel.Exists() && rawModel.Type == gjson.String { + normalized := claude.NormalizeModelID(rawModel.String()) + if normalized != rawModel.String() { + if next, ok := setJSONValueBytes(out, "model", normalized); ok { + out = next + modified = true + } modelID = normalized - modified = true } } // 确保 tools 字段存在(即使为空数组) - if _, exists := req["tools"]; !exists { - req["tools"] = []any{} - modified = true - } - - if opts.stripSystemCacheControl { - if system, ok := req["system"]; ok { - _ = stripCacheControlFromSystemBlocks(system) + if !gjson.GetBytes(out, "tools").Exists() { + if next, ok := setJSONRawBytes(out, "tools", []byte("[]")); ok { + out = next modified = true } } if opts.injectMetadata && opts.metadataUserID != "" { - metadata, ok := req["metadata"].(map[string]any) - if !ok { - metadata = map[string]any{} - req["metadata"] = metadata - } - if existing, ok := metadata["user_id"].(string); !ok || existing == "" { - metadata["user_id"] = opts.metadataUserID + if next, changed := ensureClaudeOAuthMetadataUserID(out, opts.metadataUserID); changed { + out = next modified = true } } - if _, hasTemp := req["temperature"]; hasTemp { - delete(req, "temperature") - modified = true + if gjson.GetBytes(out, "temperature").Exists() { + if next, ok := deleteJSONPathBytes(out, "temperature"); ok { + out = next + modified = true + } } - if _, hasChoice := req["tool_choice"]; hasChoice { - delete(req, "tool_choice") - modified = true + if gjson.GetBytes(out, "tool_choice").Exists() { + if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok { + out = next + modified = true + } } if !modified { return body, modelID } - newBody, err := json.Marshal(req) - if err != nil { - return body, modelID - } - return newBody, modelID + return out, modelID } func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string { @@ -1026,13 +1130,13 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account sessionID = generateSessionUUID(seed) } - // Prefer the newer format that includes account_uuid (if present), - // otherwise fall back to the legacy Claude Code format. - accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) - if accountUUID != "" { - return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID) + // 根据指纹 UA 版本选择输出格式 + var uaVersion string + if fp != nil { + uaVersion = ExtractCLIVersion(fp.UserAgent) } - return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID) + accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) + return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion) } // GenerateSessionUUID creates a deterministic UUID4 from a seed string. @@ -1314,19 +1418,24 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { // 粘性账号在路由列表中,优先使用 if stickyAccount, ok := accountByID[stickyAccountID]; ok { - if s.isAccountSchedulableForSelection(stickyAccount) && + var stickyCacheMissReason string + + gatePass := s.isAccountSchedulableForSelection(stickyAccount) && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && s.isAccountSchedulableForQuota(stickyAccount) && - s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && + s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) - s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 + rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true) + + if rpmPass { // 粘性会话窗口费用+RPM 检查 result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { result.ReleaseFunc() // 释放槽位 + stickyCacheMissReason = "session_limit" // 继续到负载感知选择 } else { if s.debugModelRoutingEnabled() { @@ -1340,27 +1449,49 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) - if waitingCount < cfg.StickySessionMaxWaiting { - // 会话数量限制检查(等待计划也需要占用会话配额) - if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { - // 会话限制已满,继续到负载感知选择 + if stickyCacheMissReason == "" { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) + if waitingCount < cfg.StickySessionMaxWaiting { + // 会话数量限制检查(等待计划也需要占用会话配额) + if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { + stickyCacheMissReason = "session_limit" + // 会话限制已满,继续到负载感知选择 + } else { + return &AccountSelectionResult{ + Account: stickyAccount, + WaitPlan: &AccountWaitPlan{ + AccountID: stickyAccountID, + MaxConcurrency: stickyAccount.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } } else { - return &AccountSelectionResult{ - Account: stickyAccount, - WaitPlan: &AccountWaitPlan{ - AccountID: stickyAccountID, - MaxConcurrency: stickyAccount.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + stickyCacheMissReason = "wait_queue_full" } } // 粘性账号槽位满且等待队列已满,继续使用负载感知选择 + } else if !gatePass { + stickyCacheMissReason = "gate_check" + } else { + stickyCacheMissReason = "rpm_red" + } + + // 记录粘性缓存未命中的结构化日志 + if stickyCacheMissReason != "" { + baseRPM := stickyAccount.GetBaseRPM() + var currentRPM int + if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok { + currentRPM = count + } + logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d", + stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM) } } else { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0", + stickyAccountID, shortSessionHash(sessionHash)) } } } @@ -2640,6 +2771,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, preferOAuth := platform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) + // require_privacy_set: 获取分组信息 + var schedGroup *Group + if groupID != nil && s.groupRepo != nil { + schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) + } + var accounts []Account accountsLoaded := false @@ -2711,6 +2848,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } @@ -2813,6 +2956,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } @@ -2876,6 +3025,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g preferOAuth := nativePlatform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform) + // require_privacy_set: 获取分组信息 + var schedGroup *Group + if groupID != nil && s.groupRepo != nil { + schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) + } + var accounts []Account accountsLoaded := false @@ -2943,6 +3098,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue @@ -3047,6 +3208,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue @@ -3648,9 +3815,28 @@ func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequ return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) } +// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil), +// 避免 type switch 中 json.RawMessage(底层 []byte)无法匹配 case string / case []any / case nil 的问题。 +// 这是 Go 的 typed nil 陷阱:(json.RawMessage, nil) ≠ (nil, nil)。 +func normalizeSystemParam(system any) any { + raw, ok := system.(json.RawMessage) + if !ok { + return system + } + if len(raw) == 0 { + return nil + } + var parsed any + if err := json.Unmarshal(raw, &parsed); err != nil { + return nil + } + return parsed +} + // systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词 // 使用前缀匹配支持多种变体(标准版、Agent SDK 版等) func systemIncludesClaudeCodePrompt(system any) bool { + system = normalizeSystemParam(system) switch v := system.(type) { case string: return hasClaudeCodePrefix(v) @@ -3676,82 +3862,29 @@ func hasClaudeCodePrefix(text string) bool { return false } -// matchesFilterPrefix 检查文本是否匹配任一过滤前缀 -func matchesFilterPrefix(text string) bool { - for _, prefix := range systemBlockFilterPrefixes { - if strings.HasPrefix(text, prefix) { - return true - } - } - return false -} - -// filterSystemBlocksByPrefix 从 body 的 system 中移除文本匹配 systemBlockFilterPrefixes 前缀的元素 -// 直接从 body 解析 system,不依赖外部传入的 parsed.System(因为前置步骤可能已修改 body 中的 system) -func filterSystemBlocksByPrefix(body []byte) []byte { - sys := gjson.GetBytes(body, "system") - if !sys.Exists() { - return body - } - - switch { - case sys.Type == gjson.String: - if matchesFilterPrefix(sys.Str) { - result, err := sjson.DeleteBytes(body, "system") - if err != nil { - return body - } - return result - } - case sys.IsArray(): - var parsed []any - if err := json.Unmarshal([]byte(sys.Raw), &parsed); err != nil { - return body - } - filtered := make([]any, 0, len(parsed)) - changed := false - for _, item := range parsed { - if m, ok := item.(map[string]any); ok { - if text, ok := m["text"].(string); ok && matchesFilterPrefix(text) { - changed = true - continue - } - } - filtered = append(filtered, item) - } - if changed { - result, err := sjson.SetBytes(body, "system", filtered) - if err != nil { - return body - } - return result - } - } - return body -} - // injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词 // 处理 null、字符串、数组三种格式 func injectClaudeCodePrompt(body []byte, system any) []byte { - claudeCodeBlock := map[string]any{ - "type": "text", - "text": claudeCodeSystemPrompt, - "cache_control": map[string]string{"type": "ephemeral"}, + system = normalizeSystemParam(system) + claudeCodeBlock, err := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true) + if err != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to build Claude Code prompt block: %v", err) + return body } // Opencode plugin applies an extra safeguard: it not only prepends the Claude Code // banner, it also prefixes the next system instruction with the same banner plus // a blank line. This helps when upstream concatenates system instructions. claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt) - var newSystem []any + var items [][]byte switch v := system.(type) { case nil: - newSystem = []any{claudeCodeBlock} + items = [][]byte{claudeCodeBlock} case string: // Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines. if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) { - newSystem = []any{claudeCodeBlock} + items = [][]byte{claudeCodeBlock} } else { // Mirror opencode behavior: keep the banner as a separate system entry, // but also prefix the next system text with the banner. @@ -3759,18 +3892,54 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { if !strings.HasPrefix(v, claudeCodePrefix) { merged = claudeCodePrefix + "\n\n" + v } - newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": merged}} + nextBlock, buildErr := marshalAnthropicSystemTextBlock(merged, false) + if buildErr != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to build prefixed Claude Code system block: %v", buildErr) + return body + } + items = [][]byte{claudeCodeBlock, nextBlock} } case []any: - newSystem = make([]any, 0, len(v)+1) - newSystem = append(newSystem, claudeCodeBlock) + items = make([][]byte, 0, len(v)+1) + items = append(items, claudeCodeBlock) prefixedNext := false - for _, item := range v { - if m, ok := item.(map[string]any); ok { + systemResult := gjson.GetBytes(body, "system") + if systemResult.IsArray() { + systemResult.ForEach(func(_, item gjson.Result) bool { + textResult := item.Get("text") + if textResult.Exists() && textResult.Type == gjson.String && + strings.TrimSpace(textResult.String()) == strings.TrimSpace(claudeCodeSystemPrompt) { + return true + } + + raw := []byte(item.Raw) + // Prefix the first subsequent text system block once. + if !prefixedNext && item.Get("type").String() == "text" && textResult.Exists() && textResult.Type == gjson.String { + text := textResult.String() + if strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) { + next, setErr := sjson.SetBytes(raw, "text", claudeCodePrefix+"\n\n"+text) + if setErr == nil { + raw = next + prefixedNext = true + } + } + } + items = append(items, raw) + return true + }) + } else { + for _, item := range v { + m, ok := item.(map[string]any) + if !ok { + raw, marshalErr := json.Marshal(item) + if marshalErr == nil { + items = append(items, raw) + } + continue + } if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) { continue } - // Prefix the first subsequent text system block once. if !prefixedNext { if blockType, _ := m["type"].(string); blockType == "text" { if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) { @@ -3779,197 +3948,150 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { } } } + raw, marshalErr := json.Marshal(m) + if marshalErr == nil { + items = append(items, raw) + } } - newSystem = append(newSystem, item) } default: - newSystem = []any{claudeCodeBlock} + items = [][]byte{claudeCodeBlock} } - result, err := sjson.SetBytes(body, "system", newSystem) - if err != nil { - logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt: %v", err) + result, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw(items)) + if !ok { + logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt") return body } return result } +type cacheControlPath struct { + path string + log string +} + +func collectCacheControlPaths(body []byte) (invalidThinking []cacheControlPath, messagePaths []string, systemPaths []string) { + system := gjson.GetBytes(body, "system") + if system.IsArray() { + sysIndex := 0 + system.ForEach(func(_, item gjson.Result) bool { + if item.Get("cache_control").Exists() { + path := fmt.Sprintf("system.%d.cache_control", sysIndex) + if item.Get("type").String() == "thinking" { + invalidThinking = append(invalidThinking, cacheControlPath{ + path: path, + log: "[Warning] Removed illegal cache_control from thinking block in system", + }) + } else { + systemPaths = append(systemPaths, path) + } + } + sysIndex++ + return true + }) + } + + messages := gjson.GetBytes(body, "messages") + if messages.IsArray() { + msgIndex := 0 + messages.ForEach(func(_, msg gjson.Result) bool { + content := msg.Get("content") + if content.IsArray() { + contentIndex := 0 + content.ForEach(func(_, item gjson.Result) bool { + if item.Get("cache_control").Exists() { + path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIndex, contentIndex) + if item.Get("type").String() == "thinking" { + invalidThinking = append(invalidThinking, cacheControlPath{ + path: path, + log: fmt.Sprintf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIndex, contentIndex), + }) + } else { + messagePaths = append(messagePaths, path) + } + } + contentIndex++ + return true + }) + } + msgIndex++ + return true + }) + } + + return invalidThinking, messagePaths, systemPaths +} + // enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个) // 超限时优先从 messages 中移除 cache_control,保护 system 中的缓存控制 func enforceCacheControlLimit(body []byte) []byte { - var data map[string]any - if err := json.Unmarshal(body, &data); err != nil { + if len(body) == 0 { return body } - // 清理 thinking 块中的非法 cache_control(thinking 块不支持该字段) - removeCacheControlFromThinkingBlocks(data) + invalidThinking, messagePaths, systemPaths := collectCacheControlPaths(body) + out := body + modified := false - // 计算当前 cache_control 块数量 - count := countCacheControlBlocks(data) + // 先清理 thinking 块中的非法 cache_control(thinking 块不支持该字段) + for _, item := range invalidThinking { + if !gjson.GetBytes(out, item.path).Exists() { + continue + } + next, ok := deleteJSONPathBytes(out, item.path) + if !ok { + continue + } + out = next + modified = true + logger.LegacyPrintf("service.gateway", "%s", item.log) + } + + count := len(messagePaths) + len(systemPaths) if count <= maxCacheControlBlocks { + if modified { + return out + } return body } // 超限:优先从 messages 中移除,再从 system 中移除 - for count > maxCacheControlBlocks { - if removeCacheControlFromMessages(data) { - count-- + remaining := count - maxCacheControlBlocks + for _, path := range messagePaths { + if remaining <= 0 { + break + } + if !gjson.GetBytes(out, path).Exists() { continue } - if removeCacheControlFromSystem(data) { - count-- - continue - } - break - } - - result, err := json.Marshal(data) - if err != nil { - return body - } - return result -} - -// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量 -// 注意:thinking 块不支持 cache_control,统计时跳过 -func countCacheControlBlocks(data map[string]any) int { - count := 0 - - // 统计 system 中的块 - if system, ok := data["system"].([]any); ok { - for _, item := range system { - if m, ok := item.(map[string]any); ok { - // thinking 块不支持 cache_control,跳过 - if blockType, _ := m["type"].(string); blockType == "thinking" { - continue - } - if _, has := m["cache_control"]; has { - count++ - } - } - } - } - - // 统计 messages 中的块 - if messages, ok := data["messages"].([]any); ok { - for _, msg := range messages { - if msgMap, ok := msg.(map[string]any); ok { - if content, ok := msgMap["content"].([]any); ok { - for _, item := range content { - if m, ok := item.(map[string]any); ok { - // thinking 块不支持 cache_control,跳过 - if blockType, _ := m["type"].(string); blockType == "thinking" { - continue - } - if _, has := m["cache_control"]; has { - count++ - } - } - } - } - } - } - } - - return count -} - -// removeCacheControlFromMessages 从 messages 中移除一个 cache_control(从头开始) -// 返回 true 表示成功移除,false 表示没有可移除的 -// 注意:跳过 thinking 块(它不支持 cache_control) -func removeCacheControlFromMessages(data map[string]any) bool { - messages, ok := data["messages"].([]any) - if !ok { - return false - } - - for _, msg := range messages { - msgMap, ok := msg.(map[string]any) + next, ok := deleteJSONPathBytes(out, path) if !ok { continue } - content, ok := msgMap["content"].([]any) + out = next + modified = true + remaining-- + } + + for i := len(systemPaths) - 1; i >= 0 && remaining > 0; i-- { + path := systemPaths[i] + if !gjson.GetBytes(out, path).Exists() { + continue + } + next, ok := deleteJSONPathBytes(out, path) if !ok { continue } - for _, item := range content { - if m, ok := item.(map[string]any); ok { - // thinking 块不支持 cache_control,跳过 - if blockType, _ := m["type"].(string); blockType == "thinking" { - continue - } - if _, has := m["cache_control"]; has { - delete(m, "cache_control") - return true - } - } - } - } - return false -} - -// removeCacheControlFromSystem 从 system 中移除一个 cache_control(从尾部开始,保护注入的 prompt) -// 返回 true 表示成功移除,false 表示没有可移除的 -// 注意:跳过 thinking 块(它不支持 cache_control) -func removeCacheControlFromSystem(data map[string]any) bool { - system, ok := data["system"].([]any) - if !ok { - return false + out = next + modified = true + remaining-- } - // 从尾部开始移除,保护开头注入的 Claude Code prompt - for i := len(system) - 1; i >= 0; i-- { - if m, ok := system[i].(map[string]any); ok { - // thinking 块不支持 cache_control,跳过 - if blockType, _ := m["type"].(string); blockType == "thinking" { - continue - } - if _, has := m["cache_control"]; has { - delete(m, "cache_control") - return true - } - } - } - return false -} - -// removeCacheControlFromThinkingBlocks 强制清理所有 thinking 块中的非法 cache_control -// thinking 块不支持 cache_control 字段,这个函数确保所有 thinking 块都不含该字段 -func removeCacheControlFromThinkingBlocks(data map[string]any) { - // 清理 system 中的 thinking 块 - if system, ok := data["system"].([]any); ok { - for _, item := range system { - if m, ok := item.(map[string]any); ok { - if blockType, _ := m["type"].(string); blockType == "thinking" { - if _, has := m["cache_control"]; has { - delete(m, "cache_control") - logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in system") - } - } - } - } - } - - // 清理 messages 中的 thinking 块 - if messages, ok := data["messages"].([]any); ok { - for msgIdx, msg := range messages { - if msgMap, ok := msg.(map[string]any); ok { - if content, ok := msgMap["content"].([]any); ok { - for contentIdx, item := range content { - if m, ok := item.(map[string]any); ok { - if blockType, _ := m["type"].(string); blockType == "thinking" { - if _, has := m["cache_control"]; has { - delete(m, "cache_control") - logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx) - } - } - } - } - } - } - } + if modified { + return out } + return body } // Forward 转发请求到Claude API @@ -3989,7 +4111,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A passthroughModel = mappedModel } } - return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime) + return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{ + Body: passthroughBody, + RequestModel: passthroughModel, + OriginalModel: parsed.Model, + RequestStream: parsed.Stream, + StartTime: startTime, + }) } if account != nil && account.IsBedrock() { @@ -4015,6 +4143,16 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A reqStream := parsed.Stream originalModel := reqModel + // === DEBUG: 打印客户端原始请求(headers + body 摘要)=== + if c != nil { + s.debugLogGatewaySnapshot("CLIENT_ORIGINAL", c.Request.Header, body, map[string]string{ + "account": fmt.Sprintf("%d(%s)", account.ID, account.Name), + "account_type": string(account.Type), + "model": reqModel, + "stream": strconv.FormatBool(reqStream), + }) + } + isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode @@ -4030,9 +4168,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if s.identityService != nil { fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) if err == nil && fp != nil { - if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" { - normalizeOpts.injectMetadata = true - normalizeOpts.metadataUserID = metadataUserID + // metadata 透传开启时跳过 metadata 注入 + _, mimicMPT := s.settingService.GetGatewayForwardingSettings(ctx) + if !mimicMPT { + if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" { + normalizeOpts.injectMetadata = true + normalizeOpts.metadataUserID = metadataUserID + } } } } @@ -4040,12 +4182,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) } - // OAuth/SetupToken 账号:移除黑名单前缀匹配的 system 元素(如客户端注入的计费元数据) - // 放在 inject/normalize 之后,确保不会被覆盖 - if account.IsOAuth() { - body = filterSystemBlocksByPrefix(body) - } - // 强制执行 cache_control 块数量限制(最多 4 个) body = enforceCacheControlLimit(body) @@ -4080,15 +4216,23 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return nil, err } - // 获取代理URL + // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递) proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() + if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" { + proxyURL = account.Proxy.URL() + } } + // 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析) + tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account) + // 调试日志:记录即将转发的账号信息 logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s", - account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL) + account.ID, account.Name, account.Platform, account.Type, tlsProfile, proxyURL) + // Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400. + body = StripEmptyTextBlocks(body) + // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 setOpsUpstreamRequestBody(c, body) @@ -4105,7 +4249,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 发送请求 - resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, tlsProfile) if err != nil { if resp != nil && resp.Body != nil { _ = resp.Body.Close() @@ -4118,6 +4262,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "request_error", Message: safeErr, }) @@ -4137,13 +4282,14 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if readErr == nil { _ = resp.Body.Close() - if s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) { + if s.shouldRectifySignatureError(ctx, account, respBody) { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "signature_error", Message: extractUpstreamErrorMessage(respBody), Detail: func() string { @@ -4169,7 +4315,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A resp.Body = io.NopCloser(bytes.NewReader(respBody)) break } - logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID) + logger.LegacyPrintf("service.gateway", "[warn] Account %d: thinking blocks have invalid signature, retrying with filtered blocks", account.ID) // Conservative two-stage fallback: // 1) Disable thinking + thinking->text (preserve content) @@ -4181,23 +4327,24 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A retryReq, buildErr := s.buildUpstreamRequest(retryCtx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) releaseRetryCtx() if buildErr == nil { - retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, tlsProfile) if retryErr == nil { if retryResp.StatusCode < 400 { - logger.LegacyPrintf("service.gateway", "Account %d: signature error retry succeeded (thinking downgraded)", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: thinking block retry succeeded (blocks downgraded)", account.ID) resp = retryResp break } retryRespBody, retryReadErr := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) _ = retryResp.Body.Close() - if retryReadErr == nil && retryResp.StatusCode == 400 && s.isThinkingBlockSignatureError(retryRespBody) { + if retryReadErr == nil && retryResp.StatusCode == 400 && s.isSignatureErrorPattern(ctx, account, retryRespBody) { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: retryResp.StatusCode, UpstreamRequestID: retryResp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(retryReq.URL.String()), Kind: "signature_retry_thinking", Message: extractUpstreamErrorMessage(retryRespBody), Detail: func() string { @@ -4215,7 +4362,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A retryReq2, buildErr2 := s.buildUpstreamRequest(retryCtx2, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) releaseRetryCtx2() if buildErr2 == nil { - retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, tlsProfile) if retryErr2 == nil { resp = retryResp2 break @@ -4228,6 +4375,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(retryReq2.URL.String()), Kind: "signature_retry_tools_request_error", Message: sanitizeUpstreamErrorMessage(retryErr2.Error()), }) @@ -4267,6 +4415,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "budget_constraint_error", Message: errMsg, Detail: func() string { @@ -4284,7 +4433,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A budgetRetryReq, buildErr := s.buildUpstreamRequest(budgetRetryCtx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) releaseBudgetRetryCtx() if buildErr == nil { - budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, tlsProfile) if retryErr == nil { resp = budgetRetryResp break @@ -4328,6 +4477,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "retry", Message: extractUpstreamErrorMessage(respBody), Detail: func() string { @@ -4513,6 +4663,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, // 使用原始模型用于计费和日志 + UpstreamModel: mappedModel, Stream: reqStream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -4520,14 +4671,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }, nil } +type anthropicPassthroughForwardInput struct { + Body []byte + RequestModel string + OriginalModel string + RequestStream bool + StartTime time.Time +} + func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( ctx context.Context, c *gin.Context, account *Account, body []byte, reqModel string, + originalModel string, reqStream bool, startTime time.Time, +) (*ForwardResult, error) { + return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{ + Body: body, + RequestModel: reqModel, + OriginalModel: originalModel, + RequestStream: reqStream, + StartTime: startTime, + }) +} + +func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( + ctx context.Context, + c *gin.Context, + account *Account, + input anthropicPassthroughForwardInput, ) (*ForwardResult, error) { token, tokenType, err := s.GetAccessToken(ctx, account) if err != nil { @@ -4543,25 +4718,28 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( } logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v", - account.ID, account.Name, reqModel, reqStream) + account.ID, account.Name, input.RequestModel, input.RequestStream) if c != nil { c.Set("anthropic_passthrough", true) } + // Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400. + input.Body = StripEmptyTextBlocks(input.Body) + // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 - setOpsUpstreamRequestBody(c, body) + setOpsUpstreamRequestBody(c, input.Body) var resp *http.Response retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) - upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, body, token) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, input.RequestStream) + upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, input.Body, token) releaseUpstreamCtx() if err != nil { return nil, err } - resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) if err != nil { if resp != nil && resp.Body != nil { _ = resp.Body.Close() @@ -4573,6 +4751,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Passthrough: true, Kind: "request_error", Message: safeErr, @@ -4612,6 +4791,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Passthrough: true, Kind: "retry", Message: extractUpstreamErrorMessage(respBody), @@ -4713,8 +4893,8 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( var usage *ClaudeUsage var firstTokenMs *int var clientDisconnect bool - if reqStream { - streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime, reqModel) + if input.RequestStream { + streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, input.StartTime, input.RequestModel) if err != nil { return nil, err } @@ -4734,9 +4914,10 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( return &ForwardResult{ RequestID: resp.Header.Get("x-request-id"), Usage: *usage, - Model: reqModel, - Stream: reqStream, - Duration: time.Since(startTime), + Model: input.OriginalModel, + UpstreamModel: input.RequestModel, + Stream: input.RequestStream, + Duration: time.Since(input.StartTime), FirstTokenMs: firstTokenMs, ClientDisconnect: clientDisconnect, }, nil @@ -4770,8 +4951,9 @@ func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough( if !allowedHeaders[lowerKey] { continue } + wireKey := resolveWireCasing(key) for _, v := range values { - req.Header.Add(key, v) + addHeaderRaw(req.Header, wireKey, v) } } } @@ -4781,13 +4963,13 @@ func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough( req.Header.Del("x-api-key") req.Header.Del("x-goog-api-key") req.Header.Del("cookie") - req.Header.Set("x-api-key", token) + setHeaderRaw(req.Header, "x-api-key", token) - if req.Header.Get("content-type") == "" { - req.Header.Set("content-type", "application/json") + if getHeaderRaw(req.Header, "content-type") == "" { + setHeaderRaw(req.Header, "content-type", "application/json") } - if req.Header.Get("anthropic-version") == "" { - req.Header.Set("anthropic-version", "2023-06-01") + if getHeaderRaw(req.Header, "anthropic-version") == "" { + setHeaderRaw(req.Header, "anthropic-version", "2023-06-01") } return req, nil @@ -5241,6 +5423,7 @@ func (s *GatewayService) forwardBedrock( RequestID: resp.Header.Get("x-amzn-requestid"), Usage: *usage, Model: reqModel, + UpstreamModel: mappedModel, Stream: reqStream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -5275,7 +5458,7 @@ func (s *GatewayService) executeBedrockUpstream( return nil, err } - resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, false) + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, nil) if err != nil { if resp != nil && resp.Body != nil { _ = resp.Body.Close() @@ -5287,6 +5470,7 @@ func (s *GatewayService) executeBedrockUpstream( AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "request_error", Message: safeErr, }) @@ -5323,6 +5507,7 @@ func (s *GatewayService) executeBedrockUpstream( AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "retry", Message: extractUpstreamErrorMessage(respBody), Detail: func() string { @@ -5511,6 +5696,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } targetURL = validatedURL + "/v1/messages?beta=true" } + } else if account.IsCustomBaseURLEnabled() { + customURL := account.GetCustomBaseURL() + if customURL == "" { + return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID) + } + validatedURL, err := s.validateUpstreamBaseURL(customURL) + if err != nil { + return nil, err + } + targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages", account) } clientHeaders := http.Header{} @@ -5518,8 +5713,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex clientHeaders = c.Request.Header } - // OAuth账号:应用统一指纹 + // OAuth账号:应用统一指纹和metadata重写(受设置开关控制) var fingerprint *Fingerprint + enableFP, enableMPT := true, false + if s.settingService != nil { + enableFP, enableMPT = s.settingService.GetGatewayForwardingSettings(ctx) + } if account.IsOAuth() && s.identityService != nil { // 1. 获取或创建指纹(包含随机生成的ClientID) fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) @@ -5527,14 +5726,19 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex logger.LegacyPrintf("service.gateway", "Warning: failed to get fingerprint for account %d: %v", account.ID, err) // 失败时降级为透传原始headers } else { - fingerprint = fp + if enableFP { + fingerprint = fp + } // 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid) // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 - accountUUID := account.GetExtraString("account_uuid") - if accountUUID != "" && fp.ClientID != "" { - if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { - body = newBody + // 当 metadata 透传开启时跳过重写 + if !enableMPT { + accountUUID := account.GetExtraString("account_uuid") + if accountUUID != "" && fp.ClientID != "" { + if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 { + body = newBody + } } } } @@ -5545,19 +5749,20 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex return nil, err } - // 设置认证头 + // 设置认证头(保持原始大小写) if tokenType == "oauth" { - req.Header.Set("authorization", "Bearer "+token) + setHeaderRaw(req.Header, "authorization", "Bearer "+token) } else { - req.Header.Set("x-api-key", token) + setHeaderRaw(req.Header, "x-api-key", token) } - // 白名单透传headers + // 白名单透传headers(恢复真实 wire casing) for key, values := range clientHeaders { lowerKey := strings.ToLower(key) if allowedHeaders[lowerKey] { + wireKey := resolveWireCasing(key) for _, v := range values { - req.Header.Add(key, v) + addHeaderRaw(req.Header, wireKey, v) } } } @@ -5567,15 +5772,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex s.identityService.ApplyFingerprint(req, fingerprint) } - // 确保必要的headers存在 - if req.Header.Get("content-type") == "" { - req.Header.Set("content-type", "application/json") + // 确保必要的headers存在(保持原始大小写) + if getHeaderRaw(req.Header, "content-type") == "" { + setHeaderRaw(req.Header, "content-type", "application/json") } - if req.Header.Get("anthropic-version") == "" { - req.Header.Set("anthropic-version", "2023-06-01") + if getHeaderRaw(req.Header, "anthropic-version") == "" { + setHeaderRaw(req.Header, "anthropic-version", "2023-06-01") } if tokenType == "oauth" { - applyClaudeOAuthHeaderDefaults(req, reqStream) + applyClaudeOAuthHeaderDefaults(req) } // Build effective drop set: merge static defaults with dynamic beta policy filter rules @@ -5591,31 +5796,50 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在 applyClaudeCodeMimicHeaders(req, reqStream) - incomingBeta := req.Header.Get("anthropic-beta") + incomingBeta := getHeaderRaw(req.Header, "anthropic-beta") // Match real Claude CLI traffic (per mitmproxy reports): // messages requests typically use only oauth + interleaved-thinking. // Also drop claude-code beta if a downstream client added it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet)) + setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet)) } else { // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta - clientBetaHeader := req.Header.Get("anthropic-beta") - req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), effectiveDropSet)) + clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") + setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), effectiveDropSet)) } } else { // API-key accounts: apply beta policy filter to strip controlled tokens - if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" { - req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, effectiveDropSet)) + if existingBeta := getHeaderRaw(req.Header, "anthropic-beta"); existingBeta != "" { + setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(existingBeta, effectiveDropSet)) } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey { // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) if requestNeedsBetaFeatures(body) { if beta := defaultAPIKeyBetaHeader(body); beta != "" { - req.Header.Set("anthropic-beta", beta) + setHeaderRaw(req.Header, "anthropic-beta", beta) } } } } + // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 + if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { + if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { + if parsed := ParseMetadataUserID(uid); parsed != nil { + setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) + } + } + } + + // === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 === + s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{ + "url": req.URL.String(), + "token_type": tokenType, + "mimic_claude_code": strconv.FormatBool(mimicClaudeCode), + "fingerprint_applied": strconv.FormatBool(fingerprint != nil), + "enable_fp": strconv.FormatBool(enableFP), + "enable_mpt": strconv.FormatBool(enableMPT), + }) + // Always capture a compact fingerprint line for later error diagnostics. // We only print it when needed (or when the explicit debug flag is enabled). if c != nil && tokenType == "oauth" { @@ -5695,24 +5919,21 @@ func defaultAPIKeyBetaHeader(body []byte) string { return claude.APIKeyBetaHeader } -func applyClaudeOAuthHeaderDefaults(req *http.Request, isStream bool) { +func applyClaudeOAuthHeaderDefaults(req *http.Request) { if req == nil { return } - if req.Header.Get("accept") == "" { - req.Header.Set("accept", "application/json") + if getHeaderRaw(req.Header, "Accept") == "" { + setHeaderRaw(req.Header, "Accept", "application/json") } for key, value := range claude.DefaultHeaders { if value == "" { continue } - if req.Header.Get(key) == "" { - req.Header.Set(key, value) + if getHeaderRaw(req.Header, key) == "" { + setHeaderRaw(req.Header, resolveWireCasing(key), value) } } - if isStream && req.Header.Get("x-stainless-helper-method") == "" { - req.Header.Set("x-stainless-helper-method", "stream") - } } func mergeAnthropicBeta(required []string, incoming string) string { @@ -6007,18 +6228,19 @@ func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) { return } // Start with the standard defaults (fill missing). - applyClaudeOAuthHeaderDefaults(req, isStream) + applyClaudeOAuthHeaderDefaults(req) // Then force key headers to match Claude Code fingerprint regardless of what the client sent. + // 使用 resolveWireCasing 确保 key 与真实 wire format 一致(如 "x-app" 而非 "X-App") for key, value := range claude.DefaultHeaders { if value == "" { continue } - req.Header.Set(key, value) + setHeaderRaw(req.Header, resolveWireCasing(key), value) } // Real Claude CLI uses Accept: application/json (even for streaming). - req.Header.Set("accept", "application/json") + setHeaderRaw(req.Header, "Accept", "application/json") if isStream { - req.Header.Set("x-stainless-helper-method", "stream") + setHeaderRaw(req.Header, "x-stainless-helper-method", "stream") } } @@ -6036,6 +6258,59 @@ func truncateForLog(b []byte, maxBytes int) string { return s } +// shouldRectifySignatureError 统一判断是否应触发签名整流(strip thinking blocks 并重试)。 +// 根据账号类型检查对应的开关和匹配模式。 +func (s *GatewayService) shouldRectifySignatureError(ctx context.Context, account *Account, respBody []byte) bool { + if account.Type == AccountTypeAPIKey { + // API Key 账号:独立开关,一次读取配置 + settings, err := s.settingService.GetRectifierSettings(ctx) + if err != nil || !settings.Enabled || !settings.APIKeySignatureEnabled { + return false + } + // 先检查内置模式(同 OAuth),再检查自定义关键词 + if s.isThinkingBlockSignatureError(respBody) { + return true + } + return matchSignaturePatterns(respBody, settings.APIKeySignaturePatterns) + } + // OAuth/SetupToken/Upstream/Bedrock 等:保持原有行为(内置模式 + 原开关) + return s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) +} + +// isSignatureErrorPattern 仅做模式匹配,不检查开关。 +// 用于已进入重试流程后的二阶段检测(此时开关已在首次调用时验证过)。 +func (s *GatewayService) isSignatureErrorPattern(ctx context.Context, account *Account, respBody []byte) bool { + if s.isThinkingBlockSignatureError(respBody) { + return true + } + if account.Type == AccountTypeAPIKey { + settings, err := s.settingService.GetRectifierSettings(ctx) + if err != nil { + return false + } + return matchSignaturePatterns(respBody, settings.APIKeySignaturePatterns) + } + return false +} + +// matchSignaturePatterns 检查响应体是否匹配自定义关键词列表(不区分大小写)。 +func matchSignaturePatterns(respBody []byte, patterns []string) bool { + if len(patterns) == 0 { + return false + } + bodyLower := strings.ToLower(string(respBody)) + for _, p := range patterns { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if strings.Contains(bodyLower, strings.ToLower(p)) { + return true + } + } + return false +} + // isThinkingBlockSignatureError 检测是否是thinking block相关错误 // 这类错误可以通过过滤thinking blocks并重试来解决 func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { @@ -6044,13 +6319,9 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { return false } - // Log for debugging - logger.LegacyPrintf("service.gateway", "[SignatureCheck] Checking error message: %s", msg) - // 检测signature相关的错误(更宽松的匹配) // 例如: "Invalid `signature` in `thinking` block", "***.signature" 等 if strings.Contains(msg, "signature") { - logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected signature error") return true } @@ -6068,9 +6339,11 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { return true } - // 检测空消息内容错误(可能是过滤 thinking blocks 后导致的) + // 检测空消息内容错误(可能是过滤 thinking blocks 后导致的,或客户端发送了空 text block) // 例如: "all messages must have non-empty content" - if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") { + // "messages: text content blocks must be non-empty" + if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") || + strings.Contains(msg, "content blocks must be non-empty") { logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected empty content error") return true } @@ -7456,6 +7729,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } var cost *CostBreakdown + billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) // 根据请求类型选择计费方式 if result.MediaType == "image" || result.MediaType == "video" { @@ -7471,7 +7745,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu if result.MediaType == "image" { cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) } else { - cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier) + cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) } } else if result.MediaType == "prompt" { cost = &CostBreakdown{} @@ -7485,7 +7759,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu Price4K: apiKey.Group.ImagePrice4K, } } - cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) + cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) } else { // Token 计费 tokens := UsageTokens{ @@ -7497,7 +7771,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error - cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier) + cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) if err != nil { logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} @@ -7529,6 +7803,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu AccountID: account.ID, RequestID: requestID, Model: result.Model, + RequestedModel: result.Model, + UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), @@ -7658,6 +7934,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * } var cost *CostBreakdown + billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) // 根据请求类型选择计费方式 if result.ImageCount > 0 { @@ -7670,7 +7947,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * Price4K: apiKey.Group.ImagePrice4K, } } - cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) + cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) } else { // Token 计费(使用长上下文计费方法) tokens := UsageTokens{ @@ -7682,7 +7959,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error - cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) + cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) if err != nil { logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} @@ -7710,6 +7987,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * AccountID: account.ID, RequestID: requestID, Model: result.Model, + RequestedModel: result.Model, + UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), @@ -7813,6 +8092,9 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, body := parsed.Body reqModel := parsed.Model + // Pre-filter: strip empty text blocks to prevent upstream 400. + body = StripEmptyTextBlocks(body) + isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode @@ -7868,14 +8150,16 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return err } - // 获取代理URL + // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递) proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() + if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" { + proxyURL = account.Proxy.URL() + } } // 发送请求 - resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) if err != nil { setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "") s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") @@ -7897,13 +8181,13 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks) - if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) { + if resp.StatusCode == 400 && s.shouldRectifySignatureError(ctx, account, respBody) { logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) filteredBody := FilterThinkingBlocksForRetry(body) retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode) if buildErr == nil { - retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) if retryErr == nil { resp = retryResp respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) @@ -7992,7 +8276,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex proxyURL = account.Proxy.URL() } - resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) if err != nil { setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "") appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ @@ -8000,6 +8284,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Passthrough: true, Kind: "request_error", Message: sanitizeUpstreamErrorMessage(err.Error()), @@ -8055,6 +8340,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Passthrough: true, Kind: "http_error", Message: upstreamMsg, @@ -8112,8 +8398,9 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough( if !allowedHeaders[lowerKey] { continue } + wireKey := resolveWireCasing(key) for _, v := range values { - req.Header.Add(key, v) + addHeaderRaw(req.Header, wireKey, v) } } } @@ -8147,6 +8434,16 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" } + } else if account.IsCustomBaseURLEnabled() { + customURL := account.GetCustomBaseURL() + if customURL == "" { + return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID) + } + validatedURL, err := s.validateUpstreamBaseURL(customURL) + if err != nil { + return nil, err + } + targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages/count_tokens", account) } clientHeaders := http.Header{} @@ -8154,15 +8451,23 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con clientHeaders = c.Request.Header } - // OAuth 账号:应用统一指纹和重写 userID + // OAuth 账号:应用统一指纹和重写 userID(受设置开关控制) // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 + ctEnableFP, ctEnableMPT := true, false + if s.settingService != nil { + ctEnableFP, ctEnableMPT = s.settingService.GetGatewayForwardingSettings(ctx) + } + var ctFingerprint *Fingerprint if account.IsOAuth() && s.identityService != nil { fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) if err == nil { - accountUUID := account.GetExtraString("account_uuid") - if accountUUID != "" && fp.ClientID != "" { - if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { - body = newBody + ctFingerprint = fp + if !ctEnableMPT { + accountUUID := account.GetExtraString("account_uuid") + if accountUUID != "" && fp.ClientID != "" { + if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 { + body = newBody + } } } } @@ -8173,40 +8478,38 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con return nil, err } - // 设置认证头 + // 设置认证头(保持原始大小写) if tokenType == "oauth" { - req.Header.Set("authorization", "Bearer "+token) + setHeaderRaw(req.Header, "authorization", "Bearer "+token) } else { - req.Header.Set("x-api-key", token) + setHeaderRaw(req.Header, "x-api-key", token) } - // 白名单透传 headers + // 白名单透传 headers(恢复真实 wire casing) for key, values := range clientHeaders { lowerKey := strings.ToLower(key) if allowedHeaders[lowerKey] { + wireKey := resolveWireCasing(key) for _, v := range values { - req.Header.Add(key, v) + addHeaderRaw(req.Header, wireKey, v) } } } - // OAuth 账号:应用指纹到请求头 - if account.IsOAuth() && s.identityService != nil { - fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) - if fp != nil { - s.identityService.ApplyFingerprint(req, fp) - } + // OAuth 账号:应用指纹到请求头(受设置开关控制) + if ctEnableFP && ctFingerprint != nil { + s.identityService.ApplyFingerprint(req, ctFingerprint) } - // 确保必要的 headers 存在 - if req.Header.Get("content-type") == "" { - req.Header.Set("content-type", "application/json") + // 确保必要的 headers 存在(保持原始大小写) + if getHeaderRaw(req.Header, "content-type") == "" { + setHeaderRaw(req.Header, "content-type", "application/json") } - if req.Header.Get("anthropic-version") == "" { - req.Header.Set("anthropic-version", "2023-06-01") + if getHeaderRaw(req.Header, "anthropic-version") == "" { + setHeaderRaw(req.Header, "anthropic-version", "2023-06-01") } if tokenType == "oauth" { - applyClaudeOAuthHeaderDefaults(req, false) + applyClaudeOAuthHeaderDefaults(req) } // Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules @@ -8217,35 +8520,44 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if mimicClaudeCode { applyClaudeCodeMimicHeaders(req, false) - incomingBeta := req.Header.Get("anthropic-beta") + incomingBeta := getHeaderRaw(req.Header, "anthropic-beta") requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting} - req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet)) + setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet)) } else { - clientBetaHeader := req.Header.Get("anthropic-beta") + clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") if clientBetaHeader == "" { - req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader) + setHeaderRaw(req.Header, "anthropic-beta", claude.CountTokensBetaHeader) } else { beta := s.getBetaHeader(modelID, clientBetaHeader) if !strings.Contains(beta, claude.BetaTokenCounting) { beta = beta + "," + claude.BetaTokenCounting } - req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet)) + setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet)) } } } else { // API-key accounts: apply beta policy filter to strip controlled tokens - if existingBeta := req.Header.Get("anthropic-beta"); existingBeta != "" { - req.Header.Set("anthropic-beta", stripBetaTokensWithSet(existingBeta, ctEffectiveDropSet)) + if existingBeta := getHeaderRaw(req.Header, "anthropic-beta"); existingBeta != "" { + setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(existingBeta, ctEffectiveDropSet)) } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey { // API-key:与 messages 同步的按需 beta 注入(默认关闭) if requestNeedsBetaFeatures(body) { if beta := defaultAPIKeyBetaHeader(body); beta != "" { - req.Header.Set("anthropic-beta", beta) + setHeaderRaw(req.Header, "anthropic-beta", beta) } } } } + // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 + if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { + if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { + if parsed := ParseMetadataUserID(uid); parsed != nil { + setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) + } + } + } + if c != nil && tokenType == "oauth" { c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) } @@ -8267,6 +8579,19 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m }) } +// buildCustomRelayURL 构建自定义中继转发 URL +// 在 path 后附加 beta=true 和可选的 proxy 查询参数 +func (s *GatewayService) buildCustomRelayURL(baseURL, path string, account *Account) string { + u := strings.TrimRight(baseURL, "/") + path + "?beta=true" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL := account.Proxy.URL() + if proxyURL != "" { + u += "&proxy=" + url.QueryEscape(proxyURL) + } + } + return u +} + func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) @@ -8410,3 +8735,95 @@ func reconcileCachedTokens(usage map[string]any) bool { usage["cache_read_input_tokens"] = cached return true } + +const debugGatewayBodyDefaultFilename = "gateway_debug.log" + +// initDebugGatewayBodyFile 初始化网关调试日志文件。 +// +// - "1"/"true" 等布尔值 → 当前目录下 gateway_debug.log +// - 已有目录路径 → 该目录下 gateway_debug.log +// - 其他 → 视为完整文件路径 +func (s *GatewayService) initDebugGatewayBodyFile(path string) { + if parseDebugEnvBool(path) { + path = debugGatewayBodyDefaultFilename + } + + // 如果 path 指向一个已存在的目录,自动追加默认文件名 + if info, err := os.Stat(path); err == nil && info.IsDir() { + path = filepath.Join(path, debugGatewayBodyDefaultFilename) + } + + // 确保父目录存在 + if dir := filepath.Dir(path); dir != "." { + if err := os.MkdirAll(dir, 0755); err != nil { + slog.Error("failed to create gateway debug log directory", "dir", dir, "error", err) + return + } + } + + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + slog.Error("failed to open gateway debug log file", "path", path, "error", err) + return + } + s.debugGatewayBodyFile.Store(f) + slog.Info("gateway debug logging enabled", "path", path) +} + +// debugLogGatewaySnapshot 将网关请求的完整快照(headers + body)写入独立的调试日志文件, +// 用于对比客户端原始请求和上游转发请求。 +// +// 启用方式(环境变量): +// +// SUB2API_DEBUG_GATEWAY_BODY=1 # 写入 gateway_debug.log +// SUB2API_DEBUG_GATEWAY_BODY=/tmp/gateway_debug.log # 写入指定路径 +// +// tag: "CLIENT_ORIGINAL" 或 "UPSTREAM_FORWARD" +func (s *GatewayService) debugLogGatewaySnapshot(tag string, headers http.Header, body []byte, extra map[string]string) { + f := s.debugGatewayBodyFile.Load() + if f == nil { + return + } + + var buf strings.Builder + ts := time.Now().Format("2006-01-02 15:04:05.000") + fmt.Fprintf(&buf, "\n========== [%s] %s ==========\n", ts, tag) + + // 1. context + if len(extra) > 0 { + fmt.Fprint(&buf, "--- context ---\n") + extraKeys := make([]string, 0, len(extra)) + for k := range extra { + extraKeys = append(extraKeys, k) + } + sort.Strings(extraKeys) + for _, k := range extraKeys { + fmt.Fprintf(&buf, " %s: %s\n", k, extra[k]) + } + } + + // 2. headers(按真实 Claude CLI wire 顺序排列,便于与抓包对比;auth 脱敏) + fmt.Fprint(&buf, "--- headers ---\n") + for _, k := range sortHeadersByWireOrder(headers) { + for _, v := range headers[k] { + fmt.Fprintf(&buf, " %s: %s\n", k, safeHeaderValueForLog(k, v)) + } + } + + // 3. body(完整输出,格式化 JSON 便于 diff) + fmt.Fprint(&buf, "--- body ---\n") + if len(body) == 0 { + fmt.Fprint(&buf, " (empty)\n") + } else { + var pretty bytes.Buffer + if json.Indent(&pretty, body, " ", " ") == nil { + fmt.Fprintf(&buf, " %s\n", pretty.Bytes()) + } else { + // JSON 格式化失败时原样输出 + fmt.Fprintf(&buf, " %s\n", body) + } + } + + // 写入文件(调试用,并发写入可能交错但不影响可读性) + _, _ = f.WriteString(buf.String()) +} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index e65c838d..5b1abc11 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -1028,14 +1028,15 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, - Stream: req.Stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: imageSize, + RequestID: requestID, + Usage: *usage, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: req.Stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: imageSize, }, nil } @@ -1241,12 +1242,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. estimated := estimateGeminiCountTokens(body) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) return &ForwardResult{ - RequestID: "", - Usage: ClaudeUsage{}, - Model: originalModel, - Stream: false, - Duration: time.Since(startTime), - FirstTokenMs: nil, + RequestID: "", + Usage: ClaudeUsage{}, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: nil, }, nil } setOpsUpstreamError(c, 0, safeErr, "") @@ -1310,12 +1312,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. estimated := estimateGeminiCountTokens(body) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) return &ForwardResult{ - RequestID: "", - Usage: ClaudeUsage{}, - Model: originalModel, - Stream: false, - Duration: time.Since(startTime), - FirstTokenMs: nil, + RequestID: "", + Usage: ClaudeUsage{}, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: nil, }, nil } // Final attempt: surface the upstream error body (passed through below) instead of a generic retry error. @@ -1350,12 +1353,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. estimated := estimateGeminiCountTokens(body) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) return &ForwardResult{ - RequestID: requestID, - Usage: ClaudeUsage{}, - Model: originalModel, - Stream: false, - Duration: time.Since(startTime), - FirstTokenMs: nil, + RequestID: requestID, + Usage: ClaudeUsage{}, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: nil, }, nil } @@ -1527,14 +1531,15 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, - Stream: stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: imageSize, + RequestID: requestID, + Usage: *usage, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: imageSize, }, nil } diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go index 7560f480..f659f0e6 100644 --- a/backend/internal/service/gemini_messages_compat_service_test.go +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -1,6 +1,7 @@ package service import ( + "context" "encoding/json" "fmt" "io" @@ -11,10 +12,35 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) +type geminiCompatHTTPUpstreamStub struct { + response *http.Response + err error + calls int + lastReq *http.Request +} + +func (s *geminiCompatHTTPUpstreamStub) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + s.calls++ + s.lastReq = req + if s.err != nil { + return nil, s.err + } + if s.response == nil { + return nil, fmt.Errorf("missing stub response") + } + resp := *s.response + return &resp, nil +} + +func (s *geminiCompatHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { + return s.Do(req, proxyURL, accountID, accountConcurrency) +} + // TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换 func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { tests := []struct { @@ -170,6 +196,42 @@ func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLo require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志") } +func TestGeminiMessagesCompatServiceForward_PreservesRequestedModelAndMappedUpstreamModel(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + httpStub := &geminiCompatHTTPUpstreamStub{ + response: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"x-request-id": []string{"gemini-req-1"}}, + Body: io.NopCloser(strings.NewReader(`{"candidates":[{"content":{"parts":[{"text":"hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}`)), + }, + } + svc := &GeminiMessagesCompatService{httpUpstream: httpStub, cfg: &config.Config{}} + account := &Account{ + ID: 1, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-key", + "model_mapping": map[string]any{ + "claude-sonnet-4": "claude-sonnet-4-20250514", + }, + }, + } + body := []byte(`{"model":"claude-sonnet-4","max_tokens":16,"messages":[{"role":"user","content":"hello"}]}`) + + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "claude-sonnet-4", result.Model) + require.Equal(t, "claude-sonnet-4-20250514", result.UpstreamModel) + require.Equal(t, 1, httpStub.calls) + require.NotNil(t, httpStub.lastReq) + require.Contains(t, httpStub.lastReq.URL.String(), "/models/claude-sonnet-4-20250514:") +} + func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) { claudeReq := map[string]any{ "model": "claude-haiku-4-5-20251001", diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index b0b804eb..5e09b95a 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -79,7 +79,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { +func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { @@ -230,8 +230,8 @@ func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platf func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) { return false, nil } -func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { - return 0, nil +func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { + return 0, 0, nil } func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, nil diff --git a/backend/internal/service/gemini_session.go b/backend/internal/service/gemini_session.go index 1780d1da..cd291328 100644 --- a/backend/internal/service/gemini_session.go +++ b/backend/internal/service/gemini_session.go @@ -52,10 +52,11 @@ func BuildGeminiDigestChain(req *antigravity.GeminiRequest) string { // 返回 16 字符的 Base64 编码的 SHA256 前缀 func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, model string) string { // 组合所有标识符 + normalizedUserAgent := NormalizeSessionUserAgent(userAgent) combined := strconv.FormatInt(userID, 10) + ":" + strconv.FormatInt(apiKeyID, 10) + ":" + ip + ":" + - userAgent + ":" + + normalizedUserAgent + ":" + platform + ":" + model diff --git a/backend/internal/service/gemini_session_test.go b/backend/internal/service/gemini_session_test.go index a034cddd..27321996 100644 --- a/backend/internal/service/gemini_session_test.go +++ b/backend/internal/service/gemini_session_test.go @@ -152,6 +152,24 @@ func TestGenerateGeminiPrefixHash(t *testing.T) { } } +func TestGenerateGeminiPrefixHash_IgnoresUserAgentVersionNoise(t *testing.T) { + hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0 codex_cli_rs/0.1.0", "antigravity", "gemini-2.5-pro") + hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0 codex_cli_rs/0.1.1", "antigravity", "gemini-2.5-pro") + + if hash1 != hash2 { + t.Fatalf("version-only User-Agent changes should not perturb Gemini prefix hash: %s vs %s", hash1, hash2) + } +} + +func TestGenerateGeminiPrefixHash_IgnoresFreeformUserAgentVersionNoise(t *testing.T) { + hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Codex CLI 0.1.0", "antigravity", "gemini-2.5-pro") + hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Codex CLI 0.1.1", "antigravity", "gemini-2.5-pro") + + if hash1 != hash2 { + t.Fatalf("free-form version-only User-Agent changes should not perturb Gemini prefix hash: %s vs %s", hash1, hash2) + } +} + func TestParseGeminiSessionValue(t *testing.T) { tests := []struct { name string diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index 1dab67c4..7add3460 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -135,7 +135,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou if tierID != "" { account.Credentials["tier_id"] = tierID } - _ = p.accountRepo.Update(ctx, account) + _ = persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials) } } diff --git a/backend/internal/service/generate_session_hash_test.go b/backend/internal/service/generate_session_hash_test.go index 8aa358a5..39679c3d 100644 --- a/backend/internal/service/generate_session_hash_test.go +++ b/backend/internal/service/generate_session_hash_test.go @@ -24,7 +24,7 @@ func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) { svc := &GatewayService{} parsed := &ParsedRequest{ - MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000", System: "You are a helpful assistant.", HasSystem: true, Messages: []any{ @@ -196,7 +196,7 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) { svc := &GatewayService{} parsed := &ParsedRequest{ - MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000", Messages: []any{ map[string]any{"role": "user", "content": "hello"}, }, @@ -212,6 +212,22 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) { "metadata session_id should take priority over SessionContext") } +func TestGenerateSessionHash_MetadataJSON_HasHighestPriority(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + MetadataUserID: `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`, + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + hash := svc.GenerateSessionHash(parsed) + require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", hash, "JSON format metadata session_id should have highest priority") +} + func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) { svc := &GatewayService{} @@ -488,6 +504,48 @@ func TestGenerateSessionHash_SessionContext_UADifference(t *testing.T) { require.NotEqual(t, h1, h2, "different User-Agent should produce different hash") } +func TestGenerateSessionHash_SessionContext_UAVersionNoiseIgnored(t *testing.T) { + svc := &GatewayService{} + + base := func(ua string) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: ua, + APIKeyID: 1, + }, + } + } + + h1 := svc.GenerateSessionHash(base("Mozilla/5.0 codex_cli_rs/0.1.0")) + h2 := svc.GenerateSessionHash(base("Mozilla/5.0 codex_cli_rs/0.1.1")) + require.Equal(t, h1, h2, "version-only User-Agent changes should not perturb the sticky session hash") +} + +func TestGenerateSessionHash_SessionContext_FreeformUAVersionNoiseIgnored(t *testing.T) { + svc := &GatewayService{} + + base := func(ua string) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: ua, + APIKeyID: 1, + }, + } + } + + h1 := svc.GenerateSessionHash(base("Codex CLI 0.1.0")) + h2 := svc.GenerateSessionHash(base("Codex CLI 0.1.1")) + require.Equal(t, h1, h2, "free-form version-only User-Agent changes should not perturb the sticky session hash") +} + func TestGenerateSessionHash_SessionContext_APIKeyIDDifference(t *testing.T) { svc := &GatewayService{} diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 537b5a3b..e0f81a39 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -59,13 +59,17 @@ type Group struct { // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch bool + RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini) + RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini) DefaultMappedModel string CreatedAt time.Time UpdatedAt time.Time - AccountGroups []AccountGroup - AccountCount int64 + AccountGroups []AccountGroup + AccountCount int64 + ActiveAccountCount int64 + RateLimitedAccountCount int64 } func (g *Group) IsActive() bool { diff --git a/backend/internal/service/group_capacity_service.go b/backend/internal/service/group_capacity_service.go new file mode 100644 index 00000000..459084dc --- /dev/null +++ b/backend/internal/service/group_capacity_service.go @@ -0,0 +1,131 @@ +package service + +import ( + "context" + "time" +) + +// GroupCapacitySummary holds aggregated capacity for a single group. +type GroupCapacitySummary struct { + GroupID int64 `json:"group_id"` + ConcurrencyUsed int `json:"concurrency_used"` + ConcurrencyMax int `json:"concurrency_max"` + SessionsUsed int `json:"sessions_used"` + SessionsMax int `json:"sessions_max"` + RPMUsed int `json:"rpm_used"` + RPMMax int `json:"rpm_max"` +} + +// GroupCapacityService aggregates per-group capacity from runtime data. +type GroupCapacityService struct { + accountRepo AccountRepository + groupRepo GroupRepository + concurrencyService *ConcurrencyService + sessionLimitCache SessionLimitCache + rpmCache RPMCache +} + +// NewGroupCapacityService creates a new GroupCapacityService. +func NewGroupCapacityService( + accountRepo AccountRepository, + groupRepo GroupRepository, + concurrencyService *ConcurrencyService, + sessionLimitCache SessionLimitCache, + rpmCache RPMCache, +) *GroupCapacityService { + return &GroupCapacityService{ + accountRepo: accountRepo, + groupRepo: groupRepo, + concurrencyService: concurrencyService, + sessionLimitCache: sessionLimitCache, + rpmCache: rpmCache, + } +} + +// GetAllGroupCapacity returns capacity summary for all active groups. +func (s *GroupCapacityService) GetAllGroupCapacity(ctx context.Context) ([]GroupCapacitySummary, error) { + groups, err := s.groupRepo.ListActive(ctx) + if err != nil { + return nil, err + } + + results := make([]GroupCapacitySummary, 0, len(groups)) + for i := range groups { + cap, err := s.getGroupCapacity(ctx, groups[i].ID) + if err != nil { + // Skip groups with errors, return partial results + continue + } + cap.GroupID = groups[i].ID + results = append(results, cap) + } + return results, nil +} + +func (s *GroupCapacityService) getGroupCapacity(ctx context.Context, groupID int64) (GroupCapacitySummary, error) { + accounts, err := s.accountRepo.ListSchedulableByGroupID(ctx, groupID) + if err != nil { + return GroupCapacitySummary{}, err + } + if len(accounts) == 0 { + return GroupCapacitySummary{}, nil + } + + // Collect account IDs and config values + accountIDs := make([]int64, 0, len(accounts)) + sessionTimeouts := make(map[int64]time.Duration) + var concurrencyMax, sessionsMax, rpmMax int + + for i := range accounts { + acc := &accounts[i] + accountIDs = append(accountIDs, acc.ID) + concurrencyMax += acc.Concurrency + + if ms := acc.GetMaxSessions(); ms > 0 { + sessionsMax += ms + timeout := time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute + if timeout <= 0 { + timeout = 5 * time.Minute + } + sessionTimeouts[acc.ID] = timeout + } + + if rpm := acc.GetBaseRPM(); rpm > 0 { + rpmMax += rpm + } + } + + // Batch query runtime data from Redis + concurrencyMap, _ := s.concurrencyService.GetAccountConcurrencyBatch(ctx, accountIDs) + + var sessionsMap map[int64]int + if sessionsMax > 0 && s.sessionLimitCache != nil { + sessionsMap, _ = s.sessionLimitCache.GetActiveSessionCountBatch(ctx, accountIDs, sessionTimeouts) + } + + var rpmMap map[int64]int + if rpmMax > 0 && s.rpmCache != nil { + rpmMap, _ = s.rpmCache.GetRPMBatch(ctx, accountIDs) + } + + // Aggregate + var concurrencyUsed, sessionsUsed, rpmUsed int + for _, id := range accountIDs { + concurrencyUsed += concurrencyMap[id] + if sessionsMap != nil { + sessionsUsed += sessionsMap[id] + } + if rpmMap != nil { + rpmUsed += rpmMap[id] + } + } + + return GroupCapacitySummary{ + ConcurrencyUsed: concurrencyUsed, + ConcurrencyMax: concurrencyMax, + SessionsUsed: sessionsUsed, + SessionsMax: sessionsMax, + RPMUsed: rpmUsed, + RPMMax: rpmMax, + }, nil +} diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 22a67eda..87174e03 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -27,7 +27,7 @@ type GroupRepository interface { ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) ExistsByName(ctx context.Context, name string) (bool, error) - GetAccountCount(ctx context.Context, groupID int64) (int64, error) + GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) // GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) @@ -202,7 +202,7 @@ func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, } // 获取账号数量 - accountCount, err := s.groupRepo.GetAccountCount(ctx, id) + accountCount, _, err := s.groupRepo.GetAccountCount(ctx, id) if err != nil { return nil, fmt.Errorf("get account count: %w", err) } diff --git a/backend/internal/service/header_util.go b/backend/internal/service/header_util.go new file mode 100644 index 00000000..1091070d --- /dev/null +++ b/backend/internal/service/header_util.go @@ -0,0 +1,165 @@ +package service + +import ( + "net/http" + "strings" +) + +// headerWireCasing 定义每个白名单 header 在真实 Claude CLI 抓包中的准确大小写。 +// Go 的 HTTP server 解析请求时会将所有 header key 转为 Canonical 形式(如 x-app → X-App), +// 此 map 用于在转发时恢复到真实的 wire format。 +// +// 来源:对真实 Claude CLI (claude-cli/2.1.81) 到 api.anthropic.com 的 HTTPS 流量抓包。 +var headerWireCasing = map[string]string{ + // Title case + "accept": "Accept", + "user-agent": "User-Agent", + + // X-Stainless-* 保持 SDK 原始大小写 + "x-stainless-retry-count": "X-Stainless-Retry-Count", + "x-stainless-timeout": "X-Stainless-Timeout", + "x-stainless-lang": "X-Stainless-Lang", + "x-stainless-package-version": "X-Stainless-Package-Version", + "x-stainless-os": "X-Stainless-OS", + "x-stainless-arch": "X-Stainless-Arch", + "x-stainless-runtime": "X-Stainless-Runtime", + "x-stainless-runtime-version": "X-Stainless-Runtime-Version", + "x-stainless-helper-method": "x-stainless-helper-method", + + // Anthropic SDK 自身设置的 header,全小写 + "anthropic-dangerous-direct-browser-access": "anthropic-dangerous-direct-browser-access", + "anthropic-version": "anthropic-version", + "anthropic-beta": "anthropic-beta", + "x-app": "x-app", + "content-type": "content-type", + "accept-language": "accept-language", + "sec-fetch-mode": "sec-fetch-mode", + "accept-encoding": "accept-encoding", + "authorization": "authorization", + + // Claude Code 2.1.87+ 新增 header + "x-claude-code-session-id": "X-Claude-Code-Session-Id", + "x-client-request-id": "x-client-request-id", + "content-length": "content-length", +} + +// headerWireOrder 定义真实 Claude CLI 发送 header 的顺序(基于抓包)。 +// 用于 debug log 按此顺序输出,便于与抓包结果直接对比。 +var headerWireOrder = []string{ + "Accept", + "X-Stainless-Retry-Count", + "X-Stainless-Timeout", + "X-Stainless-Lang", + "X-Stainless-Package-Version", + "X-Stainless-OS", + "X-Stainless-Arch", + "X-Stainless-Runtime", + "X-Stainless-Runtime-Version", + "anthropic-dangerous-direct-browser-access", + "anthropic-version", + "authorization", + "x-app", + "User-Agent", + "X-Claude-Code-Session-Id", + "content-type", + "anthropic-beta", + "x-client-request-id", + "accept-language", + "sec-fetch-mode", + "accept-encoding", + "content-length", + "x-stainless-helper-method", +} + +// headerWireOrderSet 用于快速判断某个 key 是否在 headerWireOrder 中(按 lowercase 匹配)。 +var headerWireOrderSet map[string]struct{} + +func init() { + headerWireOrderSet = make(map[string]struct{}, len(headerWireOrder)) + for _, k := range headerWireOrder { + headerWireOrderSet[strings.ToLower(k)] = struct{}{} + } +} + +// resolveWireCasing 将 Go canonical key(如 X-Stainless-Os)映射为真实 wire casing(如 X-Stainless-OS)。 +// 如果 map 中没有对应条目,返回原始 key 不变。 +func resolveWireCasing(key string) string { + if wk, ok := headerWireCasing[strings.ToLower(key)]; ok { + return wk + } + return key +} + +// setHeaderRaw sets a header bypassing Go's canonical-case normalization. +// The key is stored exactly as provided, preserving original casing. +// +// It first removes any existing value under the canonical key, the wire casing key, +// and the exact raw key, preventing duplicates from any source. +func setHeaderRaw(h http.Header, key, value string) { + h.Del(key) // remove canonical form (e.g. "Anthropic-Beta") + if wk := resolveWireCasing(key); wk != key { + delete(h, wk) // remove wire casing form if different + } + delete(h, key) // remove exact raw key if it differs from canonical + h[key] = []string{value} +} + +// addHeaderRaw appends a header value bypassing Go's canonical-case normalization. +func addHeaderRaw(h http.Header, key, value string) { + h[key] = append(h[key], value) +} + +// getHeaderRaw reads a header value, trying multiple key forms to handle the mismatch +// between Go canonical keys, wire casing keys, and raw keys: +// 1. exact key as provided +// 2. wire casing form (from headerWireCasing) +// 3. Go canonical form (via http.Header.Get) +func getHeaderRaw(h http.Header, key string) string { + // 1. exact key + if vals := h[key]; len(vals) > 0 { + return vals[0] + } + // 2. wire casing (e.g. looking up "Anthropic-Dangerous-Direct-Browser-Access" finds "anthropic-dangerous-direct-browser-access") + if wk := resolveWireCasing(key); wk != key { + if vals := h[wk]; len(vals) > 0 { + return vals[0] + } + } + // 3. canonical fallback + return h.Get(key) +} + +// sortHeadersByWireOrder 按照真实 Claude CLI 的 header 顺序返回排序后的 key 列表。 +// 在 headerWireOrder 中定义的 key 按其顺序排列,未定义的 key 追加到末尾。 +func sortHeadersByWireOrder(h http.Header) []string { + // 构建 lowercase -> actual map key 的映射 + present := make(map[string]string, len(h)) + for k := range h { + present[strings.ToLower(k)] = k + } + + result := make([]string, 0, len(h)) + seen := make(map[string]struct{}, len(h)) + + // 先按 wire order 输出 + for _, wk := range headerWireOrder { + lk := strings.ToLower(wk) + if actual, ok := present[lk]; ok { + if _, dup := seen[lk]; !dup { + result = append(result, actual) + seen[lk] = struct{}{} + } + } + } + + // 再追加不在 wire order 中的 header + for k := range h { + lk := strings.ToLower(k) + if _, ok := seen[lk]; !ok { + result = append(result, k) + seen[lk] = struct{}{} + } + } + + return result +} diff --git a/backend/internal/service/http_upstream_port.go b/backend/internal/service/http_upstream_port.go index 0e4cfbec..e8e76957 100644 --- a/backend/internal/service/http_upstream_port.go +++ b/backend/internal/service/http_upstream_port.go @@ -1,55 +1,24 @@ package service -import "net/http" +import ( + "net/http" + + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" +) // HTTPUpstream 上游 HTTP 请求接口 // 用于向上游 API(Claude、OpenAI、Gemini 等)发送请求 -// 这是一个通用接口,可用于任何基于 HTTP 的上游服务 -// -// 设计说明: -// - 支持可选代理配置 -// - 支持账户级连接池隔离 -// - 实现类负责连接池管理和复用 -// - 支持可选的 TLS 指纹伪装 type HTTPUpstream interface { - // Do 执行 HTTP 请求 - // - // 参数: - // - req: HTTP 请求对象,由调用方构建 - // - proxyURL: 代理服务器地址,空字符串表示直连 - // - accountID: 账户 ID,用于连接池隔离(隔离策略为 account 或 account_proxy 时生效) - // - accountConcurrency: 账户并发限制,用于动态调整连接池大小 - // - // 返回: - // - *http.Response: HTTP 响应,调用方必须关闭 Body - // - error: 请求错误(网络错误、超时等) - // - // 注意: - // - 调用方必须关闭 resp.Body,否则会导致连接泄漏 - // - 响应体可能已被包装以跟踪请求生命周期 + // Do 执行 HTTP 请求(不启用 TLS 指纹) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) // DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求 // - // 参数: - // - req: HTTP 请求对象,由调用方构建 - // - proxyURL: 代理服务器地址,空字符串表示直连 - // - accountID: 账户 ID,用于连接池隔离和 TLS 指纹模板选择 - // - accountConcurrency: 账户并发限制,用于动态调整连接池大小 - // - enableTLSFingerprint: 是否启用 TLS 指纹伪装 + // profile 参数: + // - nil: 不启用 TLS 指纹,行为与 Do 方法相同 + // - non-nil: 使用指定的 Profile 进行 TLS 指纹伪装 // - // 返回: - // - *http.Response: HTTP 响应,调用方必须关闭 Body - // - error: 请求错误(网络错误、超时等) - // - // TLS 指纹说明: - // - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹 - // - TLS 指纹模板根据 accountID % len(profiles) 自动选择 - // - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景 - // - 如果 enableTLSFingerprint=false,行为与 Do 方法相同 - // - // 注意: - // - 调用方必须关闭 resp.Body,否则会导致连接泄漏 - // - TLS 指纹客户端与普通客户端使用不同的缓存键,互不影响 - DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) + // Profile 由调用方通过 TLSFingerprintProfileService 解析后传入, + // 支持按账号绑定的数据库 profile 或内置默认 profile。 + DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) } diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index f6a94d15..3d706508 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -5,7 +5,6 @@ import ( "crypto/rand" "crypto/sha256" "encoding/hex" - "encoding/json" "fmt" "log/slog" "net/http" @@ -15,14 +14,12 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) // 预编译正则表达式(避免每次调用重新编译) var ( - // 匹配 user_id 格式: - // 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID) - // 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID) - userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`) // 匹配 User-Agent 版本号: xxx/x.y.z userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`) ) @@ -177,6 +174,7 @@ func getHeaderOrDefault(headers http.Header, key, defaultValue string) string { } // ApplyFingerprint 将指纹应用到请求头(覆盖原有的x-stainless-*头) +// 使用 setHeaderRaw 保持原始大小写(如 X-Stainless-OS 而非 X-Stainless-Os) func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) { if fp == nil { return @@ -184,92 +182,82 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) { // 设置user-agent if fp.UserAgent != "" { - req.Header.Set("user-agent", fp.UserAgent) + setHeaderRaw(req.Header, "User-Agent", fp.UserAgent) } - // 设置x-stainless-*头 + // 设置x-stainless-*头(保持与 claude.DefaultHeaders 一致的大小写) if fp.StainlessLang != "" { - req.Header.Set("X-Stainless-Lang", fp.StainlessLang) + setHeaderRaw(req.Header, "X-Stainless-Lang", fp.StainlessLang) } if fp.StainlessPackageVersion != "" { - req.Header.Set("X-Stainless-Package-Version", fp.StainlessPackageVersion) + setHeaderRaw(req.Header, "X-Stainless-Package-Version", fp.StainlessPackageVersion) } if fp.StainlessOS != "" { - req.Header.Set("X-Stainless-OS", fp.StainlessOS) + setHeaderRaw(req.Header, "X-Stainless-OS", fp.StainlessOS) } if fp.StainlessArch != "" { - req.Header.Set("X-Stainless-Arch", fp.StainlessArch) + setHeaderRaw(req.Header, "X-Stainless-Arch", fp.StainlessArch) } if fp.StainlessRuntime != "" { - req.Header.Set("X-Stainless-Runtime", fp.StainlessRuntime) + setHeaderRaw(req.Header, "X-Stainless-Runtime", fp.StainlessRuntime) } if fp.StainlessRuntimeVersion != "" { - req.Header.Set("X-Stainless-Runtime-Version", fp.StainlessRuntimeVersion) + setHeaderRaw(req.Header, "X-Stainless-Runtime-Version", fp.StainlessRuntimeVersion) } } // RewriteUserID 重写body中的metadata.user_id -// 输入格式:user_{clientId}_account__session_{sessionUUID} -// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash} +// 支持旧拼接格式和新 JSON 格式的 user_id 解析, +// 根据 fingerprintUA 版本选择输出格式。 // // 重要:此函数使用 json.RawMessage 保留其他字段的原始字节, // 避免重新序列化导致 thinking 块等内容被修改。 -func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) { +func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) { if len(body) == 0 || accountUUID == "" || cachedClientID == "" { return body, nil } - // 使用 RawMessage 保留其他字段的原始字节 - var reqMap map[string]json.RawMessage - if err := json.Unmarshal(body, &reqMap); err != nil { + metadata := gjson.GetBytes(body, "metadata") + if !metadata.Exists() || metadata.Type == gjson.Null { + return body, nil + } + if !strings.HasPrefix(strings.TrimSpace(metadata.Raw), "{") { return body, nil } - // 解析 metadata 字段 - metadataRaw, ok := reqMap["metadata"] - if !ok { + userIDResult := metadata.Get("user_id") + if !userIDResult.Exists() || userIDResult.Type != gjson.String { + return body, nil + } + userID := userIDResult.String() + if userID == "" { return body, nil } - var metadata map[string]any - if err := json.Unmarshal(metadataRaw, &metadata); err != nil { + // 解析 user_id(兼容旧拼接格式和新 JSON 格式) + parsed := ParseMetadataUserID(userID) + if parsed == nil { return body, nil } - userID, ok := metadata["user_id"].(string) - if !ok || userID == "" { - return body, nil - } - - // 匹配格式: - // 旧格式: user_{64位hex}_account__session_{uuid} - // 新格式: user_{64位hex}_account_{uuid}_session_{uuid} - matches := userIDRegex.FindStringSubmatch(userID) - if matches == nil { - return body, nil - } - - // matches[1] = account UUID (可能为空), matches[2] = session UUID - sessionTail := matches[2] // 原始session UUID + sessionTail := parsed.SessionID // 原始session UUID // 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式 seed := fmt.Sprintf("%d::%s", accountID, sessionTail) newSessionHash := generateUUIDFromSeed(seed) - // 构建新的user_id - // 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash} - newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash) + // 根据客户端版本选择输出格式 + version := ExtractCLIVersion(fingerprintUA) + newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version) + if newUserID == userID { + return body, nil + } - metadata["user_id"] = newUserID - - // 只重新序列化 metadata 字段 - newMetadataRaw, err := json.Marshal(metadata) + newBody, err := sjson.SetBytes(body, "metadata.user_id", newUserID) if err != nil { return body, nil } - reqMap["metadata"] = newMetadataRaw - - return json.Marshal(reqMap) + return newBody, nil } // RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装 @@ -278,9 +266,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI // // 重要:此函数使用 json.RawMessage 保留其他字段的原始字节, // 避免重新序列化导致 thinking 块等内容被修改。 -func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) { +func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) { // 先执行常规的 RewriteUserID 逻辑 - newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID) + newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID, fingerprintUA) if err != nil { return newBody, err } @@ -290,32 +278,26 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b return newBody, nil } - // 使用 RawMessage 保留其他字段的原始字节 - var reqMap map[string]json.RawMessage - if err := json.Unmarshal(newBody, &reqMap); err != nil { + metadata := gjson.GetBytes(newBody, "metadata") + if !metadata.Exists() || metadata.Type == gjson.Null { + return newBody, nil + } + if !strings.HasPrefix(strings.TrimSpace(metadata.Raw), "{") { return newBody, nil } - // 解析 metadata 字段 - metadataRaw, ok := reqMap["metadata"] - if !ok { + userIDResult := metadata.Get("user_id") + if !userIDResult.Exists() || userIDResult.Type != gjson.String { + return newBody, nil + } + userID := userIDResult.String() + if userID == "" { return newBody, nil } - var metadata map[string]any - if err := json.Unmarshal(metadataRaw, &metadata); err != nil { - return newBody, nil - } - - userID, ok := metadata["user_id"].(string) - if !ok || userID == "" { - return newBody, nil - } - - // 查找 _session_ 的位置,替换其后的内容 - const sessionMarker = "_session_" - idx := strings.LastIndex(userID, sessionMarker) - if idx == -1 { + // 解析已重写的 user_id + uidParsed := ParseMetadataUserID(userID) + if uidParsed == nil { return newBody, nil } @@ -337,8 +319,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err) } - // 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容 - newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID + // 用 FormatMetadataUserID 重建(保持与 RewriteUserID 相同的格式) + version := ExtractCLIVersion(fingerprintUA) + newUserID := FormatMetadataUserID(uidParsed.DeviceID, uidParsed.AccountUUID, maskedSessionID, version) slog.Debug("session_id_masking_applied", "account_id", account.ID, @@ -346,16 +329,15 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b "after", newUserID, ) - metadata["user_id"] = newUserID - - // 只重新序列化 metadata 字段 - newMetadataRaw, marshalErr := json.Marshal(metadata) - if marshalErr != nil { + if newUserID == userID { return newBody, nil } - reqMap["metadata"] = newMetadataRaw - return json.Marshal(reqMap) + maskedBody, setErr := sjson.SetBytes(newBody, "metadata.user_id", newUserID) + if setErr != nil { + return newBody, nil + } + return maskedBody, nil } // generateRandomUUID 生成随机 UUID v4 格式字符串 diff --git a/backend/internal/service/identity_service_order_test.go b/backend/internal/service/identity_service_order_test.go new file mode 100644 index 00000000..d1e12274 --- /dev/null +++ b/backend/internal/service/identity_service_order_test.go @@ -0,0 +1,82 @@ +package service + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +type identityCacheStub struct { + maskedSessionID string +} + +func (s *identityCacheStub) GetFingerprint(_ context.Context, _ int64) (*Fingerprint, error) { + return nil, nil +} +func (s *identityCacheStub) SetFingerprint(_ context.Context, _ int64, _ *Fingerprint) error { + return nil +} +func (s *identityCacheStub) GetMaskedSessionID(_ context.Context, _ int64) (string, error) { + return s.maskedSessionID, nil +} +func (s *identityCacheStub) SetMaskedSessionID(_ context.Context, _ int64, sessionID string) error { + s.maskedSessionID = sessionID + return nil +} + +func TestIdentityService_RewriteUserID_PreservesTopLevelFieldOrder(t *testing.T) { + cache := &identityCacheStub{} + svc := NewIdentityService(cache) + + originalUserID := FormatMetadataUserID( + "d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169", + "", + "7578cf37-aaca-46e4-a45c-71285d9dbb83", + "2.1.78", + ) + body := []byte(`{"alpha":1,"messages":[],"metadata":{"user_id":` + strconvQuote(originalUserID) + `},"max_tokens":64000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"stream":true}`) + + result, err := svc.RewriteUserID(body, 123, "acc-uuid", "client-xyz", "claude-cli/2.1.78 (external, cli)") + require.NoError(t, err) + resultStr := string(result) + + assertJSONTokenOrder(t, resultStr, `"alpha"`, `"messages"`, `"metadata"`, `"max_tokens"`, `"thinking"`, `"output_config"`, `"stream"`) + require.NotContains(t, resultStr, originalUserID) + require.Contains(t, resultStr, `"metadata":{"user_id":"`) +} + +func TestIdentityService_RewriteUserIDWithMasking_PreservesTopLevelFieldOrder(t *testing.T) { + cache := &identityCacheStub{maskedSessionID: "11111111-2222-4333-8444-555555555555"} + svc := NewIdentityService(cache) + + originalUserID := FormatMetadataUserID( + "d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169", + "", + "7578cf37-aaca-46e4-a45c-71285d9dbb83", + "2.1.78", + ) + body := []byte(`{"alpha":1,"messages":[],"metadata":{"user_id":` + strconvQuote(originalUserID) + `},"max_tokens":64000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"stream":true}`) + + account := &Account{ + ID: 123, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "session_id_masking_enabled": true, + }, + } + + result, err := svc.RewriteUserIDWithMasking(context.Background(), body, account, "acc-uuid", "client-xyz", "claude-cli/2.1.78 (external, cli)") + require.NoError(t, err) + resultStr := string(result) + + assertJSONTokenOrder(t, resultStr, `"alpha"`, `"messages"`, `"metadata"`, `"max_tokens"`, `"thinking"`, `"output_config"`, `"stream"`) + require.Contains(t, resultStr, cache.maskedSessionID) + require.True(t, strings.Contains(resultStr, `"metadata":{"user_id":"`)) +} + +func strconvQuote(v string) string { + return `"` + strings.ReplaceAll(strings.ReplaceAll(v, `\`, `\\`), `"`, `\"`) + `"` +} diff --git a/backend/internal/service/internal500_counter.go b/backend/internal/service/internal500_counter.go new file mode 100644 index 00000000..0f0bc50c --- /dev/null +++ b/backend/internal/service/internal500_counter.go @@ -0,0 +1,11 @@ +package service + +import "context" + +// Internal500CounterCache 追踪 Antigravity 账号连续 INTERNAL 500 失败轮数 +type Internal500CounterCache interface { + // IncrementInternal500Count 原子递增计数并返回当前值 + IncrementInternal500Count(ctx context.Context, accountID int64) (int64, error) + // ResetInternal500Count 清零计数器(成功响应时调用) + ResetInternal500Count(ctx context.Context, accountID int64) error +} diff --git a/backend/internal/service/metadata_userid.go b/backend/internal/service/metadata_userid.go new file mode 100644 index 00000000..ee1ef64a --- /dev/null +++ b/backend/internal/service/metadata_userid.go @@ -0,0 +1,104 @@ +package service + +import ( + "encoding/json" + "regexp" + "strings" +) + +// NewMetadataFormatMinVersion is the minimum Claude Code version that uses +// JSON-formatted metadata.user_id instead of the legacy concatenated string. +const NewMetadataFormatMinVersion = "2.1.78" + +// ParsedUserID represents the components extracted from a metadata.user_id value. +type ParsedUserID struct { + DeviceID string // 64-char hex (or arbitrary client id) + AccountUUID string // may be empty + SessionID string // UUID + IsNewFormat bool // true if the original was JSON format +} + +// legacyUserIDRegex matches the legacy user_id format: +// +// user_{64hex}_account_{optional_uuid}_session_{uuid} +var legacyUserIDRegex = regexp.MustCompile(`^user_([a-fA-F0-9]{64})_account_([a-fA-F0-9-]*)_session_([a-fA-F0-9-]{36})$`) + +// jsonUserID is the JSON structure for the new metadata.user_id format. +type jsonUserID struct { + DeviceID string `json:"device_id"` + AccountUUID string `json:"account_uuid"` + SessionID string `json:"session_id"` +} + +// ParseMetadataUserID parses a metadata.user_id string in either format. +// Returns nil if the input cannot be parsed. +func ParseMetadataUserID(raw string) *ParsedUserID { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + + // Try JSON format first (starts with '{') + if raw[0] == '{' { + var j jsonUserID + if err := json.Unmarshal([]byte(raw), &j); err != nil { + return nil + } + if j.DeviceID == "" || j.SessionID == "" { + return nil + } + return &ParsedUserID{ + DeviceID: j.DeviceID, + AccountUUID: j.AccountUUID, + SessionID: j.SessionID, + IsNewFormat: true, + } + } + + // Try legacy format + matches := legacyUserIDRegex.FindStringSubmatch(raw) + if matches == nil { + return nil + } + return &ParsedUserID{ + DeviceID: matches[1], + AccountUUID: matches[2], + SessionID: matches[3], + IsNewFormat: false, + } +} + +// FormatMetadataUserID builds a metadata.user_id string in the format +// appropriate for the given CLI version. Components are the rewritten values +// (not necessarily the originals). +func FormatMetadataUserID(deviceID, accountUUID, sessionID, uaVersion string) string { + if IsNewMetadataFormatVersion(uaVersion) { + b, _ := json.Marshal(jsonUserID{ + DeviceID: deviceID, + AccountUUID: accountUUID, + SessionID: sessionID, + }) + return string(b) + } + // Legacy format + return "user_" + deviceID + "_account_" + accountUUID + "_session_" + sessionID +} + +// IsNewMetadataFormatVersion returns true if the given CLI version uses the +// new JSON metadata.user_id format (>= 2.1.78). +func IsNewMetadataFormatVersion(version string) bool { + if version == "" { + return false + } + return CompareVersions(version, NewMetadataFormatMinVersion) >= 0 +} + +// ExtractCLIVersion extracts the Claude Code version from a User-Agent string. +// Returns "" if the UA doesn't match the expected pattern. +func ExtractCLIVersion(ua string) string { + matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua) + if len(matches) >= 2 { + return matches[1] + } + return "" +} diff --git a/backend/internal/service/metadata_userid_test.go b/backend/internal/service/metadata_userid_test.go new file mode 100644 index 00000000..40ad7087 --- /dev/null +++ b/backend/internal/service/metadata_userid_test.go @@ -0,0 +1,183 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// ============ ParseMetadataUserID Tests ============ + +func TestParseMetadataUserID_LegacyFormat_WithoutAccountUUID(t *testing.T) { + raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000" + parsed := ParseMetadataUserID(raw) + require.NotNil(t, parsed) + require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID) + require.Equal(t, "", parsed.AccountUUID) + require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID) + require.False(t, parsed.IsNewFormat) +} + +func TestParseMetadataUserID_LegacyFormat_WithAccountUUID(t *testing.T) { + raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000" + parsed := ParseMetadataUserID(raw) + require.NotNil(t, parsed) + require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID) + require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID) + require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID) + require.False(t, parsed.IsNewFormat) +} + +func TestParseMetadataUserID_JSONFormat_WithoutAccountUUID(t *testing.T) { + raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}` + parsed := ParseMetadataUserID(raw) + require.NotNil(t, parsed) + require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID) + require.Equal(t, "", parsed.AccountUUID) + require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID) + require.True(t, parsed.IsNewFormat) +} + +func TestParseMetadataUserID_JSONFormat_WithAccountUUID(t *testing.T) { + raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"550e8400-e29b-41d4-a716-446655440000","session_id":"c72554f2-1234-5678-abcd-123456789abc"}` + parsed := ParseMetadataUserID(raw) + require.NotNil(t, parsed) + require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID) + require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID) + require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID) + require.True(t, parsed.IsNewFormat) +} + +func TestParseMetadataUserID_InvalidInputs(t *testing.T) { + tests := []struct { + name string + raw string + }{ + {"empty string", ""}, + {"whitespace only", " "}, + {"random text", "not-a-valid-user-id"}, + {"partial legacy format", "session_123e4567-e89b-12d3-a456-426614174000"}, + {"invalid JSON", `{"device_id":}`}, + {"JSON missing device_id", `{"account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`}, + {"JSON missing session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":""}`}, + {"JSON empty device_id", `{"device_id":"","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`}, + {"JSON empty session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":""}`}, + {"legacy format short hex", "user_a1b2c3d4_account__session_123e4567-e89b-12d3-a456-426614174000"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Nil(t, ParseMetadataUserID(tt.raw), "should return nil for: %s", tt.raw) + }) + } +} + +func TestParseMetadataUserID_HexCaseInsensitive(t *testing.T) { + // Legacy format should accept both upper and lower case hex + rawUpper := "user_A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2_account__session_123e4567-e89b-12d3-a456-426614174000" + parsed := ParseMetadataUserID(rawUpper) + require.NotNil(t, parsed, "legacy format should accept uppercase hex") + require.Equal(t, "A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2", parsed.DeviceID) +} + +// ============ FormatMetadataUserID Tests ============ + +func TestFormatMetadataUserID_LegacyVersion(t *testing.T) { + result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.77") + require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account_acc-uuid_session_sess-uuid", result) +} + +func TestFormatMetadataUserID_NewVersion(t *testing.T) { + result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.78") + require.Equal(t, `{"device_id":"deadbeef00112233445566778899aabbccddeeff0011223344556677","account_uuid":"acc-uuid","session_id":"sess-uuid"}`, result) +} + +func TestFormatMetadataUserID_EmptyVersion_Legacy(t *testing.T) { + result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "") + require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account__session_sess-uuid", result) +} + +func TestFormatMetadataUserID_EmptyAccountUUID(t *testing.T) { + // Legacy format with empty account UUID → double underscore + result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.22") + require.Contains(t, result, "_account__session_") + + // New format with empty account UUID → empty string in JSON + result = FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.78") + require.Contains(t, result, `"account_uuid":""`) +} + +// ============ IsNewMetadataFormatVersion Tests ============ + +func TestIsNewMetadataFormatVersion(t *testing.T) { + tests := []struct { + version string + want bool + }{ + {"", false}, + {"2.1.77", false}, + {"2.1.78", true}, + {"2.1.79", true}, + {"2.2.0", true}, + {"3.0.0", true}, + {"2.0.100", false}, + {"1.9.99", false}, + } + for _, tt := range tests { + t.Run(tt.version, func(t *testing.T) { + require.Equal(t, tt.want, IsNewMetadataFormatVersion(tt.version)) + }) + } +} + +// ============ Round-trip Tests ============ + +func TestParseFormat_RoundTrip_Legacy(t *testing.T) { + deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + accountUUID := "550e8400-e29b-41d4-a716-446655440000" + sessionID := "123e4567-e89b-12d3-a456-426614174000" + + formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.22") + parsed := ParseMetadataUserID(formatted) + require.NotNil(t, parsed) + require.Equal(t, deviceID, parsed.DeviceID) + require.Equal(t, accountUUID, parsed.AccountUUID) + require.Equal(t, sessionID, parsed.SessionID) + require.False(t, parsed.IsNewFormat) +} + +func TestParseFormat_RoundTrip_JSON(t *testing.T) { + deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + accountUUID := "550e8400-e29b-41d4-a716-446655440000" + sessionID := "123e4567-e89b-12d3-a456-426614174000" + + formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.78") + parsed := ParseMetadataUserID(formatted) + require.NotNil(t, parsed) + require.Equal(t, deviceID, parsed.DeviceID) + require.Equal(t, accountUUID, parsed.AccountUUID) + require.Equal(t, sessionID, parsed.SessionID) + require.True(t, parsed.IsNewFormat) +} + +func TestParseFormat_RoundTrip_EmptyAccountUUID(t *testing.T) { + deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + sessionID := "123e4567-e89b-12d3-a456-426614174000" + + // Legacy round-trip with empty account UUID + formatted := FormatMetadataUserID(deviceID, "", sessionID, "2.1.22") + parsed := ParseMetadataUserID(formatted) + require.NotNil(t, parsed) + require.Equal(t, deviceID, parsed.DeviceID) + require.Equal(t, "", parsed.AccountUUID) + require.Equal(t, sessionID, parsed.SessionID) + + // JSON round-trip with empty account UUID + formatted = FormatMetadataUserID(deviceID, "", sessionID, "2.1.78") + parsed = ParseMetadataUserID(formatted) + require.NotNil(t, parsed) + require.Equal(t, deviceID, parsed.DeviceID) + require.Equal(t, "", parsed.AccountUUID) + require.Equal(t, sessionID, parsed.SessionID) +} diff --git a/backend/internal/service/oauth_refresh_api.go b/backend/internal/service/oauth_refresh_api.go index 17b9128c..5dbba638 100644 --- a/backend/internal/service/oauth_refresh_api.go +++ b/backend/internal/service/oauth_refresh_api.go @@ -108,8 +108,7 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded( // 5. 设置版本号 + 更新 DB if newCredentials != nil { newCredentials["_token_version"] = time.Now().UnixMilli() - freshAccount.Credentials = newCredentials - if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil { + if updateErr := persistAccountCredentials(ctx, api.accountRepo, freshAccount, newCredentials); updateErr != nil { slog.Error("oauth_refresh_update_failed", "account_id", freshAccount.ID, "error", updateErr, diff --git a/backend/internal/service/oauth_refresh_api_test.go b/backend/internal/service/oauth_refresh_api_test.go index 6cf9371f..c3b38ddf 100644 --- a/backend/internal/service/oauth_refresh_api_test.go +++ b/backend/internal/service/oauth_refresh_api_test.go @@ -16,10 +16,11 @@ import ( // refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests. type refreshAPIAccountRepo struct { mockAccountRepoForGemini - account *Account // returned by GetByID - getByIDErr error - updateErr error - updateCalls int + account *Account // returned by GetByID + getByIDErr error + updateErr error + updateCalls int + updateCredentialsCalls int } func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) { @@ -34,6 +35,19 @@ func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error { return r.updateErr } +func (r *refreshAPIAccountRepo) UpdateCredentials(_ context.Context, id int64, credentials map[string]any) error { + r.updateCalls++ + r.updateCredentialsCalls++ + if r.updateErr != nil { + return r.updateErr + } + if r.account == nil || r.account.ID != id { + r.account = &Account{ID: id} + } + r.account.Credentials = cloneCredentials(credentials) + return nil +} + // refreshAPIExecutorStub implements OAuthRefreshExecutor for tests. type refreshAPIExecutorStub struct { needsRefresh bool @@ -106,10 +120,36 @@ func TestRefreshIfNeeded_Success(t *testing.T) { require.Equal(t, "new-token", result.NewCredentials["access_token"]) require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set require.Equal(t, 1, repo.updateCalls) // DB updated - require.Equal(t, 1, cache.releaseCalls) // lock released + require.Equal(t, 1, repo.updateCredentialsCalls) + require.Equal(t, 1, cache.releaseCalls) // lock released require.Equal(t, 1, executor.refreshCalls) } +func TestRefreshIfNeeded_UpdateCredentialsPreservesRateLimitState(t *testing.T) { + resetAt := time.Now().Add(45 * time.Minute) + account := &Account{ + ID: 11, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + RateLimitResetAt: &resetAt, + } + repo := &refreshAPIAccountRepo{account: account} + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "safe-token"}, + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.True(t, result.Refreshed) + require.Equal(t, 1, repo.updateCredentialsCalls) + require.NotNil(t, repo.account.RateLimitResetAt) + require.WithinDuration(t, resetAt, *repo.account.RateLimitResetAt, time.Second) +} + func TestRefreshIfNeeded_LockHeld(t *testing.T) { account := &Account{ID: 2, Platform: PlatformAnthropic} repo := &refreshAPIAccountRepo{account: account} @@ -193,7 +233,7 @@ func TestRefreshIfNeeded_RefreshError(t *testing.T) { require.Error(t, err) require.Nil(t, result) require.Contains(t, err.Error(), "invalid_grant") - require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error + require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error require.Equal(t, 1, cache.releaseCalls) // lock still released via defer } @@ -299,8 +339,8 @@ func TestMergeCredentials_NewOverridesOld(t *testing.T) { result := MergeCredentials(old, new) - require.Equal(t, "new-token", result["access_token"]) // overridden - require.Equal(t, "old-refresh", result["refresh_token"]) // preserved + require.Equal(t, "new-token", result["access_token"]) // overridden + require.Equal(t, "old-refresh", result["refresh_token"]) // preserved } // ========== BuildClaudeAccountCredentials tests ========== diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 789888cb..6c09e354 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -4,6 +4,7 @@ import ( "container/heap" "context" "errors" + "fmt" "hash/fnv" "math" "sort" @@ -330,6 +331,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } + account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel) + if account == nil { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if acquireErr == nil && result.Acquired { @@ -570,6 +576,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( return nil, 0, 0, 0, errors.New("no available OpenAI accounts") } + // require_privacy_set: 获取分组信息 + var schedGroup *Group + if req.GroupID != nil && s.service.schedulerSnapshot != nil { + schedGroup, _ = s.service.schedulerSnapshot.GetGroupByID(ctx, *req.GroupID) + } + filtered := make([]*Account, 0, len(accounts)) loadReq := make([]AccountWithConcurrency, 0, len(accounts)) for i := range accounts { @@ -582,6 +594,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( if !account.IsSchedulable() || !account.IsOpenAI() { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !account.IsPrivacySet() { + _ = s.service.accountRepo.SetError(ctx, account.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { continue } @@ -691,6 +709,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { continue } + fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + continue + } result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if acquireErr != nil { return nil, len(candidates), topK, loadSkew, acquireErr diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index 977c4ee8..088815ed 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -84,6 +84,61 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa require.Equal(t, int64(32002), account.ID) } +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount(t *testing.T) { + ctx := context.Background() + groupID := int64(10103) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + staleSticky := &Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0} + staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} + dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}} + snapshotCache := &openAISnapshotCacheStub{ + snapshotAccounts: []*Account{staleSticky, staleBackup}, + accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup}, + } + snapshotService := &SchedulerSnapshotService{cache: snapshotCache} + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}}, + cache: cache, + cfg: &config.Config{}, + schedulerSnapshot: snapshotService, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(33002), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) +} + +func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeRecheckSkipsStaleCachedCandidate(t *testing.T) { + ctx := context.Background() + groupID := int64(10104) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + stalePrimary := &Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0} + staleSecondary := &Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + dbPrimary := Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} + dbSecondary := Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + snapshotCache := &openAISnapshotCacheStub{ + snapshotAccounts: []*Account{stalePrimary, staleSecondary}, + accountsByID: map[int64]*Account{34001: stalePrimary, 34002: staleSecondary}, + } + snapshotService := &SchedulerSnapshotService{cache: snapshotCache} + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}}, + cfg: &config.Config{}, + schedulerSnapshot: snapshotService, + } + + account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(34002), account.ID) +} + func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) { ctx := context.Background() groupID := int64(9) diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 29f2b672..21b4874e 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -7,6 +7,8 @@ import ( var codexModelMap = map[string]string{ "gpt-5.4": "gpt-5.4", + "gpt-5.4-mini": "gpt-5.4-mini", + "gpt-5.4-nano": "gpt-5.4-nano", "gpt-5.4-none": "gpt-5.4", "gpt-5.4-low": "gpt-5.4", "gpt-5.4-medium": "gpt-5.4", @@ -83,7 +85,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact if v, ok := reqBody["model"].(string); ok { model = v } - normalizedModel := normalizeCodexModel(model) + normalizedModel := strings.TrimSpace(model) if normalizedModel != "" { if model != normalizedModel { reqBody["model"] = normalizedModel @@ -172,6 +174,11 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact result.PromptCacheKey = strings.TrimSpace(v) } + // 提取 input 中 role:"system" 消息至 instructions(OAuth 上游不支持 system role)。 + if extractSystemMessagesFromInput(reqBody) { + result.Modified = true + } + // instructions 处理逻辑:根据是否是 Codex CLI 分别调用不同方法 if applyInstructions(reqBody, isCodexCLI) { result.Modified = true @@ -220,6 +227,12 @@ func normalizeCodexModel(model string) string { normalized := strings.ToLower(modelID) + if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") { + return "gpt-5.4-mini" + } + if strings.Contains(normalized, "gpt-5.4-nano") || strings.Contains(normalized, "gpt 5.4 nano") { + return "gpt-5.4-nano" + } if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") { return "gpt-5.4" } @@ -301,6 +314,73 @@ func getNormalizedCodexModel(modelID string) string { return "" } +// extractTextFromContent extracts plain text from a content value that is either +// a Go string or a []any of content-part maps with type:"text". +func extractTextFromContent(content any) string { + switch v := content.(type) { + case string: + return v + case []any: + var parts []string + for _, part := range v { + m, ok := part.(map[string]any) + if !ok { + continue + } + if t, _ := m["type"].(string); t == "text" { + if text, ok := m["text"].(string); ok { + parts = append(parts, text) + } + } + } + return strings.Join(parts, "") + default: + return "" + } +} + +// extractSystemMessagesFromInput scans the input array for items with role=="system", +// removes them, and merges their content into reqBody["instructions"]. +// If instructions is already non-empty, extracted content is prepended with "\n\n". +// Returns true if any system messages were extracted. +func extractSystemMessagesFromInput(reqBody map[string]any) bool { + input, ok := reqBody["input"].([]any) + if !ok || len(input) == 0 { + return false + } + + var systemTexts []string + remaining := make([]any, 0, len(input)) + + for _, item := range input { + m, ok := item.(map[string]any) + if !ok { + remaining = append(remaining, item) + continue + } + if role, _ := m["role"].(string); role != "system" { + remaining = append(remaining, item) + continue + } + if text := extractTextFromContent(m["content"]); text != "" { + systemTexts = append(systemTexts, text) + } + } + + if len(systemTexts) == 0 { + return false + } + + extracted := strings.Join(systemTexts, "\n\n") + if existing, ok := reqBody["instructions"].(string); ok && strings.TrimSpace(existing) != "" { + reqBody["instructions"] = extracted + "\n\n" + existing + } else { + reqBody["instructions"] = extracted + } + reqBody["input"] = remaining + return true +} + // applyInstructions 处理 instructions 字段:仅在 instructions 为空时填充默认值。 func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { if !isInstructionsEmpty(reqBody) { diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index ae6f8555..889ac615 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -238,10 +238,15 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { "gpt-5.4-high": "gpt-5.4", "gpt-5.4-chat-latest": "gpt-5.4", "gpt 5.4": "gpt-5.4", + "gpt-5.4-mini": "gpt-5.4-mini", + "gpt 5.4 mini": "gpt-5.4-mini", + "gpt-5.4-nano": "gpt-5.4-nano", + "gpt 5.4 nano": "gpt-5.4-nano", "gpt-5.3": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex", "gpt-5.3-codex-spark": "gpt-5.3-codex", + "gpt 5.3 codex spark": "gpt-5.3-codex", "gpt-5.3-codex-spark-high": "gpt-5.3-codex", "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", "gpt 5.3 codex": "gpt-5.3-codex", @@ -252,6 +257,34 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { } } +func TestApplyCodexOAuthTransform_PreservesBareSparkModel(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.3-codex-spark", + "input": []any{}, + } + + result := applyCodexOAuthTransform(reqBody, false, false) + + require.Equal(t, "gpt-5.3-codex-spark", reqBody["model"]) + require.Equal(t, "gpt-5.3-codex-spark", result.NormalizedModel) + store, ok := reqBody["store"].(bool) + require.True(t, ok) + require.False(t, store) +} + +func TestApplyCodexOAuthTransform_TrimmedModelWithoutPolicyRewrite(t *testing.T) { + reqBody := map[string]any{ + "model": " gpt-5.3-codex-spark ", + "input": []any{}, + } + + result := applyCodexOAuthTransform(reqBody, false, false) + + require.Equal(t, "gpt-5.3-codex-spark", reqBody["model"]) + require.Equal(t, "gpt-5.3-codex-spark", result.NormalizedModel) + require.True(t, result.Modified) +} + func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) { // Codex CLI 场景:已有 instructions 时不修改 @@ -344,6 +377,135 @@ func TestApplyCodexOAuthTransform_StringInputWithToolsField(t *testing.T) { require.Len(t, input, 1) } +func TestExtractSystemMessagesFromInput(t *testing.T) { + t.Run("no system messages", func(t *testing.T) { + reqBody := map[string]any{ + "input": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + result := extractSystemMessagesFromInput(reqBody) + require.False(t, result) + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) + _, hasInstructions := reqBody["instructions"] + require.False(t, hasInstructions) + }) + + t.Run("string content system message", func(t *testing.T) { + reqBody := map[string]any{ + "input": []any{ + map[string]any{"role": "system", "content": "You are an assistant."}, + map[string]any{"role": "user", "content": "hello"}, + }, + } + result := extractSystemMessagesFromInput(reqBody) + require.True(t, result) + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) + msg, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "user", msg["role"]) + require.Equal(t, "You are an assistant.", reqBody["instructions"]) + }) + + t.Run("array content system message", func(t *testing.T) { + reqBody := map[string]any{ + "input": []any{ + map[string]any{ + "role": "system", + "content": []any{ + map[string]any{"type": "text", "text": "Be helpful."}, + }, + }, + }, + } + result := extractSystemMessagesFromInput(reqBody) + require.True(t, result) + require.Equal(t, "Be helpful.", reqBody["instructions"]) + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 0) + }) + + t.Run("multiple system messages concatenated", func(t *testing.T) { + reqBody := map[string]any{ + "input": []any{ + map[string]any{"role": "system", "content": "First."}, + map[string]any{"role": "system", "content": "Second."}, + map[string]any{"role": "user", "content": "hi"}, + }, + } + result := extractSystemMessagesFromInput(reqBody) + require.True(t, result) + require.Equal(t, "First.\n\nSecond.", reqBody["instructions"]) + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) + }) + + t.Run("mixed system and non-system preserves non-system", func(t *testing.T) { + reqBody := map[string]any{ + "input": []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "system", "content": "Sys prompt."}, + map[string]any{"role": "assistant", "content": "Hi there"}, + }, + } + result := extractSystemMessagesFromInput(reqBody) + require.True(t, result) + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 2) + first, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "user", first["role"]) + second, ok := input[1].(map[string]any) + require.True(t, ok) + require.Equal(t, "assistant", second["role"]) + }) + + t.Run("existing instructions prepended", func(t *testing.T) { + reqBody := map[string]any{ + "input": []any{ + map[string]any{"role": "system", "content": "Extracted."}, + map[string]any{"role": "user", "content": "hi"}, + }, + "instructions": "Existing instructions.", + } + result := extractSystemMessagesFromInput(reqBody) + require.True(t, result) + require.Equal(t, "Extracted.\n\nExisting instructions.", reqBody["instructions"]) + }) +} + +func TestApplyCodexOAuthTransform_ExtractsSystemMessages(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.1", + "input": []any{ + map[string]any{"role": "system", "content": "You are a coding assistant."}, + map[string]any{"role": "user", "content": "Write a function."}, + }, + } + + result := applyCodexOAuthTransform(reqBody, false, false) + + require.True(t, result.Modified) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) + msg, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "user", msg["role"]) + + instructions, ok := reqBody["instructions"].(string) + require.True(t, ok) + require.Equal(t, "You are a coding assistant.", instructions) +} + func TestIsInstructionsEmpty(t *testing.T) { tests := []struct { name string diff --git a/backend/internal/service/openai_compat_model.go b/backend/internal/service/openai_compat_model.go new file mode 100644 index 00000000..5f140d01 --- /dev/null +++ b/backend/internal/service/openai_compat_model.go @@ -0,0 +1,103 @@ +package service + +import ( + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" +) + +func NormalizeOpenAICompatRequestedModel(model string) string { + trimmed := strings.TrimSpace(model) + if trimmed == "" { + return "" + } + + normalized, _, ok := splitOpenAICompatReasoningModel(trimmed) + if !ok || normalized == "" { + return trimmed + } + return normalized +} + +func applyOpenAICompatModelNormalization(req *apicompat.AnthropicRequest) { + if req == nil { + return + } + + originalModel := strings.TrimSpace(req.Model) + if originalModel == "" { + return + } + + normalizedModel, derivedEffort, hasReasoningSuffix := splitOpenAICompatReasoningModel(originalModel) + if hasReasoningSuffix && normalizedModel != "" { + req.Model = normalizedModel + } + + if req.OutputConfig != nil && strings.TrimSpace(req.OutputConfig.Effort) != "" { + return + } + + claudeEffort := openAIReasoningEffortToClaudeOutputEffort(derivedEffort) + if claudeEffort == "" { + return + } + + if req.OutputConfig == nil { + req.OutputConfig = &apicompat.AnthropicOutputConfig{} + } + req.OutputConfig.Effort = claudeEffort +} + +func splitOpenAICompatReasoningModel(model string) (normalizedModel string, reasoningEffort string, ok bool) { + trimmed := strings.TrimSpace(model) + if trimmed == "" { + return "", "", false + } + + modelID := trimmed + if strings.Contains(modelID, "/") { + parts := strings.Split(modelID, "/") + modelID = parts[len(parts)-1] + } + modelID = strings.TrimSpace(modelID) + if !strings.HasPrefix(strings.ToLower(modelID), "gpt-") { + return trimmed, "", false + } + + parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool { + switch r { + case '-', '_', ' ': + return true + default: + return false + } + }) + if len(parts) == 0 { + return trimmed, "", false + } + + last := strings.NewReplacer("-", "", "_", "", " ", "").Replace(parts[len(parts)-1]) + switch last { + case "none", "minimal": + case "low", "medium", "high": + reasoningEffort = last + case "xhigh", "extrahigh": + reasoningEffort = "xhigh" + default: + return trimmed, "", false + } + + return normalizeCodexModel(modelID), reasoningEffort, true +} + +func openAIReasoningEffortToClaudeOutputEffort(effort string) string { + switch strings.TrimSpace(effort) { + case "low", "medium", "high": + return effort + case "xhigh": + return "max" + default: + return "" + } +} diff --git a/backend/internal/service/openai_compat_model_test.go b/backend/internal/service/openai_compat_model_test.go new file mode 100644 index 00000000..32c646d4 --- /dev/null +++ b/backend/internal/service/openai_compat_model_test.go @@ -0,0 +1,129 @@ +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestNormalizeOpenAICompatRequestedModel(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want string + }{ + {name: "gpt reasoning alias strips xhigh", input: "gpt-5.4-xhigh", want: "gpt-5.4"}, + {name: "gpt reasoning alias strips none", input: "gpt-5.4-none", want: "gpt-5.4"}, + {name: "codex max model stays intact", input: "gpt-5.1-codex-max", want: "gpt-5.1-codex-max"}, + {name: "non openai model unchanged", input: "claude-opus-4-6", want: "claude-opus-4-6"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, NormalizeOpenAICompatRequestedModel(tt.input)) + }) + } +} + +func TestApplyOpenAICompatModelNormalization(t *testing.T) { + t.Parallel() + + t.Run("derives xhigh from model suffix when output config missing", func(t *testing.T) { + req := &apicompat.AnthropicRequest{Model: "gpt-5.4-xhigh"} + + applyOpenAICompatModelNormalization(req) + + require.Equal(t, "gpt-5.4", req.Model) + require.NotNil(t, req.OutputConfig) + require.Equal(t, "max", req.OutputConfig.Effort) + }) + + t.Run("explicit output config wins over model suffix", func(t *testing.T) { + req := &apicompat.AnthropicRequest{ + Model: "gpt-5.4-xhigh", + OutputConfig: &apicompat.AnthropicOutputConfig{Effort: "low"}, + } + + applyOpenAICompatModelNormalization(req) + + require.Equal(t, "gpt-5.4", req.Model) + require.NotNil(t, req.OutputConfig) + require.Equal(t, "low", req.OutputConfig.Effort) + }) + + t.Run("non openai model is untouched", func(t *testing.T) { + req := &apicompat.AnthropicRequest{Model: "claude-opus-4-6"} + + applyOpenAICompatModelNormalization(req) + + require.Equal(t, "claude-opus-4-6", req.Model) + require.Nil(t, req.OutputConfig) + }) +} + +func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4-xhigh","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_compat"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + "model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4", + }, + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "gpt-5.4-xhigh", result.Model) + require.Equal(t, "gpt-5.4", result.UpstreamModel) + require.Equal(t, "gpt-5.4", result.BillingModel) + require.NotNil(t, result.ReasoningEffort) + require.Equal(t, "xhigh", *result.ReasoningEffort) + + require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "xhigh", gjson.GetBytes(upstream.lastBody, "reasoning.effort").String()) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "gpt-5.4-xhigh", gjson.GetBytes(rec.Body.Bytes(), "model").String()) + require.Equal(t, "ok", gjson.GetBytes(rec.Body.Bytes(), "content.0.text").String()) + t.Logf("upstream body: %s", string(upstream.lastBody)) + t.Logf("response body: %s", rec.Body.String()) +} diff --git a/backend/internal/service/openai_compat_prompt_cache_key.go b/backend/internal/service/openai_compat_prompt_cache_key.go new file mode 100644 index 00000000..46381838 --- /dev/null +++ b/backend/internal/service/openai_compat_prompt_cache_key.go @@ -0,0 +1,81 @@ +package service + +import ( + "encoding/json" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" +) + +const compatPromptCacheKeyPrefix = "compat_cc_" + +func shouldAutoInjectPromptCacheKeyForCompat(model string) bool { + switch resolveOpenAIUpstreamModel(strings.TrimSpace(model)) { + case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark": + return true + default: + return false + } +} + +func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedModel string) string { + if req == nil { + return "" + } + + normalizedModel := resolveOpenAIUpstreamModel(strings.TrimSpace(mappedModel)) + if normalizedModel == "" { + normalizedModel = resolveOpenAIUpstreamModel(strings.TrimSpace(req.Model)) + } + if normalizedModel == "" { + normalizedModel = strings.TrimSpace(req.Model) + } + + seedParts := []string{"model=" + normalizedModel} + if req.ReasoningEffort != "" { + seedParts = append(seedParts, "reasoning_effort="+strings.TrimSpace(req.ReasoningEffort)) + } + if len(req.ToolChoice) > 0 { + seedParts = append(seedParts, "tool_choice="+normalizeCompatSeedJSON(req.ToolChoice)) + } + if len(req.Tools) > 0 { + if raw, err := json.Marshal(req.Tools); err == nil { + seedParts = append(seedParts, "tools="+normalizeCompatSeedJSON(raw)) + } + } + if len(req.Functions) > 0 { + if raw, err := json.Marshal(req.Functions); err == nil { + seedParts = append(seedParts, "functions="+normalizeCompatSeedJSON(raw)) + } + } + + firstUserCaptured := false + for _, msg := range req.Messages { + switch strings.TrimSpace(msg.Role) { + case "system": + seedParts = append(seedParts, "system="+normalizeCompatSeedJSON(msg.Content)) + case "user": + if !firstUserCaptured { + seedParts = append(seedParts, "first_user="+normalizeCompatSeedJSON(msg.Content)) + firstUserCaptured = true + } + } + } + + return compatPromptCacheKeyPrefix + hashSensitiveValueForLog(strings.Join(seedParts, "|")) +} + +func normalizeCompatSeedJSON(v json.RawMessage) string { + if len(v) == 0 { + return "" + } + var tmp any + if err := json.Unmarshal(v, &tmp); err != nil { + return string(v) + } + out, err := json.Marshal(tmp) + if err != nil { + return string(v) + } + return string(out) +} diff --git a/backend/internal/service/openai_compat_prompt_cache_key_test.go b/backend/internal/service/openai_compat_prompt_cache_key_test.go new file mode 100644 index 00000000..6ca3e85c --- /dev/null +++ b/backend/internal/service/openai_compat_prompt_cache_key_test.go @@ -0,0 +1,79 @@ +package service + +import ( + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/stretchr/testify/require" +) + +func mustRawJSON(t *testing.T, s string) json.RawMessage { + t.Helper() + return json.RawMessage(s) +} + +func TestShouldAutoInjectPromptCacheKeyForCompat(t *testing.T) { + require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4")) + require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3")) + require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex")) + require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex-spark")) + require.False(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-4o")) +} + +func TestDeriveCompatPromptCacheKey_StableAcrossLaterTurns(t *testing.T) { + base := &apicompat.ChatCompletionsRequest{ + Model: "gpt-5.4", + Messages: []apicompat.ChatMessage{ + {Role: "system", Content: mustRawJSON(t, `"You are helpful."`)}, + {Role: "user", Content: mustRawJSON(t, `"Hello"`)}, + }, + } + extended := &apicompat.ChatCompletionsRequest{ + Model: "gpt-5.4", + Messages: []apicompat.ChatMessage{ + {Role: "system", Content: mustRawJSON(t, `"You are helpful."`)}, + {Role: "user", Content: mustRawJSON(t, `"Hello"`)}, + {Role: "assistant", Content: mustRawJSON(t, `"Hi there!"`)}, + {Role: "user", Content: mustRawJSON(t, `"How are you?"`)}, + }, + } + + k1 := deriveCompatPromptCacheKey(base, "gpt-5.4") + k2 := deriveCompatPromptCacheKey(extended, "gpt-5.4") + require.Equal(t, k1, k2, "cache key should be stable across later turns") + require.NotEmpty(t, k1) +} + +func TestDeriveCompatPromptCacheKey_DiffersAcrossSessions(t *testing.T) { + req1 := &apicompat.ChatCompletionsRequest{ + Model: "gpt-5.4", + Messages: []apicompat.ChatMessage{ + {Role: "user", Content: mustRawJSON(t, `"Question A"`)}, + }, + } + req2 := &apicompat.ChatCompletionsRequest{ + Model: "gpt-5.4", + Messages: []apicompat.ChatMessage{ + {Role: "user", Content: mustRawJSON(t, `"Question B"`)}, + }, + } + + k1 := deriveCompatPromptCacheKey(req1, "gpt-5.4") + k2 := deriveCompatPromptCacheKey(req2, "gpt-5.4") + require.NotEqual(t, k1, k2, "different first user messages should yield different keys") +} + +func TestDeriveCompatPromptCacheKey_UsesResolvedSparkFamily(t *testing.T) { + req := &apicompat.ChatCompletionsRequest{ + Model: "gpt-5.3-codex-spark", + Messages: []apicompat.ChatMessage{ + {Role: "user", Content: mustRawJSON(t, `"Question A"`)}, + }, + } + + k1 := deriveCompatPromptCacheKey(req, "gpt-5.3-codex-spark") + k2 := deriveCompatPromptCacheKey(req, " openai/gpt-5.3-codex-spark ") + require.NotEmpty(t, k1) + require.Equal(t, k1, k2, "resolved spark family should derive a stable compat cache key") +} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 9529f6be..1d5bf0d0 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -43,23 +43,40 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( clientStream := chatReq.Stream includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage - // 2. Convert to Responses and forward + // 2. Resolve model mapping early so compat prompt_cache_key injection can + // derive a stable seed from the final upstream model family. + billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) + upstreamModel := resolveOpenAIUpstreamModel(billingModel) + + promptCacheKey = strings.TrimSpace(promptCacheKey) + compatPromptCacheInjected := false + if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) { + promptCacheKey = deriveCompatPromptCacheKey(&chatReq, upstreamModel) + compatPromptCacheInjected = promptCacheKey != "" + } + + // 3. Convert to Responses and forward // ChatCompletionsToResponses always sets Stream=true (upstream always streams). responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq) if err != nil { return nil, fmt.Errorf("convert chat completions to responses: %w", err) } + responsesReq.Model = upstreamModel - // 3. Model mapping - mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) - responsesReq.Model = mappedModel - - logger.L().Debug("openai chat_completions: model mapping applied", + logFields := []zap.Field{ zap.Int64("account_id", account.ID), zap.String("original_model", originalModel), - zap.String("mapped_model", mappedModel), + zap.String("billing_model", billingModel), + zap.String("upstream_model", upstreamModel), zap.Bool("stream", clientStream), - ) + } + if compatPromptCacheInjected { + logFields = append(logFields, + zap.Bool("compat_prompt_cache_key_injected", true), + zap.String("compat_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey)), + ) + } + logger.L().Debug("openai chat_completions: model mapping applied", logFields...) // 4. Marshal Responses request body, then apply OAuth codex transform responsesBody, err := json.Marshal(responsesReq) @@ -73,6 +90,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( return nil, fmt.Errorf("unmarshal for codex transform: %w", err) } codexResult := applyCodexOAuthTransform(reqBody, false, false) + if codexResult.NormalizedModel != "" { + upstreamModel = codexResult.NormalizedModel + } if codexResult.PromptCacheKey != "" { promptCacheKey = codexResult.PromptCacheKey } else if promptCacheKey != "" { @@ -165,9 +185,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( var result *OpenAIForwardResult var handleErr error if clientStream { - result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime) + result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, includeUsage, startTime) } else { - result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) + result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime) } // Propagate ServiceTier and ReasoningEffort to result for billing @@ -209,7 +229,8 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, - mappedModel string, + billingModel string, + upstreamModel string, startTime time.Time, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") @@ -277,12 +298,13 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( c.JSON(http.StatusOK, chatResp) return &OpenAIForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - BillingModel: mappedModel, - Stream: false, - Duration: time.Since(startTime), + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, + Stream: false, + Duration: time.Since(startTime), }, nil } @@ -292,7 +314,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, - mappedModel string, + billingModel string, + upstreamModel string, includeUsage bool, startTime time.Time, ) (*OpenAIForwardResult, error) { @@ -324,13 +347,14 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( resultWithUsage := func() *OpenAIForwardResult { return &OpenAIForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - BillingModel: mappedModel, - Stream: true, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, } } diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 58714571..8c389556 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -40,6 +40,8 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( return nil, fmt.Errorf("parse anthropic request: %w", err) } originalModel := anthropicReq.Model + applyOpenAICompatModelNormalization(&anthropicReq) + normalizedModel := anthropicReq.Model clientStream := anthropicReq.Stream // client's original stream preference // 2. Convert Anthropic → Responses @@ -59,13 +61,16 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } // 3. Model mapping - mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) - responsesReq.Model = mappedModel + billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel) + upstreamModel := resolveOpenAIUpstreamModel(billingModel) + responsesReq.Model = upstreamModel logger.L().Debug("openai messages: model mapping applied", zap.Int64("account_id", account.ID), zap.String("original_model", originalModel), - zap.String("mapped_model", mappedModel), + zap.String("normalized_model", normalizedModel), + zap.String("billing_model", billingModel), + zap.String("upstream_model", upstreamModel), zap.Bool("stream", isStream), ) @@ -81,6 +86,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( return nil, fmt.Errorf("unmarshal for codex transform: %w", err) } codexResult := applyCodexOAuthTransform(reqBody, false, false) + if codexResult.NormalizedModel != "" { + upstreamModel = codexResult.NormalizedModel + } if codexResult.PromptCacheKey != "" { promptCacheKey = codexResult.PromptCacheKey } else if promptCacheKey != "" { @@ -181,10 +189,10 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( var result *OpenAIForwardResult var handleErr error if clientStream { - result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime) + result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime) } else { // Client wants JSON: buffer the streaming response and assemble a JSON reply. - result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) + result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime) } // Propagate ServiceTier and ReasoningEffort to result for billing @@ -229,7 +237,8 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, - mappedModel string, + billingModel string, + upstreamModel string, startTime time.Time, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") @@ -299,12 +308,13 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( c.JSON(http.StatusOK, anthropicResp) return &OpenAIForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - BillingModel: mappedModel, - Stream: false, - Duration: time.Since(startTime), + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, + Stream: false, + Duration: time.Since(startTime), }, nil } @@ -317,7 +327,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, - mappedModel string, + billingModel string, + upstreamModel string, startTime time.Time, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") @@ -347,13 +358,14 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( // resultWithUsage builds the final result snapshot. resultWithUsage := func() *OpenAIForwardResult { return &OpenAIForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - BillingModel: mappedModel, - Stream: true, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, } } diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index ada7d805..7a636afa 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -846,7 +846,7 @@ func TestExtractOpenAIServiceTierFromBody(t *testing.T) { require.Nil(t, extractOpenAIServiceTierFromBody(nil)) } -func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *testing.T) { +func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetadataFields(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} userRepo := &openAIRecordUsageUserRepoStub{} subRepo := &openAIRecordUsageSubRepoStub{} @@ -859,6 +859,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te RequestID: "resp_billing_model_override", BillingModel: "gpt-5.1-codex", Model: "gpt-5.1", + UpstreamModel: "gpt-5.1-codex", ServiceTier: &serviceTier, ReasoningEffort: &reasoning, Usage: OpenAIUsage{ @@ -877,7 +878,10 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te require.NoError(t, err) require.NotNil(t, usageRepo.lastLog) - require.Equal(t, "gpt-5.1-codex", usageRepo.lastLog.Model) + require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model) + require.Equal(t, "gpt-5.1", usageRepo.lastLog.RequestedModel) + require.NotNil(t, usageRepo.lastLog.UpstreamModel) + require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel) require.NotNil(t, usageRepo.lastLog.ServiceTier) require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier) require.NotNil(t, usageRepo.lastLog.ReasoningEffort) @@ -891,6 +895,42 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te require.Equal(t, 1, userRepo.deductCalls) } +func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingRequestedModel(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10} + + // Billing should use the requested model ("gpt-5.1"), not the upstream mapped model ("gpt-5.1-codex"). + // This ensures pricing is always based on the model the user requested. + expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{ + InputTokens: 20, + OutputTokens: 10, + }, 1.1) + require.NoError(t, err) + + err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_upstream_model_billing_fallback", + Model: "gpt-5.1", + UpstreamModel: "gpt-5.1-codex", + Usage: usage, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10}, + User: &User{ID: 20}, + Account: &Account{ID: 30}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model) + require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost) + require.Equal(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost) + require.Equal(t, expectedCost.ActualCost, userRepo.lastAmount) +} + func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} userRepo := &openAIRecordUsageUserRepoStub{} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index c8876edb..e85f0705 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -216,6 +216,9 @@ type OpenAIForwardResult struct { // This is set by the Anthropic Messages conversion path where // the mapped upstream model differs from the client-facing model. BillingModel string + // UpstreamModel is the actual model sent to the upstream provider after mapping. + // Empty when no mapping was applied (requested model was used as-is). + UpstreamModel string // ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex". // Nil means the request did not specify a recognized tier. ServiceTier *string @@ -1198,6 +1201,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID if requestedModel != "" && !account.IsModelSupported(requestedModel) { return nil } + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) + if account == nil { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + return nil + } // 刷新会话 TTL 并返回账号 // Refresh session TTL and return account @@ -1226,6 +1234,10 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [ if fresh == nil { continue } + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel) + if fresh == nil { + continue + } // 选择优先级最高且最久未使用的账号 // Select highest priority and least recently used @@ -1350,27 +1362,32 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex } if !clearSticky && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { - result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) - if err == nil && result.Acquired { - _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) + if account == nil { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + } else { + result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if err == nil && result.Acquired { + _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) - if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } } } } @@ -1557,6 +1574,28 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context. return fresh } +func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string) *Account { + if account == nil { + return nil + } + if s.schedulerSnapshot == nil || s.accountRepo == nil { + return account + } + + latest, err := s.accountRepo.GetByID(ctx, account.ID) + if err != nil || latest == nil { + return nil + } + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, latest, time.Now()) + if !latest.IsSchedulable() || !latest.IsOpenAI() { + return nil + } + if requestedModel != "" && !latest.IsModelSupported(requestedModel) { + return nil + } + return latest +} + func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { var ( account *Account @@ -1775,29 +1814,29 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // 对所有请求执行模型映射(包含 Codex CLI)。 - mappedModel := account.GetMappedModel(reqModel) - if mappedModel != reqModel { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI) - reqBody["model"] = mappedModel + billingModel := account.GetMappedModel(reqModel) + if billingModel != reqModel { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, billingModel, account.Name, isCodexCLI) + reqBody["model"] = billingModel bodyModified = true - markPatchSet("model", mappedModel) + markPatchSet("model", billingModel) } + upstreamModel := billingModel // 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。 if model, ok := reqBody["model"].(string); ok { - normalizedModel := normalizeCodexModel(model) - if normalizedModel != "" && normalizedModel != model { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", - model, normalizedModel, account.Name, account.Type, isCodexCLI) - reqBody["model"] = normalizedModel - mappedModel = normalizedModel + upstreamModel = resolveOpenAIUpstreamModel(model) + if upstreamModel != "" && upstreamModel != model { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", + model, upstreamModel, account.Name, account.Type, isCodexCLI) + reqBody["model"] = upstreamModel bodyModified = true - markPatchSet("model", normalizedModel) + markPatchSet("model", upstreamModel) } // 移除 gpt-5.2-codex 以下的版本 verbosity 参数 // 确保高版本模型向低版本模型映射不报错 - if !SupportsVerbosity(normalizedModel) { + if !SupportsVerbosity(upstreamModel) { if text, ok := reqBody["text"].(map[string]any); ok { delete(text, "verbosity") } @@ -1821,7 +1860,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco disablePatch() } if codexResult.NormalizedModel != "" { - mappedModel = codexResult.NormalizedModel + upstreamModel = codexResult.NormalizedModel } if codexResult.PromptCacheKey != "" { promptCacheKey = codexResult.PromptCacheKey @@ -1938,7 +1977,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco "forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v", account.ID, account.Type, - mappedModel, + upstreamModel, reqStream, hasPreviousResponseID, ) @@ -2027,7 +2066,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco isCodexCLI, reqStream, originalModel, - mappedModel, + upstreamModel, startTime, attempt, wsLastFailureReason, @@ -2128,6 +2167,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco firstTokenMs, wsAttempts, ) + wsResult.UpstreamModel = upstreamModel return wsResult, nil } s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) @@ -2232,14 +2272,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco var usage *OpenAIUsage var firstTokenMs *int if reqStream { - streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel) + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel) if err != nil { return nil, err } usage = streamResult.usage firstTokenMs = streamResult.firstTokenMs } else { - usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel) + usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel) if err != nil { return nil, err } @@ -2263,6 +2303,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, + UpstreamModel: upstreamModel, ServiceTier: serviceTier, ReasoningEffort: reasoningEffort, Stream: reqStream, @@ -2593,6 +2634,12 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough( } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) + if s.rateLimitService != nil { + // Passthrough mode preserves the raw upstream error response, but runtime + // account state still needs to be updated so sticky routing can stop + // reusing a freshly rate-limited account. + _ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + } appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, @@ -4105,10 +4152,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) } - billingModel := result.Model - if result.BillingModel != "" { - billingModel = result.BillingModel - } + billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) serviceTier := "" if result.ServiceTier != nil { serviceTier = strings.TrimSpace(*result.ServiceTier) @@ -4134,7 +4178,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec APIKeyID: apiKey.ID, AccountID: account.ID, RequestID: requestID, - Model: billingModel, + Model: result.Model, + RequestedModel: result.Model, + UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ServiceTier: result.ServiceTier, ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), @@ -4700,11 +4746,3 @@ func normalizeOpenAIReasoningEffort(raw string) string { return "" } } - -func optionalTrimmedStringPtr(raw string) *string { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return nil - } - return &trimmed -} diff --git a/backend/internal/service/openai_model_mapping.go b/backend/internal/service/openai_model_mapping.go index 9bf3fba3..4f8c094b 100644 --- a/backend/internal/service/openai_model_mapping.go +++ b/backend/internal/service/openai_model_mapping.go @@ -1,8 +1,10 @@ package service -// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible -// forwarding. Group-level default mapping only applies when the account itself -// did not match any explicit model_mapping rule. +import "strings" + +// resolveOpenAIForwardModel resolves the account/group mapping result for +// OpenAI-compatible forwarding. Group-level default mapping only applies when +// the account itself did not match any explicit model_mapping rule. func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string { if account == nil { if defaultMappedModel != "" { @@ -17,3 +19,23 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo } return mappedModel } + +func resolveOpenAIUpstreamModel(model string) string { + if isBareGPT53CodexSparkModel(model) { + return "gpt-5.3-codex-spark" + } + return normalizeCodexModel(strings.TrimSpace(model)) +} + +func isBareGPT53CodexSparkModel(model string) bool { + modelID := strings.TrimSpace(model) + if modelID == "" { + return false + } + if strings.Contains(modelID, "/") { + parts := strings.Split(modelID, "/") + modelID = parts[len(parts)-1] + } + normalized := strings.ToLower(strings.TrimSpace(modelID)) + return normalized == "gpt-5.3-codex-spark" || normalized == "gpt 5.3 codex spark" +} diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index 7af3ecae..42f58b37 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -68,3 +68,36 @@ func TestResolveOpenAIForwardModel(t *testing.T) { }) } } + +func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *testing.T) { + account := &Account{ + Credentials: map[string]any{}, + } + + withoutDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "")) + if withoutDefault != "gpt-5.1" { + t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withoutDefault, "gpt-5.1") + } + + withDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")) + if withDefault != "gpt-5.4" { + t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withDefault, "gpt-5.4") + } +} + +func TestResolveOpenAIUpstreamModel(t *testing.T) { + cases := map[string]string{ + "gpt-5.3-codex-spark": "gpt-5.3-codex-spark", + "gpt 5.3 codex spark": "gpt-5.3-codex-spark", + " openai/gpt-5.3-codex-spark ": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3": "gpt-5.3-codex", + } + + for input, expected := range cases { + if got := resolveOpenAIUpstreamModel(input); got != expected { + t.Fatalf("resolveOpenAIUpstreamModel(%q) = %q, want %q", input, got, expected) + } + } +} diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index f51a7491..97fa218d 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -14,6 +14,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" @@ -43,7 +44,7 @@ func (u *httpUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID return u.resp, nil } -func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { +func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { return u.Do(req, proxyURL, accountID, accountConcurrency) } @@ -536,6 +537,55 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF require.True(t, arr[len(arr)-1].Passthrough) } +func TestOpenAIGatewayService_OAuthPassthrough_429PersistsRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`) + resetAt := time.Now().Add(7 * 24 * time.Hour).Unix() + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-rate-limit"}, + }, + Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{"error":{"message":"The usage limit has been reached","type":"usage_limit_reached","resets_at":%d}}`, resetAt))), + } + upstream := &httpUpstreamRecorder{resp: resp} + repo := &openAIWSRateLimitSignalRepo{} + rateSvc := &RateLimitService{accountRepo: repo} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + rateLimitService: rateSvc, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.Error(t, err) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + require.Contains(t, rec.Body.String(), "usage_limit_reached") + require.Len(t, repo.rateLimitCalls, 1) + require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) +} + func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index bd82e107..0f004b01 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -29,9 +29,10 @@ type soraSessionChunk struct { // OpenAIOAuthService handles OpenAI OAuth authentication flows type OpenAIOAuthService struct { - sessionStore *openai.SessionStore - proxyRepo ProxyRepository - oauthClient OpenAIOAuthClient + sessionStore *openai.SessionStore + proxyRepo ProxyRepository + oauthClient OpenAIOAuthClient + privacyClientFactory PrivacyClientFactory // 用于调用 chatgpt.com/backend-api(ImpersonateChrome) } // NewOpenAIOAuthService creates a new OpenAI OAuth service @@ -43,6 +44,12 @@ func NewOpenAIOAuthService(proxyRepo ProxyRepository, oauthClient OpenAIOAuthCli } } +// SetPrivacyClientFactory 注入 ImpersonateChrome 客户端工厂, +// 用于调用 chatgpt.com/backend-api 获取账号信息(plan_type 等)。 +func (s *OpenAIOAuthService) SetPrivacyClientFactory(factory PrivacyClientFactory) { + s.privacyClientFactory = factory +} + // OpenAIAuthURLResult contains the authorization URL and session info type OpenAIAuthURLResult struct { AuthURL string `json:"auth_url"` @@ -131,6 +138,7 @@ type OpenAITokenInfo struct { ChatGPTUserID string `json:"chatgpt_user_id,omitempty"` OrganizationID string `json:"organization_id,omitempty"` PlanType string `json:"plan_type,omitempty"` + PrivacyMode string `json:"privacy_mode,omitempty"` } // ExchangeCode exchanges authorization code for tokens @@ -251,6 +259,30 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre tokenInfo.PlanType = userInfo.PlanType } + // id_token 中缺少 plan_type 时(如 Mobile RT),尝试通过 ChatGPT backend-api 补全 + if tokenInfo.PlanType == "" && tokenInfo.AccessToken != "" && s.privacyClientFactory != nil { + // 从 access_token JWT 中提取 orgID(poid),用于匹配正确的账号 + orgID := tokenInfo.OrganizationID + if orgID == "" { + if atClaims, err := openai.DecodeIDToken(tokenInfo.AccessToken); err == nil && atClaims.OpenAIAuth != nil { + orgID = atClaims.OpenAIAuth.POID + } + } + if info := fetchChatGPTAccountInfo(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL, orgID); info != nil { + if tokenInfo.PlanType == "" && info.PlanType != "" { + tokenInfo.PlanType = info.PlanType + } + if tokenInfo.Email == "" && info.Email != "" { + tokenInfo.Email = info.Email + } + } + } + + // 尝试设置隐私(关闭训练数据共享),best-effort + if tokenInfo.AccessToken != "" && s.privacyClientFactory != nil { + tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL) + } + return tokenInfo, nil } @@ -470,6 +502,25 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A refreshToken := account.GetCredential("refresh_token") if refreshToken == "" { + accessToken := account.GetCredential("access_token") + if accessToken != "" { + tokenInfo := &OpenAITokenInfo{ + AccessToken: accessToken, + RefreshToken: "", + IDToken: account.GetCredential("id_token"), + ClientID: account.GetCredential("client_id"), + Email: account.GetCredential("email"), + ChatGPTAccountID: account.GetCredential("chatgpt_account_id"), + ChatGPTUserID: account.GetCredential("chatgpt_user_id"), + OrganizationID: account.GetCredential("organization_id"), + PlanType: account.GetCredential("plan_type"), + } + if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil { + tokenInfo.ExpiresAt = expiresAt.Unix() + tokenInfo.ExpiresIn = int64(time.Until(*expiresAt).Seconds()) + } + return tokenInfo, nil + } return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available") } diff --git a/backend/internal/service/openai_oauth_service_refresh_test.go b/backend/internal/service/openai_oauth_service_refresh_test.go new file mode 100644 index 00000000..a31eb8cb --- /dev/null +++ b/backend/internal/service/openai_oauth_service_refresh_test.go @@ -0,0 +1,54 @@ +package service + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientRefreshStub struct { + refreshCalls int32 +} + +func (s *openaiOAuthClientRefreshStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientRefreshStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + atomic.AddInt32(&s.refreshCalls, 1) + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientRefreshStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + atomic.AddInt32(&s.refreshCalls, 1) + return nil, errors.New("not implemented") +} + +func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccessToken(t *testing.T) { + client := &openaiOAuthClientRefreshStub{} + svc := NewOpenAIOAuthService(nil, client) + + expiresAt := time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339) + account := &Account{ + ID: 77, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "existing-access-token", + "expires_at": expiresAt, + "client_id": "client-id-1", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, "existing-access-token", info.AccessToken) + require.Equal(t, "client-id-1", info.ClientID) + require.Zero(t, atomic.LoadInt32(&client.refreshCalls), "existing access token should be reused without calling refresh") +} diff --git a/backend/internal/service/openai_privacy_retry_test.go b/backend/internal/service/openai_privacy_retry_test.go new file mode 100644 index 00000000..24534ea9 --- /dev/null +++ b/backend/internal/service/openai_privacy_retry_test.go @@ -0,0 +1,89 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/imroc/req/v3" + "github.com/stretchr/testify/require" +) + +func TestAdminService_EnsureOpenAIPrivacy_RetriesNonSuccessModes(t *testing.T) { + t.Parallel() + + for _, mode := range []string{PrivacyModeFailed, PrivacyModeCFBlocked} { + t.Run(mode, func(t *testing.T) { + t.Parallel() + + privacyCalls := 0 + svc := &adminServiceImpl{ + accountRepo: &mockAccountRepoForGemini{}, + privacyClientFactory: func(proxyURL string) (*req.Client, error) { + privacyCalls++ + return nil, errors.New("factory failed") + }, + } + + account := &Account{ + ID: 101, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-1", + }, + Extra: map[string]any{ + "privacy_mode": mode, + }, + } + + got := svc.EnsureOpenAIPrivacy(context.Background(), account) + + require.Equal(t, PrivacyModeFailed, got) + require.Equal(t, 1, privacyCalls) + }) + } +} + +func TestTokenRefreshService_ensureOpenAIPrivacy_RetriesNonSuccessModes(t *testing.T) { + t.Parallel() + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + + for _, mode := range []string{PrivacyModeFailed, PrivacyModeCFBlocked} { + t.Run(mode, func(t *testing.T) { + t.Parallel() + + service := NewTokenRefreshService(&tokenRefreshAccountRepo{}, nil, nil, nil, nil, nil, nil, cfg, nil) + privacyCalls := 0 + service.SetPrivacyDeps(func(proxyURL string) (*req.Client, error) { + privacyCalls++ + return nil, errors.New("factory failed") + }, nil) + + account := &Account{ + ID: 202, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-2", + }, + Extra: map[string]any{ + "privacy_mode": mode, + }, + } + + service.ensureOpenAIPrivacy(context.Background(), account) + + require.Equal(t, 1, privacyCalls) + }) + } +} diff --git a/backend/internal/service/openai_privacy_service.go b/backend/internal/service/openai_privacy_service.go index 90cd522d..6bc71ab9 100644 --- a/backend/internal/service/openai_privacy_service.go +++ b/backend/internal/service/openai_privacy_service.go @@ -22,6 +22,19 @@ const ( PrivacyModeCFBlocked = "training_set_cf_blocked" ) +func shouldSkipOpenAIPrivacyEnsure(extra map[string]any) bool { + if extra == nil { + return false + } + raw, ok := extra["privacy_mode"] + if !ok { + return false + } + mode, _ := raw.(string) + mode = strings.TrimSpace(mode) + return mode != PrivacyModeFailed && mode != PrivacyModeCFBlocked +} + // disableOpenAITraining calls ChatGPT settings API to turn off "Improve the model for everyone". // Returns privacy_mode value: "training_off" on success, "cf_blocked" / "failed" on failure. func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFactory, accessToken, proxyURL string) string { @@ -69,6 +82,139 @@ func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFacto return PrivacyModeTrainingOff } +// ChatGPTAccountInfo 从 chatgpt.com/backend-api/accounts/check 获取的账号信息 +type ChatGPTAccountInfo struct { + PlanType string + Email string +} + +const chatGPTAccountsCheckURL = "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27" + +// fetchChatGPTAccountInfo calls ChatGPT backend-api to get account info (plan_type, etc.). +// Used as fallback when id_token doesn't contain these fields (e.g., Mobile RT). +// orgID is used to match the correct account when multiple accounts exist (e.g., personal + team). +// Returns nil on any failure (best-effort, non-blocking). +func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFactory, accessToken, proxyURL, orgID string) *ChatGPTAccountInfo { + if accessToken == "" || clientFactory == nil { + return nil + } + + ctx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + + client, err := clientFactory(proxyURL) + if err != nil { + slog.Debug("chatgpt_account_check_client_error", "error", err.Error()) + return nil + } + + var result map[string]any + resp, err := client.R(). + SetContext(ctx). + SetHeader("Authorization", "Bearer "+accessToken). + SetHeader("Origin", "https://chatgpt.com"). + SetHeader("Referer", "https://chatgpt.com/"). + SetHeader("Accept", "application/json"). + SetSuccessResult(&result). + Get(chatGPTAccountsCheckURL) + + if err != nil { + slog.Debug("chatgpt_account_check_request_error", "error", err.Error()) + return nil + } + + if !resp.IsSuccessState() { + slog.Debug("chatgpt_account_check_failed", "status", resp.StatusCode, "body", truncate(resp.String(), 200)) + return nil + } + + info := &ChatGPTAccountInfo{} + + accounts, ok := result["accounts"].(map[string]any) + if !ok { + slog.Debug("chatgpt_account_check_no_accounts", "body", truncate(resp.String(), 300)) + return nil + } + + // 优先匹配 orgID 对应的账号(access_token JWT 中的 poid) + if orgID != "" { + if matched := extractPlanFromAccount(accounts, orgID); matched != "" { + info.PlanType = matched + } + } + + // 未匹配到时,遍历所有账号:优先 is_default,次选非 free + if info.PlanType == "" { + var defaultPlan, paidPlan, anyPlan string + for _, acctRaw := range accounts { + acct, ok := acctRaw.(map[string]any) + if !ok { + continue + } + planType := extractPlanType(acct) + if planType == "" { + continue + } + if anyPlan == "" { + anyPlan = planType + } + if account, ok := acct["account"].(map[string]any); ok { + if isDefault, _ := account["is_default"].(bool); isDefault { + defaultPlan = planType + } + } + if !strings.EqualFold(planType, "free") && paidPlan == "" { + paidPlan = planType + } + } + // 优先级:default > 非 free > 任意 + switch { + case defaultPlan != "": + info.PlanType = defaultPlan + case paidPlan != "": + info.PlanType = paidPlan + default: + info.PlanType = anyPlan + } + } + + if info.PlanType == "" { + slog.Debug("chatgpt_account_check_no_plan_type", "body", truncate(resp.String(), 300)) + return nil + } + + slog.Info("chatgpt_account_check_success", "plan_type", info.PlanType, "org_id", orgID) + return info +} + +// extractPlanFromAccount 从 accounts map 中按 key(account_id)精确匹配并提取 plan_type +func extractPlanFromAccount(accounts map[string]any, accountKey string) string { + acctRaw, ok := accounts[accountKey] + if !ok { + return "" + } + acct, ok := acctRaw.(map[string]any) + if !ok { + return "" + } + return extractPlanType(acct) +} + +// extractPlanType 从单个 account 对象中提取 plan_type +func extractPlanType(acct map[string]any) string { + if account, ok := acct["account"].(map[string]any); ok { + if planType, ok := account["plan_type"].(string); ok && planType != "" { + return planType + } + } + if entitlement, ok := acct["entitlement"].(map[string]any); ok { + if subPlan, ok := entitlement["subscription_plan"].(string); ok && subPlan != "" { + return subPlan + } + } + return "" +} + func truncate(s string, n int) string { if len(s) <= n { return s diff --git a/backend/internal/service/openai_ws_account_sticky_test.go b/backend/internal/service/openai_ws_account_sticky_test.go index 9a8803d3..a5b97ca9 100644 --- a/backend/internal/service/openai_ws_account_sticky_test.go +++ b/backend/internal/service/openai_ws_account_sticky_test.go @@ -85,6 +85,58 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss( require.Zero(t, boundAccountID) } +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheckRateLimitedMiss(t *testing.T) { + ctx := context.Background() + groupID := int64(24) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + staleAccount := &Account{ + ID: 13, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + dbAccount := Account{ + ID: 13, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + RateLimitResetAt: &rateLimitedUntil, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + snapshotCache := &openAISnapshotCacheStub{ + accountsByID: map[int64]*Account{dbAccount.ID: staleAccount}, + } + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbAccount}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache}, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_db_rl", dbAccount.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil) + require.NoError(t, err) + require.Nil(t, selection, "DB 中已限流的账号不应继续命中 previous_response_id 粘连") + boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_db_rl") + require.NoError(t, getErr) + require.Zero(t, boundAccountID) +} + func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) { ctx := context.Background() groupID := int64(23) diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 1d3d8fdf..1ebe5542 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -2328,6 +2328,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( RequestID: responseID, Usage: *usage, Model: originalModel, + UpstreamModel: mappedModel, ServiceTier: extractOpenAIServiceTier(reqBody), ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), Stream: reqStream, @@ -2514,12 +2515,9 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } normalized = next } - mappedModel := account.GetMappedModel(originalModel) - if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { - mappedModel = normalizedModel - } - if mappedModel != originalModel { - next, setErr := applyPayloadMutation(normalized, "model", mappedModel) + upstreamModel := resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel)) + if upstreamModel != originalModel { + next, setErr := applyPayloadMutation(normalized, "model", upstreamModel) if setErr != nil { return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) } @@ -2775,10 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( mappedModel := "" var mappedModelBytes []byte if originalModel != "" { - mappedModel = account.GetMappedModel(originalModel) - if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { - mappedModel = normalizedModel - } + mappedModel = resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel)) needModelReplace = mappedModel != "" && mappedModel != originalModel if needModelReplace { mappedModelBytes = []byte(mappedModel) @@ -2945,6 +2940,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( RequestID: responseID, Usage: usage, Model: originalModel, + UpstreamModel: mappedModel, ServiceTier: extractOpenAIServiceTierFromBody(payload), ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), Stream: reqStream, @@ -3844,6 +3840,11 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( if requestedModel != "" && !account.IsModelSupported(requestedModel) { return nil, nil } + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) + if account == nil { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if acquireErr == nil && result.Acquired { diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go index 76c66f2f..8c5c9368 100644 --- a/backend/internal/service/openai_ws_protocol_forward_test.go +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/stretchr/testify/require" @@ -57,7 +58,7 @@ func (u *httpUpstreamSequenceRecorder) Do(req *http.Request, proxyURL string, ac return u.responses[len(u.responses)-1], nil } -func (u *httpUpstreamSequenceRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { +func (u *httpUpstreamSequenceRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { return u.Do(req, proxyURL, accountID, accountConcurrency) } diff --git a/backend/internal/service/openai_ws_ratelimit_signal_test.go b/backend/internal/service/openai_ws_ratelimit_signal_test.go index f5c79923..ffe79152 100644 --- a/backend/internal/service/openai_ws_ratelimit_signal_test.go +++ b/backend/internal/service/openai_ws_ratelimit_signal_test.go @@ -73,12 +73,13 @@ func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, re return nil } -func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { +func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { _ = platform _ = accountType _ = status _ = search _ = groupID + _ = privacyMode return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil } @@ -491,7 +492,7 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount( } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0) + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "") require.NoError(t, err) require.Equal(t, int64(1), total) require.Len(t, accounts, 1) diff --git a/backend/internal/service/ops_alert_evaluator_service.go b/backend/internal/service/ops_alert_evaluator_service.go index 88883180..11c5d5ce 100644 --- a/backend/internal/service/ops_alert_evaluator_service.go +++ b/backend/internal/service/ops_alert_evaluator_service.go @@ -88,6 +88,7 @@ func (s *OpsAlertEvaluatorService) Start() { if s.stopCh == nil { s.stopCh = make(chan struct{}) } + s.wg.Add(1) go s.run() }) } @@ -105,7 +106,6 @@ func (s *OpsAlertEvaluatorService) Stop() { } func (s *OpsAlertEvaluatorService) run() { - s.wg.Add(1) defer s.wg.Done() // Start immediately to produce early feedback in ops dashboard. @@ -848,7 +848,9 @@ func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, loc return nil, false } return func() { - _, _ = opsAlertEvaluatorReleaseScript.Run(ctx, s.redisClient, []string{key}, s.instanceID).Result() + releaseCtx, releaseCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer releaseCancel() + _, _ = opsAlertEvaluatorReleaseScript.Run(releaseCtx, s.redisClient, []string{key}, s.instanceID).Result() }, true } diff --git a/backend/internal/service/ops_concurrency.go b/backend/internal/service/ops_concurrency.go index a571dd4d..ad303d92 100644 --- a/backend/internal/service/ops_concurrency.go +++ b/backend/internal/service/ops_concurrency.go @@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{ Page: page, PageSize: opsAccountsPageSize, - }, platformFilter, "", "", "", 0) + }, platformFilter, "", "", "", 0, "") if err != nil { return nil, err } diff --git a/backend/internal/service/ops_models.go b/backend/internal/service/ops_models.go index 2ed06d90..5fefb74f 100644 --- a/backend/internal/service/ops_models.go +++ b/backend/internal/service/ops_models.go @@ -62,6 +62,12 @@ type OpsErrorLog struct { ClientIP *string `json:"client_ip"` RequestPath string `json:"request_path"` Stream bool `json:"stream"` + + InboundEndpoint string `json:"inbound_endpoint"` + UpstreamEndpoint string `json:"upstream_endpoint"` + RequestedModel string `json:"requested_model"` + UpstreamModel string `json:"upstream_model"` + RequestType *int16 `json:"request_type"` } type OpsErrorLogDetail struct { diff --git a/backend/internal/service/ops_port.go b/backend/internal/service/ops_port.go index 0ce9d425..04bf91c8 100644 --- a/backend/internal/service/ops_port.go +++ b/backend/internal/service/ops_port.go @@ -79,6 +79,17 @@ type OpsInsertErrorLogInput struct { Model string RequestPath string Stream bool + // InboundEndpoint is the normalized client-facing API endpoint path, e.g. /v1/chat/completions. + InboundEndpoint string + // UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses. + UpstreamEndpoint string + // RequestedModel is the client-requested model name before mapping. + RequestedModel string + // UpstreamModel is the actual model sent to upstream after mapping. Empty means no mapping. + UpstreamModel string + // RequestType is the granular request type: 0=unknown, 1=sync, 2=stream, 3=ws_v2. + // Matches service.RequestType enum semantics from usage_log.go. + RequestType *int16 UserAgent string ErrorPhase string diff --git a/backend/internal/service/ops_upstream_context.go b/backend/internal/service/ops_upstream_context.go index 21e09c43..05d444e1 100644 --- a/backend/internal/service/ops_upstream_context.go +++ b/backend/internal/service/ops_upstream_context.go @@ -53,6 +53,13 @@ func SetOpsLatencyMs(c *gin.Context, key string, value int64) { c.Set(key, value) } +// SetOpsUpstreamError is the exported wrapper for setOpsUpstreamError, used by +// handler-layer code (e.g. failover-exhausted paths) that needs to record the +// original upstream status code before mapping it to a client-facing code. +func SetOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) { + setOpsUpstreamError(c, upstreamStatusCode, upstreamMessage, upstreamDetail) +} + func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) { if c == nil { return @@ -86,6 +93,10 @@ type OpsUpstreamErrorEvent struct { UpstreamStatusCode int `json:"upstream_status_code,omitempty"` UpstreamRequestID string `json:"upstream_request_id,omitempty"` + // UpstreamURL is the actual upstream URL that was called (host + path, query/fragment stripped). + // Helps debug 404/routing errors by showing which endpoint was targeted. + UpstreamURL string `json:"upstream_url,omitempty"` + // Best-effort upstream request capture (sanitized+trimmed). // Required for retrying a specific upstream attempt. UpstreamRequestBody string `json:"upstream_request_body,omitempty"` @@ -112,6 +123,7 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) { ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody) ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody) ev.Kind = strings.TrimSpace(ev.Kind) + ev.UpstreamURL = strings.TrimSpace(ev.UpstreamURL) ev.Message = strings.TrimSpace(ev.Message) ev.Detail = strings.TrimSpace(ev.Detail) if ev.Message != "" { @@ -198,3 +210,19 @@ func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) { } return out, nil } + +// safeUpstreamURL returns scheme + host + path from a URL, stripping query/fragment +// to avoid leaking sensitive query parameters (e.g. OAuth tokens). +func safeUpstreamURL(rawURL string) string { + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" { + return "" + } + if idx := strings.IndexByte(rawURL, '?'); idx >= 0 { + rawURL = rawURL[:idx] + } + if idx := strings.IndexByte(rawURL, '#'); idx >= 0 { + rawURL = rawURL[:idx] + } + return rawURL +} diff --git a/backend/internal/service/ops_upstream_context_test.go b/backend/internal/service/ops_upstream_context_test.go index 50ceaa0e..fa6d1085 100644 --- a/backend/internal/service/ops_upstream_context_test.go +++ b/backend/internal/service/ops_upstream_context_test.go @@ -8,6 +8,27 @@ import ( "github.com/stretchr/testify/require" ) +func TestSafeUpstreamURL(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"strips query", "https://api.anthropic.com/v1/messages?beta=true", "https://api.anthropic.com/v1/messages"}, + {"strips fragment", "https://api.openai.com/v1/responses#frag", "https://api.openai.com/v1/responses"}, + {"strips both", "https://host/path?token=secret#x", "https://host/path"}, + {"no query or fragment", "https://host/path", "https://host/path"}, + {"empty string", "", ""}, + {"whitespace only", " ", ""}, + {"query before fragment", "https://h/p?a=1#f", "https://h/p"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, safeUpstreamURL(tt.input)) + }) + } +} + func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() diff --git a/backend/internal/service/overload_cooldown_test.go b/backend/internal/service/overload_cooldown_test.go new file mode 100644 index 00000000..ef5e7fd1 --- /dev/null +++ b/backend/internal/service/overload_cooldown_test.go @@ -0,0 +1,298 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// errSettingRepo: a SettingRepository that always returns errors on read +// --------------------------------------------------------------------------- + +type errSettingRepo struct { + mockSettingRepo // embed the existing mock from backup_service_test.go + readErr error +} + +func (r *errSettingRepo) GetValue(_ context.Context, _ string) (string, error) { + return "", r.readErr +} + +func (r *errSettingRepo) Get(_ context.Context, _ string) (*Setting, error) { + return nil, r.readErr +} + +// --------------------------------------------------------------------------- +// overloadAccountRepoStub: records SetOverloaded calls +// --------------------------------------------------------------------------- + +type overloadAccountRepoStub struct { + mockAccountRepoForGemini + overloadCalls int + lastOverloadID int64 + lastOverloadEnd time.Time +} + +func (r *overloadAccountRepoStub) SetOverloaded(_ context.Context, id int64, until time.Time) error { + r.overloadCalls++ + r.lastOverloadID = id + r.lastOverloadEnd = until + return nil +} + +// =========================================================================== +// SettingService: GetOverloadCooldownSettings +// =========================================================================== + +func TestGetOverloadCooldownSettings_DefaultsWhenNotSet(t *testing.T) { + repo := newMockSettingRepo() + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetOverloadCooldownSettings(context.Background()) + require.NoError(t, err) + require.True(t, settings.Enabled) + require.Equal(t, 10, settings.CooldownMinutes) +} + +func TestGetOverloadCooldownSettings_ReadsFromDB(t *testing.T) { + repo := newMockSettingRepo() + data, _ := json.Marshal(OverloadCooldownSettings{Enabled: false, CooldownMinutes: 30}) + repo.data[SettingKeyOverloadCooldownSettings] = string(data) + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetOverloadCooldownSettings(context.Background()) + require.NoError(t, err) + require.False(t, settings.Enabled) + require.Equal(t, 30, settings.CooldownMinutes) +} + +func TestGetOverloadCooldownSettings_ClampsMinValue(t *testing.T) { + repo := newMockSettingRepo() + data, _ := json.Marshal(OverloadCooldownSettings{Enabled: true, CooldownMinutes: 0}) + repo.data[SettingKeyOverloadCooldownSettings] = string(data) + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetOverloadCooldownSettings(context.Background()) + require.NoError(t, err) + require.Equal(t, 1, settings.CooldownMinutes) +} + +func TestGetOverloadCooldownSettings_ClampsMaxValue(t *testing.T) { + repo := newMockSettingRepo() + data, _ := json.Marshal(OverloadCooldownSettings{Enabled: true, CooldownMinutes: 999}) + repo.data[SettingKeyOverloadCooldownSettings] = string(data) + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetOverloadCooldownSettings(context.Background()) + require.NoError(t, err) + require.Equal(t, 120, settings.CooldownMinutes) +} + +func TestGetOverloadCooldownSettings_InvalidJSON_ReturnsDefaults(t *testing.T) { + repo := newMockSettingRepo() + repo.data[SettingKeyOverloadCooldownSettings] = "not-json" + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetOverloadCooldownSettings(context.Background()) + require.NoError(t, err) + require.True(t, settings.Enabled) + require.Equal(t, 10, settings.CooldownMinutes) +} + +func TestGetOverloadCooldownSettings_EmptyValue_ReturnsDefaults(t *testing.T) { + repo := newMockSettingRepo() + repo.data[SettingKeyOverloadCooldownSettings] = "" + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetOverloadCooldownSettings(context.Background()) + require.NoError(t, err) + require.True(t, settings.Enabled) + require.Equal(t, 10, settings.CooldownMinutes) +} + +// =========================================================================== +// SettingService: SetOverloadCooldownSettings +// =========================================================================== + +func TestSetOverloadCooldownSettings_Success(t *testing.T) { + repo := newMockSettingRepo() + svc := NewSettingService(repo, &config.Config{}) + + err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{ + Enabled: false, + CooldownMinutes: 25, + }) + require.NoError(t, err) + + // Verify round-trip + settings, err := svc.GetOverloadCooldownSettings(context.Background()) + require.NoError(t, err) + require.False(t, settings.Enabled) + require.Equal(t, 25, settings.CooldownMinutes) +} + +func TestSetOverloadCooldownSettings_RejectsNil(t *testing.T) { + svc := NewSettingService(newMockSettingRepo(), &config.Config{}) + err := svc.SetOverloadCooldownSettings(context.Background(), nil) + require.Error(t, err) +} + +func TestSetOverloadCooldownSettings_EnabledRejectsOutOfRange(t *testing.T) { + svc := NewSettingService(newMockSettingRepo(), &config.Config{}) + + for _, minutes := range []int{0, -1, 121, 999} { + err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{ + Enabled: true, CooldownMinutes: minutes, + }) + require.Error(t, err, "should reject enabled=true + cooldown_minutes=%d", minutes) + require.Contains(t, err.Error(), "cooldown_minutes must be between 1-120") + } +} + +func TestSetOverloadCooldownSettings_DisabledNormalizesOutOfRange(t *testing.T) { + repo := newMockSettingRepo() + svc := NewSettingService(repo, &config.Config{}) + + // enabled=false + cooldown_minutes=0 应该保存成功,值被归一化为10 + err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{ + Enabled: false, CooldownMinutes: 0, + }) + require.NoError(t, err, "disabled with invalid minutes should NOT be rejected") + + // 验证持久化后读回来的值 + settings, err := svc.GetOverloadCooldownSettings(context.Background()) + require.NoError(t, err) + require.False(t, settings.Enabled) + require.Equal(t, 10, settings.CooldownMinutes, "should be normalized to default") +} + +func TestSetOverloadCooldownSettings_AcceptsBoundaries(t *testing.T) { + svc := NewSettingService(newMockSettingRepo(), &config.Config{}) + + for _, minutes := range []int{1, 60, 120} { + err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{ + Enabled: true, CooldownMinutes: minutes, + }) + require.NoError(t, err, "should accept cooldown_minutes=%d", minutes) + } +} + +// =========================================================================== +// RateLimitService: handle529 behaviour +// =========================================================================== + +func TestHandle529_EnabledFromDB_PausesAccount(t *testing.T) { + accountRepo := &overloadAccountRepoStub{} + settingRepo := newMockSettingRepo() + data, _ := json.Marshal(OverloadCooldownSettings{Enabled: true, CooldownMinutes: 15}) + settingRepo.data[SettingKeyOverloadCooldownSettings] = string(data) + + settingSvc := NewSettingService(settingRepo, &config.Config{}) + svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil) + svc.SetSettingService(settingSvc) + + account := &Account{ID: 42, Platform: PlatformAnthropic, Type: AccountTypeOAuth} + before := time.Now() + svc.handle529(context.Background(), account) + + require.Equal(t, 1, accountRepo.overloadCalls) + require.Equal(t, int64(42), accountRepo.lastOverloadID) + require.WithinDuration(t, before.Add(15*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second) +} + +func TestHandle529_DisabledFromDB_SkipsAccount(t *testing.T) { + accountRepo := &overloadAccountRepoStub{} + settingRepo := newMockSettingRepo() + data, _ := json.Marshal(OverloadCooldownSettings{Enabled: false, CooldownMinutes: 15}) + settingRepo.data[SettingKeyOverloadCooldownSettings] = string(data) + + settingSvc := NewSettingService(settingRepo, &config.Config{}) + svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil) + svc.SetSettingService(settingSvc) + + account := &Account{ID: 42, Platform: PlatformAnthropic, Type: AccountTypeOAuth} + svc.handle529(context.Background(), account) + + require.Equal(t, 0, accountRepo.overloadCalls, "should NOT pause when disabled") +} + +func TestHandle529_NilSettingService_FallsBackToConfig(t *testing.T) { + accountRepo := &overloadAccountRepoStub{} + cfg := &config.Config{} + cfg.RateLimit.OverloadCooldownMinutes = 20 + svc := NewRateLimitService(accountRepo, nil, cfg, nil, nil) + // NOT calling SetSettingService — remains nil + + account := &Account{ID: 77, Platform: PlatformAnthropic, Type: AccountTypeOAuth} + before := time.Now() + svc.handle529(context.Background(), account) + + require.Equal(t, 1, accountRepo.overloadCalls) + require.WithinDuration(t, before.Add(20*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second) +} + +func TestHandle529_NilSettingService_ZeroConfig_DefaultsTen(t *testing.T) { + accountRepo := &overloadAccountRepoStub{} + svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil) + + account := &Account{ID: 88, Platform: PlatformAnthropic, Type: AccountTypeOAuth} + before := time.Now() + svc.handle529(context.Background(), account) + + require.Equal(t, 1, accountRepo.overloadCalls) + require.WithinDuration(t, before.Add(10*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second) +} + +func TestHandle529_DBReadError_FallsBackToConfig(t *testing.T) { + accountRepo := &overloadAccountRepoStub{} + errRepo := &errSettingRepo{readErr: context.DeadlineExceeded} + errRepo.data = make(map[string]string) + + cfg := &config.Config{} + cfg.RateLimit.OverloadCooldownMinutes = 7 + settingSvc := NewSettingService(errRepo, cfg) + svc := NewRateLimitService(accountRepo, nil, cfg, nil, nil) + svc.SetSettingService(settingSvc) + + account := &Account{ID: 99, Platform: PlatformAnthropic, Type: AccountTypeOAuth} + before := time.Now() + svc.handle529(context.Background(), account) + + require.Equal(t, 1, accountRepo.overloadCalls) + require.WithinDuration(t, before.Add(7*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second) +} + +// =========================================================================== +// Model: defaults & JSON round-trip +// =========================================================================== + +func TestDefaultOverloadCooldownSettings(t *testing.T) { + d := DefaultOverloadCooldownSettings() + require.True(t, d.Enabled) + require.Equal(t, 10, d.CooldownMinutes) +} + +func TestOverloadCooldownSettings_JSONRoundTrip(t *testing.T) { + original := OverloadCooldownSettings{Enabled: false, CooldownMinutes: 42} + data, err := json.Marshal(original) + require.NoError(t, err) + + var decoded OverloadCooldownSettings + require.NoError(t, json.Unmarshal(data, &decoded)) + require.Equal(t, original, decoded) + + // Verify JSON uses snake_case field names + var raw map[string]any + require.NoError(t, json.Unmarshal(data, &raw)) + _, hasEnabled := raw["enabled"] + _, hasCooldown := raw["cooldown_minutes"] + require.True(t, hasEnabled, "JSON must use 'enabled'") + require.True(t, hasCooldown, "JSON must use 'cooldown_minutes'") +} diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 7ed4e7e4..5623d4b7 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -34,6 +34,22 @@ var ( Mode: "chat", SupportsPromptCaching: true, } + openAIGPT54MiniFallbackPricing = &LiteLLMModelPricing{ + InputCostPerToken: 7.5e-07, + OutputCostPerToken: 4.5e-06, + CacheReadInputTokenCost: 7.5e-08, + LiteLLMProvider: "openai", + Mode: "chat", + SupportsPromptCaching: true, + } + openAIGPT54NanoFallbackPricing = &LiteLLMModelPricing{ + InputCostPerToken: 2e-07, + OutputCostPerToken: 1.25e-06, + CacheReadInputTokenCost: 2e-08, + LiteLLMProvider: "openai", + Mode: "chat", + SupportsPromptCaching: true, + } ) // LiteLLMModelPricing LiteLLM价格数据结构 @@ -173,10 +189,38 @@ func (s *PricingService) checkAndUpdatePricing() error { return s.downloadPricingData() } - // 检查文件是否过期 + // 先加载本地文件(确保服务可用),再检查是否需要更新 + if err := s.loadPricingData(pricingFile); err != nil { + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to load local file, downloading: %v", err) + return s.downloadPricingData() + } + + // 如果配置了哈希URL,通过远程哈希检查是否有更新 + if s.cfg.Pricing.HashURL != "" { + remoteHash, err := s.fetchRemoteHash() + if err != nil { + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to fetch remote hash on startup: %v", err) + return nil // 已加载本地文件,哈希获取失败不影响启动 + } + + s.mu.RLock() + localHash := s.localHash + s.mu.RUnlock() + + if localHash == "" || remoteHash != localHash { + logger.LegacyPrintf("service.pricing", "[Pricing] Remote hash differs on startup (local=%s remote=%s), downloading...", + localHash[:min(8, len(localHash))], remoteHash[:min(8, len(remoteHash))]) + if err := s.downloadPricingData(); err != nil { + logger.LegacyPrintf("service.pricing", "[Pricing] Download failed, using existing file: %v", err) + } + } + return nil + } + + // 没有哈希URL时,基于文件年龄检查 info, err := os.Stat(pricingFile) if err != nil { - return s.downloadPricingData() + return nil // 已加载本地文件 } fileAge := time.Since(info.ModTime()) @@ -189,21 +233,11 @@ func (s *PricingService) checkAndUpdatePricing() error { } } - // 加载本地文件 - return s.loadPricingData(pricingFile) + return nil } // syncWithRemote 与远程同步(基于哈希校验) func (s *PricingService) syncWithRemote() error { - pricingFile := s.getPricingFilePath() - - // 计算本地文件哈希 - localHash, err := s.computeFileHash(pricingFile) - if err != nil { - logger.LegacyPrintf("service.pricing", "[Pricing] Failed to compute local hash: %v", err) - return s.downloadPricingData() - } - // 如果配置了哈希URL,从远程获取哈希进行比对 if s.cfg.Pricing.HashURL != "" { remoteHash, err := s.fetchRemoteHash() @@ -212,8 +246,13 @@ func (s *PricingService) syncWithRemote() error { return nil // 哈希获取失败不影响正常使用 } - if remoteHash != localHash { - logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Remote hash differs, downloading new version...") + s.mu.RLock() + localHash := s.localHash + s.mu.RUnlock() + + if localHash == "" || remoteHash != localHash { + logger.LegacyPrintf("service.pricing", "[Pricing] Remote hash differs (local=%s remote=%s), downloading new version...", + localHash[:min(8, len(localHash))], remoteHash[:min(8, len(remoteHash))]) return s.downloadPricingData() } logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Hash check passed, no update needed") @@ -221,6 +260,7 @@ func (s *PricingService) syncWithRemote() error { } // 没有哈希URL时,基于时间检查 + pricingFile := s.getPricingFilePath() info, err := os.Stat(pricingFile) if err != nil { return s.downloadPricingData() @@ -248,11 +288,12 @@ func (s *PricingService) downloadPricingData() error { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - var expectedHash string + // 获取远程哈希(用于同步锚点,不作为完整性校验) + var remoteHash string if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" { - expectedHash, err = s.fetchRemoteHash() + remoteHash, err = s.fetchRemoteHash() if err != nil { - return fmt.Errorf("fetch remote hash: %w", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to fetch remote hash (continuing): %v", err) } } @@ -261,11 +302,13 @@ func (s *PricingService) downloadPricingData() error { return fmt.Errorf("download failed: %w", err) } - if expectedHash != "" { - actualHash := sha256.Sum256(body) - if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) { - return fmt.Errorf("pricing hash mismatch") - } + // 哈希校验:不匹配时仅告警,不阻止更新 + // 远程哈希文件可能与数据文件不同步(如维护者更新了数据但未更新哈希文件) + dataHash := sha256.Sum256(body) + dataHashStr := hex.EncodeToString(dataHash[:]) + if remoteHash != "" && !strings.EqualFold(remoteHash, dataHashStr) { + logger.LegacyPrintf("service.pricing", "[Pricing] Hash mismatch warning: remote=%s data=%s (hash file may be out of sync)", + remoteHash[:min(8, len(remoteHash))], dataHashStr[:8]) } // 解析JSON数据(使用灵活的解析方式) @@ -280,11 +323,14 @@ func (s *PricingService) downloadPricingData() error { logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save file: %v", err) } - // 保存哈希 - hash := sha256.Sum256(body) - hashStr := hex.EncodeToString(hash[:]) + // 使用远程哈希作为同步锚点,防止重复下载 + // 当远程哈希不可用时,回退到数据本身的哈希 + syncHash := dataHashStr + if remoteHash != "" { + syncHash = remoteHash + } hashFile := s.getHashFilePath() - if err := os.WriteFile(hashFile, []byte(hashStr+"\n"), 0644); err != nil { + if err := os.WriteFile(hashFile, []byte(syncHash+"\n"), 0644); err != nil { logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save hash: %v", err) } @@ -292,7 +338,7 @@ func (s *PricingService) downloadPricingData() error { s.mu.Lock() s.pricingData = data s.lastUpdated = time.Now() - s.localHash = hashStr + s.localHash = syncHash s.mu.Unlock() logger.LegacyPrintf("service.pricing", "[Pricing] Downloaded %d models successfully", len(data)) @@ -470,16 +516,6 @@ func (s *PricingService) validatePricingURL(raw string) (string, error) { return normalized, nil } -// computeFileHash 计算文件哈希 -func (s *PricingService) computeFileHash(filePath string) (string, error) { - data, err := os.ReadFile(filePath) - if err != nil { - return "", err - } - hash := sha256.Sum256(data) - return hex.EncodeToString(hash[:]), nil -} - // GetModelPricing 获取模型价格(带模糊匹配) func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing { s.mu.RLock() @@ -723,6 +759,18 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { } } + if strings.HasPrefix(model, "gpt-5.4-mini") { + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4-mini(static)")) + return openAIGPT54MiniFallbackPricing + } + + if strings.HasPrefix(model, "gpt-5.4-nano") { + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4-nano(static)")) + return openAIGPT54NanoFallbackPricing + } + if strings.HasPrefix(model, "gpt-5.4") { logger.With(zap.String("component", "service.pricing")). Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4(static)")) diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go index 775024fd..13a5c70c 100644 --- a/backend/internal/service/pricing_service_test.go +++ b/backend/internal/service/pricing_service_test.go @@ -98,6 +98,36 @@ func TestGetModelPricing_Gpt54UsesStaticFallbackWhenRemoteMissing(t *testing.T) require.InDelta(t, 1.5, got.LongContextOutputCostMultiplier, 1e-12) } +func TestGetModelPricing_Gpt54MiniUsesDedicatedStaticFallbackWhenRemoteMissing(t *testing.T) { + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.1-codex": {InputCostPerToken: 1.25e-6}, + }, + } + + got := svc.GetModelPricing("gpt-5.4-mini") + require.NotNil(t, got) + require.InDelta(t, 7.5e-7, got.InputCostPerToken, 1e-12) + require.InDelta(t, 4.5e-6, got.OutputCostPerToken, 1e-12) + require.InDelta(t, 7.5e-8, got.CacheReadInputTokenCost, 1e-12) + require.Zero(t, got.LongContextInputTokenThreshold) +} + +func TestGetModelPricing_Gpt54NanoUsesDedicatedStaticFallbackWhenRemoteMissing(t *testing.T) { + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.1-codex": {InputCostPerToken: 1.25e-6}, + }, + } + + got := svc.GetModelPricing("gpt-5.4-nano") + require.NotNil(t, got) + require.InDelta(t, 2e-7, got.InputCostPerToken, 1e-12) + require.InDelta(t, 1.25e-6, got.OutputCostPerToken, 1e-12) + require.InDelta(t, 2e-8, got.CacheReadInputTokenCost, 1e-12) + require.Zero(t, got.LongContextInputTokenThreshold) +} + func TestParsePricingData_PreservesPriorityAndServiceTierFields(t *testing.T) { raw := map[string]any{ "gpt-5.4": map[string]any{ diff --git a/backend/internal/service/proxy.go b/backend/internal/service/proxy.go index fc449091..a2896d6c 100644 --- a/backend/internal/service/proxy.go +++ b/backend/internal/service/proxy.go @@ -1,7 +1,9 @@ package service import ( - "fmt" + "net" + "net/url" + "strconv" "time" ) @@ -23,10 +25,14 @@ func (p *Proxy) IsActive() bool { } func (p *Proxy) URL() string { - if p.Username != "" && p.Password != "" { - return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port) + u := &url.URL{ + Scheme: p.Protocol, + Host: net.JoinHostPort(p.Host, strconv.Itoa(p.Port)), } - return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port) + if p.Username != "" && p.Password != "" { + u.User = url.UserPassword(p.Username, p.Password) + } + return u.String() } type ProxyWithAccountCount struct { diff --git a/backend/internal/service/proxy_test.go b/backend/internal/service/proxy_test.go new file mode 100644 index 00000000..da6d1236 --- /dev/null +++ b/backend/internal/service/proxy_test.go @@ -0,0 +1,95 @@ +package service + +import ( + "net/url" + "testing" +) + +func TestProxyURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + proxy Proxy + want string + }{ + { + name: "without auth", + proxy: Proxy{ + Protocol: "http", + Host: "proxy.example.com", + Port: 8080, + }, + want: "http://proxy.example.com:8080", + }, + { + name: "with auth", + proxy: Proxy{ + Protocol: "socks5", + Host: "socks.example.com", + Port: 1080, + Username: "user", + Password: "pass", + }, + want: "socks5://user:pass@socks.example.com:1080", + }, + { + name: "username only keeps no auth for compatibility", + proxy: Proxy{ + Protocol: "http", + Host: "proxy.example.com", + Port: 8080, + Username: "user-only", + }, + want: "http://proxy.example.com:8080", + }, + { + name: "with special characters in credentials", + proxy: Proxy{ + Protocol: "http", + Host: "proxy.example.com", + Port: 3128, + Username: "first last@corp", + Password: "p@ ss:#word", + }, + want: "http://first%20last%40corp:p%40%20ss%3A%23word@proxy.example.com:3128", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := tc.proxy.URL(); got != tc.want { + t.Fatalf("Proxy.URL() mismatch: got=%q want=%q", got, tc.want) + } + }) + } +} + +func TestProxyURL_SpecialCharactersRoundTrip(t *testing.T) { + t.Parallel() + + proxy := Proxy{ + Protocol: "http", + Host: "proxy.example.com", + Port: 3128, + Username: "first last@corp", + Password: "p@ ss:#word", + } + + parsed, err := url.Parse(proxy.URL()) + if err != nil { + t.Fatalf("parse proxy URL failed: %v", err) + } + if got := parsed.User.Username(); got != proxy.Username { + t.Fatalf("username mismatch after parse: got=%q want=%q", got, proxy.Username) + } + pass, ok := parsed.User.Password() + if !ok { + t.Fatal("password missing after parse") + } + if pass != proxy.Password { + t.Fatalf("password mismatch after parse: got=%q want=%q", pass, proxy.Password) + } +} diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 5861a811..aa0ae200 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -12,6 +12,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/tidwall/gjson" ) // RateLimitService 处理限流和过载状态管理 @@ -149,6 +150,17 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc } // 其他 400 错误(如参数问题)不处理,不禁用账号 case 401: + // OpenAI: token_invalidated / token_revoked 表示 token 被永久作废(非过期),直接标记 error + openai401Code := extractUpstreamErrorCode(responseBody) + if account.Platform == PlatformOpenAI && (openai401Code == "token_invalidated" || openai401Code == "token_revoked") { + msg := "Token revoked (401): account authentication permanently revoked" + if upstreamMsg != "" { + msg = "Token revoked (401): " + upstreamMsg + } + s.handleAuthError(ctx, account, msg) + shouldDisable = true + break + } // OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为。 // Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制。 if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity { @@ -163,7 +175,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc account.Credentials = make(map[string]any) } account.Credentials["expires_at"] = time.Now().Format(time.RFC3339) - if err := s.accountRepo.Update(ctx, account); err != nil { + if err := persistAccountCredentials(ctx, s.accountRepo, account, account.Credentials); err != nil { slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err) } else { slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform) @@ -192,6 +204,13 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc shouldDisable = true } case 402: + // OpenAI: deactivated_workspace 表示工作区已停用,直接标记 error + if account.Platform == PlatformOpenAI && gjson.GetBytes(responseBody, "detail.code").String() == "deactivated_workspace" { + msg := "Workspace deactivated (402): workspace has been deactivated" + s.handleAuthError(ctx, account, msg) + shouldDisable = true + break + } // 支付要求:余额不足或计费问题,停止调度 msg := "Payment required (402): insufficient balance or billing issue" if upstreamMsg != "" { @@ -1023,11 +1042,34 @@ func parseOpenAIRateLimitResetTime(body []byte) *int64 { } // handle529 处理529过载错误 -// 根据配置设置过载冷却时间 +// 根据配置决定是否暂停账号调度及冷却时长 func (s *RateLimitService) handle529(ctx context.Context, account *Account) { - cooldownMinutes := s.cfg.RateLimit.OverloadCooldownMinutes + var settings *OverloadCooldownSettings + if s.settingService != nil { + var err error + settings, err = s.settingService.GetOverloadCooldownSettings(ctx) + if err != nil { + slog.Warn("overload_settings_read_failed", "account_id", account.ID, "error", err) + settings = nil + } + } + // 回退到配置文件 + if settings == nil { + cooldown := s.cfg.RateLimit.OverloadCooldownMinutes + if cooldown <= 0 { + cooldown = 10 + } + settings = &OverloadCooldownSettings{Enabled: true, CooldownMinutes: cooldown} + } + + if !settings.Enabled { + slog.Info("account_529_ignored", "account_id", account.ID, "reason", "overload_cooldown_disabled") + return + } + + cooldownMinutes := settings.CooldownMinutes if cooldownMinutes <= 0 { - cooldownMinutes = 10 // 默认10分钟 + cooldownMinutes = 10 } until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) @@ -1051,18 +1093,49 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc var windowStart, windowEnd *time.Time needInitWindow := account.SessionWindowEnd == nil || time.Now().After(*account.SessionWindowEnd) - if needInitWindow && (status == "allowed" || status == "allowed_warning") { - // 预测时间窗口:从当前时间的整点开始,+5小时为结束 - // 例如:现在是 14:30,窗口为 14:00 ~ 19:00 + // 优先使用响应头中的真实重置时间(比预测更准确) + if resetStr := headers.Get("anthropic-ratelimit-unified-5h-reset"); resetStr != "" { + if ts, err := strconv.ParseInt(resetStr, 10, 64); err == nil { + // 检测可能的毫秒时间戳(秒级约为 1e9,毫秒约为 1e12) + if ts > 1e11 { + slog.Warn("account_session_window_header_millis_detected", "account_id", account.ID, "raw_reset", resetStr) + ts = ts / 1000 + } + end := time.Unix(ts, 0) + // 校验时间戳是否在合理范围内(不早于 5h 前,不晚于 7 天后) + minAllowed := time.Now().Add(-5 * time.Hour) + maxAllowed := time.Now().Add(7 * 24 * time.Hour) + if end.Before(minAllowed) || end.After(maxAllowed) { + slog.Warn("account_session_window_header_out_of_range", "account_id", account.ID, "raw_reset", resetStr, "parsed_end", end) + } else if needInitWindow || account.SessionWindowEnd == nil || !end.Equal(*account.SessionWindowEnd) { + // 窗口需要初始化,或者真实重置时间与已存储的不同,则更新 + start := end.Add(-5 * time.Hour) + windowStart = &start + windowEnd = &end + slog.Info("account_session_window_from_header", "account_id", account.ID, "window_start", start, "window_end", end, "status", status) + } + } else { + slog.Warn("account_session_window_header_parse_failed", "account_id", account.ID, "raw_reset", resetStr, "error", err) + } + } + + // 回退:如果没有真实重置时间且需要初始化窗口,使用预测 + if windowEnd == nil && needInitWindow && (status == "allowed" || status == "allowed_warning") { now := time.Now() start := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location()) end := start.Add(5 * time.Hour) windowStart = &start windowEnd = &end slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status) - // 窗口重置时清除旧的 utilization,避免残留上个窗口的数据 + } + + // 窗口重置时清除旧的 utilization 和被动采样数据,避免残留上个窗口的数据 + if windowEnd != nil && needInitWindow { _ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{ - "session_window_utilization": nil, + "session_window_utilization": nil, + "passive_usage_7d_utilization": nil, + "passive_usage_7d_reset": nil, + "passive_usage_sampled_at": nil, }) } @@ -1070,14 +1143,33 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err) } - // 存储真实的 utilization 值(0-1 小数),供 estimateSetupTokenUsage 使用 + // 被动采样:从响应头收集 5h + 7d utilization,合并为一次 DB 写入 + extraUpdates := make(map[string]any, 4) + // 5h utilization(0-1 小数),供 estimateSetupTokenUsage 使用 if utilStr := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilStr != "" { if util, err := strconv.ParseFloat(utilStr, 64); err == nil { - if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{ - "session_window_utilization": util, - }); err != nil { - slog.Warn("session_window_utilization_update_failed", "account_id", account.ID, "error", err) + extraUpdates["session_window_utilization"] = util + } + } + // 7d utilization(0-1 小数) + if utilStr := headers.Get("anthropic-ratelimit-unified-7d-utilization"); utilStr != "" { + if util, err := strconv.ParseFloat(utilStr, 64); err == nil { + extraUpdates["passive_usage_7d_utilization"] = util + } + } + // 7d reset timestamp + if resetStr := headers.Get("anthropic-ratelimit-unified-7d-reset"); resetStr != "" { + if ts, err := strconv.ParseInt(resetStr, 10, 64); err == nil { + if ts > 1e11 { + ts = ts / 1000 } + extraUpdates["passive_usage_7d_reset"] = ts + } + } + if len(extraUpdates) > 0 { + extraUpdates["passive_usage_sampled_at"] = time.Now().UTC().Format(time.RFC3339) + if err := s.accountRepo.UpdateExtra(ctx, account.ID, extraUpdates); err != nil { + slog.Warn("passive_usage_update_failed", "account_id", account.ID, "error", err) } } diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go index 4a6e5d6c..67b22e52 100644 --- a/backend/internal/service/ratelimit_service_401_test.go +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -15,9 +15,11 @@ import ( type rateLimitAccountRepoStub struct { mockAccountRepoForGemini - setErrorCalls int - tempCalls int - lastErrorMsg string + setErrorCalls int + tempCalls int + updateCredentialsCalls int + lastCredentials map[string]any + lastErrorMsg string } func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { @@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id return nil } +func (r *rateLimitAccountRepoStub) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error { + r.updateCredentialsCalls++ + r.lastCredentials = cloneCredentials(credentials) + return nil +} + type tokenCacheInvalidatorRecorder struct { accounts []*Account err error @@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin require.True(t, shouldDisable) require.Equal(t, 0, repo.setErrorCalls) require.Equal(t, 1, repo.tempCalls) + require.Equal(t, 1, repo.updateCredentialsCalls) require.Len(t, invalidator.accounts, 1) } @@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) { require.Equal(t, 1, repo.setErrorCalls) require.Empty(t, invalidator.accounts) } + +func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 103, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token", + }, + } + + shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.updateCredentialsCalls) + require.NotEmpty(t, repo.lastCredentials["expires_at"]) +} diff --git a/backend/internal/service/ratelimit_session_window_test.go b/backend/internal/service/ratelimit_session_window_test.go new file mode 100644 index 00000000..7796a85e --- /dev/null +++ b/backend/internal/service/ratelimit_session_window_test.go @@ -0,0 +1,370 @@ +package service + +import ( + "context" + "fmt" + "net/http" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +// sessionWindowMockRepo is a minimal AccountRepository mock that records calls +// made by UpdateSessionWindow. Unrelated methods panic if invoked. +type sessionWindowMockRepo struct { + // captured calls + sessionWindowCalls []swCall + updateExtraCalls []ueCall + clearRateLimitIDs []int64 +} + +var _ AccountRepository = (*sessionWindowMockRepo)(nil) + +type swCall struct { + ID int64 + Start *time.Time + End *time.Time + Status string +} + +type ueCall struct { + ID int64 + Updates map[string]any +} + +func (m *sessionWindowMockRepo) UpdateSessionWindow(_ context.Context, id int64, start, end *time.Time, status string) error { + m.sessionWindowCalls = append(m.sessionWindowCalls, swCall{ID: id, Start: start, End: end, Status: status}) + return nil +} +func (m *sessionWindowMockRepo) UpdateExtra(_ context.Context, id int64, updates map[string]any) error { + m.updateExtraCalls = append(m.updateExtraCalls, ueCall{ID: id, Updates: updates}) + return nil +} +func (m *sessionWindowMockRepo) ClearRateLimit(_ context.Context, id int64) error { + m.clearRateLimitIDs = append(m.clearRateLimitIDs, id) + return nil +} +func (m *sessionWindowMockRepo) ClearAntigravityQuotaScopes(_ context.Context, _ int64) error { + return nil +} +func (m *sessionWindowMockRepo) ClearModelRateLimits(_ context.Context, _ int64) error { + return nil +} +func (m *sessionWindowMockRepo) ClearTempUnschedulable(_ context.Context, _ int64) error { + return nil +} + +// --- Unused interface methods (panic on unexpected call) --- + +func (m *sessionWindowMockRepo) Create(context.Context, *Account) error { panic("unexpected") } +func (m *sessionWindowMockRepo) GetByID(context.Context, int64) (*Account, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) GetByIDs(context.Context, []int64) ([]*Account, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) ExistsByID(context.Context, int64) (bool, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) GetByCRSAccountID(context.Context, string) (*Account, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) FindByExtraField(context.Context, string, any) ([]Account, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) ListCRSAccountIDs(context.Context) (map[string]int64, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) Update(context.Context, *Account) error { panic("unexpected") } +func (m *sessionWindowMockRepo) Delete(context.Context, int64) error { panic("unexpected") } +func (m *sessionWindowMockRepo) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]Account, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) ListByGroup(context.Context, int64) ([]Account, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) ListActive(context.Context) ([]Account, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) ListByPlatform(context.Context, string) ([]Account, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) UpdateLastUsed(context.Context, int64) error { panic("unexpected") } +func (m *sessionWindowMockRepo) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error { + panic("unexpected") +} +func (m *sessionWindowMockRepo) SetError(context.Context, int64, string) error { + panic("unexpected") +} +func (m *sessionWindowMockRepo) ClearError(context.Context, int64) error { panic("unexpected") } +func (m *sessionWindowMockRepo) SetSchedulable(context.Context, int64, bool) error { + panic("unexpected") +} +func (m *sessionWindowMockRepo) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) BindGroups(context.Context, int64, []int64) error { + panic("unexpected") +} +func (m *sessionWindowMockRepo) ListSchedulable(context.Context) ([]Account, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) ListSchedulableByGroupID(context.Context, int64) ([]Account, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) ListSchedulableByPlatform(context.Context, string) ([]Account, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]Account, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) ListSchedulableByPlatforms(context.Context, []string) ([]Account, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]Account, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) ListSchedulableUngroupedByPlatform(context.Context, string) ([]Account, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) ListSchedulableUngroupedByPlatforms(context.Context, []string) ([]Account, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) SetRateLimited(context.Context, int64, time.Time) error { + panic("unexpected") +} +func (m *sessionWindowMockRepo) SetModelRateLimit(context.Context, int64, string, time.Time) error { + panic("unexpected") +} +func (m *sessionWindowMockRepo) SetOverloaded(context.Context, int64, time.Time) error { + panic("unexpected") +} +func (m *sessionWindowMockRepo) SetTempUnschedulable(context.Context, int64, time.Time, string) error { + panic("unexpected") +} +func (m *sessionWindowMockRepo) BulkUpdate(context.Context, []int64, AccountBulkUpdate) (int64, error) { + panic("unexpected") +} +func (m *sessionWindowMockRepo) IncrementQuotaUsed(context.Context, int64, float64) error { + panic("unexpected") +} +func (m *sessionWindowMockRepo) ResetQuotaUsed(context.Context, int64) error { panic("unexpected") } + +// newRateLimitServiceForTest creates a RateLimitService with the given mock repo. +func newRateLimitServiceForTest(repo AccountRepository) *RateLimitService { + return &RateLimitService{accountRepo: repo} +} + +func TestUpdateSessionWindow_UsesResetHeader(t *testing.T) { + // The reset header provides the real window end as a Unix timestamp. + // UpdateSessionWindow should use it instead of the hour-truncated prediction. + resetUnix := time.Now().Add(3 * time.Hour).Unix() + wantEnd := time.Unix(resetUnix, 0) + wantStart := wantEnd.Add(-5 * time.Hour) + + repo := &sessionWindowMockRepo{} + svc := newRateLimitServiceForTest(repo) + + account := &Account{ID: 42} // no existing window → needInitWindow=true + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-status", "allowed") + headers.Set("anthropic-ratelimit-unified-5h-reset", fmt.Sprintf("%d", resetUnix)) + + svc.UpdateSessionWindow(context.Background(), account, headers) + + if len(repo.sessionWindowCalls) != 1 { + t.Fatalf("expected 1 UpdateSessionWindow call, got %d", len(repo.sessionWindowCalls)) + } + + call := repo.sessionWindowCalls[0] + if call.ID != 42 { + t.Errorf("expected account ID 42, got %d", call.ID) + } + if call.End == nil || !call.End.Equal(wantEnd) { + t.Errorf("expected window end %v, got %v", wantEnd, call.End) + } + if call.Start == nil || !call.Start.Equal(wantStart) { + t.Errorf("expected window start %v, got %v", wantStart, call.Start) + } + if call.Status != "allowed" { + t.Errorf("expected status 'allowed', got %q", call.Status) + } +} + +func TestUpdateSessionWindow_FallbackPredictionWhenNoResetHeader(t *testing.T) { + // When the reset header is absent, should fall back to hour-truncated prediction. + repo := &sessionWindowMockRepo{} + svc := newRateLimitServiceForTest(repo) + + account := &Account{ID: 10} // no existing window + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-status", "allowed_warning") + // No anthropic-ratelimit-unified-5h-reset header + + // Capture now before the call to avoid hour-boundary races + now := time.Now() + expectedStart := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location()) + expectedEnd := expectedStart.Add(5 * time.Hour) + + svc.UpdateSessionWindow(context.Background(), account, headers) + + if len(repo.sessionWindowCalls) != 1 { + t.Fatalf("expected 1 UpdateSessionWindow call, got %d", len(repo.sessionWindowCalls)) + } + + call := repo.sessionWindowCalls[0] + if call.End == nil { + t.Fatal("expected window end to be set (fallback prediction)") + } + // Fallback: start = current hour truncated, end = start + 5h + + if !call.End.Equal(expectedEnd) { + t.Errorf("expected fallback end %v, got %v", expectedEnd, *call.End) + } + if call.Start == nil || !call.Start.Equal(expectedStart) { + t.Errorf("expected fallback start %v, got %v", expectedStart, call.Start) + } +} + +func TestUpdateSessionWindow_CorrectsStalePrediction(t *testing.T) { + // When the stored SessionWindowEnd is wrong (from a previous prediction), + // and the reset header provides the real time, it should update the window. + staleEnd := time.Now().Add(2 * time.Hour) // existing prediction: 2h from now + realResetUnix := time.Now().Add(4 * time.Hour).Unix() // real reset: 4h from now + wantEnd := time.Unix(realResetUnix, 0) + + repo := &sessionWindowMockRepo{} + svc := newRateLimitServiceForTest(repo) + + account := &Account{ + ID: 55, + SessionWindowEnd: &staleEnd, + } + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-status", "allowed") + headers.Set("anthropic-ratelimit-unified-5h-reset", fmt.Sprintf("%d", realResetUnix)) + + svc.UpdateSessionWindow(context.Background(), account, headers) + + if len(repo.sessionWindowCalls) != 1 { + t.Fatalf("expected 1 UpdateSessionWindow call, got %d", len(repo.sessionWindowCalls)) + } + + call := repo.sessionWindowCalls[0] + if call.End == nil || !call.End.Equal(wantEnd) { + t.Errorf("expected corrected end %v, got %v", wantEnd, call.End) + } +} + +func TestUpdateSessionWindow_NoUpdateWhenHeaderMatchesStored(t *testing.T) { + // If the reset header matches the stored SessionWindowEnd, no window update needed. + futureUnix := time.Now().Add(3 * time.Hour).Unix() + existingEnd := time.Unix(futureUnix, 0) + + repo := &sessionWindowMockRepo{} + svc := newRateLimitServiceForTest(repo) + + account := &Account{ + ID: 77, + SessionWindowEnd: &existingEnd, + } + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-status", "allowed") + headers.Set("anthropic-ratelimit-unified-5h-reset", fmt.Sprintf("%d", futureUnix)) // same as stored + + svc.UpdateSessionWindow(context.Background(), account, headers) + + if len(repo.sessionWindowCalls) != 1 { + t.Fatalf("expected 1 UpdateSessionWindow call, got %d", len(repo.sessionWindowCalls)) + } + + call := repo.sessionWindowCalls[0] + // windowStart and windowEnd should be nil (no update needed) + if call.Start != nil || call.End != nil { + t.Errorf("expected nil start/end (no window change needed), got start=%v end=%v", call.Start, call.End) + } + // Status is still updated + if call.Status != "allowed" { + t.Errorf("expected status 'allowed', got %q", call.Status) + } +} + +func TestUpdateSessionWindow_ClearsUtilizationOnWindowReset(t *testing.T) { + // When needInitWindow=true and window is set, utilization should be cleared. + resetUnix := time.Now().Add(3 * time.Hour).Unix() + + repo := &sessionWindowMockRepo{} + svc := newRateLimitServiceForTest(repo) + + account := &Account{ID: 33} // no existing window → needInitWindow=true + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-status", "allowed") + headers.Set("anthropic-ratelimit-unified-5h-reset", fmt.Sprintf("%d", resetUnix)) + headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.15") + + svc.UpdateSessionWindow(context.Background(), account, headers) + + // Should have 2 UpdateExtra calls: one to clear utilization, one to store new utilization + if len(repo.updateExtraCalls) != 2 { + t.Fatalf("expected 2 UpdateExtra calls, got %d", len(repo.updateExtraCalls)) + } + + // First call: clear utilization (nil value) + clearCall := repo.updateExtraCalls[0] + if clearCall.Updates["session_window_utilization"] != nil { + t.Errorf("expected utilization cleared to nil, got %v", clearCall.Updates["session_window_utilization"]) + } + + // Second call: store new utilization + storeCall := repo.updateExtraCalls[1] + if val, ok := storeCall.Updates["session_window_utilization"].(float64); !ok || val != 0.15 { + t.Errorf("expected utilization stored as 0.15, got %v", storeCall.Updates["session_window_utilization"]) + } +} + +func TestUpdateSessionWindow_NoClearUtilizationOnCorrection(t *testing.T) { + // When correcting a stale prediction (needInitWindow=false), utilization should NOT be cleared. + staleEnd := time.Now().Add(2 * time.Hour) + realResetUnix := time.Now().Add(4 * time.Hour).Unix() + + repo := &sessionWindowMockRepo{} + svc := newRateLimitServiceForTest(repo) + + account := &Account{ + ID: 66, + SessionWindowEnd: &staleEnd, + } + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-status", "allowed") + headers.Set("anthropic-ratelimit-unified-5h-reset", fmt.Sprintf("%d", realResetUnix)) + headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.30") + + svc.UpdateSessionWindow(context.Background(), account, headers) + + // Only 1 UpdateExtra call (store utilization), no clear call + if len(repo.updateExtraCalls) != 1 { + t.Fatalf("expected 1 UpdateExtra call (no clear), got %d", len(repo.updateExtraCalls)) + } + + if val, ok := repo.updateExtraCalls[0].Updates["session_window_utilization"].(float64); !ok || val != 0.30 { + t.Errorf("expected utilization 0.30, got %v", repo.updateExtraCalls[0].Updates["session_window_utilization"]) + } +} + +func TestUpdateSessionWindow_NoStatusHeader(t *testing.T) { + // Should return immediately if no status header. + repo := &sessionWindowMockRepo{} + svc := newRateLimitServiceForTest(repo) + + account := &Account{ID: 1} + + svc.UpdateSessionWindow(context.Background(), account, http.Header{}) + + if len(repo.sessionWindowCalls) != 0 { + t.Errorf("expected no calls when status header absent, got %d", len(repo.sessionWindowCalls)) + } +} diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index 4c9540f1..d1330abb 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -152,6 +152,14 @@ func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int return s.accountRepo.GetByID(fallbackCtx, accountID) } +// GetGroupByID 获取分组信息(供调度器使用) +func (s *SchedulerSnapshotService) GetGroupByID(ctx context.Context, groupID int64) (*Group, error) { + if s.groupRepo == nil { + return nil, nil + } + return s.groupRepo.GetByID(ctx, groupID) +} + // UpdateAccountInCache 立即更新 Redis 中单个账号的数据(用于模型限流后立即生效) func (s *SchedulerSnapshotService) UpdateAccountInCache(ctx context.Context, account *Account) error { if s.cache == nil || account == nil { diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 6cb13b11..f24707e7 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -44,26 +44,27 @@ type SettingRepository interface { Delete(ctx context.Context, key string) error } -// cachedMinVersion 缓存最低 Claude Code 版本号(进程内缓存,60s TTL) -type cachedMinVersion struct { - value string // 空字符串 = 不检查 +// cachedVersionBounds 缓存 Claude Code 版本号上下限(进程内缓存,60s TTL) +type cachedVersionBounds struct { + min string // 空字符串 = 不检查 + max string // 空字符串 = 不检查 expiresAt int64 // unix nano } -// minVersionCache 最低版本号进程内缓存 -var minVersionCache atomic.Value // *cachedMinVersion +// versionBoundsCache 版本号上下限进程内缓存 +var versionBoundsCache atomic.Value // *cachedVersionBounds -// minVersionSF 防止缓存过期时 thundering herd -var minVersionSF singleflight.Group +// versionBoundsSF 防止缓存过期时 thundering herd +var versionBoundsSF singleflight.Group -// minVersionCacheTTL 缓存有效期 -const minVersionCacheTTL = 60 * time.Second +// versionBoundsCacheTTL 缓存有效期 +const versionBoundsCacheTTL = 60 * time.Second -// minVersionErrorTTL DB 错误时的短缓存,快速重试 -const minVersionErrorTTL = 5 * time.Second +// versionBoundsErrorTTL DB 错误时的短缓存,快速重试 +const versionBoundsErrorTTL = 5 * time.Second -// minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context -const minVersionDBTimeout = 5 * time.Second +// versionBoundsDBTimeout singleflight 内 DB 查询超时,独立于请求 context +const versionBoundsDBTimeout = 5 * time.Second // cachedBackendMode Backend Mode cache (in-process, 60s TTL) type cachedBackendMode struct { @@ -78,6 +79,20 @@ const backendModeCacheTTL = 60 * time.Second const backendModeErrorTTL = 5 * time.Second const backendModeDBTimeout = 5 * time.Second +// cachedGatewayForwardingSettings 缓存网关转发行为设置(进程内缓存,60s TTL) +type cachedGatewayForwardingSettings struct { + fingerprintUnification bool + metadataPassthrough bool + expiresAt int64 // unix nano +} + +var gatewayForwardingCache atomic.Value // *cachedGatewayForwardingSettings +var gatewayForwardingSF singleflight.Group + +const gatewayForwardingCacheTTL = 60 * time.Second +const gatewayForwardingErrorTTL = 5 * time.Second +const gatewayForwardingDBTimeout = 5 * time.Second + // DefaultSubscriptionGroupReader validates group references used by default subscriptions. type DefaultSubscriptionGroupReader interface { GetByID(ctx context.Context, id int64) (*Group, error) @@ -149,6 +164,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyPurchaseSubscriptionURL, SettingKeySoraClientEnabled, SettingKeyCustomMenuItems, + SettingKeyCustomEndpoints, SettingKeyLinuxDoConnectEnabled, SettingKeyBackendModeEnabled, } @@ -194,6 +210,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", CustomMenuItems: settings[SettingKeyCustomMenuItems], + CustomEndpoints: settings[SettingKeyCustomEndpoints], LinuxDoOAuthEnabled: linuxDoEnabled, BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", }, nil @@ -246,6 +263,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems json.RawMessage `json:"custom_menu_items"` + CustomEndpoints json.RawMessage `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"` Version string `json:"version,omitempty"` @@ -271,6 +289,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, SoraClientEnabled: settings.SoraClientEnabled, CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), + CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, BackendModeEnabled: settings.BackendModeEnabled, Version: s.version, @@ -313,6 +332,18 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage { return result } +// safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]". +func safeRawJSONArray(raw string) json.RawMessage { + raw = strings.TrimSpace(raw) + if raw == "" { + return json.RawMessage("[]") + } + if json.Valid([]byte(raw)) { + return json.RawMessage(raw) + } + return json.RawMessage("[]") +} + // GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url // and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection. func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) { @@ -453,6 +484,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled) updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems + updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) @@ -484,6 +516,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet // Claude Code version check updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion + updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion // 分组隔离 updates[SettingKeyAllowUngroupedKeyScheduling] = strconv.FormatBool(settings.AllowUngroupedKeyScheduling) @@ -491,19 +524,30 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet // Backend Mode updates[SettingKeyBackendModeEnabled] = strconv.FormatBool(settings.BackendModeEnabled) + // Gateway forwarding behavior + updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification) + updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough) + err = s.settingRepo.SetMultiple(ctx, updates) if err == nil { // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 - minVersionSF.Forget("min_version") - minVersionCache.Store(&cachedMinVersion{ - value: settings.MinClaudeCodeVersion, - expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(), + versionBoundsSF.Forget("version_bounds") + versionBoundsCache.Store(&cachedVersionBounds{ + min: settings.MinClaudeCodeVersion, + max: settings.MaxClaudeCodeVersion, + expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(), }) backendModeSF.Forget("backend_mode") backendModeCache.Store(&cachedBackendMode{ value: settings.BackendModeEnabled, expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), }) + gatewayForwardingSF.Forget("gateway_forwarding") + gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ + fingerprintUnification: settings.EnableFingerprintUnification, + metadataPassthrough: settings.EnableMetadataPassthrough, + expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), + }) if s.onUpdate != nil { s.onUpdate() // Invalidate cache after settings update } @@ -606,6 +650,57 @@ func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool { return false } +// GetGatewayForwardingSettings returns cached gateway forwarding settings. +// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path. +// Returns (fingerprintUnification, metadataPassthrough). +func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough bool) { + if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.fingerprintUnification, cached.metadataPassthrough + } + } + type gwfResult struct { + fp, mp bool + } + val, _, _ := gatewayForwardingSF.Do("gateway_forwarding", func() (any, error) { + if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return gwfResult{cached.fingerprintUnification, cached.metadataPassthrough}, nil + } + } + dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), gatewayForwardingDBTimeout) + defer cancel() + values, err := s.settingRepo.GetMultiple(dbCtx, []string{ + SettingKeyEnableFingerprintUnification, + SettingKeyEnableMetadataPassthrough, + }) + if err != nil { + slog.Warn("failed to get gateway forwarding settings", "error", err) + gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ + fingerprintUnification: true, + metadataPassthrough: false, + expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(), + }) + return gwfResult{true, false}, nil + } + fp := true + if v, ok := values[SettingKeyEnableFingerprintUnification]; ok && v != "" { + fp = v == "true" + } + mp := values[SettingKeyEnableMetadataPassthrough] == "true" + gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ + fingerprintUnification: fp, + metadataPassthrough: mp, + expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), + }) + return gwfResult{fp, mp}, nil + }) + if r, ok := val.(gwfResult); ok { + return r.fp, r.mp + } + return true, false // fail-open defaults +} + // IsEmailVerifyEnabled 检查是否开启邮件验证 func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled) @@ -737,6 +832,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyPurchaseSubscriptionURL: "", SettingKeySoraClientEnabled: "false", SettingKeyCustomMenuItems: "[]", + SettingKeyCustomEndpoints: "[]", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyDefaultSubscriptions: "[]", @@ -760,6 +856,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // Claude Code version check (default: empty = disabled) SettingKeyMinClaudeCodeVersion: "", + SettingKeyMaxClaudeCodeVersion: "", // 分组隔离(默认不允许未分组 Key 调度) SettingKeyAllowUngroupedKeyScheduling: "false", @@ -801,6 +898,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", CustomMenuItems: settings[SettingKeyCustomMenuItems], + CustomEndpoints: settings[SettingKeyCustomEndpoints], BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", } @@ -895,10 +993,19 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin // Claude Code version check result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion] + result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion] // 分组隔离 result.AllowUngroupedKeyScheduling = settings[SettingKeyAllowUngroupedKeyScheduling] == "true" + // Gateway forwarding behavior (defaults: fingerprint=true, metadata_passthrough=false) + if v, ok := settings[SettingKeyEnableFingerprintUnification]; ok && v != "" { + result.EnableFingerprintUnification = v == "true" + } else { + result.EnableFingerprintUnification = true // default: enabled (current behavior) + } + result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true" + return result } @@ -1172,6 +1279,57 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf return effective, nil } +// GetOverloadCooldownSettings 获取529过载冷却配置 +func (s *SettingService) GetOverloadCooldownSettings(ctx context.Context) (*OverloadCooldownSettings, error) { + value, err := s.settingRepo.GetValue(ctx, SettingKeyOverloadCooldownSettings) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + return DefaultOverloadCooldownSettings(), nil + } + return nil, fmt.Errorf("get overload cooldown settings: %w", err) + } + if value == "" { + return DefaultOverloadCooldownSettings(), nil + } + + var settings OverloadCooldownSettings + if err := json.Unmarshal([]byte(value), &settings); err != nil { + return DefaultOverloadCooldownSettings(), nil + } + + // 修正配置值范围 + if settings.CooldownMinutes < 1 { + settings.CooldownMinutes = 1 + } + if settings.CooldownMinutes > 120 { + settings.CooldownMinutes = 120 + } + + return &settings, nil +} + +// SetOverloadCooldownSettings 设置529过载冷却配置 +func (s *SettingService) SetOverloadCooldownSettings(ctx context.Context, settings *OverloadCooldownSettings) error { + if settings == nil { + return fmt.Errorf("settings cannot be nil") + } + + // 禁用时修正为合法值即可,不拒绝请求 + if settings.CooldownMinutes < 1 || settings.CooldownMinutes > 120 { + if settings.Enabled { + return fmt.Errorf("cooldown_minutes must be between 1-120") + } + settings.CooldownMinutes = 10 // 禁用状态下归一化为默认值 + } + + data, err := json.Marshal(settings) + if err != nil { + return fmt.Errorf("marshal overload cooldown settings: %w", err) + } + + return s.settingRepo.Set(ctx, SettingKeyOverloadCooldownSettings, string(data)) +} + // GetStreamTimeoutSettings 获取流超时处理配置 func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamTimeoutSettings, error) { value, err := s.settingRepo.GetValue(ctx, SettingKeyStreamTimeoutSettings) @@ -1230,51 +1388,61 @@ func (s *SettingService) IsUngroupedKeySchedulingAllowed(ctx context.Context) bo return value == "true" } -// GetMinClaudeCodeVersion 获取最低 Claude Code 版本号要求 +// GetClaudeCodeVersionBounds 获取 Claude Code 版本号上下限要求 // 使用进程内 atomic.Value 缓存,60 秒 TTL,热路径零锁开销 // singleflight 防止缓存过期时 thundering herd -// 返回空字符串表示不做版本检查 -func (s *SettingService) GetMinClaudeCodeVersion(ctx context.Context) string { - if cached, ok := minVersionCache.Load().(*cachedMinVersion); ok { +// 返回空字符串表示不做对应方向的版本检查 +func (s *SettingService) GetClaudeCodeVersionBounds(ctx context.Context) (min, max string) { + if cached, ok := versionBoundsCache.Load().(*cachedVersionBounds); ok { if time.Now().UnixNano() < cached.expiresAt { - return cached.value + return cached.min, cached.max } } // singleflight: 同一时刻只有一个 goroutine 查询 DB,其余复用结果 - result, err, _ := minVersionSF.Do("min_version", func() (any, error) { + type bounds struct{ min, max string } + result, err, _ := versionBoundsSF.Do("version_bounds", func() (any, error) { // 二次检查,避免排队的 goroutine 重复查询 - if cached, ok := minVersionCache.Load().(*cachedMinVersion); ok { + if cached, ok := versionBoundsCache.Load().(*cachedVersionBounds); ok { if time.Now().UnixNano() < cached.expiresAt { - return cached.value, nil + return bounds{cached.min, cached.max}, nil } } // 使用独立 context:断开请求取消链,避免客户端断连导致空值被长期缓存 - dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), minVersionDBTimeout) + dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), versionBoundsDBTimeout) defer cancel() - value, err := s.settingRepo.GetValue(dbCtx, SettingKeyMinClaudeCodeVersion) + values, err := s.settingRepo.GetMultiple(dbCtx, []string{ + SettingKeyMinClaudeCodeVersion, + SettingKeyMaxClaudeCodeVersion, + }) if err != nil { // fail-open: DB 错误时不阻塞请求,但记录日志并使用短 TTL 快速重试 - slog.Warn("failed to get min claude code version setting, skipping version check", "error", err) - minVersionCache.Store(&cachedMinVersion{ - value: "", - expiresAt: time.Now().Add(minVersionErrorTTL).UnixNano(), + slog.Warn("failed to get claude code version bounds setting, skipping version check", "error", err) + versionBoundsCache.Store(&cachedVersionBounds{ + min: "", + max: "", + expiresAt: time.Now().Add(versionBoundsErrorTTL).UnixNano(), }) - return "", nil + return bounds{"", ""}, nil } - minVersionCache.Store(&cachedMinVersion{ - value: value, - expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(), + b := bounds{ + min: values[SettingKeyMinClaudeCodeVersion], + max: values[SettingKeyMaxClaudeCodeVersion], + } + versionBoundsCache.Store(&cachedVersionBounds{ + min: b.min, + max: b.max, + expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(), }) - return value, nil + return b, nil }) if err != nil { - return "" + return "", "" } - ver, ok := result.(string) + b, ok := result.(bounds) if !ok { - return "" + return "", "" } - return ver + return b.min, b.max } // GetRectifierSettings 获取请求整流器配置 diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 71c2e7aa..411939bb 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -43,6 +43,7 @@ type SystemSettings struct { PurchaseSubscriptionURL string SoraClientEnabled bool CustomMenuItems string // JSON array of custom menu items + CustomEndpoints string // JSON array of custom endpoints DefaultConcurrency int DefaultBalance float64 @@ -67,12 +68,17 @@ type SystemSettings struct { // Claude Code version check MinClaudeCodeVersion string + MaxClaudeCodeVersion string // 分组隔离:允许未分组 Key 调度(默认 false → 403) AllowUngroupedKeyScheduling bool // Backend 模式:禁用用户注册和自助服务,仅管理员可登录 BackendModeEnabled bool + + // Gateway forwarding behavior + EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true) + EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false) } type DefaultSubscriptionSetting struct { @@ -103,6 +109,7 @@ type PublicSettings struct { PurchaseSubscriptionURL string SoraClientEnabled bool CustomMenuItems string // JSON array of custom menu items + CustomEndpoints string // JSON array of custom endpoints LinuxDoOAuthEnabled bool BackendModeEnabled bool @@ -183,9 +190,11 @@ func DefaultStreamTimeoutSettings() *StreamTimeoutSettings { // RectifierSettings 请求整流器配置 type RectifierSettings struct { - Enabled bool `json:"enabled"` // 总开关 - ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"` // Thinking 签名整流 - ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"` // Thinking Budget 整流 + Enabled bool `json:"enabled"` // 总开关 + ThinkingSignatureEnabled bool `json:"thinking_signature_enabled"` // Thinking 签名整流 + ThinkingBudgetEnabled bool `json:"thinking_budget_enabled"` // Thinking Budget 整流 + APIKeySignatureEnabled bool `json:"apikey_signature_enabled"` // API Key 签名整流开关 + APIKeySignaturePatterns []string `json:"apikey_signature_patterns"` // API Key 自定义匹配关键词 } // DefaultRectifierSettings 返回默认的整流器配置(全部启用) @@ -222,6 +231,22 @@ type BetaPolicySettings struct { Rules []BetaPolicyRule `json:"rules"` } +// OverloadCooldownSettings 529过载冷却配置 +type OverloadCooldownSettings struct { + // Enabled 是否在收到529时暂停账号调度 + Enabled bool `json:"enabled"` + // CooldownMinutes 冷却时长(分钟) + CooldownMinutes int `json:"cooldown_minutes"` +} + +// DefaultOverloadCooldownSettings 返回默认的过载冷却配置(启用,10分钟) +func DefaultOverloadCooldownSettings() *OverloadCooldownSettings { + return &OverloadCooldownSettings{ + Enabled: true, + CooldownMinutes: 10, + } +} + // DefaultBetaPolicySettings 返回默认的 Beta 策略配置 func DefaultBetaPolicySettings() *BetaPolicySettings { return &BetaPolicySettings{ diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index ab6871bb..e9d325f4 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -148,10 +148,13 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream) return nil, errors.New("model is required") } + originalModel := reqModel mappedModel := account.GetMappedModel(reqModel) + var upstreamModel string if mappedModel != "" && mappedModel != reqModel { reqModel = mappedModel + upstreamModel = mappedModel } modelCfg, ok := GetSoraModelConfig(reqModel) @@ -213,13 +216,14 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel)) } return &ForwardResult{ - RequestID: "", - Model: reqModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: "prompt", + RequestID: "", + Model: originalModel, + UpstreamModel: upstreamModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", }, nil } @@ -269,13 +273,14 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun c.JSON(http.StatusOK, resp) } return &ForwardResult{ - RequestID: "", - Model: reqModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: "prompt", + RequestID: "", + Model: originalModel, + UpstreamModel: upstreamModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", }, nil } if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { @@ -419,16 +424,17 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun } return &ForwardResult{ - RequestID: taskID, - Model: reqModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: mediaType, - MediaURL: firstMediaURL(finalURLs), - ImageCount: imageCount, - ImageSize: imageSize, + RequestID: taskID, + Model: originalModel, + UpstreamModel: upstreamModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: mediaType, + MediaURL: firstMediaURL(finalURLs), + ImageCount: imageCount, + ImageSize: imageSize, }, nil } diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go index 206636ff..2fef600c 100644 --- a/backend/internal/service/sora_gateway_service_test.go +++ b/backend/internal/service/sora_gateway_service_test.go @@ -144,6 +144,11 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { ID: 1, Platform: PlatformSora, Status: StatusActive, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "prompt-enhance-short-10s": "prompt-enhance-short-15s", + }, + }, } body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`) @@ -152,6 +157,7 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { require.NotNil(t, result) require.Equal(t, "prompt", result.MediaType) require.Equal(t, "prompt-enhance-short-10s", result.Model) + require.Equal(t, "prompt-enhance-short-15s", result.UpstreamModel) } func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) { diff --git a/backend/internal/service/sora_generation_service_test.go b/backend/internal/service/sora_generation_service_test.go index 46f322c8..6f33ff39 100644 --- a/backend/internal/service/sora_generation_service_test.go +++ b/backend/internal/service/sora_generation_service_test.go @@ -162,6 +162,9 @@ func (r *stubUserRepoForQuota) ExistsByEmail(context.Context, string) (bool, err func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { return 0, nil } +func (r *stubUserRepoForQuota) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + return nil +} func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil } func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil } func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil } diff --git a/backend/internal/service/sora_quota_service_test.go b/backend/internal/service/sora_quota_service_test.go index 040e427d..da8efe77 100644 --- a/backend/internal/service/sora_quota_service_test.go +++ b/backend/internal/service/sora_quota_service_test.go @@ -52,8 +52,8 @@ func (r *stubGroupRepoForQuota) ListActiveByPlatform(context.Context, string) ([ func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) { return false, nil } -func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, error) { - return 0, nil +func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, int64, error) { + return 0, 0, nil } func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { return 0, nil diff --git a/backend/internal/service/sora_sdk_client.go b/backend/internal/service/sora_sdk_client.go index f9221c5b..6243f867 100644 --- a/backend/internal/service/sora_sdk_client.go +++ b/backend/internal/service/sora_sdk_client.go @@ -947,7 +947,7 @@ func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Accoun } if c.accountRepo != nil { - if err := c.accountRepo.Update(ctx, account); err != nil && c.debugEnabled() { + if err := persistAccountCredentials(ctx, c.accountRepo, account, account.Credentials); err != nil && c.debugEnabled() { c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) } } diff --git a/backend/internal/service/subscription_assign_idempotency_test.go b/backend/internal/service/subscription_assign_idempotency_test.go index 0defafba..40bab206 100644 --- a/backend/internal/service/subscription_assign_idempotency_test.go +++ b/backend/internal/service/subscription_assign_idempotency_test.go @@ -40,7 +40,7 @@ func (groupRepoNoop) ListActiveByPlatform(context.Context, string) ([]Group, err func (groupRepoNoop) ExistsByName(context.Context, string) (bool, error) { panic("unexpected ExistsByName call") } -func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, error) { +func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, int64, error) { panic("unexpected GetAccountCount call") } func (groupRepoNoop) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { @@ -92,7 +92,7 @@ func (userSubRepoNoop) ListActiveByUserID(context.Context, int64) ([]UserSubscri func (userSubRepoNoop) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) { panic("unexpected ListByGroupID call") } -func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) { +func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) { panic("unexpected List call") } func (userSubRepoNoop) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) { diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index af548509..f0a5540e 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -634,9 +634,9 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI } // List 获取所有订阅(分页,支持筛选和排序) -func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) { +func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} - subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, sortBy, sortOrder) + subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, platform, sortBy, sortOrder) if err != nil { return nil, nil, err } diff --git a/backend/internal/service/tls_fingerprint_profile_service.go b/backend/internal/service/tls_fingerprint_profile_service.go new file mode 100644 index 00000000..33937cc7 --- /dev/null +++ b/backend/internal/service/tls_fingerprint_profile_service.go @@ -0,0 +1,259 @@ +package service + +import ( + "context" + "math/rand/v2" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" +) + +// TLSFingerprintProfileRepository 定义 TLS 指纹模板的数据访问接口 +type TLSFingerprintProfileRepository interface { + List(ctx context.Context) ([]*model.TLSFingerprintProfile, error) + GetByID(ctx context.Context, id int64) (*model.TLSFingerprintProfile, error) + Create(ctx context.Context, profile *model.TLSFingerprintProfile) (*model.TLSFingerprintProfile, error) + Update(ctx context.Context, profile *model.TLSFingerprintProfile) (*model.TLSFingerprintProfile, error) + Delete(ctx context.Context, id int64) error +} + +// TLSFingerprintProfileCache 定义 TLS 指纹模板的缓存接口 +type TLSFingerprintProfileCache interface { + Get(ctx context.Context) ([]*model.TLSFingerprintProfile, bool) + Set(ctx context.Context, profiles []*model.TLSFingerprintProfile) error + Invalidate(ctx context.Context) error + NotifyUpdate(ctx context.Context) error + SubscribeUpdates(ctx context.Context, handler func()) +} + +// TLSFingerprintProfileService TLS 指纹模板管理服务 +type TLSFingerprintProfileService struct { + repo TLSFingerprintProfileRepository + cache TLSFingerprintProfileCache + + // 本地 ID→Profile 映射缓存,用于 DoWithTLS 热路径快速查找 + localCache map[int64]*model.TLSFingerprintProfile + localMu sync.RWMutex +} + +// NewTLSFingerprintProfileService 创建 TLS 指纹模板服务 +func NewTLSFingerprintProfileService( + repo TLSFingerprintProfileRepository, + cache TLSFingerprintProfileCache, +) *TLSFingerprintProfileService { + svc := &TLSFingerprintProfileService{ + repo: repo, + cache: cache, + localCache: make(map[int64]*model.TLSFingerprintProfile), + } + + ctx := context.Background() + if err := svc.reloadFromDB(ctx); err != nil { + logger.LegacyPrintf("service.tls_fp_profile", "[TLSFPProfileService] Failed to load profiles from DB on startup: %v", err) + if fallbackErr := svc.refreshLocalCache(ctx); fallbackErr != nil { + logger.LegacyPrintf("service.tls_fp_profile", "[TLSFPProfileService] Failed to load profiles from cache fallback on startup: %v", fallbackErr) + } + } + + if cache != nil { + cache.SubscribeUpdates(ctx, func() { + if err := svc.refreshLocalCache(context.Background()); err != nil { + logger.LegacyPrintf("service.tls_fp_profile", "[TLSFPProfileService] Failed to refresh cache on notification: %v", err) + } + }) + } + + return svc +} + +// --- CRUD --- + +// List 获取所有模板 +func (s *TLSFingerprintProfileService) List(ctx context.Context) ([]*model.TLSFingerprintProfile, error) { + return s.repo.List(ctx) +} + +// GetByID 根据 ID 获取模板 +func (s *TLSFingerprintProfileService) GetByID(ctx context.Context, id int64) (*model.TLSFingerprintProfile, error) { + return s.repo.GetByID(ctx, id) +} + +// Create 创建模板 +func (s *TLSFingerprintProfileService) Create(ctx context.Context, profile *model.TLSFingerprintProfile) (*model.TLSFingerprintProfile, error) { + if err := profile.Validate(); err != nil { + return nil, err + } + + created, err := s.repo.Create(ctx, profile) + if err != nil { + return nil, err + } + + refreshCtx, cancel := s.newCacheRefreshContext() + defer cancel() + s.invalidateAndNotify(refreshCtx) + + return created, nil +} + +// Update 更新模板 +func (s *TLSFingerprintProfileService) Update(ctx context.Context, profile *model.TLSFingerprintProfile) (*model.TLSFingerprintProfile, error) { + if err := profile.Validate(); err != nil { + return nil, err + } + + updated, err := s.repo.Update(ctx, profile) + if err != nil { + return nil, err + } + + refreshCtx, cancel := s.newCacheRefreshContext() + defer cancel() + s.invalidateAndNotify(refreshCtx) + + return updated, nil +} + +// Delete 删除模板 +func (s *TLSFingerprintProfileService) Delete(ctx context.Context, id int64) error { + if err := s.repo.Delete(ctx, id); err != nil { + return err + } + + refreshCtx, cancel := s.newCacheRefreshContext() + defer cancel() + s.invalidateAndNotify(refreshCtx) + + return nil +} + +// --- 热路径:运行时 Profile 查找 --- + +// GetProfileByID 根据 ID 从本地缓存获取 Profile(用于 DoWithTLS 热路径) +// 返回 nil 表示未找到,调用方应 fallback 到内置默认 Profile +func (s *TLSFingerprintProfileService) GetProfileByID(id int64) *tlsfingerprint.Profile { + s.localMu.RLock() + p, ok := s.localCache[id] + s.localMu.RUnlock() + + if ok && p != nil { + return p.ToTLSProfile() + } + return nil +} + +// getRandomProfile 从本地缓存中随机选择一个 Profile +func (s *TLSFingerprintProfileService) getRandomProfile() *tlsfingerprint.Profile { + s.localMu.RLock() + defer s.localMu.RUnlock() + + if len(s.localCache) == 0 { + return nil + } + + // 收集所有 profile + profiles := make([]*model.TLSFingerprintProfile, 0, len(s.localCache)) + for _, p := range s.localCache { + if p != nil { + profiles = append(profiles, p) + } + } + if len(profiles) == 0 { + return nil + } + + return profiles[rand.IntN(len(profiles))].ToTLSProfile() +} + +// ResolveTLSProfile 根据 Account 的配置解析出运行时 TLS Profile +// +// 逻辑: +// 1. 未启用 TLS 指纹 → 返回 nil(不伪装) +// 2. 启用 + 绑定了 profile_id → 从缓存查找对应 profile +// 3. 启用 + 未绑定或找不到 → 返回空 Profile(使用代码内置默认值) +func (s *TLSFingerprintProfileService) ResolveTLSProfile(account *Account) *tlsfingerprint.Profile { + if account == nil || !account.IsTLSFingerprintEnabled() { + return nil + } + id := account.GetTLSFingerprintProfileID() + if id > 0 { + if p := s.GetProfileByID(id); p != nil { + return p + } + } + if id == -1 { + // 随机选择一个 profile + if p := s.getRandomProfile(); p != nil { + return p + } + } + // TLS 启用但无绑定 profile → 空 Profile → dialer 使用内置默认值 + return &tlsfingerprint.Profile{Name: "Built-in Default (Node.js 24.x)"} +} + +// --- 缓存管理 --- + +func (s *TLSFingerprintProfileService) refreshLocalCache(ctx context.Context) error { + if s.cache != nil { + if profiles, ok := s.cache.Get(ctx); ok { + s.setLocalCache(profiles) + return nil + } + } + return s.reloadFromDB(ctx) +} + +func (s *TLSFingerprintProfileService) reloadFromDB(ctx context.Context) error { + profiles, err := s.repo.List(ctx) + if err != nil { + return err + } + + if s.cache != nil { + if err := s.cache.Set(ctx, profiles); err != nil { + logger.LegacyPrintf("service.tls_fp_profile", "[TLSFPProfileService] Failed to set cache: %v", err) + } + } + + s.setLocalCache(profiles) + return nil +} + +func (s *TLSFingerprintProfileService) setLocalCache(profiles []*model.TLSFingerprintProfile) { + m := make(map[int64]*model.TLSFingerprintProfile, len(profiles)) + for _, p := range profiles { + m[p.ID] = p + } + + s.localMu.Lock() + s.localCache = m + s.localMu.Unlock() +} + +func (s *TLSFingerprintProfileService) newCacheRefreshContext() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 3*time.Second) +} + +func (s *TLSFingerprintProfileService) invalidateAndNotify(ctx context.Context) { + if s.cache != nil { + if err := s.cache.Invalidate(ctx); err != nil { + logger.LegacyPrintf("service.tls_fp_profile", "[TLSFPProfileService] Failed to invalidate cache: %v", err) + } + } + + if err := s.reloadFromDB(ctx); err != nil { + logger.LegacyPrintf("service.tls_fp_profile", "[TLSFPProfileService] Failed to refresh local cache: %v", err) + s.localMu.Lock() + s.localCache = make(map[int64]*model.TLSFingerprintProfile) + s.localMu.Unlock() + } + + if s.cache != nil { + if err := s.cache.NotifyUpdate(ctx); err != nil { + logger.LegacyPrintf("service.tls_fp_profile", "[TLSFPProfileService] Failed to notify cache update: %v", err) + } + } +} diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index cb8841b0..fb2b5210 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -12,6 +12,9 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" ) +// tokenRefreshTempUnschedDuration token 刷新重试耗尽后临时不可调度的持续时间 +const tokenRefreshTempUnschedDuration = 10 * time.Minute + // TokenRefreshService OAuth token自动刷新服务 // 定期检查并刷新即将过期的token type TokenRefreshService struct { @@ -29,8 +32,9 @@ type TokenRefreshService struct { privacyClientFactory PrivacyClientFactory proxyRepo ProxyRepository - stopCh chan struct{} - wg sync.WaitGroup + stopCh chan struct{} + stopOnce sync.Once + wg sync.WaitGroup } // NewTokenRefreshService 创建token刷新服务 @@ -125,9 +129,11 @@ func (s *TokenRefreshService) Start() { ) } -// Stop 停止刷新服务 +// Stop 停止刷新服务(可安全多次调用) func (s *TokenRefreshService) Stop() { - close(s.stopCh) + s.stopOnce.Do(func() { + close(s.stopCh) + }) s.wg.Wait() slog.Info("token_refresh.service_stopped") } @@ -277,8 +283,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc newCredentials, err = refresher.Refresh(ctx, account) if newCredentials != nil { newCredentials["_token_version"] = time.Now().UnixMilli() - account.Credentials = newCredentials - if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil { + if saveErr := persistAccountCredentials(ctx, s.accountRepo, account, newCredentials); saveErr != nil { return fmt.Errorf("failed to save credentials: %w", saveErr) } } @@ -298,6 +303,9 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc "error", setErr, ) } + // 刷新失败但 access_token 可能仍有效,尝试设置隐私 + s.ensureOpenAIPrivacy(ctx, account) + s.ensureAntigravityPrivacy(ctx, account) return err } @@ -317,7 +325,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc } } - // 可重试错误耗尽:仅记录日志,不标记 error(可能是临时网络问题,下个周期继续重试) + // 可重试错误耗尽:临时标记账号不可调度,避免请求路径反复命中已知失败的账号 slog.Warn("token_refresh.retry_exhausted", "account_id", account.ID, "platform", account.Platform, @@ -325,6 +333,25 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc "error", lastErr, ) + // 刷新失败但 access_token 可能仍有效,尝试设置隐私 + s.ensureOpenAIPrivacy(ctx, account) + s.ensureAntigravityPrivacy(ctx, account) + + // 设置临时不可调度 10 分钟(不标记 error,保持 status=active 让下个刷新周期能继续尝试) + until := time.Now().Add(tokenRefreshTempUnschedDuration) + reason := fmt.Sprintf("token refresh retry exhausted: %v", lastErr) + if setErr := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); setErr != nil { + slog.Warn("token_refresh.set_temp_unschedulable_failed", + "account_id", account.ID, + "error", setErr, + ) + } else { + slog.Info("token_refresh.temp_unschedulable_set", + "account_id", account.ID, + "until", until.Format(time.RFC3339), + ) + } + return lastErr } @@ -387,6 +414,8 @@ func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *A } // OpenAI OAuth: 刷新成功后,检查是否已设置 privacy_mode,未设置则尝试关闭训练数据共享 s.ensureOpenAIPrivacy(ctx, account) + // Antigravity OAuth: 刷新成功后,检查是否已设置 privacy_mode,未设置则调用 setUserSettings + s.ensureAntigravityPrivacy(ctx, account) } // errRefreshSkipped 表示刷新被跳过(锁竞争或已被其他路径刷新),不计入 failed 或 refreshed @@ -406,6 +435,7 @@ func isNonRetryableRefreshError(err error) bool { "unauthorized_client", // 客户端未授权 "access_denied", // 访问被拒绝 "missing_project_id", // 缺少 project_id + "no refresh token available", } for _, needle := range nonRetryable { if strings.Contains(msg, needle) { @@ -424,11 +454,8 @@ func (s *TokenRefreshService) ensureOpenAIPrivacy(ctx context.Context, account * if s.privacyClientFactory == nil { return } - // 已设置过则跳过 - if account.Extra != nil { - if _, ok := account.Extra["privacy_mode"]; ok { - return - } + if shouldSkipOpenAIPrivacyEnsure(account.Extra) { + return } token, _ := account.Credentials["access_token"].(string) @@ -460,3 +487,50 @@ func (s *TokenRefreshService) ensureOpenAIPrivacy(ctx context.Context, account * ) } } + +// ensureAntigravityPrivacy 后台刷新中检查 Antigravity OAuth 账号隐私状态。 +// 仅做 Extra["privacy_mode"] 存在性检查,不发起 HTTP 请求,避免每轮循环产生额外网络开销。 +// 用户可通过前端 SetPrivacy 按钮强制重新设置。 +func (s *TokenRefreshService) ensureAntigravityPrivacy(ctx context.Context, account *Account) { + if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth { + return + } + // 已设置过(无论成功或失败)则跳过,不发 HTTP + if account.Extra != nil { + if _, ok := account.Extra["privacy_mode"]; ok { + return + } + } + + token, _ := account.Credentials["access_token"].(string) + if token == "" { + return + } + + projectID, _ := account.Credentials["project_id"].(string) + + var proxyURL string + if account.ProxyID != nil && s.proxyRepo != nil { + if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil { + proxyURL = p.URL() + } + } + + mode := setAntigravityPrivacy(ctx, token, projectID, proxyURL) + if mode == "" { + return + } + + if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode}); err != nil { + slog.Warn("token_refresh.update_antigravity_privacy_mode_failed", + "account_id", account.ID, + "error", err, + ) + } else { + applyAntigravityPrivacyMode(account, mode) + slog.Info("token_refresh.antigravity_privacy_mode_set", + "account_id", account.ID, + "privacy_mode", mode, + ) + } +} diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go index f48de65e..2179a85e 100644 --- a/backend/internal/service/token_refresh_service_test.go +++ b/backend/internal/service/token_refresh_service_test.go @@ -14,19 +14,41 @@ import ( type tokenRefreshAccountRepo struct { mockAccountRepoForGemini - updateCalls int - setErrorCalls int - clearTempCalls int - lastAccount *Account - updateErr error + updateCalls int + fullUpdateCalls int + updateCredentialsCalls int + setErrorCalls int + clearTempCalls int + setTempUnschedCalls int + lastAccount *Account + updateErr error } func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error { r.updateCalls++ + r.fullUpdateCalls++ r.lastAccount = account return r.updateErr } +func (r *tokenRefreshAccountRepo) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error { + r.updateCalls++ + r.updateCredentialsCalls++ + if r.updateErr != nil { + return r.updateErr + } + cloned := cloneCredentials(credentials) + if r.accountsByID != nil { + if acc, ok := r.accountsByID[id]; ok && acc != nil { + acc.Credentials = cloned + r.lastAccount = acc + return nil + } + } + r.lastAccount = &Account{ID: id, Credentials: cloned} + return nil +} + func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { r.setErrorCalls++ return nil @@ -37,6 +59,11 @@ func (r *tokenRefreshAccountRepo) ClearTempUnschedulable(ctx context.Context, id return nil } +func (r *tokenRefreshAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + r.setTempUnschedCalls++ + return nil +} + type tokenCacheInvalidatorStub struct { calls int err error @@ -112,6 +139,8 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) + require.Equal(t, 1, repo.updateCredentialsCalls) + require.Equal(t, 0, repo.fullUpdateCalls) require.Equal(t, 1, invalidator.calls) require.Equal(t, "new-token", account.GetCredential("access_token")) } @@ -249,9 +278,43 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) + require.Equal(t, 1, repo.updateCredentialsCalls) require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效 } +func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil) + resetAt := time.Now().Add(30 * time.Minute) + account := &Account{ + ID: 17, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + RateLimitResetAt: &resetAt, + Credentials: map[string]any{ + "access_token": "old-token", + }, + } + refresher := &tokenRefresherStub{ + credentials: map[string]any{ + "access_token": "new-token", + }, + } + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.NoError(t, err) + require.Equal(t, 1, repo.updateCredentialsCalls) + require.Equal(t, 0, repo.fullUpdateCalls) + require.NotNil(t, account.RateLimitResetAt) + require.WithinDuration(t, resetAt, *account.RateLimitResetAt, time.Second) +} + // TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况 func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) { repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")} @@ -390,7 +453,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) - require.Equal(t, 1, repo.clearTempCalls) // DB 清除 + require.Equal(t, 1, repo.clearTempCalls) // DB 清除 require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除 } @@ -433,6 +496,31 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t } } +func TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedule(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 2, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil) + account := &Account{ + ID: 18, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + refresher := &tokenRefresherStub{ + err: errors.New("no refresh token available"), + } + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.Error(t, err) + require.Equal(t, 0, repo.updateCalls) + require.Equal(t, 0, repo.setTempUnschedCalls, "missing refresh token should not mark the account temp unschedulable") + require.Equal(t, 1, repo.setErrorCalls, "missing refresh token should be treated as a non-retryable credential state") +} + // TestIsNonRetryableRefreshError 测试不可重试错误判断 func TestIsNonRetryableRefreshError(t *testing.T) { tests := []struct { @@ -446,6 +534,7 @@ func TestIsNonRetryableRefreshError(t *testing.T) { {name: "invalid_client", err: errors.New("invalid_client"), expected: true}, {name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true}, {name: "access_denied", err: errors.New("access_denied"), expected: true}, + {name: "no_refresh_token", err: errors.New("no refresh token available"), expected: true}, {name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true}, {name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true}, } diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index 7f1bef7f..576841fa 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -98,6 +98,12 @@ type UsageLog struct { AccountID int64 RequestID string Model string + // RequestedModel is the client-requested model name recorded for stable user/admin display. + // Empty should be treated as Model for backward compatibility with historical rows. + RequestedModel string + // UpstreamModel is the actual model sent to the upstream provider after mapping. + // Nil means no mapping was applied (requested model was used as-is). + UpstreamModel *string // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". ServiceTier *string // ReasoningEffort is the request's reasoning effort level. diff --git a/backend/internal/service/usage_log_helpers.go b/backend/internal/service/usage_log_helpers.go new file mode 100644 index 00000000..a7bcae99 --- /dev/null +++ b/backend/internal/service/usage_log_helpers.go @@ -0,0 +1,28 @@ +package service + +import "strings" + +func optionalTrimmedStringPtr(raw string) *string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return nil + } + return &trimmed +} + +// optionalNonEqualStringPtr returns a pointer to value if it is non-empty and +// differs from compare; otherwise nil. Used to store upstream_model only when +// it differs from the requested model. +func optionalNonEqualStringPtr(value, compare string) *string { + if value == "" || value == compare { + return nil + } + return &value +} + +func forwardResultBillingModel(requestedModel, upstreamModel string) string { + if trimmed := strings.TrimSpace(requestedModel); trimmed != "" { + return trimmed + } + return strings.TrimSpace(upstreamModel) +} diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 49ba3645..4045c0aa 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -21,6 +21,7 @@ type UserListFilters struct { Status string // User status filter Role string // User role filter Search string // Search in email, username + GroupName string // Filter by allowed group name (fuzzy match) Attributes map[int64]string // Custom attribute filters: attributeID -> value // IncludeSubscriptions controls whether ListWithFilters should load active subscriptions. // For large datasets this can be expensive; admin list pages should enable it on demand. @@ -46,6 +47,8 @@ type UserRepository interface { RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) // AddGroupToAllowedGroups 将指定分组增量添加到用户的 allowed_groups(幂等,冲突忽略) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error + // RemoveGroupFromUserAllowedGroups 移除单个用户的指定分组权限 + RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error // TOTP 双因素认证 UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 05fe5056..e88694f5 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -46,7 +46,10 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int return 0, nil } func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } -func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + return nil +} +func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } diff --git a/backend/internal/service/user_subscription_port.go b/backend/internal/service/user_subscription_port.go index 2dfc8d02..4484fae8 100644 --- a/backend/internal/service/user_subscription_port.go +++ b/backend/internal/service/user_subscription_port.go @@ -18,7 +18,7 @@ type UserSubscriptionRepository interface { ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) - List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) + List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 7da72630..d79a3531 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -114,11 +114,13 @@ func ProvideAntigravityTokenProvider( tokenCache GeminiTokenCache, antigravityOAuthService *AntigravityOAuthService, refreshAPI *OAuthRefreshAPI, + tempUnschedCache TempUnschedCache, ) *AntigravityTokenProvider { p := NewAntigravityTokenProvider(accountRepo, tokenCache, antigravityOAuthService) executor := NewAntigravityTokenRefresher(antigravityOAuthService) p.SetRefreshAPI(refreshAPI, executor) p.SetRefreshPolicy(AntigravityProviderRefreshPolicy()) + p.SetTempUnschedCache(tempUnschedCache) return p } @@ -480,10 +482,12 @@ var ProviderSet = wire.NewSet( NewUsageCache, NewTotpService, NewErrorPassthroughService, + NewTLSFingerprintProfileService, NewDigestSessionStore, ProvideIdempotencyCoordinator, ProvideSystemOperationLockService, ProvideIdempotencyCleanupService, ProvideScheduledTestService, ProvideScheduledTestRunnerService, + NewGroupCapacityService, ) diff --git a/backend/internal/setup/handler.go b/backend/internal/setup/handler.go index 1531c97b..c2944ced 100644 --- a/backend/internal/setup/handler.go +++ b/backend/internal/setup/handler.go @@ -247,6 +247,12 @@ func install(c *gin.Context) { return } + req.Admin.Email = strings.TrimSpace(req.Admin.Email) + req.Database.Host = strings.TrimSpace(req.Database.Host) + req.Database.User = strings.TrimSpace(req.Database.User) + req.Database.DBName = strings.TrimSpace(req.Database.DBName) + req.Redis.Host = strings.TrimSpace(req.Redis.Host) + // ========== COMPREHENSIVE INPUT VALIDATION ========== // Database validation if !validateHostname(req.Database.Host) { @@ -319,13 +325,6 @@ func install(c *gin.Context) { return } - // Trim whitespace from string inputs - req.Admin.Email = strings.TrimSpace(req.Admin.Email) - req.Database.Host = strings.TrimSpace(req.Database.Host) - req.Database.User = strings.TrimSpace(req.Database.User) - req.Database.DBName = strings.TrimSpace(req.Database.DBName) - req.Redis.Host = strings.TrimSpace(req.Redis.Host) - cfg := &SetupConfig{ Database: req.Database, Redis: req.Redis, diff --git a/backend/internal/setup/setup.go b/backend/internal/setup/setup.go index de3b765a..9256d245 100644 --- a/backend/internal/setup/setup.go +++ b/backend/internal/setup/setup.go @@ -164,8 +164,8 @@ func NeedsSetup() bool { func TestDatabaseConnection(cfg *DatabaseConfig) error { // First, connect to the default 'postgres' database to check/create target database defaultDSN := fmt.Sprintf( - "host=%s port=%d user=%s password=%s dbname=postgres sslmode=%s", - cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.SSLMode, + "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", + cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName, cfg.SSLMode, ) db, err := sql.Open("postgres", defaultDSN) diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go index 41ce4d48..ffca98a5 100644 --- a/backend/internal/web/embed_on.go +++ b/backend/internal/web/embed_on.go @@ -180,7 +180,37 @@ func (s *FrontendServer) injectSettings(settingsJSON []byte) []byte { // Inject before headClose := []byte("") - return bytes.Replace(s.baseHTML, headClose, append(script, headClose...), 1) + result := bytes.Replace(s.baseHTML, headClose, append(script, headClose...), 1) + + // Replace with custom site name so the browser tab shows it immediately + result = injectSiteTitle(result, settingsJSON) + + return result +} + +// injectSiteTitle replaces the static <title> in HTML with the configured site name. +// This ensures the browser tab shows the correct title before JS executes. +func injectSiteTitle(html, settingsJSON []byte) []byte { + var cfg struct { + SiteName string `json:"site_name"` + } + if err := json.Unmarshal(settingsJSON, &cfg); err != nil || cfg.SiteName == "" { + return html + } + + // Find and replace the existing <title>... + titleStart := bytes.Index(html, []byte("")) + titleEnd := bytes.Index(html, []byte("")) + if titleStart == -1 || titleEnd == -1 || titleEnd <= titleStart { + return html + } + + newTitle := []byte("" + cfg.SiteName + " - AI API Gateway") + var buf bytes.Buffer + buf.Write(html[:titleStart]) + buf.Write(newTitle) + buf.Write(html[titleEnd+len(""):]) + return buf.Bytes() } // replaceNoncePlaceholder replaces the nonce placeholder with actual nonce value diff --git a/backend/internal/web/embed_test.go b/backend/internal/web/embed_test.go index f270b624..fd47c4da 100644 --- a/backend/internal/web/embed_test.go +++ b/backend/internal/web/embed_test.go @@ -20,6 +20,78 @@ func init() { gin.SetMode(gin.TestMode) } +func TestInjectSiteTitle(t *testing.T) { + t.Run("replaces_title_with_site_name", func(t *testing.T) { + html := []byte(`Sub2API - AI API Gateway`) + settingsJSON := []byte(`{"site_name":"MyCustomSite"}`) + + result := injectSiteTitle(html, settingsJSON) + + assert.Contains(t, string(result), "MyCustomSite - AI API Gateway") + assert.NotContains(t, string(result), "Sub2API") + }) + + t.Run("returns_unchanged_when_site_name_empty", func(t *testing.T) { + html := []byte(`Sub2API - AI API Gateway`) + settingsJSON := []byte(`{"site_name":""}`) + + result := injectSiteTitle(html, settingsJSON) + + assert.Equal(t, string(html), string(result)) + }) + + t.Run("returns_unchanged_when_site_name_missing", func(t *testing.T) { + html := []byte(`Sub2API - AI API Gateway`) + settingsJSON := []byte(`{"other_field":"value"}`) + + result := injectSiteTitle(html, settingsJSON) + + assert.Equal(t, string(html), string(result)) + }) + + t.Run("returns_unchanged_when_invalid_json", func(t *testing.T) { + html := []byte(`Sub2API - AI API Gateway`) + settingsJSON := []byte(`{invalid json}`) + + result := injectSiteTitle(html, settingsJSON) + + assert.Equal(t, string(html), string(result)) + }) + + t.Run("returns_unchanged_when_no_title_tag", func(t *testing.T) { + html := []byte(``) + settingsJSON := []byte(`{"site_name":"MyCustomSite"}`) + + result := injectSiteTitle(html, settingsJSON) + + assert.Equal(t, string(html), string(result)) + }) + + t.Run("returns_unchanged_when_title_has_attributes", func(t *testing.T) { + // The function looks for "" literally, so attributes are not supported + // This is acceptable since index.html uses plain <title> without attributes + html := []byte(`<html><head><title lang="en">Sub2API`) + settingsJSON := []byte(`{"site_name":"NewSite"}`) + + result := injectSiteTitle(html, settingsJSON) + + // Should return unchanged since with attributes is not matched + assert.Equal(t, string(html), string(result)) + }) + + t.Run("preserves_rest_of_html", func(t *testing.T) { + html := []byte(`<html><head><meta charset="UTF-8"><title>Sub2API
`) + settingsJSON := []byte(`{"site_name":"TestSite"}`) + + result := injectSiteTitle(html, settingsJSON) + + assert.Contains(t, string(result), ``) + assert.Contains(t, string(result), ``) + assert.Contains(t, string(result), `
`) + assert.Contains(t, string(result), "TestSite - AI API Gateway") + }) +} + func TestReplaceNoncePlaceholder(t *testing.T) { t.Run("replaces_single_placeholder", func(t *testing.T) { html := []byte(``) diff --git a/backend/migrations/075_add_usage_log_upstream_model.sql b/backend/migrations/075_add_usage_log_upstream_model.sql new file mode 100644 index 00000000..7f9f8ec6 --- /dev/null +++ b/backend/migrations/075_add_usage_log_upstream_model.sql @@ -0,0 +1,4 @@ +-- Add upstream_model field to usage_logs. +-- Stores the actual upstream model name when it differs from the requested model +-- (i.e., when model mapping is applied). NULL means no mapping was applied. +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS upstream_model VARCHAR(100); diff --git a/backend/migrations/075_map_haiku45_to_sonnet46.sql b/backend/migrations/075_map_haiku45_to_sonnet46.sql new file mode 100644 index 00000000..bbaa45e6 --- /dev/null +++ b/backend/migrations/075_map_haiku45_to_sonnet46.sql @@ -0,0 +1,17 @@ +-- Map claude-haiku-4-5 variants target from claude-sonnet-4-5 to claude-sonnet-4-6 +-- +-- Only updates when the current target is exactly claude-sonnet-4-5. + +-- 1. claude-haiku-4-5 +UPDATE accounts +SET credentials = jsonb_set(credentials, '{model_mapping,claude-haiku-4-5}', '"claude-sonnet-4-6"') +WHERE platform = 'antigravity' + AND deleted_at IS NULL + AND credentials->'model_mapping'->>'claude-haiku-4-5' = 'claude-sonnet-4-5'; + +-- 2. claude-haiku-4-5-20251001 +UPDATE accounts +SET credentials = jsonb_set(credentials, '{model_mapping,claude-haiku-4-5-20251001}', '"claude-sonnet-4-6"') +WHERE platform = 'antigravity' + AND deleted_at IS NULL + AND credentials->'model_mapping'->>'claude-haiku-4-5-20251001' = 'claude-sonnet-4-5'; diff --git a/backend/migrations/076_add_usage_log_upstream_model_index_notx.sql b/backend/migrations/076_add_usage_log_upstream_model_index_notx.sql new file mode 100644 index 00000000..9eee61be --- /dev/null +++ b/backend/migrations/076_add_usage_log_upstream_model_index_notx.sql @@ -0,0 +1,3 @@ +-- Support upstream_model / mapping model distribution aggregations with time-range filters. +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_logs_created_model_upstream_model +ON usage_logs (created_at, model, upstream_model); diff --git a/backend/migrations/077_add_usage_log_requested_model.sql b/backend/migrations/077_add_usage_log_requested_model.sql new file mode 100644 index 00000000..4b87df86 --- /dev/null +++ b/backend/migrations/077_add_usage_log_requested_model.sql @@ -0,0 +1,3 @@ +-- Add requested_model field to usage_logs for normalized request/upstream model tracking. +-- NULL means historical rows written before requested_model dual-write was introduced. +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS requested_model VARCHAR(100); diff --git a/backend/migrations/078_add_usage_log_requested_model_index_notx.sql b/backend/migrations/078_add_usage_log_requested_model_index_notx.sql new file mode 100644 index 00000000..c3412562 --- /dev/null +++ b/backend/migrations/078_add_usage_log_requested_model_index_notx.sql @@ -0,0 +1,3 @@ +-- Support requested_model / upstream_model aggregations with time-range filters. +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_logs_created_requested_model_upstream_model +ON usage_logs (created_at, requested_model, upstream_model); diff --git a/backend/migrations/079_ops_error_logs_add_endpoint_fields.sql b/backend/migrations/079_ops_error_logs_add_endpoint_fields.sql new file mode 100644 index 00000000..56f83b84 --- /dev/null +++ b/backend/migrations/079_ops_error_logs_add_endpoint_fields.sql @@ -0,0 +1,28 @@ +-- Ops error logs: add endpoint, model mapping, and request_type fields +-- to match usage_logs observability coverage. +-- +-- All columns are nullable with no default to preserve backward compatibility +-- with existing rows. + +SET LOCAL lock_timeout = '5s'; +SET LOCAL statement_timeout = '10min'; + +-- 1) Standardized endpoint paths (analogous to usage_logs.inbound_endpoint / upstream_endpoint) +ALTER TABLE ops_error_logs + ADD COLUMN IF NOT EXISTS inbound_endpoint VARCHAR(256), + ADD COLUMN IF NOT EXISTS upstream_endpoint VARCHAR(256); + +-- 2) Model mapping fields (analogous to usage_logs.requested_model / upstream_model) +ALTER TABLE ops_error_logs + ADD COLUMN IF NOT EXISTS requested_model VARCHAR(100), + ADD COLUMN IF NOT EXISTS upstream_model VARCHAR(100); + +-- 3) Granular request type enum (analogous to usage_logs.request_type: 0=unknown, 1=sync, 2=stream, 3=ws_v2) +ALTER TABLE ops_error_logs + ADD COLUMN IF NOT EXISTS request_type SMALLINT; + +COMMENT ON COLUMN ops_error_logs.inbound_endpoint IS 'Normalized client-facing API endpoint path, e.g. /v1/chat/completions. Populated from InboundEndpointMiddleware.'; +COMMENT ON COLUMN ops_error_logs.upstream_endpoint IS 'Normalized upstream endpoint path derived from platform, e.g. /v1/responses.'; +COMMENT ON COLUMN ops_error_logs.requested_model IS 'Client-requested model name before mapping (raw from request body).'; +COMMENT ON COLUMN ops_error_logs.upstream_model IS 'Actual model sent to upstream provider after mapping. NULL means no mapping applied.'; +COMMENT ON COLUMN ops_error_logs.request_type IS 'Request type enum: 0=unknown, 1=sync, 2=stream, 3=ws_v2. Matches usage_logs.request_type semantics.'; diff --git a/backend/migrations/080_create_tls_fingerprint_profiles.sql b/backend/migrations/080_create_tls_fingerprint_profiles.sql new file mode 100644 index 00000000..c13c21f8 --- /dev/null +++ b/backend/migrations/080_create_tls_fingerprint_profiles.sql @@ -0,0 +1,29 @@ +-- Create tls_fingerprint_profiles table for managing TLS fingerprint templates. +-- Each profile contains ClientHello parameters to simulate specific client TLS handshake characteristics. + +SET LOCAL lock_timeout = '5s'; +SET LOCAL statement_timeout = '10min'; + +CREATE TABLE IF NOT EXISTS tls_fingerprint_profiles ( + id BIGSERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL UNIQUE, + description TEXT, + enable_grease BOOLEAN NOT NULL DEFAULT false, + cipher_suites JSONB, + curves JSONB, + point_formats JSONB, + signature_algorithms JSONB, + alpn_protocols JSONB, + supported_versions JSONB, + key_share_groups JSONB, + psk_modes JSONB, + extensions JSONB, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +COMMENT ON TABLE tls_fingerprint_profiles IS 'TLS fingerprint templates for simulating specific client TLS handshake characteristics'; +COMMENT ON COLUMN tls_fingerprint_profiles.name IS 'Unique profile name, e.g. "macOS Node.js v24"'; +COMMENT ON COLUMN tls_fingerprint_profiles.enable_grease IS 'Whether to insert GREASE values in ClientHello extensions'; +COMMENT ON COLUMN tls_fingerprint_profiles.cipher_suites IS 'TLS cipher suite list as JSON array of uint16 (order-sensitive, affects JA3)'; +COMMENT ON COLUMN tls_fingerprint_profiles.extensions IS 'TLS extension type IDs in send order as JSON array of uint16'; diff --git a/backend/migrations/081_add_group_account_filter.sql b/backend/migrations/081_add_group_account_filter.sql new file mode 100644 index 00000000..0afb21d9 --- /dev/null +++ b/backend/migrations/081_add_group_account_filter.sql @@ -0,0 +1,2 @@ +ALTER TABLE groups ADD COLUMN IF NOT EXISTS require_oauth_only BOOLEAN NOT NULL DEFAULT false; +ALTER TABLE groups ADD COLUMN IF NOT EXISTS require_privacy_set BOOLEAN NOT NULL DEFAULT false; diff --git a/backend/migrations/README.md b/backend/migrations/README.md index 47f6fa35..40455ad9 100644 --- a/backend/migrations/README.md +++ b/backend/migrations/README.md @@ -34,18 +34,18 @@ Example: `017_add_gemini_tier_id.sql` ## Migration File Structure -```sql --- +goose Up --- +goose StatementBegin --- Your forward migration SQL here --- +goose StatementEnd +This project uses a custom migration runner (`internal/repository/migrations_runner.go`) that executes the full SQL file content as-is. --- +goose Down --- +goose StatementBegin --- Your rollback migration SQL here --- +goose StatementEnd +- Regular migrations (`*.sql`): executed in a transaction. +- Non-transactional migrations (`*_notx.sql`): split by statement and executed without transaction (for `CONCURRENTLY`). + +```sql +-- Forward-only migration (recommended) +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS example_column VARCHAR(100); ``` +> ⚠️ Do **not** place executable "Down" SQL in the same file. The runner does not parse goose Up/Down sections and will execute all SQL statements in the file. + ## Important Rules ### ⚠️ Immutability Principle @@ -66,9 +66,9 @@ Why? touch migrations/018_your_change.sql ``` -2. **Write Up and Down migrations** - - Up: Apply the change - - Down: Revert the change (should be symmetric with Up) +2. **Write forward-only migration SQL** + - Put only the intended schema change in the file + - If rollback is needed, create a new migration file to revert 3. **Test locally** ```bash @@ -144,8 +144,6 @@ touch migrations/018_your_new_change.sql ## Example Migration ```sql --- +goose Up --- +goose StatementBegin -- Add tier_id field to Gemini OAuth accounts for quota tracking UPDATE accounts SET credentials = jsonb_set( @@ -157,17 +155,6 @@ SET credentials = jsonb_set( WHERE platform = 'gemini' AND type = 'oauth' AND credentials->>'tier_id' IS NULL; --- +goose StatementEnd - --- +goose Down --- +goose StatementBegin --- Remove tier_id field -UPDATE accounts -SET credentials = credentials - 'tier_id' -WHERE platform = 'gemini' - AND type = 'oauth' - AND credentials->>'tier_id' = 'LEGACY'; --- +goose StatementEnd ``` ## Troubleshooting @@ -194,5 +181,4 @@ VALUES ('NNN_migration.sql', 'calculated_checksum', NOW()); ## References - Migration runner: `internal/repository/migrations_runner.go` -- Goose syntax: https://github.com/pressly/goose - PostgreSQL docs: https://www.postgresql.org/docs/ diff --git a/backend/resources/model-pricing/model_prices_and_context_window.json b/backend/resources/model-pricing/model_prices_and_context_window.json index 72860bf9..0a096257 100644 --- a/backend/resources/model-pricing/model_prices_and_context_window.json +++ b/backend/resources/model-pricing/model_prices_and_context_window.json @@ -5173,6 +5173,71 @@ "supports_tool_choice": true, "supports_vision": true }, + "gpt-5.4-mini": { + "cache_read_input_token_cost": 7.5e-08, + "input_cost_per_token": 7.5e-07, + "litellm_provider": "openai", + "max_input_tokens": 400000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 4.5e-06, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/batch", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_service_tier": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": true + }, + "gpt-5.4-nano": { + "cache_read_input_token_cost": 2e-08, + "input_cost_per_token": 2e-07, + "litellm_provider": "openai", + "max_input_tokens": 400000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 1.25e-06, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/batch", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": true + }, "gpt-5.3-codex": { "cache_read_input_token_cost": 1.75e-07, "cache_read_input_token_cost_priority": 3.5e-07, diff --git a/deploy/Dockerfile b/deploy/Dockerfile index 0f4f1de9..7caa5ca6 100644 --- a/deploy/Dockerfile +++ b/deploy/Dockerfile @@ -82,6 +82,7 @@ RUN apk add --no-cache \ ca-certificates \ tzdata \ curl \ + su-exec \ && rm -rf /var/cache/apk/* # Create non-root user @@ -97,8 +98,9 @@ COPY --from=backend-builder /app/sub2api /app/sub2api # Create data directory RUN mkdir -p /app/data && chown -R sub2api:sub2api /app -# Switch to non-root user -USER sub2api +# Copy entrypoint script (fixes volume permissions then drops to sub2api) +COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh +RUN chmod +x /app/docker-entrypoint.sh # Expose port (can be overridden by SERVER_PORT env var) EXPOSE 8080 @@ -107,5 +109,6 @@ EXPOSE 8080 HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \ CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1 -# Run the application -ENTRYPOINT ["/app/sub2api"] +# Run the application (entrypoint fixes /app/data ownership then execs as sub2api) +ENTRYPOINT ["/app/docker-entrypoint.sh"] +CMD ["/app/sub2api"] diff --git a/deploy/README.md b/deploy/README.md index 807bf510..dd311721 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -53,13 +53,13 @@ chmod +x docker-deploy.sh **After running the script:** ```bash # Start services -docker-compose -f docker-compose.local.yml up -d +docker compose -f docker-compose.local.yml up -d # View logs -docker-compose -f docker-compose.local.yml logs -f sub2api +docker compose -f docker-compose.local.yml logs -f sub2api # If admin password was auto-generated, find it in logs: -docker-compose -f docker-compose.local.yml logs sub2api | grep "admin password" +docker compose -f docker-compose.local.yml logs sub2api | grep "admin password" # Access Web UI # http://localhost:8080 @@ -88,10 +88,10 @@ echo "TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY}" >> .env mkdir -p data postgres_data redis_data # Start all services using local directory version -docker-compose -f docker-compose.local.yml up -d +docker compose -f docker-compose.local.yml up -d # View logs (check for auto-generated admin password) -docker-compose -f docker-compose.local.yml logs -f sub2api +docker compose -f docker-compose.local.yml logs -f sub2api # Access Web UI # http://localhost:8080 @@ -121,7 +121,7 @@ When using Docker Compose with `AUTO_SETUP=true`: 3. If `ADMIN_PASSWORD` is not set, check logs for the generated password: ```bash - docker-compose logs sub2api | grep "admin password" + docker compose logs sub2api | grep "admin password" ``` ### Database Migration Notes (PostgreSQL) @@ -162,23 +162,23 @@ For **local directory version** (docker-compose.local.yml): ```bash # Start services -docker-compose -f docker-compose.local.yml up -d +docker compose -f docker-compose.local.yml up -d # Stop services -docker-compose -f docker-compose.local.yml down +docker compose -f docker-compose.local.yml down # View logs -docker-compose -f docker-compose.local.yml logs -f sub2api +docker compose -f docker-compose.local.yml logs -f sub2api # Restart Sub2API only -docker-compose -f docker-compose.local.yml restart sub2api +docker compose -f docker-compose.local.yml restart sub2api # Update to latest version -docker-compose -f docker-compose.local.yml pull -docker-compose -f docker-compose.local.yml up -d +docker compose -f docker-compose.local.yml pull +docker compose -f docker-compose.local.yml up -d # Remove all data (caution!) -docker-compose -f docker-compose.local.yml down +docker compose -f docker-compose.local.yml down rm -rf data/ postgres_data/ redis_data/ ``` @@ -186,23 +186,23 @@ For **named volumes version** (docker-compose.yml): ```bash # Start services -docker-compose up -d +docker compose up -d # Stop services -docker-compose down +docker compose down # View logs -docker-compose logs -f sub2api +docker compose logs -f sub2api # Restart Sub2API only -docker-compose restart sub2api +docker compose restart sub2api # Update to latest version -docker-compose pull -docker-compose up -d +docker compose pull +docker compose up -d # Remove all data (caution!) -docker-compose down -v +docker compose down -v ``` ### Environment Variables @@ -232,7 +232,7 @@ When using `docker-compose.local.yml`, all data is stored in local directories, ```bash # On source server: Stop services and create archive cd /path/to/deployment -docker-compose -f docker-compose.local.yml down +docker compose -f docker-compose.local.yml down cd .. tar czf sub2api-complete.tar.gz deployment/ @@ -242,7 +242,7 @@ scp sub2api-complete.tar.gz user@new-server:/path/to/destination/ # On new server: Extract and start tar xzf sub2api-complete.tar.gz cd deployment/ -docker-compose -f docker-compose.local.yml up -d +docker compose -f docker-compose.local.yml up -d ``` Your entire deployment (configuration + data) is migrated! @@ -492,19 +492,19 @@ For **local directory version**: ```bash # Check container status -docker-compose -f docker-compose.local.yml ps +docker compose -f docker-compose.local.yml ps # View detailed logs -docker-compose -f docker-compose.local.yml logs --tail=100 sub2api +docker compose -f docker-compose.local.yml logs --tail=100 sub2api # Check database connection -docker-compose -f docker-compose.local.yml exec postgres pg_isready +docker compose -f docker-compose.local.yml exec postgres pg_isready # Check Redis connection -docker-compose -f docker-compose.local.yml exec redis redis-cli ping +docker compose -f docker-compose.local.yml exec redis redis-cli ping # Restart all services -docker-compose -f docker-compose.local.yml restart +docker compose -f docker-compose.local.yml restart # Check data directories ls -la data/ postgres_data/ redis_data/ @@ -514,19 +514,19 @@ For **named volumes version**: ```bash # Check container status -docker-compose ps +docker compose ps # View detailed logs -docker-compose logs --tail=100 sub2api +docker compose logs --tail=100 sub2api # Check database connection -docker-compose exec postgres pg_isready +docker compose exec postgres pg_isready # Check Redis connection -docker-compose exec redis redis-cli ping +docker compose exec redis redis-cli ping # Restart all services -docker-compose restart +docker compose restart ``` ### Binary Install diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 2058ced1..8f60acd5 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -865,10 +865,10 @@ rate_limit: pricing: # URL to fetch model pricing data (default: pinned model-price-repo commit) # 获取模型定价数据的 URL(默认:固定 commit 的 model-price-repo) - remote_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json" + remote_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/refs/heads/main//model_prices_and_context_window.json" # Hash verification URL (optional) # 哈希校验 URL(可选) - hash_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256" + hash_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/refs/heads/main//model_prices_and_context_window.sha256" # Local data directory for caching # 本地数据缓存目录 data_dir: "./data" diff --git a/deploy/docker-compose.local.yml b/deploy/docker-compose.local.yml index d404ac0b..5aea78fb 100644 --- a/deploy/docker-compose.local.yml +++ b/deploy/docker-compose.local.yml @@ -38,7 +38,7 @@ services: - ./data:/app/data # Optional: Mount custom config.yaml (uncomment and create the file first) # Copy config.example.yaml to config.yaml, modify it, then uncomment: - # - ./config.yaml:/app/data/config.yaml:ro + # - ./config.yaml:/app/data/config.yaml environment: # ======================================================================= # Auto Setup (REQUIRED for Docker deployment) diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 99b05446..d1a564a3 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -29,6 +29,9 @@ services: - "${BIND_HOST:-0.0.0.0}:${SERVER_PORT:-8080}:8080" volumes: - sub2api_data:/app/data + # Optional: Mount custom config.yaml (uncomment and create the file first) + # Copy config.example.yaml to config.yaml, modify it, then uncomment: + # - ./config.yaml:/app/data/config.yaml environment: # Auto Setup - AUTO_SETUP=true diff --git a/deploy/docker-entrypoint.sh b/deploy/docker-entrypoint.sh new file mode 100644 index 00000000..47ab6bf1 --- /dev/null +++ b/deploy/docker-entrypoint.sh @@ -0,0 +1,23 @@ +#!/bin/sh +set -e + +# Fix data directory permissions when running as root. +# Docker named volumes / host bind-mounts may be owned by root, +# preventing the non-root sub2api user from writing files. +if [ "$(id -u)" = "0" ]; then + mkdir -p /app/data + # Use || true to avoid failure on read-only mounted files (e.g. config.yaml:ro) + chown -R sub2api:sub2api /app/data 2>/dev/null || true + # Re-invoke this script as sub2api so the flag-detection below + # also runs under the correct user. + exec su-exec sub2api "$0" "$@" +fi + +# Compatibility: if the first arg looks like a flag (e.g. --help), +# prepend the default binary so it behaves the same as the old +# ENTRYPOINT ["/app/sub2api"] style. +if [ "${1#-}" != "$1" ]; then + set -- /app/sub2api "$@" +fi + +exec "$@" diff --git a/frontend/package.json b/frontend/package.json index 1b380b17..d2a6dede 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -16,6 +16,7 @@ }, "dependencies": { "@lobehub/icons": "^4.0.2", + "@tanstack/vue-virtual": "^3.13.23", "@vueuse/core": "^10.7.0", "axios": "^1.13.5", "chart.js": "^4.4.1", diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml index 37c384b4..505b72f3 100644 --- a/frontend/pnpm-lock.yaml +++ b/frontend/pnpm-lock.yaml @@ -11,6 +11,9 @@ importers: '@lobehub/icons': specifier: ^4.0.2 version: 4.0.2(@lobehub/ui@4.9.2)(@types/react@19.2.7)(antd@6.1.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react-dom@19.2.3(react@19.2.3))(react@19.2.3) + '@tanstack/vue-virtual': + specifier: ^3.13.23 + version: 3.13.23(vue@3.5.26(typescript@5.6.3)) '@vueuse/core': specifier: ^10.7.0 version: 10.11.1(vue@3.5.26(typescript@5.6.3)) @@ -1376,6 +1379,14 @@ packages: peerDependencies: react: '>= 16.3.0' + '@tanstack/virtual-core@3.13.23': + resolution: {integrity: sha512-zSz2Z2HNyLjCplANTDyl3BcdQJc2k1+yyFoKhNRmCr7V7dY8o8q5m8uFTI1/Pg1kL+Hgrz6u3Xo6eFUB7l66cg==} + + '@tanstack/vue-virtual@3.13.23': + resolution: {integrity: sha512-b5jPluAR6U3eOq6GWAYSpj3ugnAIZgGR0e6aGAgyRse0Yu6MVQQ0ZWm9SArSXWtageogn6bkVD8D//c4IjW3xQ==} + peerDependencies: + vue: ^2.7.0 || ^3.0.0 + '@types/d3-array@3.2.2': resolution: {integrity: sha512-hOLWVbm7uRza0BYXpIIW5pxfrKe0W+D5lrFiAEYR+pb6w3N2SwSMaJbXdUfSEv+dT4MfHBLtn5js0LAWaO6otw==} @@ -5808,6 +5819,13 @@ snapshots: dependencies: react: 19.2.3 + '@tanstack/virtual-core@3.13.23': {} + + '@tanstack/vue-virtual@3.13.23(vue@3.5.26(typescript@5.6.3))': + dependencies: + '@tanstack/virtual-core': 3.13.23 + vue: 3.5.26(typescript@5.6.3) + '@types/d3-array@3.2.2': {} '@types/d3-axis@3.0.6': diff --git a/frontend/src/App.vue b/frontend/src/App.vue index 4fc6a7c8..7485aa1a 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -3,6 +3,7 @@ import { RouterView, useRouter, useRoute } from 'vue-router' import { onMounted, onBeforeUnmount, watch } from 'vue' import Toast from '@/components/common/Toast.vue' import NavigationProgress from '@/components/common/NavigationProgress.vue' +import { resolveDocumentTitle } from '@/router/title' import AnnouncementPopup from '@/components/common/AnnouncementPopup.vue' import { useAppStore, useAuthStore, useSubscriptionStore, useAnnouncementStore } from '@/stores' import { getSetupStatus } from '@/api/setup' @@ -104,6 +105,9 @@ onMounted(async () => { // Load public settings into appStore (will be cached for other components) await appStore.fetchPublicSettings() + + // Re-resolve document title now that siteName is available + document.title = resolveDocumentTitle(route.meta.title, appStore.siteName, route.meta.titleKey as string) }) diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 23d50d3a..fd93fe7e 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -36,6 +36,7 @@ export async function list( status?: string group?: string search?: string + privacy_mode?: string lite?: string }, options?: { @@ -66,7 +67,9 @@ export async function listWithEtag( platform?: string type?: string status?: string + group?: string search?: string + privacy_mode?: string lite?: string }, options?: { @@ -223,8 +226,10 @@ export async function clearError(id: number): Promise { * @param id - Account ID * @returns Account usage info */ -export async function getUsage(id: number): Promise { - const { data } = await apiClient.get(`/admin/accounts/${id}/usage`) +export async function getUsage(id: number, source?: 'passive' | 'active'): Promise { + const { data } = await apiClient.get(`/admin/accounts/${id}/usage`, { + params: source ? { source } : undefined + }) return data } @@ -547,14 +552,18 @@ export async function getAntigravityDefaultModelMapping(): Promise> { - const payload: { refresh_token: string; proxy_id?: number } = { + const payload: { refresh_token: string; proxy_id?: number; client_id?: string } = { refresh_token: refreshToken } if (proxyId) { payload.proxy_id = proxyId } + if (clientId) { + payload.client_id = clientId + } const { data } = await apiClient.post>(endpoint, payload) return data } @@ -618,6 +627,16 @@ export async function batchRefresh(accountIds: number[]): Promise { + const { data } = await apiClient.post(`/admin/accounts/${id}/set-privacy`) + return data +} + export const accountsAPI = { list, listWithEtag, @@ -654,7 +673,8 @@ export const accountsAPI = { importData, getAntigravityDefaultModelMapping, batchClearError, - batchRefresh + batchRefresh, + setPrivacy } export default accountsAPI diff --git a/frontend/src/api/admin/backup.ts b/frontend/src/api/admin/backup.ts index d349c862..bccb1f80 100644 --- a/frontend/src/api/admin/backup.ts +++ b/frontend/src/api/admin/backup.ts @@ -29,6 +29,10 @@ export interface BackupRecord { started_at: string finished_at?: string expires_at?: string + progress?: string + restore_status?: string + restore_error?: string + restored_at?: string } export interface CreateBackupRequest { @@ -69,7 +73,7 @@ export async function updateSchedule(config: BackupScheduleConfig): Promise { - const { data } = await apiClient.post('/admin/backups', req || {}, { timeout: 600000 }) + const { data } = await apiClient.post('/admin/backups', req || {}) return data } @@ -93,8 +97,9 @@ export async function getDownloadURL(id: string): Promise<{ url: string }> { } // Restore -export async function restoreBackup(id: string, password: string): Promise { - await apiClient.post(`/admin/backups/${id}/restore`, { password }, { timeout: 600000 }) +export async function restoreBackup(id: string, password: string): Promise { + const { data } = await apiClient.post(`/admin/backups/${id}/restore`, { password }) + return data } export const backupAPI = { diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts index 85200506..15d1540f 100644 --- a/frontend/src/api/admin/dashboard.ts +++ b/frontend/src/api/admin/dashboard.ts @@ -12,6 +12,7 @@ import type { ApiKeyUsageTrendPoint, UserUsageTrendPoint, UserSpendingRankingResponse, + UserBreakdownItem, UsageRequestType } from '@/types' @@ -80,6 +81,7 @@ export interface ModelStatsParams { user_id?: number api_key_id?: number model?: string + model_source?: 'requested' | 'upstream' | 'mapping' account_id?: number group_id?: number request_type?: UsageRequestType @@ -156,6 +158,30 @@ export async function getGroupStats(params?: GroupStatsParams): Promise { + const { data } = await apiClient.get('/admin/dashboard/user-breakdown', { + params + }) + return data +} + /** * Get dashboard snapshot v2 (aggregated response for heavy admin pages). */ diff --git a/frontend/src/api/admin/groups.ts b/frontend/src/api/admin/groups.ts index 7c2658fa..5885dc6a 100644 --- a/frontend/src/api/admin/groups.ts +++ b/frontend/src/api/admin/groups.ts @@ -218,6 +218,34 @@ export async function batchSetGroupRateMultipliers( return data } +/** + * Get usage summary (today + cumulative cost) for all groups + * @param timezone - IANA timezone string (e.g. "Asia/Shanghai") + * @returns Array of group usage summaries + */ +export async function getUsageSummary( + timezone?: string +): Promise<{ group_id: number; today_cost: number; total_cost: number }[]> { + const { data } = await apiClient.get< + { group_id: number; today_cost: number; total_cost: number }[] + >('/admin/groups/usage-summary', { + params: timezone ? { timezone } : undefined + }) + return data +} + +/** + * Get capacity summary (concurrency/sessions/RPM) for all active groups + */ +export async function getCapacitySummary(): Promise< + { group_id: number; concurrency_used: number; concurrency_max: number; sessions_used: number; sessions_max: number; rpm_used: number; rpm_max: number }[] +> { + const { data } = await apiClient.get< + { group_id: number; concurrency_used: number; concurrency_max: number; sessions_used: number; sessions_max: number; rpm_used: number; rpm_max: number }[] + >('/admin/groups/capacity-summary') + return data +} + export const groupsAPI = { list, getAll, @@ -232,7 +260,9 @@ export const groupsAPI = { getGroupRateMultipliers, clearGroupRateMultipliers, batchSetGroupRateMultipliers, - updateSortOrder + updateSortOrder, + getUsageSummary, + getCapacitySummary } export default groupsAPI diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index a6ebfc2c..9a3fb8c5 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -24,6 +24,7 @@ import dataManagementAPI from './dataManagement' import apiKeysAPI from './apiKeys' import scheduledTestsAPI from './scheduledTests' import backupAPI from './backup' +import tlsFingerprintProfileAPI from './tlsFingerprintProfile' /** * Unified admin API object for convenient access @@ -49,7 +50,8 @@ export const adminAPI = { dataManagement: dataManagementAPI, apiKeys: apiKeysAPI, scheduledTests: scheduledTestsAPI, - backup: backupAPI + backup: backupAPI, + tlsFingerprintProfiles: tlsFingerprintProfileAPI } export { @@ -73,7 +75,8 @@ export { dataManagementAPI, apiKeysAPI, scheduledTestsAPI, - backupAPI + backupAPI, + tlsFingerprintProfileAPI } export default adminAPI @@ -82,3 +85,4 @@ export default adminAPI export type { BalanceHistoryItem } from './users' export type { ErrorPassthroughRule, CreateRuleRequest, UpdateRuleRequest } from './errorPassthrough' export type { BackupAgentHealth, DataManagementConfig } from './dataManagement' +export type { TLSFingerprintProfile, CreateProfileRequest, UpdateProfileRequest } from './tlsFingerprintProfile' diff --git a/frontend/src/api/admin/ops.ts b/frontend/src/api/admin/ops.ts index 64f6a6d0..ac58eff4 100644 --- a/frontend/src/api/admin/ops.ts +++ b/frontend/src/api/admin/ops.ts @@ -969,6 +969,13 @@ export interface OpsErrorLog { client_ip?: string | null request_path?: string stream?: boolean + + // Error observability context (endpoint + model mapping) + inbound_endpoint?: string + upstream_endpoint?: string + requested_model?: string + upstream_model?: string + request_type?: number | null } export interface OpsErrorDetail extends OpsErrorLog { diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 040cf71e..cabdd5aa 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -4,7 +4,7 @@ */ import { apiClient } from '../client' -import type { CustomMenuItem } from '@/types' +import type { CustomMenuItem, CustomEndpoint } from '@/types' export interface DefaultSubscriptionSetting { group_id: number @@ -43,6 +43,7 @@ export interface SystemSettings { sora_client_enabled: boolean backend_mode_enabled: boolean custom_menu_items: CustomMenuItem[] + custom_endpoints: CustomEndpoint[] // SMTP settings smtp_host: string smtp_port: number @@ -81,9 +82,14 @@ export interface SystemSettings { // Claude Code version check min_claude_code_version: string + max_claude_code_version: string // 分组隔离 allow_ungrouped_key_scheduling: boolean + + // Gateway forwarding behavior + enable_fingerprint_unification: boolean + enable_metadata_passthrough: boolean } export interface UpdateSettingsRequest { @@ -111,6 +117,7 @@ export interface UpdateSettingsRequest { sora_client_enabled?: boolean backend_mode_enabled?: boolean custom_menu_items?: CustomMenuItem[] + custom_endpoints?: CustomEndpoint[] smtp_host?: string smtp_port?: number smtp_username?: string @@ -137,7 +144,10 @@ export interface UpdateSettingsRequest { ops_query_mode_default?: 'auto' | 'raw' | 'preagg' | string ops_metrics_interval_seconds?: number min_claude_code_version?: string + max_claude_code_version?: string allow_ungrouped_key_scheduling?: boolean + enable_fingerprint_unification?: boolean + enable_metadata_passthrough?: boolean } /** @@ -242,6 +252,33 @@ export async function deleteAdminApiKey(): Promise<{ message: string }> { return data } +// ==================== Overload Cooldown Settings ==================== + +/** + * Overload cooldown settings interface (529 handling) + */ +export interface OverloadCooldownSettings { + enabled: boolean + cooldown_minutes: number +} + +export async function getOverloadCooldownSettings(): Promise { + const { data } = await apiClient.get('/admin/settings/overload-cooldown') + return data +} + +export async function updateOverloadCooldownSettings( + settings: OverloadCooldownSettings +): Promise { + const { data } = await apiClient.put( + '/admin/settings/overload-cooldown', + settings + ) + return data +} + +// ==================== Stream Timeout Settings ==================== + /** * Stream timeout settings interface */ @@ -286,6 +323,8 @@ export interface RectifierSettings { enabled: boolean thinking_signature_enabled: boolean thinking_budget_enabled: boolean + apikey_signature_enabled: boolean + apikey_signature_patterns: string[] } /** @@ -499,6 +538,8 @@ export const settingsAPI = { getAdminApiKey, regenerateAdminApiKey, deleteAdminApiKey, + getOverloadCooldownSettings, + updateOverloadCooldownSettings, getStreamTimeoutSettings, updateStreamTimeoutSettings, getRectifierSettings, diff --git a/frontend/src/api/admin/subscriptions.ts b/frontend/src/api/admin/subscriptions.ts index 7557e3ad..611f67c2 100644 --- a/frontend/src/api/admin/subscriptions.ts +++ b/frontend/src/api/admin/subscriptions.ts @@ -27,6 +27,7 @@ export async function list( status?: 'active' | 'expired' | 'revoked' user_id?: number group_id?: number + platform?: string sort_by?: string sort_order?: 'asc' | 'desc' }, diff --git a/frontend/src/api/admin/tlsFingerprintProfile.ts b/frontend/src/api/admin/tlsFingerprintProfile.ts new file mode 100644 index 00000000..f6a26dd5 --- /dev/null +++ b/frontend/src/api/admin/tlsFingerprintProfile.ts @@ -0,0 +1,98 @@ +/** + * Admin TLS Fingerprint Profile API endpoints + * Handles TLS fingerprint profile CRUD for administrators + */ + +import { apiClient } from '../client' + +/** + * TLS fingerprint profile interface + */ +export interface TLSFingerprintProfile { + id: number + name: string + description: string | null + enable_grease: boolean + cipher_suites: number[] + curves: number[] + point_formats: number[] + signature_algorithms: number[] + alpn_protocols: string[] + supported_versions: number[] + key_share_groups: number[] + psk_modes: number[] + extensions: number[] + created_at: string + updated_at: string +} + +/** + * Create profile request + */ +export interface CreateProfileRequest { + name: string + description?: string | null + enable_grease?: boolean + cipher_suites?: number[] + curves?: number[] + point_formats?: number[] + signature_algorithms?: number[] + alpn_protocols?: string[] + supported_versions?: number[] + key_share_groups?: number[] + psk_modes?: number[] + extensions?: number[] +} + +/** + * Update profile request + */ +export interface UpdateProfileRequest { + name?: string + description?: string | null + enable_grease?: boolean + cipher_suites?: number[] + curves?: number[] + point_formats?: number[] + signature_algorithms?: number[] + alpn_protocols?: string[] + supported_versions?: number[] + key_share_groups?: number[] + psk_modes?: number[] + extensions?: number[] +} + +export async function list(): Promise { + const { data } = await apiClient.get('/admin/tls-fingerprint-profiles') + return data +} + +export async function getById(id: number): Promise { + const { data } = await apiClient.get(`/admin/tls-fingerprint-profiles/${id}`) + return data +} + +export async function create(profileData: CreateProfileRequest): Promise { + const { data } = await apiClient.post('/admin/tls-fingerprint-profiles', profileData) + return data +} + +export async function update(id: number, updates: UpdateProfileRequest): Promise { + const { data } = await apiClient.put(`/admin/tls-fingerprint-profiles/${id}`, updates) + return data +} + +export async function deleteProfile(id: number): Promise<{ message: string }> { + const { data } = await apiClient.delete<{ message: string }>(`/admin/tls-fingerprint-profiles/${id}`) + return data +} + +export const tlsFingerprintProfileAPI = { + list, + getById, + create, + update, + delete: deleteProfile +} + +export default tlsFingerprintProfileAPI diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts index d631a5b7..bbf0ab51 100644 --- a/frontend/src/api/admin/users.ts +++ b/frontend/src/api/admin/users.ts @@ -21,6 +21,7 @@ export async function list( status?: 'active' | 'disabled' role?: 'admin' | 'user' search?: string + group_name?: string // fuzzy filter by allowed group name attributes?: Record // attributeId -> value include_subscriptions?: boolean }, @@ -35,6 +36,7 @@ export async function list( status: filters?.status, role: filters?.role, search: filters?.search, + group_name: filters?.group_name, include_subscriptions: filters?.include_subscriptions } @@ -223,6 +225,25 @@ export async function getUserBalanceHistory( return data } +/** + * Replace user's exclusive group + * @param userId - User ID + * @param oldGroupId - Current group ID to replace + * @param newGroupId - New group ID to replace with + * @returns Number of migrated keys + */ +export async function replaceGroup( + userId: number, + oldGroupId: number, + newGroupId: number +): Promise<{ migrated_keys: number }> { + const { data } = await apiClient.post<{ migrated_keys: number }>( + `/admin/users/${userId}/replace-group`, + { old_group_id: oldGroupId, new_group_id: newGroupId } + ) + return data +} + export const usersAPI = { list, getById, @@ -234,7 +255,8 @@ export const usersAPI = { toggleStatus, getUserApiKeys, getUserUsageStats, - getUserBalanceHistory + getUserBalanceHistory, + replaceGroup } export default usersAPI diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue index 09236edd..37e18c35 100644 --- a/frontend/src/components/account/AccountUsageCell.vue +++ b/frontend/src/components/account/AccountUsageCell.vue @@ -67,21 +67,54 @@ :resets-at="usageInfo.seven_day_sonnet.resets_at" color="purple" /> + + +
+ + {{ t('admin.accounts.usageWindow.passiveSampled') }} + + +
-
- + @@ -389,8 +374,43 @@
- -
+ +
+ +
+
+ + {{ formatKeyRequests }} req + + + {{ formatKeyTokens }} + + + A ${{ formatKeyCost }} + + + U ${{ formatKeyUserCost }} + +
+
+ +
+
+
+
+
+ + + + +
-
-
-
@@ -422,17 +444,28 @@ import { useI18n } from 'vue-i18n' import { adminAPI } from '@/api/admin' import type { Account, AccountUsageInfo, GeminiCredentials, WindowStats } from '@/types' import { buildOpenAIUsageRefreshKey } from '@/utils/accountUsageRefresh' -import { resolveCodexUsageWindow } from '@/utils/codexUsage' +import { formatCompactNumber } from '@/utils/format' import UsageProgressBar from './UsageProgressBar.vue' import AccountQuotaInfo from './AccountQuotaInfo.vue' -const props = defineProps<{ - account: Account -}>() +const props = withDefaults( + defineProps<{ + account: Account + todayStats?: WindowStats | null + todayStatsLoading?: boolean + manualRefreshToken?: number + }>(), + { + todayStats: null, + todayStatsLoading: false, + manualRefreshToken: 0 + } +) const { t } = useI18n() const loading = ref(false) +const activeQueryLoading = ref(false) const error = ref(null) const usageInfo = ref(null) @@ -470,54 +503,17 @@ const geminiUsageAvailable = computed(() => { ) }) -const codex5hWindow = computed(() => resolveCodexUsageWindow(props.account.extra, '5h')) -const codex7dWindow = computed(() => resolveCodexUsageWindow(props.account.extra, '7d')) - -// OpenAI Codex usage computed properties -const hasCodexUsage = computed(() => { - return codex5hWindow.value.usedPercent !== null || codex7dWindow.value.usedPercent !== null -}) - const hasOpenAIUsageFallback = computed(() => { if (props.account.platform !== 'openai' || props.account.type !== 'oauth') return false return !!usageInfo.value?.five_hour || !!usageInfo.value?.seven_day }) -const isActiveOpenAIRateLimited = computed(() => { - if (props.account.platform !== 'openai' || props.account.type !== 'oauth') return false - if (!props.account.rate_limit_reset_at) return false - const resetAt = Date.parse(props.account.rate_limit_reset_at) - return !Number.isNaN(resetAt) && resetAt > Date.now() -}) - -const preferFetchedOpenAIUsage = computed(() => { - return (isActiveOpenAIRateLimited.value || isOpenAICodexSnapshotStale.value) && hasOpenAIUsageFallback.value -}) - const openAIUsageRefreshKey = computed(() => buildOpenAIUsageRefreshKey(props.account)) -const isOpenAICodexSnapshotStale = computed(() => { - if (props.account.platform !== 'openai' || props.account.type !== 'oauth') return false - const extra = props.account.extra as Record | undefined - const updatedAtRaw = extra?.codex_usage_updated_at - if (!updatedAtRaw) return true - const updatedAt = Date.parse(String(updatedAtRaw)) - if (Number.isNaN(updatedAt)) return true - return Date.now() - updatedAt >= 10 * 60 * 1000 -}) - const shouldAutoLoadUsageOnMount = computed(() => { - if (props.account.platform === 'openai' && props.account.type === 'oauth') { - return isActiveOpenAIRateLimited.value || !hasCodexUsage.value || isOpenAICodexSnapshotStale.value - } return shouldFetchUsage.value }) -const codex5hUsedPercent = computed(() => codex5hWindow.value.usedPercent) -const codex5hResetAt = computed(() => codex5hWindow.value.resetAt) -const codex7dUsedPercent = computed(() => codex7dWindow.value.usedPercent) -const codex7dResetAt = computed(() => codex7dWindow.value.resetAt) - // Antigravity quota types (用于 API 返回的数据) interface AntigravityUsageResult { utilization: number @@ -925,14 +921,18 @@ const copyValidationURL = async () => { } } -const loadUsage = async () => { +const isAnthropicOAuthOrSetupToken = computed(() => { + return props.account.platform === 'anthropic' && (props.account.type === 'oauth' || props.account.type === 'setup-token') +}) + +const loadUsage = async (source?: 'passive' | 'active') => { if (!shouldFetchUsage.value) return loading.value = true error.value = null try { - usageInfo.value = await adminAPI.accounts.getUsage(props.account.id) + usageInfo.value = await adminAPI.accounts.getUsage(props.account.id, source) } catch (e: any) { error.value = t('common.error') console.error('Failed to load usage:', e) @@ -941,6 +941,17 @@ const loadUsage = async () => { } } +const loadActiveUsage = async () => { + activeQueryLoading.value = true + try { + usageInfo.value = await adminAPI.accounts.getUsage(props.account.id, 'active') + } catch (e: any) { + console.error('Failed to load active usage:', e) + } finally { + activeQueryLoading.value = false + } +} + // ===== API Key quota progress bars ===== interface QuotaBarInfo { @@ -1006,18 +1017,53 @@ const quotaTotalBar = computed((): QuotaBarInfo | null => { return makeQuotaBar(props.account.quota_used ?? 0, limit) }) +// ===== Key account today stats formatters ===== + +const formatKeyRequests = computed(() => { + if (!props.todayStats) return '' + return formatCompactNumber(props.todayStats.requests, { allowBillions: false }) +}) + +const formatKeyTokens = computed(() => { + if (!props.todayStats) return '' + return formatCompactNumber(props.todayStats.tokens) +}) + +const formatKeyCost = computed(() => { + if (!props.todayStats) return '0.00' + return props.todayStats.cost.toFixed(2) +}) + +const formatKeyUserCost = computed(() => { + if (!props.todayStats || props.todayStats.user_cost == null) return '0.00' + return props.todayStats.user_cost.toFixed(2) +}) + onMounted(() => { if (!shouldAutoLoadUsageOnMount.value) return - loadUsage() + const source = isAnthropicOAuthOrSetupToken.value ? 'passive' : undefined + loadUsage(source) }) watch(openAIUsageRefreshKey, (nextKey, prevKey) => { if (!prevKey || nextKey === prevKey) return if (props.account.platform !== 'openai' || props.account.type !== 'oauth') return - if (!isActiveOpenAIRateLimited.value && hasCodexUsage.value && !isOpenAICodexSnapshotStale.value) return loadUsage().catch((e) => { console.error('Failed to refresh OpenAI usage:', e) }) }) + +watch( + () => props.manualRefreshToken, + (nextToken, prevToken) => { + if (nextToken === prevToken) return + if (!shouldFetchUsage.value) return + + const source = isAnthropicOAuthOrSetupToken.value ? 'passive' : undefined + loadUsage(source).catch((e) => { + console.error('Failed to refresh usage after manual refresh:', e) + }) + } +) diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index 64524d51..2934fbd9 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -31,6 +31,57 @@

+ +
+
+
+ +

+ {{ t('admin.accounts.openai.oauthPassthroughDesc') }} +

+
+ +
+
+ +
+
+
@@ -89,100 +140,30 @@ role="group" aria-labelledby="bulk-edit-model-restriction-label" > - -
- - -
- - -
-
-

- - - - {{ t('admin.accounts.selectAllowedModels') }} -

-
- - - -

- {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} - {{ - t('admin.accounts.supportsAllModels') - }} +

+

+ {{ t('admin.accounts.openai.modelRestrictionDisabledByPassthrough') }}

- -
-
-

+

@@ -599,6 +661,43 @@
+ +
+
+ + +
+
+

+ {{ t('admin.accounts.openai.wsModeDesc') }} +

+

+ {{ t(openAIWSModeConcurrencyHintKey) }} +

+ + + + + +
@@ -2237,6 +2245,41 @@

+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.customBaseUrl.hint') }} +

+
+ +
+
+ +
+
@@ -2504,6 +2547,7 @@ :allow-multiple="form.platform === 'anthropic'" :show-cookie-option="form.platform === 'anthropic'" :show-refresh-token-option="form.platform === 'openai' || form.platform === 'sora' || form.platform === 'antigravity'" + :show-mobile-refresh-token-option="form.platform === 'openai'" :show-session-token-option="form.platform === 'sora'" :show-access-token-option="form.platform === 'sora'" :platform="form.platform" @@ -2511,6 +2555,7 @@ @generate-url="handleGenerateUrl" @cookie-auth="handleCookieAuth" @validate-refresh-token="handleValidateRefreshToken" + @validate-mobile-refresh-token="handleOpenAIValidateMobileRT" @validate-session-token="handleValidateSessionToken" @import-access-token="handleImportAccessToken" /> @@ -3080,9 +3125,13 @@ const umqModeOptions = computed(() => [ { value: 'serialize', label: t('admin.accounts.quotaControl.rpmLimit.umqModeSerialize') }, ]) const tlsFingerprintEnabled = ref(false) +const tlsFingerprintProfileId = ref(null) +const tlsFingerprintProfiles = ref<{ id: number; name: string }[]>([]) const sessionIdMaskingEnabled = ref(false) const cacheTTLOverrideEnabled = ref(false) const cacheTTLOverrideTarget = ref('5m') +const customBaseUrlEnabled = ref(false) +const customBaseUrl = ref('') // Gemini tier selection (used as fallback when auto-detection is unavailable/fails) const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free') @@ -3245,6 +3294,10 @@ watch( () => props.show, (newVal) => { if (newVal) { + // Load TLS fingerprint profiles + adminAPI.tlsFingerprintProfiles.list() + .then(profiles => { tlsFingerprintProfiles.value = profiles.map(p => ({ id: p.id, name: p.name })) }) + .catch(() => { tlsFingerprintProfiles.value = [] }) // Modal opened - fill related models allowedModels.value = [...getModelsByPlatform(form.platform)] // Antigravity: 默认使用映射模式并填充默认映射 @@ -3745,9 +3798,12 @@ const resetForm = () => { rpmStickyBuffer.value = null userMsgQueueMode.value = '' tlsFingerprintEnabled.value = false + tlsFingerprintProfileId.value = null sessionIdMaskingEnabled.value = false cacheTTLOverrideEnabled.value = false cacheTTLOverrideTarget.value = '5m' + customBaseUrlEnabled.value = false + customBaseUrl.value = '' allowOverages.value = false antigravityAccountType.value = 'oauth' upstreamBaseUrl.value = '' @@ -4360,11 +4416,14 @@ const handleOpenAIExchange = async (authCode: string) => { } // OpenAI 手动 RT 批量验证和创建 -const handleOpenAIValidateRT = async (refreshTokenInput: string) => { +// OpenAI Mobile RT 使用的 client_id(与后端 openai.SoraClientID 一致) +const OPENAI_MOBILE_RT_CLIENT_ID = 'app_LlGpXReQgckcGGUo2JrYvtJK' + +// OpenAI/Sora RT 批量验证和创建(共享逻辑) +const handleOpenAIBatchRT = async (refreshTokenInput: string, clientId?: string) => { const oauthClient = activeOpenAIOAuth.value if (!refreshTokenInput.trim()) return - // Parse multiple refresh tokens (one per line) const refreshTokens = refreshTokenInput .split('\n') .map((rt) => rt.trim()) @@ -4389,7 +4448,8 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { try { const tokenInfo = await oauthClient.validateRefreshToken( refreshTokens[i], - form.proxy_id + form.proxy_id, + clientId ) if (!tokenInfo) { failedCount++ @@ -4399,6 +4459,9 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { } const credentials = oauthClient.buildCredentials(tokenInfo) + if (clientId) { + credentials.client_id = clientId + } const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record | undefined const extra = buildOpenAIExtra(oauthExtra) @@ -4410,8 +4473,9 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { } } - // Generate account name with index for batch - const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name + // Generate account name; fallback to email if name is empty (ent schema requires NotEmpty) + const baseName = form.name || tokenInfo.email || 'OpenAI OAuth Account' + const accountName = refreshTokens.length > 1 ? `${baseName} #${i + 1}` : baseName let openaiAccountId: string | number | undefined @@ -4494,6 +4558,12 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { } } +// 手动输入 RT(Codex CLI client_id,默认) +const handleOpenAIValidateRT = (rt: string) => handleOpenAIBatchRT(rt) + +// 手动输入 Mobile RT(SoraClientID) +const handleOpenAIValidateMobileRT = (rt: string) => handleOpenAIBatchRT(rt, OPENAI_MOBILE_RT_CLIENT_ID) + // Sora 手动 ST 批量验证和创建 const handleSoraValidateST = async (sessionTokenInput: string) => { const oauthClient = activeOpenAIOAuth.value @@ -4809,6 +4879,9 @@ const handleAnthropicExchange = async (authCode: string) => { // Add TLS fingerprint settings if (tlsFingerprintEnabled.value) { extra.enable_tls_fingerprint = true + if (tlsFingerprintProfileId.value) { + extra.tls_fingerprint_profile_id = tlsFingerprintProfileId.value + } } // Add session ID masking settings @@ -4822,6 +4895,12 @@ const handleAnthropicExchange = async (authCode: string) => { extra.cache_ttl_override_target = cacheTTLOverrideTarget.value } + // Add custom base URL settings + if (customBaseUrlEnabled.value && customBaseUrl.value.trim()) { + extra.custom_base_url_enabled = true + extra.custom_base_url = customBaseUrl.value.trim() + } + const credentials: Record = { ...tokenInfo } applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') await createAccountAndFinish(form.platform, addMethod.value as AccountType, credentials, extra) @@ -4924,6 +5003,9 @@ const handleCookieAuth = async (sessionKey: string) => { // Add TLS fingerprint settings if (tlsFingerprintEnabled.value) { extra.enable_tls_fingerprint = true + if (tlsFingerprintProfileId.value) { + extra.tls_fingerprint_profile_id = tlsFingerprintProfileId.value + } } // Add session ID masking settings @@ -4937,6 +5019,12 @@ const handleCookieAuth = async (sessionKey: string) => { extra.cache_ttl_override_target = cacheTTLOverrideTarget.value } + // Add custom base URL settings + if (customBaseUrlEnabled.value && customBaseUrl.value.trim()) { + extra.custom_base_url_enabled = true + extra.custom_base_url = customBaseUrl.value.trim() + } + const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name const credentials: Record = { ...tokenInfo } diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index c2f2f7d2..607e7a69 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1504,6 +1504,14 @@ />
+ +
+ +
@@ -1572,6 +1580,41 @@

+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.customBaseUrl.hint') }} +

+
+ +
+
+ +
+
@@ -1841,9 +1884,13 @@ const umqModeOptions = computed(() => [ { value: 'serialize', label: t('admin.accounts.quotaControl.rpmLimit.umqModeSerialize') }, ]) const tlsFingerprintEnabled = ref(false) +const tlsFingerprintProfileId = ref(null) +const tlsFingerprintProfiles = ref<{ id: number; name: string }[]>([]) const sessionIdMaskingEnabled = ref(false) const cacheTTLOverrideEnabled = ref(false) const cacheTTLOverrideTarget = ref('5m') +const customBaseUrlEnabled = ref(false) +const customBaseUrl = ref('') // OpenAI 自动透传开关(OAuth/API Key) const openaiPassthroughEnabled = ref(false) @@ -1980,276 +2027,296 @@ const normalizePoolModeRetryCount = (value: number) => { return normalized } -watch( - () => props.account, - (newAccount) => { - if (newAccount) { - antigravityMixedChannelConfirmed.value = false - showMixedChannelWarning.value = false - mixedChannelWarningDetails.value = null - mixedChannelWarningRawMessage.value = '' - mixedChannelWarningAction.value = null - form.name = newAccount.name - form.notes = newAccount.notes || '' - form.proxy_id = newAccount.proxy_id - form.concurrency = newAccount.concurrency - form.load_factor = newAccount.load_factor ?? null - form.priority = newAccount.priority - form.rate_multiplier = newAccount.rate_multiplier ?? 1 - form.status = (newAccount.status === 'active' || newAccount.status === 'inactive' || newAccount.status === 'error') - ? newAccount.status - : 'active' - form.group_ids = newAccount.group_ids || [] - form.expires_at = newAccount.expires_at ?? null +const syncFormFromAccount = (newAccount: Account | null) => { + if (!newAccount) { + return + } + antigravityMixedChannelConfirmed.value = false + showMixedChannelWarning.value = false + mixedChannelWarningDetails.value = null + mixedChannelWarningRawMessage.value = '' + mixedChannelWarningAction.value = null + form.name = newAccount.name + form.notes = newAccount.notes || '' + form.proxy_id = newAccount.proxy_id + form.concurrency = newAccount.concurrency + form.load_factor = newAccount.load_factor ?? null + form.priority = newAccount.priority + form.rate_multiplier = newAccount.rate_multiplier ?? 1 + form.status = (newAccount.status === 'active' || newAccount.status === 'inactive' || newAccount.status === 'error') + ? newAccount.status + : 'active' + form.group_ids = newAccount.group_ids || [] + form.expires_at = newAccount.expires_at ?? null - // Load intercept warmup requests setting (applies to all account types) - const credentials = newAccount.credentials as Record | undefined - interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true - autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true + // Load intercept warmup requests setting (applies to all account types) + const credentials = newAccount.credentials as Record | undefined + interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true + autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true - // Load mixed scheduling setting (only for antigravity accounts) - mixedScheduling.value = false - allowOverages.value = false - const extra = newAccount.extra as Record | undefined - mixedScheduling.value = extra?.mixed_scheduling === true - allowOverages.value = extra?.allow_overages === true + // Load mixed scheduling setting (only for antigravity accounts) + mixedScheduling.value = false + allowOverages.value = false + const extra = newAccount.extra as Record | undefined + mixedScheduling.value = extra?.mixed_scheduling === true + allowOverages.value = extra?.allow_overages === true - // Load OpenAI passthrough toggle (OpenAI OAuth/API Key) - openaiPassthroughEnabled.value = false - openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF - openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF - codexCLIOnlyEnabled.value = false - anthropicPassthroughEnabled.value = false - if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) { - openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true - openaiOAuthResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, { - modeKey: 'openai_oauth_responses_websockets_v2_mode', - enabledKey: 'openai_oauth_responses_websockets_v2_enabled', - fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled'], - defaultMode: OPENAI_WS_MODE_OFF - }) - openaiAPIKeyResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, { - modeKey: 'openai_apikey_responses_websockets_v2_mode', - enabledKey: 'openai_apikey_responses_websockets_v2_enabled', - fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled'], - defaultMode: OPENAI_WS_MODE_OFF - }) - if (newAccount.type === 'oauth') { - codexCLIOnlyEnabled.value = extra?.codex_cli_only === true - } - } - if (newAccount.platform === 'anthropic' && newAccount.type === 'apikey') { - anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true - } + // Load OpenAI passthrough toggle (OpenAI OAuth/API Key) + openaiPassthroughEnabled.value = false + openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF + openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF + codexCLIOnlyEnabled.value = false + anthropicPassthroughEnabled.value = false + if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) { + openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true + openaiOAuthResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, { + modeKey: 'openai_oauth_responses_websockets_v2_mode', + enabledKey: 'openai_oauth_responses_websockets_v2_enabled', + fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled'], + defaultMode: OPENAI_WS_MODE_OFF + }) + openaiAPIKeyResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, { + modeKey: 'openai_apikey_responses_websockets_v2_mode', + enabledKey: 'openai_apikey_responses_websockets_v2_enabled', + fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled'], + defaultMode: OPENAI_WS_MODE_OFF + }) + if (newAccount.type === 'oauth') { + codexCLIOnlyEnabled.value = extra?.codex_cli_only === true + } + } + if (newAccount.platform === 'anthropic' && newAccount.type === 'apikey') { + anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true + } - // Load quota limit for apikey/bedrock accounts (bedrock quota is also loaded in its own branch above) - if (newAccount.type === 'apikey' || newAccount.type === 'bedrock') { - const quotaVal = extra?.quota_limit as number | undefined - editQuotaLimit.value = (quotaVal && quotaVal > 0) ? quotaVal : null - const dailyVal = extra?.quota_daily_limit as number | undefined - editQuotaDailyLimit.value = (dailyVal && dailyVal > 0) ? dailyVal : null - const weeklyVal = extra?.quota_weekly_limit as number | undefined - editQuotaWeeklyLimit.value = (weeklyVal && weeklyVal > 0) ? weeklyVal : null - // Load quota reset mode config - editDailyResetMode.value = (extra?.quota_daily_reset_mode as 'rolling' | 'fixed') || null - editDailyResetHour.value = (extra?.quota_daily_reset_hour as number) ?? null - editWeeklyResetMode.value = (extra?.quota_weekly_reset_mode as 'rolling' | 'fixed') || null - editWeeklyResetDay.value = (extra?.quota_weekly_reset_day as number) ?? null - editWeeklyResetHour.value = (extra?.quota_weekly_reset_hour as number) ?? null - editResetTimezone.value = (extra?.quota_reset_timezone as string) || null + // Load quota limit for apikey/bedrock accounts (bedrock quota is also loaded in its own branch above) + if (newAccount.type === 'apikey' || newAccount.type === 'bedrock') { + const quotaVal = extra?.quota_limit as number | undefined + editQuotaLimit.value = (quotaVal && quotaVal > 0) ? quotaVal : null + const dailyVal = extra?.quota_daily_limit as number | undefined + editQuotaDailyLimit.value = (dailyVal && dailyVal > 0) ? dailyVal : null + const weeklyVal = extra?.quota_weekly_limit as number | undefined + editQuotaWeeklyLimit.value = (weeklyVal && weeklyVal > 0) ? weeklyVal : null + // Load quota reset mode config + editDailyResetMode.value = (extra?.quota_daily_reset_mode as 'rolling' | 'fixed') || null + editDailyResetHour.value = (extra?.quota_daily_reset_hour as number) ?? null + editWeeklyResetMode.value = (extra?.quota_weekly_reset_mode as 'rolling' | 'fixed') || null + editWeeklyResetDay.value = (extra?.quota_weekly_reset_day as number) ?? null + editWeeklyResetHour.value = (extra?.quota_weekly_reset_hour as number) ?? null + editResetTimezone.value = (extra?.quota_reset_timezone as string) || null + } else { + editQuotaLimit.value = null + editQuotaDailyLimit.value = null + editQuotaWeeklyLimit.value = null + editDailyResetMode.value = null + editDailyResetHour.value = null + editWeeklyResetMode.value = null + editWeeklyResetDay.value = null + editWeeklyResetHour.value = null + editResetTimezone.value = null + } + + // Load antigravity model mapping (Antigravity 只支持映射模式) + if (newAccount.platform === 'antigravity') { + const credentials = newAccount.credentials as Record | undefined + + // Antigravity 始终使用映射模式 + antigravityModelRestrictionMode.value = 'mapping' + antigravityWhitelistModels.value = [] + + // 从 model_mapping 读取映射配置 + const rawAgMapping = credentials?.model_mapping as Record | undefined + if (rawAgMapping && typeof rawAgMapping === 'object') { + const entries = Object.entries(rawAgMapping) + // 无论是白名单样式(key===value)还是真正的映射,都统一转换为映射列表 + antigravityModelMappings.value = entries.map(([from, to]) => ({ from, to })) + } else { + // 兼容旧数据:从 model_whitelist 读取,转换为映射格式 + const rawWhitelist = credentials?.model_whitelist + if (Array.isArray(rawWhitelist) && rawWhitelist.length > 0) { + antigravityModelMappings.value = rawWhitelist + .map((v) => String(v).trim()) + .filter((v) => v.length > 0) + .map((m) => ({ from: m, to: m })) } else { - editQuotaLimit.value = null - editQuotaDailyLimit.value = null - editQuotaWeeklyLimit.value = null - editDailyResetMode.value = null - editDailyResetHour.value = null - editWeeklyResetMode.value = null - editWeeklyResetDay.value = null - editWeeklyResetHour.value = null - editResetTimezone.value = null - } - - // Load antigravity model mapping (Antigravity 只支持映射模式) - if (newAccount.platform === 'antigravity') { - const credentials = newAccount.credentials as Record | undefined - - // Antigravity 始终使用映射模式 - antigravityModelRestrictionMode.value = 'mapping' - antigravityWhitelistModels.value = [] - - // 从 model_mapping 读取映射配置 - const rawAgMapping = credentials?.model_mapping as Record | undefined - if (rawAgMapping && typeof rawAgMapping === 'object') { - const entries = Object.entries(rawAgMapping) - // 无论是白名单样式(key===value)还是真正的映射,都统一转换为映射列表 - antigravityModelMappings.value = entries.map(([from, to]) => ({ from, to })) - } else { - // 兼容旧数据:从 model_whitelist 读取,转换为映射格式 - const rawWhitelist = credentials?.model_whitelist - if (Array.isArray(rawWhitelist) && rawWhitelist.length > 0) { - antigravityModelMappings.value = rawWhitelist - .map((v) => String(v).trim()) - .filter((v) => v.length > 0) - .map((m) => ({ from: m, to: m })) - } else { - antigravityModelMappings.value = [] - } - } - } else { - antigravityModelRestrictionMode.value = 'mapping' - antigravityWhitelistModels.value = [] antigravityModelMappings.value = [] } + } + } else { + antigravityModelRestrictionMode.value = 'mapping' + antigravityWhitelistModels.value = [] + antigravityModelMappings.value = [] + } - // Load quota control settings (Anthropic OAuth/SetupToken only) - loadQuotaControlSettings(newAccount) + // Load quota control settings (Anthropic OAuth/SetupToken only) + loadQuotaControlSettings(newAccount) - loadTempUnschedRules(credentials) + loadTempUnschedRules(credentials) - // Initialize API Key fields for apikey type - if (newAccount.type === 'apikey' && newAccount.credentials) { - const credentials = newAccount.credentials as Record - const platformDefaultUrl = - newAccount.platform === 'openai' || newAccount.platform === 'sora' - ? 'https://api.openai.com' - : newAccount.platform === 'gemini' - ? 'https://generativelanguage.googleapis.com' - : 'https://api.anthropic.com' - editBaseUrl.value = (credentials.base_url as string) || platformDefaultUrl + // Initialize API Key fields for apikey type + if (newAccount.type === 'apikey' && newAccount.credentials) { + const credentials = newAccount.credentials as Record + const platformDefaultUrl = + newAccount.platform === 'openai' || newAccount.platform === 'sora' + ? 'https://api.openai.com' + : newAccount.platform === 'gemini' + ? 'https://generativelanguage.googleapis.com' + : 'https://api.anthropic.com' + editBaseUrl.value = (credentials.base_url as string) || platformDefaultUrl - // Load model mappings and detect mode - const existingMappings = credentials.model_mapping as Record | undefined - if (existingMappings && typeof existingMappings === 'object') { - const entries = Object.entries(existingMappings) + // Load model mappings and detect mode + const existingMappings = credentials.model_mapping as Record | undefined + if (existingMappings && typeof existingMappings === 'object') { + const entries = Object.entries(existingMappings) - // Detect if this is whitelist mode (all from === to) or mapping mode - const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) + // Detect if this is whitelist mode (all from === to) or mapping mode + const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) - if (isWhitelistMode) { - // Whitelist mode: populate allowedModels - modelRestrictionMode.value = 'whitelist' - allowedModels.value = entries.map(([from]) => from) - modelMappings.value = [] - } else { - // Mapping mode: populate modelMappings - modelRestrictionMode.value = 'mapping' - modelMappings.value = entries.map(([from, to]) => ({ from, to })) - allowedModels.value = [] - } - } else { - // No mappings: default to whitelist mode with empty selection (allow all) - modelRestrictionMode.value = 'whitelist' - modelMappings.value = [] - allowedModels.value = [] - } - - // Load pool mode - poolModeEnabled.value = credentials.pool_mode === true - poolModeRetryCount.value = normalizePoolModeRetryCount( - Number(credentials.pool_mode_retry_count ?? DEFAULT_POOL_MODE_RETRY_COUNT) - ) - - // Load custom error codes - customErrorCodesEnabled.value = credentials.custom_error_codes_enabled === true - const existingErrorCodes = credentials.custom_error_codes as number[] | undefined - if (existingErrorCodes && Array.isArray(existingErrorCodes)) { - selectedErrorCodes.value = [...existingErrorCodes] - } else { - selectedErrorCodes.value = [] - } - } else if (newAccount.type === 'bedrock' && newAccount.credentials) { - const bedrockCreds = newAccount.credentials as Record - const authMode = (bedrockCreds.auth_mode as string) || 'sigv4' - editBedrockRegion.value = (bedrockCreds.aws_region as string) || '' - editBedrockForceGlobal.value = (bedrockCreds.aws_force_global as string) === 'true' - - if (authMode === 'apikey') { - editBedrockApiKeyValue.value = '' - } else { - editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || '' - editBedrockSecretAccessKey.value = '' - editBedrockSessionToken.value = '' - } - - // Load pool mode for bedrock - poolModeEnabled.value = bedrockCreds.pool_mode === true - const retryCount = bedrockCreds.pool_mode_retry_count - poolModeRetryCount.value = (typeof retryCount === 'number' && retryCount >= 0) ? retryCount : DEFAULT_POOL_MODE_RETRY_COUNT - - // Load quota limits for bedrock - const bedrockExtra = (newAccount.extra as Record) || {} - editQuotaLimit.value = typeof bedrockExtra.quota_limit === 'number' ? bedrockExtra.quota_limit : null - editQuotaDailyLimit.value = typeof bedrockExtra.quota_daily_limit === 'number' ? bedrockExtra.quota_daily_limit : null - editQuotaWeeklyLimit.value = typeof bedrockExtra.quota_weekly_limit === 'number' ? bedrockExtra.quota_weekly_limit : null - - // Load model mappings for bedrock - const existingMappings = bedrockCreds.model_mapping as Record | undefined - if (existingMappings && typeof existingMappings === 'object') { - const entries = Object.entries(existingMappings) - const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) - if (isWhitelistMode) { - modelRestrictionMode.value = 'whitelist' - allowedModels.value = entries.map(([from]) => from) - modelMappings.value = [] - } else { - modelRestrictionMode.value = 'mapping' - modelMappings.value = entries.map(([from, to]) => ({ from, to })) - allowedModels.value = [] - } - } else { - modelRestrictionMode.value = 'whitelist' - modelMappings.value = [] - allowedModels.value = [] - } - } else if (newAccount.type === 'upstream' && newAccount.credentials) { - const credentials = newAccount.credentials as Record - editBaseUrl.value = (credentials.base_url as string) || '' + if (isWhitelistMode) { + // Whitelist mode: populate allowedModels + modelRestrictionMode.value = 'whitelist' + allowedModels.value = entries.map(([from]) => from) + modelMappings.value = [] } else { - const platformDefaultUrl = - newAccount.platform === 'openai' || newAccount.platform === 'sora' - ? 'https://api.openai.com' - : newAccount.platform === 'gemini' - ? 'https://generativelanguage.googleapis.com' - : 'https://api.anthropic.com' - editBaseUrl.value = platformDefaultUrl + // Mapping mode: populate modelMappings + modelRestrictionMode.value = 'mapping' + modelMappings.value = entries.map(([from, to]) => ({ from, to })) + allowedModels.value = [] + } + } else { + // No mappings: default to whitelist mode with empty selection (allow all) + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } - // Load model mappings for OpenAI OAuth accounts - if (newAccount.platform === 'openai' && newAccount.credentials) { - const oauthCredentials = newAccount.credentials as Record - const existingMappings = oauthCredentials.model_mapping as Record | undefined - if (existingMappings && typeof existingMappings === 'object') { - const entries = Object.entries(existingMappings) - const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) - if (isWhitelistMode) { - modelRestrictionMode.value = 'whitelist' - allowedModels.value = entries.map(([from]) => from) - modelMappings.value = [] - } else { - modelRestrictionMode.value = 'mapping' - modelMappings.value = entries.map(([from, to]) => ({ from, to })) - allowedModels.value = [] - } - } else { - modelRestrictionMode.value = 'whitelist' - modelMappings.value = [] - allowedModels.value = [] - } - } else { + // Load pool mode + poolModeEnabled.value = credentials.pool_mode === true + poolModeRetryCount.value = normalizePoolModeRetryCount( + Number(credentials.pool_mode_retry_count ?? DEFAULT_POOL_MODE_RETRY_COUNT) + ) + + // Load custom error codes + customErrorCodesEnabled.value = credentials.custom_error_codes_enabled === true + const existingErrorCodes = credentials.custom_error_codes as number[] | undefined + if (existingErrorCodes && Array.isArray(existingErrorCodes)) { + selectedErrorCodes.value = [...existingErrorCodes] + } else { + selectedErrorCodes.value = [] + } + } else if (newAccount.type === 'bedrock' && newAccount.credentials) { + const bedrockCreds = newAccount.credentials as Record + const authMode = (bedrockCreds.auth_mode as string) || 'sigv4' + editBedrockRegion.value = (bedrockCreds.aws_region as string) || '' + editBedrockForceGlobal.value = (bedrockCreds.aws_force_global as string) === 'true' + + if (authMode === 'apikey') { + editBedrockApiKeyValue.value = '' + } else { + editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || '' + editBedrockSecretAccessKey.value = '' + editBedrockSessionToken.value = '' + } + + // Load pool mode for bedrock + poolModeEnabled.value = bedrockCreds.pool_mode === true + const retryCount = bedrockCreds.pool_mode_retry_count + poolModeRetryCount.value = (typeof retryCount === 'number' && retryCount >= 0) ? retryCount : DEFAULT_POOL_MODE_RETRY_COUNT + + // Load quota limits for bedrock + const bedrockExtra = (newAccount.extra as Record) || {} + editQuotaLimit.value = typeof bedrockExtra.quota_limit === 'number' ? bedrockExtra.quota_limit : null + editQuotaDailyLimit.value = typeof bedrockExtra.quota_daily_limit === 'number' ? bedrockExtra.quota_daily_limit : null + editQuotaWeeklyLimit.value = typeof bedrockExtra.quota_weekly_limit === 'number' ? bedrockExtra.quota_weekly_limit : null + + // Load model mappings for bedrock + const existingMappings = bedrockCreds.model_mapping as Record | undefined + if (existingMappings && typeof existingMappings === 'object') { + const entries = Object.entries(existingMappings) + const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) + if (isWhitelistMode) { + modelRestrictionMode.value = 'whitelist' + allowedModels.value = entries.map(([from]) => from) + modelMappings.value = [] + } else { + modelRestrictionMode.value = 'mapping' + modelMappings.value = entries.map(([from, to]) => ({ from, to })) + allowedModels.value = [] + } + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } + } else if (newAccount.type === 'upstream' && newAccount.credentials) { + const credentials = newAccount.credentials as Record + editBaseUrl.value = (credentials.base_url as string) || '' + } else { + const platformDefaultUrl = + newAccount.platform === 'openai' || newAccount.platform === 'sora' + ? 'https://api.openai.com' + : newAccount.platform === 'gemini' + ? 'https://generativelanguage.googleapis.com' + : 'https://api.anthropic.com' + editBaseUrl.value = platformDefaultUrl + + // Load model mappings for OpenAI OAuth accounts + if (newAccount.platform === 'openai' && newAccount.credentials) { + const oauthCredentials = newAccount.credentials as Record + const existingMappings = oauthCredentials.model_mapping as Record | undefined + if (existingMappings && typeof existingMappings === 'object') { + const entries = Object.entries(existingMappings) + const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) + if (isWhitelistMode) { modelRestrictionMode.value = 'whitelist' + allowedModels.value = entries.map(([from]) => from) modelMappings.value = [] + } else { + modelRestrictionMode.value = 'mapping' + modelMappings.value = entries.map(([from, to]) => ({ from, to })) allowedModels.value = [] } - poolModeEnabled.value = false - poolModeRetryCount.value = DEFAULT_POOL_MODE_RETRY_COUNT - customErrorCodesEnabled.value = false - selectedErrorCodes.value = [] + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] } - editApiKey.value = '' + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } + poolModeEnabled.value = false + poolModeRetryCount.value = DEFAULT_POOL_MODE_RETRY_COUNT + customErrorCodesEnabled.value = false + selectedErrorCodes.value = [] + } + editApiKey.value = '' +} + +watch( + [() => props.show, () => props.account], + ([show, newAccount], [wasShow, previousAccount]) => { + if (!show || !newAccount) { + return + } + if (!wasShow || newAccount !== previousAccount) { + syncFormFromAccount(newAccount) + loadTLSProfiles() } }, { immediate: true } ) +const loadTLSProfiles = async () => { + try { + const profiles = await adminAPI.tlsFingerprintProfiles.list() + tlsFingerprintProfiles.value = profiles.map(p => ({ id: p.id, name: p.name })) + } catch { + tlsFingerprintProfiles.value = [] + } +} + // Model mapping helpers const addModelMapping = () => { modelMappings.value.push({ from: '', to: '' }) @@ -2448,9 +2515,12 @@ function loadQuotaControlSettings(account: Account) { rpmStickyBuffer.value = null userMsgQueueMode.value = '' tlsFingerprintEnabled.value = false + tlsFingerprintProfileId.value = null sessionIdMaskingEnabled.value = false cacheTTLOverrideEnabled.value = false cacheTTLOverrideTarget.value = '5m' + customBaseUrlEnabled.value = false + customBaseUrl.value = '' // Only applies to Anthropic OAuth/SetupToken accounts if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) { @@ -2485,6 +2555,7 @@ function loadQuotaControlSettings(account: Account) { if (account.enable_tls_fingerprint === true) { tlsFingerprintEnabled.value = true } + tlsFingerprintProfileId.value = account.tls_fingerprint_profile_id ?? null // Load session ID masking setting if (account.session_id_masking_enabled === true) { @@ -2496,6 +2567,12 @@ function loadQuotaControlSettings(account: Account) { cacheTTLOverrideEnabled.value = true cacheTTLOverrideTarget.value = account.cache_ttl_override_target || '5m' } + + // Load custom base URL setting + if (account.custom_base_url_enabled === true) { + customBaseUrlEnabled.value = true + customBaseUrl.value = account.custom_base_url || '' + } } function formatTempUnschedKeywords(value: unknown) { @@ -2922,8 +2999,14 @@ const handleSubmit = async () => { // TLS fingerprint setting if (tlsFingerprintEnabled.value) { newExtra.enable_tls_fingerprint = true + if (tlsFingerprintProfileId.value) { + newExtra.tls_fingerprint_profile_id = tlsFingerprintProfileId.value + } else { + delete newExtra.tls_fingerprint_profile_id + } } else { delete newExtra.enable_tls_fingerprint + delete newExtra.tls_fingerprint_profile_id } // Session ID masking setting @@ -2942,6 +3025,15 @@ const handleSubmit = async () => { delete newExtra.cache_ttl_override_target } + // Custom base URL relay setting + if (customBaseUrlEnabled.value && customBaseUrl.value.trim()) { + newExtra.custom_base_url_enabled = true + newExtra.custom_base_url = customBaseUrl.value.trim() + } else { + delete newExtra.custom_base_url_enabled + delete newExtra.custom_base_url + } + updatePayload.extra = newExtra } diff --git a/frontend/src/components/account/OAuthAuthorizationFlow.vue b/frontend/src/components/account/OAuthAuthorizationFlow.vue index cc74f8ce..b4c299db 100644 --- a/frontend/src/components/account/OAuthAuthorizationFlow.vue +++ b/frontend/src/components/account/OAuthAuthorizationFlow.vue @@ -48,6 +48,17 @@ t(getOAuthKey('refreshTokenAuth')) }} +
- -
+ +
@@ -759,6 +770,7 @@ interface Props { methodLabel?: string showCookieOption?: boolean // Whether to show cookie auto-auth option showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only) + showMobileRefreshTokenOption?: boolean // Whether to show mobile refresh token option (OpenAI only) showSessionTokenOption?: boolean // Whether to show session token input option (Sora only) showAccessTokenOption?: boolean // Whether to show access token input option (Sora only) platform?: AccountPlatform // Platform type for different UI/text @@ -776,6 +788,7 @@ const props = withDefaults(defineProps(), { methodLabel: 'Authorization Method', showCookieOption: true, showRefreshTokenOption: false, + showMobileRefreshTokenOption: false, showSessionTokenOption: false, showAccessTokenOption: false, platform: 'anthropic', @@ -787,6 +800,7 @@ const emit = defineEmits<{ 'exchange-code': [code: string] 'cookie-auth': [sessionKey: string] 'validate-refresh-token': [refreshToken: string] + 'validate-mobile-refresh-token': [refreshToken: string] 'validate-session-token': [sessionToken: string] 'import-access-token': [accessToken: string] 'update:inputMethod': [method: AuthInputMethod] @@ -834,7 +848,7 @@ const oauthState = ref('') const projectId = ref('') // Computed: show method selection when either cookie or refresh token option is enabled -const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showSessionTokenOption || props.showAccessTokenOption) +const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showMobileRefreshTokenOption || props.showSessionTokenOption || props.showAccessTokenOption) // Clipboard const { copied, copyToClipboard } = useClipboard() @@ -945,7 +959,11 @@ const handleCookieAuth = () => { const handleValidateRefreshToken = () => { if (refreshTokenInput.value.trim()) { - emit('validate-refresh-token', refreshTokenInput.value.trim()) + if (inputMethod.value === 'mobile_refresh_token') { + emit('validate-mobile-refresh-token', refreshTokenInput.value.trim()) + } else { + emit('validate-refresh-token', refreshTokenInput.value.trim()) + } } } diff --git a/frontend/src/components/account/UsageProgressBar.vue b/frontend/src/components/account/UsageProgressBar.vue index cd5c991f..52f0ecbb 100644 --- a/frontend/src/components/account/UsageProgressBar.vue +++ b/frontend/src/components/account/UsageProgressBar.vue @@ -2,7 +2,7 @@
@@ -12,12 +12,13 @@ {{ formatTokens }} - + A ${{ formatAccountCost }} U ${{ formatUserCost }} @@ -47,7 +48,7 @@ - + {{ formatResetTime }}
@@ -55,8 +56,11 @@ diff --git a/frontend/src/components/admin/account/AccountActionMenu.vue b/frontend/src/components/admin/account/AccountActionMenu.vue index f5bc5aa0..06bd23ab 100644 --- a/frontend/src/components/admin/account/AccountActionMenu.vue +++ b/frontend/src/components/admin/account/AccountActionMenu.vue @@ -32,6 +32,10 @@ {{ t('admin.accounts.refreshToken') }} +
- -
- - -
@@ -177,7 +166,6 @@ import { ref, onMounted, onUnmounted, toRef, watch } from 'vue' import { useI18n } from 'vue-i18n' import { adminAPI } from '@/api/admin' import Select, { type SelectOption } from '@/components/common/Select.vue' -import DateRangePicker from '@/components/common/DateRangePicker.vue' import type { SimpleApiKey, SimpleUser } from '@/api/admin/usage' type ModelValue = Record @@ -195,8 +183,6 @@ const props = withDefaults(defineProps(), { }) const emit = defineEmits([ 'update:modelValue', - 'update:startDate', - 'update:endDate', 'change', 'refresh', 'reset', @@ -248,16 +234,6 @@ const billingTypeOptions = ref([ const emitChange = () => emit('change') -const updateStartDate = (value: string) => { - emit('update:startDate', value) - filters.value.start_date = value -} - -const updateEndDate = (value: string) => { - emit('update:endDate', value) - filters.value.end_date = value -} - const debounceUserSearch = () => { if (userSearchTimeout) clearTimeout(userSearchTimeout) userSearchTimeout = setTimeout(async () => { @@ -441,7 +417,11 @@ onMounted(async () => { groupOptions.value.push(...gs.items.map((g: any) => ({ value: g.id, label: g.name }))) const uniqueModels = new Set() - ms.models?.forEach((s: any) => s.model && uniqueModels.add(s.model)) + ms.models?.forEach((s: any) => { + if (s.model) { + uniqueModels.add(s.model) + } + }) modelOptions.value.push( ...Array.from(uniqueModels) .sort() diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index aa6c2bbd..4a42ab05 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -25,8 +25,16 @@ {{ row.account?.name || '-' }} -